""" @Author: miykah @Email: miykah@163.com @FileName: test.py @DateTime: 2022/7/9 14:15 """ import numpy as np import torch from torch.utils.data import DataLoader from DDS_GADAN.load_data import get_dataset from DDS_GADAN.load_data import Nor_Dataset from DDS_GADAN.model import Extractor, LabelClassifier from sklearn.metrics import confusion_matrix from sklearn.manifold import TSNE import seaborn as sns import matplotlib.pyplot as plt def load_data(tar_condition): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_6\\processed_data\\800EachCls_shot10\\' test_data = np.load(root_dir + tar_condition + '\\val\\' + 'data.npy') test_label = np.load(root_dir + tar_condition + '\\val\\' + 'label.npy') return test_data, test_label def tsne_2d_generate(cls, data, labels, pic_title): tsne2D = TSNE(n_components=2, verbose=2, perplexity=30).fit_transform(data) x, y = tsne2D[:, 0], tsne2D[:, 1] pic = plt.figure() ax1 = pic.add_subplot() ax1.scatter(x, y, c=labels, cmap=plt.cm.get_cmap("jet", cls)) # 9为9种颜色,因为标签有9类 plt.title(pic_title) plt.show() def tsne_2d_generate1(cls, data, labels, pic_title): parameters = {'figure.dpi': 600, 'figure.figsize': (4, 3), 'savefig.dpi': 600, 'xtick.direction': 'in', 'ytick.direction': 'in', 'xtick.labelsize': 10, 'ytick.labelsize': 10, 'legend.fontsize': 11.3, } plt.rcParams.update(parameters) plt.rc('font', family='Times New Roman') # 全局字体样式 tsne2D = TSNE(n_components=2, verbose=2, perplexity=30, random_state=3407, init='random', learning_rate=200).fit_transform(data) tsne2D_min, tsne2D_max = tsne2D.min(0), tsne2D.max(0) tsne2D_final = (tsne2D - tsne2D_min) / (tsne2D_max - tsne2D_min) s1, s2 = tsne2D_final[:1000, :], tsne2D_final[1000:, :] pic = plt.figure() # ax1 = pic.add_subplot() plt.scatter(s1[:, 0], s1[:, 1], c=labels[:1000], cmap=plt.cm.get_cmap("jet", cls), marker='o', alpha=0.3) # 9为9种颜色,因为标签有9类 plt.scatter(s2[:, 0], s2[:, 1], c=labels[1000:], cmap=plt.cm.get_cmap("jet", cls), marker='x', alpha=0.3) # 9为9种颜色,因为标签有9类 plt.title(pic_title, fontsize=10) # plt.xticks([]) # plt.yticks([]) plt.show() def plot_confusion_matrix_accuracy(cls, true_labels, predict_labels): # # 画混淆矩阵 # confusion = confusion_matrix(true_labels, predict_labels) # # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis] # plt.figure(figsize=(6.4,6.4), dpi=100) # sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens") # # sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues") # indices = range(len(confusion)) # classes = ['N', 'IF', 'OF', 'TRC', 'TSP'] # # for i in range(cls): # # classes.append(str(i)) # # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表 # # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 # # plt.yticks(indices, classes, rotation=45) # plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 # plt.yticks([index + 0.5 for index in indices], classes, rotation=45) # plt.ylabel('Actual label') # plt.xlabel('Predicted label') # plt.title('confusion matrix') # plt.show() sns.set(font_scale=1.5) parameters = {'figure.dpi': 600, 'figure.figsize': (5, 5), 'savefig.dpi': 600, 'xtick.direction': 'in', 'ytick.direction': 'in', 'xtick.labelsize': 20, 'ytick.labelsize': 20, 'legend.fontsize': 11.3, } plt.rcParams.update(parameters) plt.figure() plt.rc('font', family='Times New Roman') # 全局字体样式 # 画混淆矩阵 confusion = confusion_matrix(true_labels, predict_labels) # confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis] plt.figure() # sns.heatmap(confusion, annot=True, fmt="d", cmap="Greens") sns.heatmap(confusion, annot=True, fmt="d", cmap="Blues", vmax=100, vmin=0, cbar=None, square=True) indices = range(len(confusion)) classes = ['N', 'IF', 'OF', 'TRC', 'TSP'] # for i in range(cls): # classes.append(str(i)) # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表 # plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 # plt.yticks(indices, classes, rotation=45) plt.xticks([index + 0.5 for index in indices], classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜 plt.yticks([index + 0.5 for index in indices], classes, rotation=45) plt.ylabel('Actual label', fontsize=20) plt.xlabel('Predicted label', fontsize=20) # plt.tight_layout() plt.show() def test(cls, tar_condition, G_params_path, LC_params_path): test_data, test_label = load_data(tar_condition) test_dataset = Nor_Dataset(test_data, test_label) batch_size = len(test_dataset) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) # 加载网络 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") G = Extractor().to(device) LC = LabelClassifier(cls_num=cls).to(device) G.load_state_dict( torch.load(G_params_path, map_location=device) ) LC.load_state_dict( torch.load(LC_params_path, map_location=device) ) # print(net) # params_num = sum(param.numel() for param in net.parameters_bak()) # print('参数数量:{}'.format(params_num)) test_acc = 0.0 G.eval() LC.eval() with torch.no_grad(): for batch_idx, (data, label) in enumerate(test_loader): data, label = data.to(device), label.to(device) _, feature = G(data, data) _, _, _, output = LC(feature, feature) test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy()) predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1) labels = label.cpu().numpy() predictions = output.cpu().detach().numpy() tsne_2d_generate(cls, predictions, labels, "output of neural network") plot_confusion_matrix_accuracy(cls, labels, predict_labels) print("测试集大小为{}, 成功{},准确率为{:.6f}".format(test_dataset.__len__(), test_acc, test_acc / test_dataset.__len__())) def test1(cls, src_condition, tar_condition, G_params_path, LC_params_path): if torch.cuda.is_available(): root_dir = 'D:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' else: root_dir = 'E:\\DataSet\\DDS_data\\5cls_tsne_2\\processed_data\\800EachCls_shot10\\' src_data = np.load(root_dir + src_condition + '\\src\\' + 'data.npy') src_label = np.load(root_dir + src_condition + '\\src\\' + 'label.npy') tar_data = np.load(root_dir + tar_condition + '\\tar1\\' + 'data.npy') tar_label = np.load(root_dir + tar_condition + '\\tar1\\' + 'label.npy') src_dataset = Nor_Dataset(src_data, src_label) tar_dataset = Nor_Dataset(tar_data, tar_label) src_loader = DataLoader(dataset=src_dataset, batch_size=1000, shuffle=False, drop_last=True) tar_loader = DataLoader(dataset=tar_dataset, batch_size=1000, shuffle=False, drop_last=True) # 加载网络 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") G = Extractor().to(device) LC = LabelClassifier(cls_num=cls).to(device) G.load_state_dict( torch.load(G_params_path, map_location=device) ) LC.load_state_dict( torch.load(LC_params_path, map_location=device) ) # print(net) # params_num = sum(param.numel() for param in net.parameters_bak()) # print('参数数量:{}'.format(params_num)) test_acc = 0.0 G.eval() LC.eval() with torch.no_grad(): for (src_batch_idx, (src_data, src_label)), (tar_batch_idx, (tar_data, tar_label)) in zip(enumerate(src_loader), enumerate(tar_loader)): src_data, src_label = src_data.to(device), src_label.to(device) tar_data, tar_label = tar_data.to(device), tar_label.to(device) data = torch.concat((src_data, tar_data), dim=0) label = torch.concat((src_label, tar_label), dim=0) _, feature = G(data, data) _, _, fc1, output = LC(feature, feature) test_acc += np.sum(np.argmax(output.cpu().detach().numpy(), axis=1) == label.cpu().numpy()) predict_labels = np.argmax(output.cpu().detach().numpy(), axis=1) labels = label.cpu().numpy() outputs = output.cpu().detach().numpy() fc1_outputs = fc1.cpu().detach().numpy() break tsne_2d_generate1(cls, fc1_outputs, labels, "GADAN") plot_confusion_matrix_accuracy(cls, labels, predict_labels) print("准确率为{:.6f}".format(test_acc / (src_dataset.__len__() + tar_dataset.__len__()))) if __name__ == '__main__': # pass test(5, 'B', 'parameters/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl', 'parameters/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl') # test1(5, 'A', 'B', 'parameters_bak/A_to_B/G/G_shot10_epoch200_lr0.002_miu0.5.pkl', # 'parameters_bak/A_to_B/LC/LC_shot10_epoch200_lr0.002_miu0.5.pkl')