""" @Author: miykah @Email: miykah@163.com @FileName: read_data_cwru.py @DateTime: 2022/7/20 15:58 """ import math import os import scipy.io as scio import numpy as np import PIL.Image as Image import matplotlib.pyplot as plt import torch def dat_to_numpy(samples_num, sample_length): '''根据每个类别需要的样本数和样本长度计算采样步长''' stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1)) conditions = ['A', 'B', 'C'] for condition in conditions: if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition if condition == 'A': start_idx = 1 end_idx = 5 # start_idx = 13 # end_idx = 17 # start_idx = 25 # end_idx = 29 # start_idx = 9 # end_idx = 13 elif condition == 'B': start_idx = 5 end_idx = 9 # start_idx = 17 # end_idx = 21 # start_idx = 29 # end_idx = 33 elif condition == 'C': start_idx = 9 end_idx = 13 # start_idx = 21 # end_idx = 25 # start_idx = 33 # end_idx = 37 # start_idx = 25 # end_idx = 29 dir_names = os.listdir(root_dir) print(dir_names) ''' 故障类型 ['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速', '平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速', '平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速', '平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速'] ''' if not os.path.exists(save_dir): os.makedirs(save_dir) for dir_name in dir_names: data_list = [] for i in range(start_idx, end_idx): if i < 10: path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat' else: path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat' data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3] data_list.append(data) # data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024) # data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048) data_one_cls = np.array(data_list).reshape(-1) n = math.floor((204800 * 4 - sample_length) / stride + 1) list = [] for i in range(n): start = i * stride end = i * stride + sample_length list.append(data_one_cls[start: end]) data_one_cls = np.array(list).reshape(-1, 1, sample_length) # 打乱数据 shuffle_ix = np.random.permutation(np.arange(len(data_one_cls))) data_one_cls = data_one_cls[shuffle_ix] print(data_one_cls.shape) np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls) def draw_signal_img(condition): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition else: root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition file_names = os.listdir(root_dir) pic = plt.figure(figsize=(12, 6), dpi=200) # plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif'] # plt.rcParams['axes.unicode_minus'] = False plt.rc('font',family='Times New Roman') clses = ['(a)', '(b)', '(c)', '(d)'] for file_name, i, cls in zip(file_names, range(4), clses): data = np.load(root_dir + '\\' + file_name)[0].reshape(-1) print(data.shape, file_name) plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度 plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内 plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内 plt.subplot(2, 2, i + 1) plt.title(cls, fontsize=20) plt.xlabel("Time (s)", fontsize=20) plt.ylabel("Amplitude (m/s²)", fontsize=20) plt.xlim(0, 0.16) plt.ylim(-3, 3) plt.yticks(np.array([-2, 0, 2]), fontsize=20) plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20) plt.plot(data) plt.show() def draw_signal_img2(): conditions = ['A', 'B', 'C'] pic = plt.figure(figsize=(12, 6), dpi=600) plt.text(x=1, y=3, s='A') i = 1 for condition in conditions: if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition file_names = os.listdir(root_dir) # plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif'] # plt.rcParams['axes.unicode_minus'] = False plt.rc('font',family='Times New Roman') clses = ['(a)', '(b)', '(c)', '(d)', '(e)'] for file_name, cls in zip(file_names, clses): data = np.load(root_dir + '\\' + file_name)[3].reshape(-1) print(data.shape, file_name) plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度 plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内 plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内 plt.subplot(3, 5, i) plt.title(cls, fontsize=15) plt.xlabel("Time (s)", fontsize=15) plt.ylabel("Amplitude (m/s²)", fontsize=15) plt.xlim(0, 0.16) plt.ylim(-5, 5) plt.yticks(np.array([-3, 0, 3]), fontsize=15) plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15) plt.plot(data) i += 1 plt.show() '''得到源域数据和标签''' def get_src_data(src_condition, shot): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src' file_names = os.listdir(root_dir) # print(file_names) datas = [] labels = [] for i in range(600 // shot): for file_name, cls in zip(file_names, range(5)): data = np.load(root_dir + '\\' + file_name) data_cut = data[shot * i: shot * (i + 1), :, :] label_cut = np.array([cls] * shot, dtype='int64') datas.append(data_cut) labels.append(label_cut) datas = np.array(datas).reshape(-1, 1, 2048) labels = np.array(labels).reshape(-1) print(datas.shape, labels.shape) if not os.path.exists(save_dir): os.makedirs(save_dir) np.save(save_dir + '\\' + 'data.npy', datas) np.save(save_dir + '\\' + 'label.npy', labels) '''得到没有标签的目标域数据''' def get_tar_data(tar_condition): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar' file_names = os.listdir(root_dir) datas = [] for file_name in file_names: data = np.load(root_dir + '\\' + file_name) data_cut = data[0: 600, :, :] datas.append(data_cut) datas = np.array(datas).reshape(-1, 1, 2048) print("datas.shape: {}".format(datas.shape)) if not os.path.exists(save_dir): os.makedirs(save_dir) np.save(save_dir + '\\' + 'data.npy', datas) def get_tar_data2(tar_condition, shot): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1' file_names = os.listdir(root_dir) datas = [] labels = [] for i in range(600 // shot): for file_name, cls in zip(file_names, range(5)): data = np.load(root_dir + '\\' + file_name) data_cut = data[shot * i: shot * (i + 1), :, :] label_cut = np.array([cls] * shot, dtype='int64') datas.append(data_cut) labels.append(label_cut) datas = np.array(datas).reshape(-1, 1, 2048) labels = np.array(labels).reshape(-1) print(datas.shape, labels.shape) if not os.path.exists(save_dir): os.makedirs(save_dir) np.save(save_dir + '\\' + 'data.npy', datas) np.save(save_dir + '\\' + 'label.npy', labels) '''得到验证集(每个类别100个样本)''' def get_val_data(tar_condition): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val' file_names = os.listdir(root_dir) datas = [] labels = [] for file_name, cls in zip(file_names, range(5)): data = np.load(root_dir + '\\' + file_name) data_cut = data[600: 700, :, :] label_cut = np.array([cls] * 100, dtype='int64') datas.append(data_cut) labels.append(label_cut) datas = np.array(datas).reshape(-1, 1, 2048) labels = np.array(labels).reshape(-1) print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape)) if not os.path.exists(save_dir): os.makedirs(save_dir) np.save(save_dir + '\\' + 'data.npy', datas) np.save(save_dir + '\\' + 'label.npy', labels) '''得到测试集(100个样本)''' def get_test_data(tar_condition): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test' file_names = os.listdir(root_dir) datas = [] labels = [] for file_name, cls in zip(file_names, range(5)): data = np.load(root_dir + '\\' + file_name) data_cut = data[700: , :, :] label_cut = np.array([cls] * 100, dtype='int64') datas.append(data_cut) labels.append(label_cut) datas = np.array(datas).reshape(-1, 1, 2048) labels = np.array(labels).reshape(-1) print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape)) if not os.path.exists(save_dir): os.makedirs(save_dir) np.save(save_dir + '\\' + 'data.npy', datas) np.save(save_dir + '\\' + 'label.npy', labels) if __name__ == '__main__': # dat_to_numpy(samples_num=800, sample_length=2048) # draw_signal_img('C') # draw_signal_img2() # get_src_data('A', shot=10) # get_src_data('B', shot=10) # get_src_data('C', shot=10) get_tar_data('A') get_tar_data('B') get_tar_data('C') get_val_data('A') get_val_data('B') get_val_data('C') get_test_data('A') get_test_data('B') get_test_data('C') get_src_data('A', shot=10) get_src_data('B', shot=10) get_src_data('C', shot=10) get_tar_data2('A', shot=10) get_tar_data2('B', shot=10) get_tar_data2('C', shot=10)