55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
#-*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/12/12 13:36
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
import torch.nn as nn
|
|
import setting
|
|
|
|
|
|
# CNN Model (2 conv layer)
|
|
class CNN(nn.Module):
|
|
def __init__(self):
|
|
super(CNN, self).__init__()
|
|
self.layer1 = nn.Sequential(
|
|
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(32),
|
|
nn.Dropout(0.5), # drop 50% of the neuron
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2))
|
|
self.layer2 = nn.Sequential(
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.Dropout(0.5), # drop 50% of the neuron
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2))
|
|
self.layer3 = nn.Sequential(
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.Dropout(0.5), # drop 50% of the neuron
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2))
|
|
self.fc = nn.Sequential(
|
|
# 除以8是因为MaxPool2d了3次
|
|
nn.Linear((setting.IMAGE_WIDTH // 8) * (setting.IMAGE_HEIGHT // 8) * 64, 1024),
|
|
nn.Dropout(0.5), # drop 50% of the neuron
|
|
nn.ReLU())
|
|
self.rfc = nn.Sequential(
|
|
# setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN是字典集的长度
|
|
nn.Linear(1024, setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN),
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.layer1(x)
|
|
out = self.layer2(out)
|
|
out = self.layer3(out)
|
|
# flatten展平
|
|
out = out.view(out.size(0), -1)
|
|
out = self.fc(out)
|
|
out = self.rfc(out)
|
|
return out |