#-*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/12/12 13:36 @Usage : @Desc : ''' # -*- coding: UTF-8 -*- import torch.nn as nn import setting # CNN Model (2 conv layer) class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.Dropout(0.5), # drop 50% of the neuron nn.ReLU(), nn.MaxPool2d(2)) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.Dropout(0.5), # drop 50% of the neuron nn.ReLU(), nn.MaxPool2d(2)) self.layer3 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.Dropout(0.5), # drop 50% of the neuron nn.ReLU(), nn.MaxPool2d(2)) self.fc = nn.Sequential( # 除以8是因为MaxPool2d了3次 nn.Linear((setting.IMAGE_WIDTH // 8) * (setting.IMAGE_HEIGHT // 8) * 64, 1024), nn.Dropout(0.5), # drop 50% of the neuron nn.ReLU()) self.rfc = nn.Sequential( # setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN是字典集的长度 nn.Linear(1024, setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN), ) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) # flatten展平 out = out.view(out.size(0), -1) out = self.fc(out) out = self.rfc(out) return out