欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 明星 > Pytorch学习--神经网络--网络模型的保存与读取

Pytorch学习--神经网络--网络模型的保存与读取

2025/5/6 14:58:33 来源:https://blog.csdn.net/weixin_68930974/article/details/143608048  浏览:    关键词:Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解

在这里插入图片描述
在这里插入图片描述

保存模型

import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

import torch
model = torch.load("save_method1.pth")
print(model)

输出:在这里插入图片描述

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear#保存模型和参数class Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearmodel = torch.load("save_method1_question.pth")print(model)

报错如下

在这里插入图片描述
说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearclass Mary(nn.Module):def __init__(self):super(Mary,self).__init__()self.model1 = nn.Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x = self.model1(x)return xmodel = torch.load("save_method1_question.pth")print(model)
或者
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子model = torch.load("save_method1_question.pth")print(model)

二、网络模型的保存与读取方式2

保存模型参数

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearvgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linearvgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词