299 lines
13 KiB
Python
299 lines
13 KiB
Python
"""
|
||
@Author: miykah
|
||
@Email: miykah@163.com
|
||
@FileName: read_data_cwru.py
|
||
@DateTime: 2022/7/20 15:58
|
||
"""
|
||
import math
|
||
import os
|
||
import scipy.io as scio
|
||
import numpy as np
|
||
import PIL.Image as Image
|
||
import matplotlib.pyplot as plt
|
||
import torch
|
||
|
||
def dat_to_numpy(samples_num, sample_length):
|
||
'''根据每个类别需要的样本数和样本长度计算采样步长'''
|
||
stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1))
|
||
|
||
conditions = ['A', 'B', 'C']
|
||
for condition in conditions:
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition
|
||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition
|
||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
|
||
|
||
if condition == 'A':
|
||
start_idx = 1
|
||
end_idx = 5
|
||
# start_idx = 13
|
||
# end_idx = 17
|
||
# start_idx = 25
|
||
# end_idx = 29
|
||
# start_idx = 9
|
||
# end_idx = 13
|
||
elif condition == 'B':
|
||
start_idx = 5
|
||
end_idx = 9
|
||
# start_idx = 17
|
||
# end_idx = 21
|
||
# start_idx = 29
|
||
# end_idx = 33
|
||
elif condition == 'C':
|
||
start_idx = 9
|
||
end_idx = 13
|
||
# start_idx = 21
|
||
# end_idx = 25
|
||
# start_idx = 33
|
||
# end_idx = 37
|
||
# start_idx = 25
|
||
# end_idx = 29
|
||
|
||
dir_names = os.listdir(root_dir)
|
||
print(dir_names)
|
||
'''
|
||
故障类型
|
||
['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速',
|
||
'平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速',
|
||
'平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速',
|
||
'平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速']
|
||
'''
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
for dir_name in dir_names:
|
||
data_list = []
|
||
for i in range(start_idx, end_idx):
|
||
if i < 10:
|
||
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat'
|
||
else:
|
||
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat'
|
||
data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3]
|
||
data_list.append(data)
|
||
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024)
|
||
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048)
|
||
data_one_cls = np.array(data_list).reshape(-1)
|
||
n = math.floor((204800 * 4 - sample_length) / stride + 1)
|
||
list = []
|
||
for i in range(n):
|
||
start = i * stride
|
||
end = i * stride + sample_length
|
||
list.append(data_one_cls[start: end])
|
||
data_one_cls = np.array(list).reshape(-1, 1, sample_length)
|
||
# 打乱数据
|
||
shuffle_ix = np.random.permutation(np.arange(len(data_one_cls)))
|
||
data_one_cls = data_one_cls[shuffle_ix]
|
||
print(data_one_cls.shape)
|
||
np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls)
|
||
|
||
|
||
def draw_signal_img(condition):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
|
||
file_names = os.listdir(root_dir)
|
||
pic = plt.figure(figsize=(12, 6), dpi=200)
|
||
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||
# plt.rcParams['axes.unicode_minus'] = False
|
||
plt.rc('font',family='Times New Roman')
|
||
clses = ['(a)', '(b)', '(c)', '(d)']
|
||
for file_name, i, cls in zip(file_names, range(4), clses):
|
||
data = np.load(root_dir + '\\' + file_name)[0].reshape(-1)
|
||
print(data.shape, file_name)
|
||
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
|
||
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
|
||
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
|
||
plt.subplot(2, 2, i + 1)
|
||
plt.title(cls, fontsize=20)
|
||
plt.xlabel("Time (s)", fontsize=20)
|
||
plt.ylabel("Amplitude (m/s²)", fontsize=20)
|
||
plt.xlim(0, 0.16)
|
||
plt.ylim(-3, 3)
|
||
plt.yticks(np.array([-2, 0, 2]), fontsize=20)
|
||
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20)
|
||
plt.plot(data)
|
||
plt.show()
|
||
|
||
def draw_signal_img2():
|
||
conditions = ['A', 'B', 'C']
|
||
pic = plt.figure(figsize=(12, 6), dpi=600)
|
||
plt.text(x=1, y=3, s='A')
|
||
i = 1
|
||
for condition in conditions:
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
|
||
file_names = os.listdir(root_dir)
|
||
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||
# plt.rcParams['axes.unicode_minus'] = False
|
||
plt.rc('font',family='Times New Roman')
|
||
clses = ['(a)', '(b)', '(c)', '(d)', '(e)']
|
||
for file_name, cls in zip(file_names, clses):
|
||
data = np.load(root_dir + '\\' + file_name)[3].reshape(-1)
|
||
print(data.shape, file_name)
|
||
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
|
||
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
|
||
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
|
||
plt.subplot(3, 5, i)
|
||
plt.title(cls, fontsize=15)
|
||
plt.xlabel("Time (s)", fontsize=15)
|
||
plt.ylabel("Amplitude (m/s²)", fontsize=15)
|
||
plt.xlim(0, 0.16)
|
||
plt.ylim(-5, 5)
|
||
plt.yticks(np.array([-3, 0, 3]), fontsize=15)
|
||
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15)
|
||
plt.plot(data)
|
||
i += 1
|
||
plt.show()
|
||
|
||
'''得到源域数据和标签'''
|
||
def get_src_data(src_condition, shot):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
|
||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
|
||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
|
||
|
||
file_names = os.listdir(root_dir)
|
||
# print(file_names)
|
||
datas = []
|
||
labels = []
|
||
for i in range(600 // shot):
|
||
for file_name, cls in zip(file_names, range(5)):
|
||
data = np.load(root_dir + '\\' + file_name)
|
||
data_cut = data[shot * i: shot * (i + 1), :, :]
|
||
label_cut = np.array([cls] * shot, dtype='int64')
|
||
datas.append(data_cut)
|
||
labels.append(label_cut)
|
||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||
labels = np.array(labels).reshape(-1)
|
||
print(datas.shape, labels.shape)
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||
|
||
'''得到没有标签的目标域数据'''
|
||
def get_tar_data(tar_condition):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
|
||
file_names = os.listdir(root_dir)
|
||
datas = []
|
||
for file_name in file_names:
|
||
data = np.load(root_dir + '\\' + file_name)
|
||
data_cut = data[0: 600, :, :]
|
||
datas.append(data_cut)
|
||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||
print("datas.shape: {}".format(datas.shape))
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||
|
||
def get_tar_data2(tar_condition, shot):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
|
||
file_names = os.listdir(root_dir)
|
||
datas = []
|
||
labels = []
|
||
for i in range(600 // shot):
|
||
for file_name, cls in zip(file_names, range(5)):
|
||
data = np.load(root_dir + '\\' + file_name)
|
||
data_cut = data[shot * i: shot * (i + 1), :, :]
|
||
label_cut = np.array([cls] * shot, dtype='int64')
|
||
datas.append(data_cut)
|
||
labels.append(label_cut)
|
||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||
labels = np.array(labels).reshape(-1)
|
||
print(datas.shape, labels.shape)
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||
|
||
'''得到验证集(每个类别100个样本)'''
|
||
def get_val_data(tar_condition):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
|
||
file_names = os.listdir(root_dir)
|
||
datas = []
|
||
labels = []
|
||
for file_name, cls in zip(file_names, range(5)):
|
||
data = np.load(root_dir + '\\' + file_name)
|
||
data_cut = data[600: 700, :, :]
|
||
label_cut = np.array([cls] * 100, dtype='int64')
|
||
datas.append(data_cut)
|
||
labels.append(label_cut)
|
||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||
labels = np.array(labels).reshape(-1)
|
||
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||
|
||
'''得到测试集(100个样本)'''
|
||
def get_test_data(tar_condition):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
|
||
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
|
||
file_names = os.listdir(root_dir)
|
||
datas = []
|
||
labels = []
|
||
for file_name, cls in zip(file_names, range(5)):
|
||
data = np.load(root_dir + '\\' + file_name)
|
||
data_cut = data[700: , :, :]
|
||
label_cut = np.array([cls] * 100, dtype='int64')
|
||
datas.append(data_cut)
|
||
labels.append(label_cut)
|
||
datas = np.array(datas).reshape(-1, 1, 2048)
|
||
labels = np.array(labels).reshape(-1)
|
||
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
np.save(save_dir + '\\' + 'data.npy', datas)
|
||
np.save(save_dir + '\\' + 'label.npy', labels)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# dat_to_numpy(samples_num=800, sample_length=2048)
|
||
# draw_signal_img('C')
|
||
# draw_signal_img2()
|
||
# get_src_data('A', shot=10)
|
||
# get_src_data('B', shot=10)
|
||
# get_src_data('C', shot=10)
|
||
get_tar_data('A')
|
||
get_tar_data('B')
|
||
get_tar_data('C')
|
||
get_val_data('A')
|
||
get_val_data('B')
|
||
get_val_data('C')
|
||
get_test_data('A')
|
||
get_test_data('B')
|
||
get_test_data('C')
|
||
|
||
get_src_data('A', shot=10)
|
||
get_src_data('B', shot=10)
|
||
get_src_data('C', shot=10)
|
||
get_tar_data2('A', shot=10)
|
||
get_tar_data2('B', shot=10)
|
||
get_tar_data2('C', shot=10)
|