87 lines
2.0 KiB
Python
87 lines
2.0 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/10/12 16:14
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
|
|
def getThreshold(data, cigma_num=1.5):
|
|
dims, = data.shape
|
|
|
|
mean = np.mean(data)
|
|
std = np.sqrt(np.var(data))
|
|
max = mean + cigma_num * std
|
|
min = mean - cigma_num * std
|
|
max = np.broadcast_to(max, shape=[dims, ])
|
|
min = np.broadcast_to(min, shape=[dims, ])
|
|
mean = np.broadcast_to(mean, shape=[dims, ])
|
|
|
|
# plt.plot(max)
|
|
# plt.plot(data)
|
|
# # plt.plot(mean)
|
|
# plt.plot(min)
|
|
# plt.show()
|
|
#
|
|
#
|
|
return max, min
|
|
# pass
|
|
|
|
|
|
def EWMA(data):
|
|
data1 = pd.DataFrame(data).ewm(span=5).mean()
|
|
|
|
# plt.plot(data)
|
|
# plt.plot(data1, color='blue')
|
|
# getThreshold(data)
|
|
# plt.show()
|
|
|
|
return data1
|
|
|
|
|
|
if __name__ == '__main__':
|
|
data = np.load("E:\self_example\TensorFlow_eaxmple\Model_train_test/2012轴承数据集预测挑战\data\HI_DATA\Bearing1_1.npy")
|
|
a, b = data.shape
|
|
|
|
minlist = np.array([])
|
|
maxlist = np.array([])
|
|
|
|
for d in data:
|
|
max, min = getThreshold(d)
|
|
minlist = np.concatenate([minlist, min], axis=0)
|
|
maxlist = np.concatenate([maxlist, max], axis=0)
|
|
|
|
data = data.reshape([a * b, 1])
|
|
|
|
origin_data = data
|
|
# data = np.array([0.5, 5, 0.8, 4.0, 10.0, -0.1, -0.3, 0, 0.5, 6.5])
|
|
data = EWMA(data)
|
|
|
|
data = np.squeeze(data.values)
|
|
|
|
count = 0
|
|
origin_count = 0
|
|
|
|
for a, b, c, d in zip(data, maxlist, minlist, origin_data):
|
|
if c > a or a > b:
|
|
count += 1
|
|
if c > d or d > b:
|
|
origin_count += 1
|
|
|
|
print("原始劣质率:", origin_count / len(data) * 100, "%")
|
|
print("修复后劣质率:", count / len(data) * 100, "%")
|
|
plt.plot(origin_data, color='blue', label='Original data')
|
|
plt.plot(data, color='green', label='After data repair')
|
|
plt.plot(maxlist, color='red', label='upper Threshold')
|
|
plt.plot(minlist, color='red', label='lower Threshold')
|
|
|
|
plt.show()
|
|
|
|
# getThreshold(data)
|