54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
# -*- encoding:utf-8 -*-
|
||
|
||
'''
|
||
@Author : dingjiawen
|
||
@Date : 2023/12/12 15:03
|
||
@Usage :
|
||
@Desc :
|
||
'''
|
||
|
||
# -*- coding: UTF-8 -*-
|
||
import numpy as np
|
||
import torch
|
||
from torch.autograd import Variable
|
||
# from visdom import Visdom # pip install Visdom
|
||
import setting
|
||
import dataset
|
||
from model import CNN
|
||
import encoding as encode
|
||
|
||
|
||
def main():
|
||
cnn = CNN()
|
||
cnn.eval()
|
||
cnn.load_state_dict(torch.load('./output/best_model.pkl', map_location=torch.device('cpu')))
|
||
print("load cnn net.")
|
||
|
||
predict_dataloader = dataset.get_predict_data_loader()
|
||
|
||
# vis = Visdom()
|
||
correct = 0
|
||
for i, (images, labels) in enumerate(predict_dataloader):
|
||
image = images
|
||
vimage = Variable(image)
|
||
predict_label = cnn(vimage)
|
||
actual_label = encode.decode(labels[0].numpy())
|
||
|
||
c0 = setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:setting.ALL_CHAR_SET_LEN].data.numpy())]
|
||
c1 = setting.ALL_CHAR_SET[np.argmax(
|
||
predict_label[0, setting.ALL_CHAR_SET_LEN:2 * setting.ALL_CHAR_SET_LEN].data.numpy())]
|
||
c2 = setting.ALL_CHAR_SET[np.argmax(
|
||
predict_label[0, 2 * setting.ALL_CHAR_SET_LEN:3 * setting.ALL_CHAR_SET_LEN].data.numpy())]
|
||
c3 = setting.ALL_CHAR_SET[np.argmax(
|
||
predict_label[0, 3 * setting.ALL_CHAR_SET_LEN:4 * setting.ALL_CHAR_SET_LEN].data.numpy())]
|
||
correct += 1 if "".join([c0, c1, c2, c3]) == actual_label else 0
|
||
c = '%s%s%s%s' % (c0, c1, c2, c3)
|
||
print(f"predict:{c},actual:{actual_label}")
|
||
# vis.images(image, opts=dict(caption=c))
|
||
print(f"total:{len(predict_dataloader)},correct:{correct}")
|
||
# 经过测试,最终正确率约74%,total:3000,correct:2233
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|