欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 家装 > 解决 TypeError: Expected state_dict to be dict-like, got <class ‘*‘>.

解决 TypeError: Expected state_dict to be dict-like, got <class ‘*‘>.

2025/5/25 13:38:55 来源:https://blog.csdn.net/weixin_42426841/article/details/142638624  浏览:    关键词:解决 TypeError: Expected state_dict to be dict-like, got <class ‘*‘>.

这是一个简洁的错误复现和解决文章

文章目录

    • 错误原因
    • 错误重现
    • 正确加载演示
    • 拓展阅读

错误原因

一般是因为混合使用不同的保存和加载方式,问题出在你用 load_state_dict() 去加载别人使用torch.save(model) 保存的整个模型。

错误重现

下面我们来复现它,看是不是和你的操作一致:

  1. 错误地保存整个 model 而不是其 state_dict
    import torch
    import torch.nn as nn# 定义一个线性模型进行演示
    class LinearModel(nn.Module):def __init__(self, input_size, output_size):super(LinearModel, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):return self.linear(x)# 创建模型实例
    model = LinearModel(input_size=10, output_size=1)# 打印模型结构
    print("Model:", model)# 保存模型的 state_dict
    torch.save(model.state_dict(), './linear_model_state_dict.pth')
    
  2. 加载时传入 model 对象:
    # 创建一个新的模型实例
    new_model = LinearModel(input_size=10, output_size=1)# 加载 state_dict 到新模型
    new_model.load_state_dict(torch.load('./linear_model_state_dict.pth'))# 打印加载后的新模型结构
    print("Model loaded with state_dict:", new_model)
    
    输出
    Error: Expected state_dict to be dict-like, got <class '__main__.LinearModel'>.
    

正确加载演示

下面是两种保存和加载的方法,任选其一即可。

import torch
import torch.nn as nn# 定义一个线性模型
class LinearModel(nn.Module):def __init__(self, input_size, output_size):super(LinearModel, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):return self.linear(x)# 创建模型实例
model = LinearModel(input_size=10, output_size=1)
print("Model:", model)# 方法 1:保存和加载 state_dict
# 保存模型的 state_dict
torch.save(model.state_dict(), './linear_model_state_dict.pth')# 创建一个新的模型实例
new_model = LinearModel(input_size=10, output_size=1)# 加载 state_dict 到新模型
new_model.load_state_dict(torch.load('./linear_model_state_dict.pth'))# 方法 2:保存和加载整个模型
# 保存整个模型
torch.save(model, './linear_model.pth')# 加载整个模型
loaded_model = torch.load('./linear_model.pth')

拓展阅读

PyTorch 模型保存与加载的三种常用方式

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词