广告
返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch加载模型model.load_state_dict()问题及解决
  • 231
分享到

PyTorch加载模型model.load_state_dict()问题及解决

PyTorch加载模型model.load_state_dict()PyTorch模型 2023-02-03 15:02:37 231人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录PyTorch加载模型model.load_state_dict()问题1. 对load的模型创建新的字典2. 直接用空白''代替'module.'

PyTorch加载模型model.load_state_dict()问题

希望将训练好的模型加载到新的网络上。

如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题。

Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。

表明了加载过程中,期望获得的key值为feature...,而不是module.features....。

这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。

You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

解决上面的问题有三个办法: 

1. 对load的模型创建新的字典

去掉不需要的key值"module".

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
    new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

2. 直接用空白''代替'module.'

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
 
# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

3. 最简单的方法

加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。

如果有多个GPU,将模型并行化,用DataParallel来操作。

这个过程会将key值加一个"module. ***"。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

4. 总结

从出错显示的问题就可以看出,key值不匹配,因此可以选择多种方法,将模型参数加载进去。

这个方法通常会在load_state_dict过程中遇到。将训练好的一个网络参数,移植到另外一个网络上面,继续训练。

或者将训练好的网络checkpoint加载进模型,再次进行训练。可以打印出model state_dict来看出两者的差别。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
    print(k) #只打印key值,不打印具体参数。

features.0.0.weight   
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():
    print(k) 
print("*****************************************")
 

输出结果为:

module.features.0.0.weight",

"module.features.0.1.weight",

"module.features.0.1.bias

可以看出不匹配,模型的参数中,key值不同,多了module。

PS: 追加

在移植参数的过程中,对于出现 .total_ops和.total_params结尾的参数,可参考以下代码:

from collections import OrderedDict
checkpoint = torch.load(
    pretrained_model_file_path,
    map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if not k.endswith('total_ops') and not k.endswith('total_params'):
        name = k[7:]
        new_state_dict[name] = v

最后

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: PyTorch加载模型model.load_state_dict()问题及解决

本文链接: https://www.lsjlt.com/news/194142.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • PyTorch加载模型model.load_state_dict()问题及解决
    目录PyTorch加载模型model.load_state_dict()问题1. 对load的模型创建新的字典2. 直接用空白''代替'module.'...
    99+
    2023-02-03
    PyTorch加载模型 model.load_state_dict() PyTorch模型
  • pytorch模型保存与加载问题怎么解决
    这篇“pytorch模型保存与加载问题怎么解决”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“pytorch模型保存与加载问题...
    99+
    2023-07-04
  • pytorch模型保存与加载中的一些问题实战记录
    目录前言一、torch中模型保存和加载的方式1、模型参数和模型结构保存和加载2、只保存模型的参数和加载——这种方式比较安全,但是比较稍微麻烦一点点二、torc...
    99+
    2022-11-11
  • PyTorch模型保存与加载实例详解
    目录一个简单的例子保存/加载 state_dict(推荐)保存/加载整个模型保存加载用于推理的常规Checkpoint/或继续训练保存多个模型到一个文件使用其他模型来预热当前模型跨设...
    99+
    2022-11-10
  • pytorch加载预训练模型与自己模型不匹配的解决方案
    pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。 两个有序字典找不同 模型的参数和pth文件的参数...
    99+
    2022-11-12
  • 在pytorch中复制模型时出现问题如何解决
    在pytorch中复制模型时出现问题如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。直接使用model2=model1会出现当更新model2时,model1的权重也...
    99+
    2023-06-06
  • pytorch网络模型构建场景的问题如何解决
    今天小编给大家分享一下pytorch网络模型构建场景的问题如何解决的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。网络模型构建...
    99+
    2023-07-05
  • 模型的保存加载、模型微调、GPU使用及Pytorch常见报错
    序列化与反序列化 序列化就是说内存中的某一个对象保存到硬盘当中,以二进制序列的形式存储下来,这就是一个序列化的过程。 而反序列化,就是将硬盘中存储的二进制的数,反序列化到内存当中,得到一个相应的对象,这样就可以再次使用这个模型了。 序列化和...
    99+
    2023-08-30
    pytorch 人工智能 python
  • 解决JPA @OneToMany及懒加载无效的问题
    目录JPA @OneToMany及懒加载无效@OneToMany小结一下吧实现JPA的懒加载和无外键例如转换时使用JPA @OneToMany及懒加载无效 @OneToOne @Ma...
    99+
    2022-11-12
  • pytorch模型的保存加载与续训练详解
    目录前面模型保存与加载方式1方式2方式3总结前面 最近,看到不少小伙伴问pytorch如何保存和加载模型,其实这部分pytorch官网介绍的也是很清楚的,感兴趣的点击了解详情 但是肯...
    99+
    2022-11-13
    pytorch模型保存加载训练 pytorch 模型训练
  • torch.hub.load 加载本地模型(已解决)
    背景 运行网上的项目,有时会卡住或者超时,原因是 torch.hub.load 默认会去网上找模型,有时会因为网络问题而报错 解决方法 不让 torch.hub.load 联网下载模型,改为 torc...
    99+
    2023-09-06
    python 深度学习 开发语言
  • vue中图片加载不出来的问题及解决
    目录一、项目打包完成后,打开整体空白1、路径问题原因解决办法2、vue-router的history模式打包后界面空白二、assets目录下图片加载不出来vue-cli的assets...
    99+
    2022-11-13
  • Linux下如何解决IPV6模块加载失败问题
    这篇文章主要为大家展示了“Linux下如何解决IPV6模块加载失败问题”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Linux下如何解决IPV6模块加载失败问题”这篇文章吧。同事一个SUSE L...
    99+
    2023-06-27
  • Hibernate Lazy加载问题怎么解决
    这篇文章主要讲解了“Hibernate Lazy加载问题怎么解决”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Hibernate Lazy加载问题怎么解决”吧!Hbm文件<bean i...
    99+
    2023-06-03
  • pytorch部署到jupyter中的问题及解决方案
    目录pytorch部署到jupyter中两种解决方案pytorch部署到jupyter中 在安装Aconda的同时,会将jupyter notebook一起安装,不过这里的jupy...
    99+
    2022-11-11
  • Android webView加载数据时内存溢出问题及解决
    目录Android webView加载数据时内存溢出Android内存问题 (内存溢出 内存泄漏 内存抖动)总结Android webView加载数据时内存溢出 今天使用webVie...
    99+
    2022-12-08
    Android webView webView加载数据 webView内存溢出
  • PyTorch深度学习模型的保存和加载流程详解
    一、模型参数的保存和加载  torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经...
    99+
    2022-11-12
  • 如何解决预加载InstantClick的问题
    这篇文章主要介绍如何解决预加载InstantClick的问题,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!在改造的时候代码高亮没法执行,准确的说是只执行一次,第二次就不执行了。所以发...
    99+
    2022-10-19
  • PyTorch 编写代码遇到的问题及解决方案
    PyTorch编写代码遇到的问题 错误提示:no module named xxx xxx为自定义文件夹的名字 因为搜索不到,所以将当前路径加入到包的搜索目录 解决方法: i...
    99+
    2022-11-12
  • laravel容器延迟加载及auth扩展问题怎么解决
    今天小编给大家分享一下laravel容器延迟加载及auth扩展问题怎么解决的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。发现...
    99+
    2023-07-04
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作