self_example/pytorch_example/example/load_data.py

111 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
@Author: miykah
@Email: miykah@163.com
@FileName: load_data.py
@DateTime: 2022/7/20 16:40
"""
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
'''正常Dataset类'''
class Nor_Dataset(Dataset):
def __init__(self, datas, labels=None):
self.datas = torch.tensor(datas)
if labels is not None:
self.labels = torch.tensor(labels)
else:
self.labels = None
def __getitem__(self, index):
data = self.datas[index]
if self.labels is not None:
label = self.labels[index]
return data, label
return data
def __len__(self):
return len(self.datas)
'''未标记目标数据的Dataset类'''
class Tar_U_Dataset(Dataset):
def __init__(self, datas):
self.datas = torch.tensor(datas)
def __getitem__(self, index):
data = self.datas[index]
data_bar = data.clone().detach()
data_bar2 = data.clone().detach()
mu = 0
sigma = 0.1
'''对未标记目标数据加噪得到data'和data'' '''
for i, j in zip(range(data_bar[0].shape[0]), range(data_bar2[0].shape[0])):
data_bar[0, i] += random.gauss(mu, sigma)
data_bar2[0, j] += random.gauss(mu, sigma)
return data, data_bar, data_bar2
def __len__(self):
return len(self.datas)
def draw_signal_img(data, data_bar, data_bar2):
pic = plt.figure(figsize=(12, 6), dpi=100)
plt.rcParams['font.family'] = ['Arial Unicode MS', 'Microsoft YaHei', 'SimHei', 'sans-serif']
plt.rcParams['axes.unicode_minus'] = False
plt.subplot(3, 1, 1)
plt.plot(data, 'b')
plt.subplot(3, 1, 2)
plt.plot(data_bar, 'b')
plt.subplot(3, 1, 3)
plt.plot(data_bar2, 'b')
plt.show()
def get_dataset(src_condition, tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
tar_data = np.load(root_dir + tar_condition + '\\tar\\' + 'data.npy')
val_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
val_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
test_data = np.load(root_dir + tar_condition + '\\test\\' + 'data.npy')
test_label = np.load(root_dir + tar_condition + '\\test\\' + 'label.npy')
src_dataset = Nor_Dataset(src_data, src_label)
tar_dataset = Nor_Dataset(tar_data)
val_dataset = Nor_Dataset(val_data, val_label)
test_dataset = Nor_Dataset(test_data, test_label)
return src_dataset, tar_dataset, val_dataset, test_dataset
if __name__ == '__main__':
# pass
# tar_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\tar\\data.npy")
# tar_dataset = Nor_Dataset(datas=tar_data)
# tar_loader = DataLoader(dataset=tar_dataset, batch_size=len(tar_dataset), shuffle=False)
#
# for batch_idx, (data) in enumerate(tar_loader):
# print(data[0][0])
# # print(data_bar[0][0])
# # print(data_bar2[0][0])
# if batch_idx == 0:
# break
src_data = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\data.npy")
src_label = np.load("E:\\DataSet\\DDS_data_CDAC\\processed_data\\A\\src\\label.npy")
src_dataset = Nor_Dataset(datas=src_data, labels=src_label)
src_loader = DataLoader(dataset=src_dataset, batch_size=len(src_dataset), shuffle=False)
for batch_idx, (data, label) in enumerate(src_loader):
print(data[10][0])
print(label[10])
# print(data_bar[0][0])
# print(data_bar2[0][0])
if batch_idx == 0:
break