226 lines
9.6 KiB
Python
226 lines
9.6 KiB
Python
"""
|
||
@Author: miykah
|
||
@Email: miykah@163.com
|
||
@FileName: test.py
|
||
@DateTime: 2022/7/9 14:15
|
||
"""
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from DDS_GADAN.load_data import get_dataset
|
||
from DDS_GADAN.load_data import Nor_Dataset
|
||
from DDS_GADAN.model import Extractor, LabelClassifier
|
||
from sklearn.metrics import confusion_matrix
|
||
from sklearn.manifold import TSNE
|
||
import seaborn as sns
|
||
import matplotlib.pyplot as plt
|
||
|
||
def load_data(tar_condition):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\'
|
||
|
||
test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy')
|
||
test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy')
|
||
|
||
return test_data, test_label
|
||
|
||
def tsne_2d_generate(cls, data, labels, pic_title):
|
||
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data)
|
||
x, y = tsne2D[:, 0], tsne2D[:, 1]
|
||
pic = plt.figure()
|
||
ax1 = pic.add_subplot()
|
||
ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色,因为标签有9类
|
||
plt.title(pic_title)
|
||
plt.show()
|
||
|
||
def tsne_2d_generate1(cls, data, labels, pic_title):
|
||
parameters = {'figure.dpi': 600,
|
||
'figure.figsize': (4, 3),
|
||
'savefig.dpi': 600,
|
||
'xtick.direction': 'in',
|
||
'ytick.direction': 'in',
|
||
'xtick.labelsize': 10,
|
||
'ytick.labelsize': 10,
|
||
'legend.fontsize': 11.3,
|
||
}
|
||
plt.rcParams.update(parameters)
|
||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||
tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data)
|
||
tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0)
|
||
tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min)
|
||
s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :]
|
||
pic = plt.figure()
|
||
# ax1 = pic.add_subplot()
|
||
plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色,因为标签有9类
|
||
plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色,因为标签有9类
|
||
plt.title(pic_title, fontsize=10)
|
||
# plt.xticks([])
|
||
# plt.yticks([])
|
||
plt.show()
|
||
|
||
def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels):
|
||
# # 画混淆矩阵
|
||
# confusion = confusion_matrix(true_labels, predict_labels)
|
||
# # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
|
||
# plt.figure(figsize=(6.4,6.4), dpi=100)
|
||
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
|
||
# # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
|
||
# indices = range(len(confusion))
|
||
# classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
|
||
# # for i in range(cls):
|
||
# # classes.append(str(i))
|
||
# # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||
# # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||
# # plt.yticks(indices, classes, rotation=45)
|
||
# plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||
# plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||
# plt.ylabel('Actual label')
|
||
# plt.xlabel('Predicted label')
|
||
# plt.title('confusion matrix')
|
||
# plt.show()
|
||
sns.set(font_scale=1.5)
|
||
parameters = {'figure.dpi': 600,
|
||
'figure.figsize': (5, 5),
|
||
'savefig.dpi': 600,
|
||
'xtick.direction': 'in',
|
||
'ytick.direction': 'in',
|
||
'xtick.labelsize': 20,
|
||
'ytick.labelsize': 20,
|
||
'legend.fontsize': 11.3,
|
||
}
|
||
plt.rcParams.update(parameters)
|
||
plt.figure()
|
||
plt.rc('font', family='Times New Roman') # 全局字体样式
|
||
# 画混淆矩阵
|
||
confusion = confusion_matrix(true_labels, predict_labels)
|
||
# confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
|
||
plt.figure()
|
||
# sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens")
|
||
sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True)
|
||
indices = range(len(confusion))
|
||
classes = ['N', 'IF', 'OF', 'TRC', 'TSP']
|
||
# for i in range(cls):
|
||
# classes.append(str(i))
|
||
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
|
||
# plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||
# plt.yticks(indices, classes, rotation=45)
|
||
plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
|
||
plt.yticks([index + 0.5 for index in indices], classes, rotation=45)
|
||
plt.ylabel('Actual label', fontsize=20)
|
||
plt.xlabel('Predicted label', fontsize=20)
|
||
# plt.tight_layout()
|
||
plt.show()
|
||
|
||
|
||
def test(cls, tar_condition, G_params_path, LC_params_path):
|
||
test_data, test_label = load_data(tar_condition)
|
||
test_dataset = Nor_Dataset(test_data, test_label)
|
||
|
||
batch_size = len(test_dataset)
|
||
|
||
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
|
||
|
||
# 加载网络
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
G = Extractor().to(device)
|
||
LC = LabelClassifier(cls_num=cls).to(device)
|
||
G.load_state_dict(
|
||
torch.load(G_params_path, map_location=device)
|
||
)
|
||
LC.load_state_dict(
|
||
torch.load(LC_params_path, map_location=device)
|
||
)
|
||
# print(net)
|
||
# params_num = sum(param.numel() for param in net.parameters_bak())
|
||
# print('参数数量:{}'.format(params_num))
|
||
|
||
test_acc = 0.0
|
||
|
||
G.eval()
|
||
LC.eval()
|
||
with torch.no_grad():
|
||
for batch_idx, (data, label) in enumerate(test_loader):
|
||
data, label = data.to(device), label.to(device)
|
||
_, feature = G(data, data)
|
||
_, _, _, output = LC(feature, feature)
|
||
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
|
||
|
||
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
|
||
labels = label.cpu().numpy()
|
||
|
||
predictions = output.cpu().detach().numpy()
|
||
|
||
tsne_2d_generate(cls, predictions, labels, "output of neural network")
|
||
|
||
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
|
||
|
||
print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__()))
|
||
|
||
def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path):
|
||
if torch.cuda.is_available():
|
||
root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||
else:
|
||
root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\'
|
||
|
||
src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy')
|
||
src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy')
|
||
|
||
tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy')
|
||
tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy')
|
||
|
||
src_dataset = Nor_Dataset(src_data, src_label)
|
||
tar_dataset = Nor_Dataset(tar_data, tar_label)
|
||
src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True)
|
||
tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True)
|
||
|
||
# 加载网络
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
G = Extractor().to(device)
|
||
LC = LabelClassifier(cls_num=cls).to(device)
|
||
G.load_state_dict(
|
||
torch.load(G_params_path, map_location=device)
|
||
)
|
||
LC.load_state_dict(
|
||
torch.load(LC_params_path, map_location=device)
|
||
)
|
||
# print(net)
|
||
# params_num = sum(param.numel() for param in net.parameters_bak())
|
||
# print('参数数量:{}'.format(params_num))
|
||
|
||
test_acc = 0.0
|
||
|
||
G.eval()
|
||
LC.eval()
|
||
with torch.no_grad():
|
||
for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)):
|
||
src_data, src_label = src_data.to(device), src_label.to(device)
|
||
tar_data, tar_label = tar_data.to(device), tar_label.to(device)
|
||
data = torch.concat((src_data, tar_data), dim=0)
|
||
label = torch.concat((src_label, tar_label), dim=0)
|
||
_, feature = G(data, data)
|
||
_, _, fc1, output = LC(feature, feature)
|
||
test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy())
|
||
|
||
predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1)
|
||
labels = label.cpu().numpy()
|
||
|
||
outputs = output.cpu().detach().numpy()
|
||
fc1_outputs = fc1.cpu().detach().numpy()
|
||
break
|
||
|
||
tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN")
|
||
|
||
plot_confusion_matrix_accuracy(cls, labels, predict_labels)
|
||
|
||
print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__())))
|
||
|
||
if __name__ == '__main__':
|
||
# pass
|
||
test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
|
||
'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')
|
||
|
||
# test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl',
|
||
# 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl') |