111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
"""
|
||
@Author: miykah
|
||
@Email: miykah@163.com
|
||
@FileName: load_data.py
|
||
@DateTime: 2022/7/20 16:40
|
||
"""
|
||
import matplotlib.pyplot as plt
|
||
import torch
|
||
import numpy as np
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import torchvision.transforms as transforms
|
||
import random
|
||
|
||
'''正常Dataset类'''
|
||
class Nor_Dataset(Dataset):
|
||
def __init__(self, datas, labels=None):
|
||
self.datas = torch.tensor(datas)
|
||
if labels is not None:
|
||
self.labels = torch.tensor(labels)
|
||
else:
|
||
self.labels = None
|
||
def __getitem__(self, index):
|
||
data = self.datas[index]
|
||
if self.labels is not None:
|
||
label = self.labels[index]
|
||
return data, label
|
||
return data
|
||
def __len__(self):
|
||
return len(self.datas)
|
||
|
||
'''未标记目标数据的Dataset类'''
|
||
class Tar_U_Dataset(Dataset):
|
||
def __init__(self, datas):
|
||
self.datas = torch.tensor(datas)
|
||
def __getitem__(self, index):
|
||
data = self.datas[index]
|
||
data_bar = data.clone().detach()
|
||
data_bar2 = data.clone().detach()
|
||
mu = 0
|
||
sigma = 0.1
|
||
'''对未标记目标数据加噪,得到data'和data'' '''
|
||
for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])):
|
||
data_bar[0, i] += random.gauss(mu, sigma)
|
||
data_bar2[0, j] += random.gauss(mu, sigma)
|
||
return data, data_bar, data_bar2
|
||
def __len__(self):
|
||
return len(self.datas)
|
||
|
||
def draw_signal_img(data, data_bar, data_bar2):
|
||
pic = plt.figure(figsize=(12, 6), dpi=100)
|
||
plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
plt.subplot(3, 1, 1)
|
||
plt.plot(data, 'b')
|
||
plt.subplot(3, 1, 2)
|
||
plt.plot(data_bar, 'b')
|
||
plt.subplot(3, 1, 3)
|
||
plt.plot(data_bar2, 'b')
|
||
plt.show()
|
||
|
||
def get_dataset(src_condition, tar_condition):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||
|
||
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
|
||
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
|
||
|
||
tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy')
|
||
|
||
val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
|
||
val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
|
||
|
||
test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy')
|
||
test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy')
|
||
|
||
src_dataset = Nor_Dataset(src_data, src_label)
|
||
tar_dataset = Nor_Dataset(tar_data)
|
||
val_dataset = Nor_Dataset(val_data, val_label)
|
||
test_dataset = Nor_Dataset(test_data, test_label)
|
||
|
||
return src_dataset, tar_dataset, val_dataset, test_dataset
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# pass
|
||
# tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy")
|
||
# tar_dataset = Nor_Dataset(datas=tar_data)
|
||
# tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False)
|
||
#
|
||
# for batch_idx, (data) in enumerate(tar_loader):
|
||
# print(data[0][0])
|
||
# # print(data_bar[0][0])
|
||
# # print(data_bar2[0][0])
|
||
# if batch_idx == 0:
|
||
# break
|
||
src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy")
|
||
src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy")
|
||
src_dataset = Nor_Dataset(datas=src_data, labels=src_label)
|
||
src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False)
|
||
|
||
for batch_idx, (data, label) in enumerate(src_loader):
|
||
print(data[10][0])
|
||
print(label[10])
|
||
# print(data_bar[0][0])
|
||
# print(data_bar2[0][0])
|
||
if batch_idx == 0:
|
||
break |