模型训练可以分为以下几步

  1. 数据集准备
  2. 数据集加载
  3. 模型定义及实例化
  4. 损失函数定义
  5. 参数更新方式定义
  6. 模型必要参数设置
  7. 模型训练逻辑及一些提示信息、可视化编写
  8. 开始训练模型
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    #!/usr/bin/env python
    # -*- coding: UTF-8 -*-
    """
    @Project :Pytorch_learn
    @File :train.py
    @IDE :PyCharm
    @Author :咋
    @Date :2023/7/13 15:38
    """
    import torch
    import torchvision
    from torch.utils import data
    from Mymodel import MyModule
    from tensorboardX import SummaryWriter
    # 0.设置参数
    batch_size = 64
    lr = 0.01
    epochs = 20
    savetime = 5
    # 1.准备数据集
    train_data = torchvision.datasets.CIFAR10("CIFAR10",train=True,transform=torchvision.transforms.ToTensor(),
    download=True)
    test_data = torchvision.datasets.CIFAR10("CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),
    download=True)

    # 2.加载数据集
    train_loader = data.DataLoader(dataset=train_data,batch_size=batch_size)
    test_loader = data.DataLoader(dataset=test_data,batch_size=batch_size)
    train_len = len(train_loader)
    test_len = len(test_loader)
    print("训练集长度为:{}".format(train_len))
    print("测试集长度为:{}".format(test_len))
    # 3.定义网络结构,实例化模型
    model = MyModule()

    # 4.定义损失函数
    loss_F = torch.nn.CrossEntropyLoss()

    # 5.定义参数更新方式
    optim = torch.optim.SGD(model.parameters(),lr=lr)

    # 6.开始训练和评估
    model.train()
    write = SummaryWriter("log_6")
    for epoch in range(epochs):
    print("-------开始第{}轮训练,总共{}轮-------".format(epoch,epochs))
    train_loss = 0
    for train_time,train_item in enumerate(train_loader):
    train_img,train_label = train_item
    # 前向传播
    result = model(train_img)
    # 计算损失
    loss = loss_F(result,train_label)
    # 反向传播更新参数
    optim.zero_grad()
    loss.backward()
    optim.step()
    train_loss += loss
    model.eval()
    test_loss = 0
    right =0
    with torch.no_grad():
    for test_time,test_item in enumerate(test_loader):
    test_img,test_label = test_item
    test_result = model(test_img)
    # 获取测试集上的loss
    loss = loss_F(test_result,test_label)
    test_loss += loss
    # 获取测试集上的准确率
    right += (test_result.argmax(1) == test_label).sum()
    accuracy = right/test_len
    write.add_scalar("train_loss",train_loss,epoch)
    write.add_scalar("test_loss",test_loss,epoch)
    write.add_scalar("accuracy",accuracy,epoch)
    print("训练集上的损失为:{}".format(train_loss))
    print("测试集上的损失为:{}".format(test_loss))
    print("测试集上的准确率为:{}".format(accuracy))
    if epoch % savetime ==0:
    torch.save(model,"./model/model{}.pth".format(epoch))

    print("训练完成!")
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    D:\anaconda\envs\Gpu-Pytorch\python.exe D:/Pytorch_learn/train.py
    D:\anaconda\envs\Gpu-Pytorch\lib\site-packages\torchvision\io\image.py:11: UserWarning: Failed to load image Python extension: Could not find module 'D:\anaconda\envs\Gpu-Pytorch\Lib\site-packages\torchvision\image.pyd' (or one of its dependencies). Try using the full path with constructor syntax.
    warn(f"Failed to load image Python extension: {e}")
    Files already downloaded and verified
    Files already downloaded and verified
    训练集长度为:782
    测试集长度为:157
    -------开始第0轮训练,总共20轮-------
    训练集上的损失为:1692.0927734375
    测试集上的损失为:318.7899475097656
    测试集上的准确率为:17.171974182128906
    -------开始第1轮训练,总共20轮-------
    训练集上的损失为:1462.9697265625
    测试集上的损失为:302.61761474609375
    测试集上的准确率为:19.770700454711914
    -------开始第2轮训练,总共20轮-------
    训练集上的损失为:1305.275390625
    测试集上的损失为:266.6638488769531
    测试集上的准确率为:23.980892181396484
    -------开始第3轮训练,总共20轮-------
    训练集上的损失为:1211.9080810546875
    测试集上的损失为:264.86248779296875
    测试集上的准确率为:24.78343963623047
    -------开始第4轮训练,总共20轮-------
    训练集上的损失为:1148.6689453125
    测试集上的损失为:258.8547058105469
    测试集上的准确率为:25.93630599975586
    -------开始第5轮训练,总共20轮-------
    训练集上的损失为:1094.395751953125
    测试集上的损失为:245.91709899902344
    测试集上的准确率为:27.719745635986328
    -------开始第6轮训练,总共20轮-------
    训练集上的损失为:1043.458984375
    测试集上的损失为:231.48187255859375
    测试集上的准确率为:29.585987091064453
    -------开始第7轮训练,总共20轮-------
    训练集上的损失为:993.9429321289062
    测试集上的损失为:216.33401489257812
    测试集上的准确率为:32.070064544677734
    -------开始第8轮训练,总共20轮-------
    训练集上的损失为:947.3882446289062
    测试集上的损失为:203.79718017578125
    测试集上的准确率为:34.248409271240234
    -------开始第9轮训练,总共20轮-------
    训练集上的损失为:905.7017822265625
    测试集上的损失为:194.99539184570312
    测试集上的准确率为:35.828025817871094
    -------开始第10轮训练,总共20轮-------
    训练集上的损失为:868.5162963867188
    测试集上的损失为:187.56068420410156
    测试集上的准确率为:37.08917236328125
    -------开始第11轮训练,总共20轮-------
    训练集上的损失为:835.2941284179688
    测试集上的损失为:181.5903778076172
    测试集上的准确率为:37.840763092041016
    -------开始第12轮训练,总共20轮-------
    训练集上的损失为:805.2085571289062
    测试集上的损失为:177.99606323242188
    测试集上的准确率为:38.445858001708984
    -------开始第13轮训练,总共20轮-------
    训练集上的损失为:777.8590698242188
    测试集上的损失为:175.87742614746094
    测试集上的准确率为:38.904457092285156
    -------开始第14轮训练,总共20轮-------
    训练集上的损失为:752.6644287109375
    测试集上的损失为:174.36672973632812
    测试集上的准确率为:39.25477600097656
    -------开始第15轮训练,总共20轮-------
    训练集上的损失为:729.43798828125
    测试集上的损失为:173.59202575683594
    测试集上的准确率为:39.48407745361328
    -------开始第16轮训练,总共20轮-------
    训练集上的损失为:707.3736572265625
    测试集上的损失为:173.00869750976562
    测试集上的准确率为:39.56687927246094
    -------开始第17轮训练,总共20轮-------
    训练集上的损失为:686.6292114257812
    测试集上的损失为:172.45094299316406
    测试集上的准确率为:39.8216552734375
    -------开始第18轮训练,总共20轮-------
    训练集上的损失为:667.0374755859375
    测试集上的损失为:172.76763916015625
    测试集上的准确率为:39.77070236206055
    -------开始第19轮训练,总共20轮-------
    训练集上的损失为:648.382568359375
    测试集上的损失为:172.2664794921875
    测试集上的准确率为:40.08917236328125
    训练完成!

    进程已结束,退出代码0

    tensorboard可视化:
    image.pngimage.png
    image.png