孤注一掷——基于文心Ernie-3.0大模型的影评情感分析
模型训练可以分为以下几步
- 数据集准备
- 数据集加载
- 模型定义及实例化
- 损失函数定义
- 参数更新方式定义
- 模型必要参数设置
- 模型训练逻辑及一些提示信息、可视化编写
- 开始训练模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :train.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/13 15:38
"""
import torch
import torchvision
from torch.utils import data
from Mymodel import MyModule
from tensorboardX import SummaryWriter
# 0.设置参数
batch_size = 64
lr = 0.01
epochs = 20
savetime = 5
# 1.准备数据集
train_data = torchvision.datasets.CIFAR10("CIFAR10",train=True,transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10("CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
# 2.加载数据集
train_loader = data.DataLoader(dataset=train_data,batch_size=batch_size)
test_loader = data.DataLoader(dataset=test_data,batch_size=batch_size)
train_len = len(train_loader)
test_len = len(test_loader)
print("训练集长度为:{}".format(train_len))
print("测试集长度为:{}".format(test_len))
# 3.定义网络结构,实例化模型
model = MyModule()
# 4.定义损失函数
loss_F = torch.nn.CrossEntropyLoss()
# 5.定义参数更新方式
optim = torch.optim.SGD(model.parameters(),lr=lr)
# 6.开始训练和评估
model.train()
write = SummaryWriter("log_6")
for epoch in range(epochs):
print("-------开始第{}轮训练,总共{}轮-------".format(epoch,epochs))
train_loss = 0
for train_time,train_item in enumerate(train_loader):
train_img,train_label = train_item
# 前向传播
result = model(train_img)
# 计算损失
loss = loss_F(result,train_label)
# 反向传播更新参数
optim.zero_grad()
loss.backward()
optim.step()
train_loss += loss
model.eval()
test_loss = 0
right =0
with torch.no_grad():
for test_time,test_item in enumerate(test_loader):
test_img,test_label = test_item
test_result = model(test_img)
# 获取测试集上的loss
loss = loss_F(test_result,test_label)
test_loss += loss
# 获取测试集上的准确率
right += (test_result.argmax(1) == test_label).sum()
accuracy = right/test_len
write.add_scalar("train_loss",train_loss,epoch)
write.add_scalar("test_loss",test_loss,epoch)
write.add_scalar("accuracy",accuracy,epoch)
print("训练集上的损失为:{}".format(train_loss))
print("测试集上的损失为:{}".format(test_loss))
print("测试集上的准确率为:{}".format(accuracy))
if epoch % savetime ==0:
torch.save(model,"./model/model{}.pth".format(epoch))
print("训练完成!")tensorboard可视化:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91D:\anaconda\envs\Gpu-Pytorch\python.exe D:/Pytorch_learn/train.py
D:\anaconda\envs\Gpu-Pytorch\lib\site-packages\torchvision\io\image.py:11: UserWarning: Failed to load image Python extension: Could not find module 'D:\anaconda\envs\Gpu-Pytorch\Lib\site-packages\torchvision\image.pyd' (or one of its dependencies). Try using the full path with constructor syntax.
warn(f"Failed to load image Python extension: {e}")
Files already downloaded and verified
Files already downloaded and verified
训练集长度为:782
测试集长度为:157
-------开始第0轮训练,总共20轮-------
训练集上的损失为:1692.0927734375
测试集上的损失为:318.7899475097656
测试集上的准确率为:17.171974182128906
-------开始第1轮训练,总共20轮-------
训练集上的损失为:1462.9697265625
测试集上的损失为:302.61761474609375
测试集上的准确率为:19.770700454711914
-------开始第2轮训练,总共20轮-------
训练集上的损失为:1305.275390625
测试集上的损失为:266.6638488769531
测试集上的准确率为:23.980892181396484
-------开始第3轮训练,总共20轮-------
训练集上的损失为:1211.9080810546875
测试集上的损失为:264.86248779296875
测试集上的准确率为:24.78343963623047
-------开始第4轮训练,总共20轮-------
训练集上的损失为:1148.6689453125
测试集上的损失为:258.8547058105469
测试集上的准确率为:25.93630599975586
-------开始第5轮训练,总共20轮-------
训练集上的损失为:1094.395751953125
测试集上的损失为:245.91709899902344
测试集上的准确率为:27.719745635986328
-------开始第6轮训练,总共20轮-------
训练集上的损失为:1043.458984375
测试集上的损失为:231.48187255859375
测试集上的准确率为:29.585987091064453
-------开始第7轮训练,总共20轮-------
训练集上的损失为:993.9429321289062
测试集上的损失为:216.33401489257812
测试集上的准确率为:32.070064544677734
-------开始第8轮训练,总共20轮-------
训练集上的损失为:947.3882446289062
测试集上的损失为:203.79718017578125
测试集上的准确率为:34.248409271240234
-------开始第9轮训练,总共20轮-------
训练集上的损失为:905.7017822265625
测试集上的损失为:194.99539184570312
测试集上的准确率为:35.828025817871094
-------开始第10轮训练,总共20轮-------
训练集上的损失为:868.5162963867188
测试集上的损失为:187.56068420410156
测试集上的准确率为:37.08917236328125
-------开始第11轮训练,总共20轮-------
训练集上的损失为:835.2941284179688
测试集上的损失为:181.5903778076172
测试集上的准确率为:37.840763092041016
-------开始第12轮训练,总共20轮-------
训练集上的损失为:805.2085571289062
测试集上的损失为:177.99606323242188
测试集上的准确率为:38.445858001708984
-------开始第13轮训练,总共20轮-------
训练集上的损失为:777.8590698242188
测试集上的损失为:175.87742614746094
测试集上的准确率为:38.904457092285156
-------开始第14轮训练,总共20轮-------
训练集上的损失为:752.6644287109375
测试集上的损失为:174.36672973632812
测试集上的准确率为:39.25477600097656
-------开始第15轮训练,总共20轮-------
训练集上的损失为:729.43798828125
测试集上的损失为:173.59202575683594
测试集上的准确率为:39.48407745361328
-------开始第16轮训练,总共20轮-------
训练集上的损失为:707.3736572265625
测试集上的损失为:173.00869750976562
测试集上的准确率为:39.56687927246094
-------开始第17轮训练,总共20轮-------
训练集上的损失为:686.6292114257812
测试集上的损失为:172.45094299316406
测试集上的准确率为:39.8216552734375
-------开始第18轮训练,总共20轮-------
训练集上的损失为:667.0374755859375
测试集上的损失为:172.76763916015625
测试集上的准确率为:39.77070236206055
-------开始第19轮训练,总共20轮-------
训练集上的损失为:648.382568359375
测试集上的损失为:172.2664794921875
测试集上的准确率为:40.08917236328125
训练完成!
进程已结束,退出代码0
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 咋的个人博客!