位置: IT常识 - 正文

pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层)

编辑:rootadmin
pytorch对网络层的增,删, 改, 修改预训练模型结构 #下载模型参数model.load_state_dict(torch.load('model.pth'))#再加载网络的参数torch.load('model.pth')是获得网络参数

推荐整理分享pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch怎么训练网络,pytorch 网络,pytorch网络搭建,pytorch定义网络,pytorch cnn网络,pytorch bp网络,pytorch输出网络结构,pytorch cnn网络,内容如对您有帮助,希望把文章链接给更多的朋友!

1.我们使用vgg11网络做示例, 看一下网络结构:

加载本地的模型:

vgg16 = models.vgg16(pretrained=False)#打印出预训练模型的参数vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))

加载库中的模型

import torchimport torch.nn as nnfrom torchvision import modelsnet = models.vgg11(pretrained=True)print(net)

1)(1). 在网络中添加一层:

net网络是一个树型结构, net下面有三个结点,分别是(features, avgpoll, classifier), 我们先在features结点添加一层’lastlayer'层

net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))print(net) 1)(2). 在classifier结点添加一个线性层:net.classifier.add_module('Linear', nn.Linear(1000, 10))print(net)

2)(1)修改网络中的某一层(features 结点举例):net.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))print(net)

 2)(2)修改网络中的某一层(classifier结点举例):net.classifier[6] = nn.Linear(1000, 5)print(net)pytorch对网络层的增,删, 改, 修改预训练模型结构(pytorch自定义网络层)

注意: 这里我尝试对Linear这一层进行更新, 但是Linear名字是字符串, 提取不出来,所以应该在之前添加网络时候, 名字不要取字符串, 否则会报错  ‘  'str' object cannot be interpreted as an integer’。 

 3)(1) 网络层的删除(features举例) classifier结点的操作相同。

直接使用nn.Sequential()对改层设置为空即可

net.features[13] = nn.Sequential()print(net)

 4)冻结网络中某些层 (直接使该层的requires_grad = False)即可, 这样在反向传播的时候,不会更新该层的参数#冻结指定层的预训练参数:net.feature[26].weight.requires_grad = False5). 第二种对网络结构的操作方法:net.features = nn.Sequential(*list(net.features.children())[:-4])

可以看到后面4层被去除了, 就是说可以使用列表切片的方法来删除网络层

net.classifier 对应 net.classifier.children()

net.features 对应 net.features.children()

  1. 先加载网络结构

自己的模型, model的类要有定义才可以, 如果在其他.py文件中,可以导入文件,然后用文件中的类实例化对象。model = torch.load(PATH)

 2.再加载网络参数

#下载模型参数

model.load_state_dict(torch.load('model.pth'))#再加载网络的参数torch.load('model.pth')是获得网络参数
本文链接地址:https://www.jiuchutong.com/zhishi/299377.html 转载请保留说明!

上一篇:vue 项目兼容 IE 浏览器(vue项目兼容ie9以上浏览器)

下一篇:【CSS】CSS 背景设置 ⑨ ( 背景半透明设置 )(css背景图)

  • 安装服务费税率是多少
  • 餐饮行业的成本率在多少才正常
  • 如何做好存货管理,从哪方面去做
  • 小规模纳税人个税是月报还是季报
  • 购固定资产需要交什么税
  • 固定资产怎么确定是否减值
  • 什么是进口增值税率
  • 失业保险费返还后是给单位还是给员工
  • 怎么查自己是否
  • 企业所得税年度申报
  • 代办汽车过户手续
  • 辅导期一般纳税人什么意思
  • 转让土地使用权会计分录怎么做
  • 销售退回所得税差异怎么处理
  • 公司预支了然后来报销的帐怎么做?
  • 招标代理专家费由谁支付
  • 销货成本销货成本是什么类账户
  • 固定资产清理会计处理例题
  • 差额税和增值税怎么算
  • 解除劳动关系补偿标准
  • 工会经费的计税依据包括单位社保吗
  • windows10如何更改时间
  • 公司投资理财产品
  • 欠款利息收入如何入账
  • 暂估入库的处理方式有哪三种
  • 增值税进项发票不够抵扣怎么办
  • 上市公司股票如何套现
  • 单位办事人员
  • 房屋租赁进项税
  • 华硕win10笔记本如何恢复出厂设置
  • cortana小娜可以卸载吗
  • 购买性支出和转移性支出的本质区别
  • sdhc 速度
  • 最好卖的游戏机排行榜
  • php echo语句
  • 外币存款业务
  • elipse左侧菜单栏显示
  • php 二进制转十六进制
  • javascript对象有哪些
  • php中input的用法
  • 季报弥补亏损,财报怎么填
  • 补交之前年度税款怎么调账
  • 没有发票以及收款怎么办
  • 资产减值损失属于什么科目借贷方向
  • 在什么情况下要切除子宫
  • 企业所得税季报资产总额季初季末
  • 房产税是按不含增值税计提吗
  • SqlServer 2005 T-SQL Query 学习笔记(2)
  • 服务费的增值税可以抵扣吗
  • 幼儿园财务科目明细表
  • 其他收益会计科目核算什么
  • 已认证发票退回说明模板
  • 供应商奖惩制度具体办法
  • 美国支票上的收款人地址不对怎么办
  • sql语句的基本语法
  • windows8crazy error
  • Fedora Core 5.0 安装教程,菜鸟图文教程(linux text)
  • mac画图的app叫什么
  • virtualbox虚拟机菜单找不到了
  • win8怎么打开管理员命令提示符
  • win8网络连接受限怎么处理
  • linux内核命名
  • win10系统应用和功能中不能卸载
  • win10怎么将桌面图标变小
  • linux列操作
  • 滤镜调试
  • Python 正则表达式实现计算器功能
  • 数据结构分析时间复杂度
  • jquery的动画效果
  • jquery左右选择框
  • jquery td
  • Eclipse ctrl+shift+r
  • 请问在javascript程序中
  • js键盘事件有哪些?各自的作用如何
  • 收藏一些不常用的图片
  • 2020年保安证取消了吗
  • 工行网银如何申请发票
  • 如何办理委托银行卡业务
  • 深圳税务局实名注册
  • 税务变更表
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设