# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/12/12 15:03 @Usage : @Desc : ''' # -*- coding: UTF-8 -*- import numpy as np import torch from torch.autograd import Variable # from visdom import Visdom # pip install Visdom import setting import dataset from model import CNN import encoding as encode def main(): cnn = CNN() cnn.eval() cnn.load_state_dict(torch.load('./output/best_model.pkl', map_location=torch.device('cpu'))) print("load cnn net.") predict_dataloader = dataset.get_predict_data_loader() # vis = Visdom() correct = 0 for i, (images, labels) in enumerate(predict_dataloader): image = images vimage = Variable(image) predict_label = cnn(vimage) actual_label = encode.decode(labels[0].numpy()) c0 = setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:setting.ALL_CHAR_SET_LEN].data.numpy())] c1 = setting.ALL_CHAR_SET[np.argmax( predict_label[0, setting.ALL_CHAR_SET_LEN:2 * setting.ALL_CHAR_SET_LEN].data.numpy())] c2 = setting.ALL_CHAR_SET[np.argmax( predict_label[0, 2 * setting.ALL_CHAR_SET_LEN:3 * setting.ALL_CHAR_SET_LEN].data.numpy())] c3 = setting.ALL_CHAR_SET[np.argmax( predict_label[0, 3 * setting.ALL_CHAR_SET_LEN:4 * setting.ALL_CHAR_SET_LEN].data.numpy())] correct += 1 if "".join([c0, c1, c2, c3]) == actual_label else 0 c = '%s%s%s%s' % (c0, c1, c2, c3) print(f"predict:{c},actual:{actual_label}") # vis.images(image, opts=dict(caption=c)) print(f"total:{len(predict_dataloader)},correct:{correct}") # 经过测试,最终正确率约74%,total:3000,correct:2233 if __name__ == '__main__': main()