67 lines
1.7 KiB
Python
67 lines
1.7 KiB
Python
# -*- encoding:utf-8 -*-
|
|
|
|
'''
|
|
@Author : dingjiawen
|
|
@Date : 2023/12/12 13:40
|
|
@Usage :
|
|
@Desc :
|
|
'''
|
|
|
|
# -*- coding: UTF-8 -*-
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.autograd import Variable
|
|
import dataset
|
|
from model import CNN
|
|
from evaluate import main as evaluate
|
|
import os
|
|
import os.path
|
|
|
|
num_epochs = 30
|
|
batch_size = 100
|
|
learning_rate = 0.001
|
|
|
|
output = './output'
|
|
os.path.exists(output) or os.makedirs(output)
|
|
|
|
|
|
def main():
|
|
cnn = CNN()
|
|
cnn.train()
|
|
|
|
criterion = nn.MultiLabelSoftMarginLoss()
|
|
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)
|
|
max_eval_acc = -1
|
|
|
|
train_dataloader = dataset.get_train_data_loader()
|
|
for epoch in range(num_epochs):
|
|
model_path = os.path.join(output, "model.pkl")
|
|
for i, (images, labels) in enumerate(train_dataloader):
|
|
# 在这里变成可以torch梯度autograd的变量
|
|
images = Variable(images)
|
|
labels = Variable(labels.float())
|
|
predict_labels = cnn(images)
|
|
loss = criterion(predict_labels, labels)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
if (i + 1) % 10 == 0:
|
|
print("epoch:", epoch, "step:", i, "loss:", loss.item())
|
|
|
|
|
|
|
|
print("epoch:", epoch, "step:", i, "loss:", loss.item())
|
|
torch.save(cnn.state_dict(), model_path)
|
|
print("save model")
|
|
eval_acc = evaluate(model_path)
|
|
if eval_acc > max_eval_acc:
|
|
# best model save as best_model.pkl
|
|
torch.save(cnn.state_dict(), os.path.join(output, "best_model.pkl"))
|
|
print("save best model")
|
|
torch.save(cnn.state_dict(), os.path.join(output, "model.pkl"))
|
|
print("save last model")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|