self_example/Spider/Chapter08_验证码的识别/深度学习识别图形验证码/predict.py

54 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/12/12 15:03
@Usage :
@Desc :
'''
# -*- coding: UTF-8 -*-
import numpy as np
import torch
from torch.autograd import Variable
# from visdom import Visdom # pip install Visdom
import setting
import dataset
from model import CNN
import encoding as encode
def main():
cnn = CNN()
cnn.eval()
cnn.load_state_dict(torch.load('./output/best_model.pkl', map_location=torch.device('cpu')))
print("load cnn net.")
predict_dataloader = dataset.get_predict_data_loader()
# vis = Visdom()
correct = 0
for i, (images, labels) in enumerate(predict_dataloader):
image = images
vimage = Variable(image)
predict_label = cnn(vimage)
actual_label = encode.decode(labels[0].numpy())
c0 = setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:setting.ALL_CHAR_SET_LEN].data.numpy())]
c1 = setting.ALL_CHAR_SET[np.argmax(
predict_label[0, setting.ALL_CHAR_SET_LEN:2 * setting.ALL_CHAR_SET_LEN].data.numpy())]
c2 = setting.ALL_CHAR_SET[np.argmax(
predict_label[0, 2 * setting.ALL_CHAR_SET_LEN:3 * setting.ALL_CHAR_SET_LEN].data.numpy())]
c3 = setting.ALL_CHAR_SET[np.argmax(
predict_label[0, 3 * setting.ALL_CHAR_SET_LEN:4 * setting.ALL_CHAR_SET_LEN].data.numpy())]
correct += 1 if "".join([c0, c1, c2, c3]) == actual_label else 0
c = '%s%s%s%s' % (c0, c1, c2, c3)
print(f"predict:{c},actual:{actual_label}")
# vis.images(image, opts=dict(caption=c))
print(f"total:{len(predict_dataloader)},correct:{correct}")
# 经过测试最终正确率约74%total:3000,correct:2233
if __name__ == '__main__':
main()