self_example/pytorch_example/RUL/otherIdea/dctLSTM/loadData.py

139 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/11/10 15:21
@Usage :
@Desc : 获取数据集
'''
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
'''正常Dataset类'''
class Nor_Dataset(Dataset):
def __init__(self, datas, labels=None):
self.datas = torch.tensor(datas)
if labels is not None:
self.labels = torch.tensor(labels)
else:
self.labels = None
def __getitem__(self, index):
data = self.datas[index]
if self.labels is not None:
label = self.labels[index]
return data, label
return data
def __len__(self):
return len(self.datas)
def standardization(data):
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
return (data - mu) / sigma
def normalization(data):
_range = np.max(data) - np.min(data)
return (data - np.min(data)) / _range
# LSTM_cell的数目,维度,是否正则化
def getData(filter_num, dims, if_norm: bool = False):
# 数据读入
HI_merge_data_origin = np.load("../../dataset/HI_merge_data.npy")
# plt.plot(HI_merge_data[0:1250, 1])
# 去除掉退化特征不明显前面的点
HI_merge_data = HI_merge_data_origin[0:1250, 1]
# 是否正则化
if if_norm:
HI_merge_data = normalization(HI_merge_data)
# plt.plot(HI_merge_data)
# plt.show()
(total_dims,) = HI_merge_data.shape
# # 将其分成重叠采样状态-滑动窗口函数
predict_data = np.empty(shape=[total_dims - filter_num, filter_num])
# 重叠采样获取时间部和训练次数
for dim in range(total_dims - filter_num):
predict_data[dim] = HI_merge_data[dim:dim + filter_num]
train_label = predict_data[dims:, :]
train_label_single = HI_merge_data[dims + filter_num - 1:-1]
# 再重叠采样获取一个点的维度
'''train_data.shape:(sample,filter_num) -> (sample,filter_num,dims)'''
# # 将其分成重叠采样状态-滑动窗口函数
train_data = np.empty(shape=[dims, total_dims - filter_num - dims, filter_num])
for dim in range(dims):
train_data[dim] = predict_data[dim:total_dims - filter_num - dims + dim, :]
# 转置变成想要的数据 (dims,sample,filter_num) -> (sample,filter_num,dims)
train_data = np.transpose(train_data, [1, 2, 0])
# todo 解决模型保存时,query无法序列化的问题
total_data = HI_merge_data
print("total_data.shape:", total_data.shape)
print("train_data.shape:", train_data.shape) # (20, 1200, 30)
print("train_label.shape:", train_label.shape) # (20, 1200)
print("train_label_single.shape:", train_label_single.shape)
# 所有的原始数据;所有的训练数据;所有的训练标签(预测一个序列);所有的训练标签(预测一个点)
return total_data, train_data, train_label, train_label_single
def splitValData(data, label, label_single, predict_num=50):
sample, hidden, feature = data.shape
train_data = data[:sample - predict_num, :, :]
val_data = data[sample - predict_num:, :, :]
train_label = label[:sample - predict_num, :]
val_label = label[sample - predict_num:, :]
train_label_single = label_single[:sample - predict_num, ]
val_label_single = label_single[sample - predict_num:, ]
return train_data, val_data, train_label, val_label, train_label_single, val_label_single
def getTotalData(hidden_num, feature, is_single=True, is_norm=False):
total_data, train_data, train_label, train_label_single = getData(hidden_num, feature, is_norm)
if is_single:
total_dataset = Nor_Dataset(train_data, train_label_single)
else:
total_dataset = Nor_Dataset(train_data, train_label)
return total_data, total_dataset
# lstm细胞数channel数预测多少个点是否正则化
def getDataset(hidden_num, feature, predict_num, is_single=True, is_norm=False):
total_data, train_data, train_label, train_label_single = getData(hidden_num, feature, is_norm)
# 根据预测的点数划分训练集和测试集(验证集)
train_data, val_data, train_label, val_label, train_label_single, val_label_single = splitValData(train_data,
train_label,
train_label_single,
predict_num=predict_num)
if is_single:
train_dataset = Nor_Dataset(train_data, train_label_single)
val_dataset = Nor_Dataset(val_data, val_label_single)
else:
train_dataset = Nor_Dataset(train_data, train_label)
val_dataset = Nor_Dataset(val_data, val_label)
return train_dataset, val_dataset