方法

模型保存和加载有两套方法,分别为
方法一

  1. torch.save(vgg16,”vgg_1.pth”) 直接保存模型结构和模型参数,内存占用比较大
  2. model = torch.load(“vgg_1.pth”)

方法二

  1. torch.save(vgg16.state_dict(),”vgg_2.pth”) 以键值对的形式只保存模型参数,节省内存
  2. vgg16.load_state_dict(torch.load(“vgg_2.pth”)) #加载时需要先知道模型结构,再和模型参数匹配

一、以vgg模型保存和加载为例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :model_save.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/13 13:19
"""
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 模型保存方式1 保存模型结构+模型参数
torch.save(vgg16,"vgg_1.pth")

# 模型保存方式2 以字典的形式保存模型参数
torch.save(vgg16.state_dict(),"vgg_2.pth")

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :model_load.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/13 13:30
"""
import torch
import torchvision


'''
第一种方式保存模型时,加载模型
model = torch.load("vgg_1.pth")
print(model)
'''

# 第二种方式保存模型时,加载模型
# 首先要获取模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
# 讲模型结构和读取的键值对型的模型参数匹配在一起
vgg16.load_state_dict(torch.load("vgg_2.pth"))
print(vgg16)

二、以自己定义的模型保存和加载为例

为方便管理,将自己的模型单独放置在一个文件中,其他地方需要时,只需要导入即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :Mymodel.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/13 13:24
"""

import torch
from torch.utils.data import DataLoader,Dataset
from torch.nn import Module,Conv2d,MaxPool2d,Sequential,Flatten,Linear
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision


class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.model = Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10),
)

def forward(self,x):
x = self.model(x)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :model_save_my_model.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/13 13:23
"""
from Mymodel import MyModule
import torch


model = MyModule()
# 第一种方式保存
torch.save(model,"model1.pth")
# 第二种方式保存
torch.save(model.state_dict(),"model2.pth")
print("保存完毕")

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn
@File :model_load_my_model.py
@IDE :PyCharm
@Author :咋
@Date :2023/7/13 13:50
"""
import torch
import torchvision
from Mymodel import *
'''
第一种方式加载模型
model = torch.load("model1.pth")
print(model)
'''

# 第二种方式加载模型
model = MyModule()
model.load_state_dict(torch.load("model2.pth"))
print(model)

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
D:\anaconda\envs\Gpu-Pytorch\python.exe D:/Pytorch_learn/model_load_my_model.py
D:\anaconda\envs\Gpu-Pytorch\lib\site-packages\torchvision\io\image.py:11: UserWarning: Failed to load image Python extension: Could not find module 'D:\anaconda\envs\Gpu-Pytorch\Lib\site-packages\torchvision\image.pyd' (or one of its dependencies). Try using the full path with constructor syntax.
warn(f"Failed to load image Python extension: {e}")
MyModule(
(model): Sequential(
(0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=1024, out_features=64, bias=True)
(8): Linear(in_features=64, out_features=10, bias=True)
)
)

进程已结束,退出代码0

加载自己的模型时,要从文件中提前导入自己的模型类,并进行实例化。