self_example/pytorch_example/example/test.py

226 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
@Author: miykah
@Email: miykah@163.com
@FileName: test.py
@DateTime: 2022/7/9 14:15
"""
import numpy as np
import torch
from torch.utils.data import DataLoader
from DDS_GADAN.load_data import get_dataset
from DDS_GADAN.load_data import Nor_Dataset
from DDS_GADAN.model import Extractor, LabelClassifier
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
def load_data(tar_condition):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
return test_data, test_label
def tsne_2d_generate(cls, data, labels, pic_title):
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data)
x, y = tsne2D[:, 0], tsne2D[:, 1]
pic = plt.figure()
ax1 = pic.add_subplot()
ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色因为标签有9类
plt.title(pic_title)
plt.show()
def tsne_2d_generate1(cls, data, labels, pic_title):
parameters = {'figure.dpi': 600,
'figure.figsize': (4, 3),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.rc('font', family='Times New Roman') # 全局字体样式
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data)
tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0)
tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min)
s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :]
pic = plt.figure()
# ax1 = pic.add_subplot()
plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色因为标签有9类
plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色因为标签有9类
plt.title(pic_title, fontsize=10)
# plt.xticks([])
# plt.yticks([])
plt.show()
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
# # 画混淆矩阵
# confusion = confusion_matrix(true_labels, predict_labels)
# # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
# plt.figure(figsize=(6.4,6.4), dpi=100)
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
# # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
# indices = range(len(confusion))
# classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
# # for i in range(cls):
# # classes.append(str(i))
# # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# # plt.yticks(indices, classes, rotation=45)
# plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
# plt.ylabel('Actual label')
# plt.xlabel('Predicted label')
# plt.title('confusion matrix')
# plt.show()
sns.set(font_scale=1.5)
parameters = {'figure.dpi': 600,
'figure.figsize': (5, 5),
'savefig.dpi': 600,
'xtick.direction': 'in',
'ytick.direction': 'in',
'xtick.labelsize': 20,
'ytick.labelsize': 20,
'legend.fontsize': 11.3,
}
plt.rcParams.update(parameters)
plt.figure()
plt.rc('font', family='Times New Roman') # 全局字体样式
# 画混淆矩阵
confusion = confusion_matrix(true_labels, predict_labels)
# confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
plt.figure()
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
indices = range(len(confusion))
classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
# for i in range(cls):
# classes.append(str(i))
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
# plt.xticks(indices, classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
# plt.yticks(indices, classes, rotation=45)
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向rotation=45为45度倾斜
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
plt.ylabel('Actual label', fontsize=20)
plt.xlabel('Predicted label', fontsize=20)
# plt.tight_layout()
plt.show()
def test(cls, tar_condition, G_params_path, LC_params_path):
test_data, test_label = load_data(tar_condition)
test_dataset = Nor_Dataset(test_data, test_label)
batch_size = len(test_dataset)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
# 加载网络
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G = Extractor().to(device)
LC = LabelClassifier(cls_num=cls).to(device)
G.load_state_dict(
torch.load(G_params_path, map_location=device)
)
LC.load_state_dict(
torch.load(LC_params_path, map_location=device)
)
# print(net)
# params_num = sum(param.numel() for param in net.parameters_bak())
# print('参数数量:{}'.format(params_num))
test_acc = 0.0
G.eval()
LC.eval()
with torch.no_grad():
for batch_idx, (data, label) in enumerate(test_loader):
data, label = data.to(device), label.to(device)
_, feature = G(data, data)
_, _, _, output = LC(feature, feature)
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
labels = label.cpu().numpy()
predictions = output.cpu().detach().numpy()
tsne_2d_generate(cls, predictions, labels, "output of neural network")
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__()))
def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path):
if torch.cuda.is_available():
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
else:
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy')
tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy')
src_dataset = Nor_Dataset(src_data, src_label)
tar_dataset = Nor_Dataset(tar_data, tar_label)
src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True)
tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True)
# 加载网络
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G = Extractor().to(device)
LC = LabelClassifier(cls_num=cls).to(device)
G.load_state_dict(
torch.load(G_params_path, map_location=device)
)
LC.load_state_dict(
torch.load(LC_params_path, map_location=device)
)
# print(net)
# params_num = sum(param.numel() for param in net.parameters_bak())
# print('参数数量:{}'.format(params_num))
test_acc = 0.0
G.eval()
LC.eval()
with torch.no_grad():
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)):
src_data, src_label = src_data.to(device), src_label.to(device)
tar_data, tar_label = tar_data.to(device), tar_label.to(device)
data = torch.concat((src_data, tar_data), dim=0)
label = torch.concat((src_label, tar_label), dim=0)
_, feature = G(data, data)
_, _, fc1, output = LC(feature, feature)
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
labels = label.cpu().numpy()
outputs = output.cpu().detach().numpy()
fc1_outputs = fc1.cpu().detach().numpy()
break
tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN")
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__())))
if __name__ == '__main__':
# pass
test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
# test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
# 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')