# -*- encoding:utf-8 -*- ''' @Author : dingjiawen @Date : 2023/12/12 13:40 @Usage : @Desc : ''' # -*- coding: UTF-8 -*- import torch import torch.nn as nn from torch.autograd import Variable import dataset from model import CNN from evaluate import main as evaluate import os import os.path num_epochs = 30 batch_size = 100 learning_rate = 0.001 output = './output' os.path.exists(output) or os.makedirs(output) def main(): cnn = CNN() cnn.train() criterion = nn.MultiLabelSoftMarginLoss() optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) max_eval_acc = -1 train_dataloader = dataset.get_train_data_loader() for epoch in range(num_epochs): model_path = os.path.join(output, "model.pkl") for i, (images, labels) in enumerate(train_dataloader): # 在这里变成可以torch梯度autograd的变量 images = Variable(images) labels = Variable(labels.float()) predict_labels = cnn(images) loss = criterion(predict_labels, labels) optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 10 == 0: print("epoch:", epoch, "step:", i, "loss:", loss.item()) print("epoch:", epoch, "step:", i, "loss:", loss.item()) torch.save(cnn.state_dict(), model_path) print("save model") eval_acc = evaluate(model_path) if eval_acc > max_eval_acc: # best model save as best_model.pkl torch.save(cnn.state_dict(), os.path.join(output, "best_model.pkl")) print("save best model") torch.save(cnn.state_dict(), os.path.join(output, "model.pkl")) print("save last model") if __name__ == '__main__': main()