""" @Author: miykah @Email: miykah@163.com @FileName: load_data.py @DateTime: 2022/7/20 16:40 """ import matplotlib.pyplot as plt import torch import numpy as np from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms import random '''正常Dataset类''' class Nor_Dataset(Dataset): def __init__(self, datas, labels=None): self.datas = torch.tensor(datas) if labels is not None: self.labels = torch.tensor(labels) else: self.labels = None def __getitem__(self, index): data = self.datas[index] if self.labels is not None: label = self.labels[index] return data, label return data def __len__(self): return len(self.datas) '''未标记目标数据的Dataset类''' class Tar_U_Dataset(Dataset): def __init__(self, datas): self.datas = torch.tensor(datas) def __getitem__(self, index): data = self.datas[index] data_bar = data.clone().detach() data_bar2 = data.clone().detach() mu = 0 sigma = 0.1 '''对未标记目标数据加噪,得到data'和data'' ''' for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])): data_bar[0, i] += random.gauss(mu, sigma) data_bar2[0, j] += random.gauss(mu, sigma) return data, data_bar, data_bar2 def __len__(self): return len(self.datas) def draw_signal_img(data, data_bar, data_bar2): pic = plt.figure(figsize=(12, 6), dpi=100) plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif'] plt.rcParams['axes.unicode_minus'] = False plt.subplot(3, 1, 1) plt.plot(data, 'b') plt.subplot(3, 1, 2) plt.plot(data_bar, 'b') plt.subplot(3, 1, 3) plt.plot(data_bar2, 'b') plt.show() def get_dataset(src_condition, tar_condition): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy') src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy') tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy') val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy') val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy') test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy') test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy') src_dataset = Nor_Dataset(src_data, src_label) tar_dataset = Nor_Dataset(tar_data) val_dataset = Nor_Dataset(val_data, val_label) test_dataset = Nor_Dataset(test_data, test_label) return src_dataset, tar_dataset, val_dataset, test_dataset if __name__ == '__main__': # pass # tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy") # tar_dataset = Nor_Dataset(datas=tar_data) # tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False) # # for batch_idx, (data) in enumerate(tar_loader): # print(data[0][0]) # # print(data_bar[0][0]) # # print(data_bar2[0][0]) # if batch_idx == 0: # break src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy") src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy") src_dataset = Nor_Dataset(datas=src_data, labels=src_label) src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False) for batch_idx, (data, label) in enumerate(src_loader): print(data[10][0]) print(label[10]) # print(data_bar[0][0]) # print(data_bar2[0][0]) if batch_idx == 0: break