11.模型保存和加载
方法
模型保存和加载有两套方法,分别为
方法一
- torch.save(vgg16,”vgg_1.pth”) 直接保存模型结构和模型参数,内存占用比较大
- model = torch.load(“vgg_1.pth”)
方法二
- torch.save(vgg16.state_dict(),”vgg_2.pth”) 以键值对的形式只保存模型参数,节省内存
- vgg16.load_state_dict(torch.load(“vgg_2.pth”)) #加载时需要先知道模型结构,再和模型参数匹配
一、以vgg模型保存和加载为例
1 | #!/usr/bin/env python |
1 | #!/usr/bin/env python |
二、以自己定义的模型保存和加载为例
为方便管理,将自己的模型单独放置在一个文件中,其他地方需要时,只需要导入即可。
1 | #!/usr/bin/env python |
1 | #!/usr/bin/env python |
1 | #!/usr/bin/env python |
输出:
1 | D:\anaconda\envs\Gpu-Pytorch\python.exe D:/Pytorch_learn/model_load_my_model.py |
加载自己的模型时,要从文件中提前导入自己的模型类,并进行实例化。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 咋的个人博客!