self_example/Spider/Chapter08_验证码的识别/深度学习识别图形验证码/model.py

55 lines
1.6 KiB
Python

#-*- encoding:utf-8 -*-
'''
@Author : dingjiawen
@Date : 2023/12/12 13:36
@Usage :
@Desc :
'''
# -*- coding: UTF-8 -*-
import torch.nn as nn
import setting
# CNN Model (2 conv layer)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.Dropout(0.5), # drop 50% of the neuron
nn.ReLU(),
nn.MaxPool2d(2))
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.Dropout(0.5), # drop 50% of the neuron
nn.ReLU(),
nn.MaxPool2d(2))
self.layer3 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.Dropout(0.5), # drop 50% of the neuron
nn.ReLU(),
nn.MaxPool2d(2))
self.fc = nn.Sequential(
# 除以8是因为MaxPool2d了3次
nn.Linear((setting.IMAGE_WIDTH // 8) * (setting.IMAGE_HEIGHT // 8) * 64, 1024),
nn.Dropout(0.5), # drop 50% of the neuron
nn.ReLU())
self.rfc = nn.Sequential(
# setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN是字典集的长度
nn.Linear(1024, setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN),
)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
# flatten展平
out = out.view(out.size(0), -1)
out = self.fc(out)
out = self.rfc(out)
return out