重庆分公司,新征程启航

为企业提供网站建设、域名注册、服务器等服务

pytorch加载模型遇到的问题怎么解决

这篇文章主要讲解了“pytorch加载模型遇到的问题怎么解决”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“pytorch加载模型遇到的问题怎么解决”吧!

我们提供的服务有:网站设计、成都网站设计、微信公众号开发、网站优化、网站认证、顺平ssl等。为近千家企事业单位解决了网站和推广的问题。提供周到的售前咨询和贴心的售后服务,是有科学管理、有技术的顺平网站制作公司

1. 查看网络参数

pretrained_dict1 = torch.load(model_path2, map_location='cpu')['state_dict']#预训练文件后缀是.tarpretrained_dict2 = torch.load(model_path3)#预训练文件后缀是.pth#1.查看预训练网络参数for key ,value in pretrained_dict1.items():#pretrained_dict1,pretrained_dict2就是上面的东西count+=1print(key)print(count)#2.查看model的网络参数for key ,value in model.state_dict.items():print(key,value)

2. 加载模型遇到的两大问题

1. 模型的键不匹配

以下两代码,解决了键不匹配问题,一个是删除键的某一部分,一是添加键的某一部分

例:
下面的错误是因为模型的model.state_dict().items()的键是conv1.weight,预训练的键是module.conv1.weight,导致不匹配。所以下面的代码是让module. 去掉
pytorch加载模型遇到的问题怎么解决

1.删除键的头部
pretrained_dict = {
   
   
   k.replace('module.', ''): v for k, v in pretrained_dict2.items()}

当然有时候自己model的键需要改进,如下

2.补齐键的头部
checkpoint={
   
   
   'module.'+k:v for k,v in pretrained_dict.items()}

2. 预训练模型和自己的model长度不一样

# 删除pretrained_dict.items()中model所没有的东西model_dict = model.state_dict()pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}  # 只保留预训练模型中,自己建的model有的参数model_dict.update(pretrained_dict)  # 将预训练的值,更新到自己模型的dict中model.load_state_dict(model_dict)  # model加载dict中的数据,更新网络的初始值

3. 通过查看加载参数,看是否加载成功

for value1 ,value2 in zip(checkpoint.items(), model.state_dict().items()):print(value1,value2)

如下所示,model的参数和预训练的参数是一样的
pytorch加载模型遇到的问题怎么解决

4. 案例

(这里处理的只是针对本人的model加载的情况,要想正确加载,还需遵守上面3步)

    def load_param(self, model_path):#这里的self就是modelmodel_dict = self.state_dict()pretrained_dict = torch.load(model_path)#这里model_path的后缀是.pth可直接读取# pretrained_dict = {k.replace('module.', ''): v for k, v in#                    pretrained_dict.items()}  # 因为pretrained_dict得到module.conv1.weight,但是自己建的model无module,只是conv1.weight,所以改写下pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}  # 只保留预训练模型中,自己建的model有的参数model_dict.update(pretrained_dict)  # 将预训练的值,更新到自己模型的dict中self.load_state_dict(model_dict)  # model加载dict中的数据,更新网络的初始值

感谢各位的阅读,以上就是“pytorch加载模型遇到的问题怎么解决”的内容了,经过本文的学习后,相信大家对pytorch加载模型遇到的问题怎么解决这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是创新互联,小编将为大家推送更多相关知识点的文章,欢迎关注!


网站题目:pytorch加载模型遇到的问题怎么解决
新闻来源:http://cqcxhl.cn/article/jspgoi.html

其他资讯

在线咨询
服务热线
服务热线:028-86922220
TOP