self_example/pytorch_example/example/read_data.py

299 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
@Author: miykah
@Email: miykah@163.com
@FileName: read_data_cwru.py
@DateTime: 2022/7/20 15:58
"""
import math
import os
import scipy.io as scio
import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt
import torch
def dat_to_numpy(samples_num, sample_length):
'''根据每个类别需要的样本数和样本长度计算采样步长'''
stride = math.floor((204800 * 4 - sample_length) / (samples_num - 1))
conditions = ['A', 'B', 'C']
for condition in conditions:
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\' + condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\' + condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_11\\raw_data\\' + condition
if condition == 'A':
start_idx = 1
end_idx = 5
# start_idx = 13
# end_idx = 17
# start_idx = 25
# end_idx = 29
# start_idx = 9
# end_idx = 13
elif condition == 'B':
start_idx = 5
end_idx = 9
# start_idx = 17
# end_idx = 21
# start_idx = 29
# end_idx = 33
elif condition == 'C':
start_idx = 9
end_idx = 13
# start_idx = 21
# end_idx = 25
# start_idx = 33
# end_idx = 37
# start_idx = 25
# end_idx = 29
dir_names = os.listdir(root_dir)
print(dir_names)
'''
故障类型
['平行齿轮箱轴承内圈故障恒速', '平行齿轮箱轴承外圈故障恒速',
'平行齿轮箱轴承滚动体故障恒速', '平行齿轮箱齿轮偏心故障恒速',
'平行齿轮箱齿轮断齿故障恒速', '平行齿轮箱齿轮缺齿故障恒速',
'平行齿轮箱齿轮表面磨损故障恒速', '平行齿轮箱齿轮齿根裂纹故障恒速']
'''
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for dir_name in dir_names:
data_list = []
for i in range(start_idx, end_idx):
if i < 10:
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#000' + str(i) + '.dat'
else:
path = root_dir + '\\' + dir_name + '\dds测试故障库4.6#00' + str(i) + '.dat'
data = np.fromfile(path, dtype='float32')[204800 * 2: 204800 * 3]
data_list.append(data)
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1024)
# data_one_cls = np.array(data_list).reshape(-1).reshape(-1, 1, 2048) #(400, 1, 2048)
data_one_cls = np.array(data_list).reshape(-1)
n = math.floor((204800 * 4 - sample_length) / stride + 1)
list = []
for i in range(n):
start = i * stride
end = i * stride + sample_length
list.append(data_one_cls[start: end])
data_one_cls = np.array(list).reshape(-1, 1, sample_length)
# 打乱数据
shuffle_ix = np.random.permutation(np.arange(len(data_one_cls)))
data_one_cls = data_one_cls[shuffle_ix]
print(data_one_cls.shape)
np.save(save_dir + '\\' + dir_name + '.npy', data_one_cls)
def draw_signal_img(condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
else:
root_dir = 'E:\\DataSet\\DDS_data\\4cls_speed30\\raw_data\\' + condition
file_names = os.listdir(root_dir)
pic = plt.figure(figsize=(12, 6), dpi=200)
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
# plt.rcParams['axes.unicode_minus'] = False
plt.rc('font',family='Times New Roman')
clses = ['(a)', '(b)', '(c)', '(d)']
for file_name, i, cls in zip(file_names, range(4), clses):
data = np.load(root_dir + '\\' + file_name)[0].reshape(-1)
print(data.shape, file_name)
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
plt.subplot(2, 2, i + 1)
plt.title(cls, fontsize=20)
plt.xlabel("Time (s)", fontsize=20)
plt.ylabel("Amplitude (m/s²)", fontsize=20)
plt.xlim(0, 0.16)
plt.ylim(-3, 3)
plt.yticks(np.array([-2, 0, 2]), fontsize=20)
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=20)
plt.plot(data)
plt.show()
def draw_signal_img2():
conditions = ['A', 'B', 'C']
pic = plt.figure(figsize=(12, 6), dpi=600)
plt.text(x=1, y=3, s='A')
i = 1
for condition in conditions:
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne\\raw_data\\' + condition
file_names = os.listdir(root_dir)
# plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
# plt.rcParams['axes.unicode_minus'] = False
plt.rc('font',family='Times New Roman')
clses = ['(a)', '(b)', '(c)', '(d)', '(e)']
for file_name, cls in zip(file_names, clses):
data = np.load(root_dir + '\\' + file_name)[3].reshape(-1)
print(data.shape, file_name)
plt.tick_params(top='on', right='on', which='both') # 设置上面和右面也有刻度
plt.rcParams['xtick.direction'] = 'in' # 将x周的刻度线方向设置向内
plt.rcParams['ytick.direction'] = 'in' #将y轴的刻度方向设置向内
plt.subplot(3, 5, i)
plt.title(cls, fontsize=15)
plt.xlabel("Time (s)", fontsize=15)
plt.ylabel("Amplitude (m/s²)", fontsize=15)
plt.xlim(0, 0.16)
plt.ylim(-5, 5)
plt.yticks(np.array([-3, 0, 3]), fontsize=15)
plt.xticks(np.array([0, 640, 1280, 1920, 2048]), np.array(['0', 0.05, 0.1, 0.15]), fontsize=15)
plt.plot(data)
i += 1
plt.show()
'''得到源域数据和标签'''
def get_src_data(src_condition, shot):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + src_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + src_condition + '\\src'
file_names = os.listdir(root_dir)
# print(file_names)
datas = []
labels = []
for i in range(600 // shot):
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[shot * i: shot * (i + 1), :, :]
label_cut = np.array([cls] * shot, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print(datas.shape, labels.shape)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
'''得到没有标签的目标域数据'''
def get_tar_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar'
file_names = os.listdir(root_dir)
datas = []
for file_name in file_names:
data = np.load(root_dir + '\\' + file_name)
data_cut = data[0: 600, :, :]
datas.append(data_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
print("datas.shape: {}".format(datas.shape))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
def get_tar_data2(tar_condition, shot):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\tar1'
file_names = os.listdir(root_dir)
datas = []
labels = []
for i in range(600 // shot):
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[shot * i: shot * (i + 1), :, :]
label_cut = np.array([cls] * shot, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print(datas.shape, labels.shape)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
'''得到验证集每个类别100个样本'''
def get_val_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\val'
file_names = os.listdir(root_dir)
datas = []
labels = []
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[600: 700, :, :]
label_cut = np.array([cls] * 100, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
'''得到测试集100个样本'''
def get_test_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\raw_data\\' + tar_condition
save_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' + tar_condition + '\\test'
file_names = os.listdir(root_dir)
datas = []
labels = []
for file_name, cls in zip(file_names, range(5)):
data = np.load(root_dir + '\\' + file_name)
data_cut = data[700: , :, :]
label_cut = np.array([cls] * 100, dtype='int64')
datas.append(data_cut)
labels.append(label_cut)
datas = np.array(datas).reshape(-1, 1, 2048)
labels = np.array(labels).reshape(-1)
print("datas.shape: {}, labels.shape: {}".format(datas.shape, labels.shape))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
np.save(save_dir + '\\' + 'data.npy', datas)
np.save(save_dir + '\\' + 'label.npy', labels)
if __name__ == '__main__':
# dat_to_numpy(samples_num=800, sample_length=2048)
# draw_signal_img('C')
# draw_signal_img2()
# get_src_data('A', shot=10)
# get_src_data('B', shot=10)
# get_src_data('C', shot=10)
get_tar_data('A')
get_tar_data('B')
get_tar_data('C')
get_val_data('A')
get_val_data('B')
get_val_data('C')
get_test_data('A')
get_test_data('B')
get_test_data('C')
get_src_data('A', shot=10)
get_src_data('B', shot=10)
get_src_data('C', shot=10)
get_tar_data2('A', shot=10)
get_tar_data2('B', shot=10)
get_tar_data2('C', shot=10)