位置: IT常识 - 正文

Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)

编辑:rootadmin
原力计划Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例) 目录1 什么是nn.Module?2 从一个例子说起3 nn.Module主要方法4 自定义网络一般步骤1 什么是nn.Module?

推荐整理分享Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

在实际应用过程中,经典网络结构(如卷积神经网络)往往不能满足我们的需求,因而大多数时候都需要自定义模型,比如:多输入多输出(MIMO)、多分支模型、跨层连接模型等。nn.Module就是Pytorch中用于自定义模型的核心方法。在Pytorch中,自定义层、自定义块、自定义模型,都是通过继承nn.Module类完成的。

nn.Module的定义如下

class Module(object): def __init__(self): def forward(self, *input): def __call__(self, *input, **kwargs): def parameters(self, recurse=True): def named_parameters(self, prefix='', recurse=True): def children(self): def named_children(self): def modules(self): def named_modules(self, memo=None, prefix=''): def train(self, mode=True): def eval(self): def zero_grad(self):...

注意:自定义网络需要继承nn.Module类,并重点实现上面的构造函数__init__构造函数和forward()这两个方法。

2 从一个例子说起

下面是一个自定义感知机的实例

# 感知机class Perception(nn.Module): def __init__(self, inDim, hidDim, outDim): super(Perception, self).__init__() self.perception = nn.Sequential( nn.Linear(inDim, hidDim), nn.Sigmoid(), nn.Linear(hidDim, outDim), nn.Sigmoid() ) def forward(self, x): return self.perception(x)

测试模块

perception = Perception(5,20,10)print(perception(torch.Tensor([1,2,3,4,5]))) # 自动调用forward()前向传播

其中nn.Sequential()可以序列化封装若干个相连的组件,在希望快速搭建模型且无需考虑中间过程的情形下,推荐使用nn.Sequential()进行局部模块化。

Pytorch深度学习实战3-6:详解网络骨架模块nn.Module(附实例)

从上面的实例可以看出:

一般把网络中的特定结构(如全连接层、卷积层等)以序列的形式放在构造函数__init__()中将模型自定义的各个层的连接关系和数据通路设计放在forward()函数中,以实现模型功能并保证数据结构正常不具有可学习参数的层(如ReLU、dropout、BatchNormanation层等)可并入__init__()内部的某个层,或在forward()函数中进行层间连接

库nn.functional同样提供了大量网络模块和组件,与nn.Module类不同在于其更偏向底层——nn.Module封装了对学习参数的维护,更注重模型结构;nn.functional需要手动指定参数和结构,例如下面线性模型Linear的核心源码,其前向过程仍然调用了底层的nn.functional实现。

class Linear(Module): def __init__(self, in_features: int, out_features: int) -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.bias = Parameter(torch.Tensor(out_features)) def forward(self, input: Tensor) -> Tensor: return F.linear(input, self.weight, self.bias)

一般在设计通过已有nn.Module无法组装的网络结构时,可以调用底层的nn.functional实现;或是存在无需优化学习参数的结构(如损失函数、激活函数等),可以调用nn.functional(即作为单纯函数使用)避免实例化nn.Module,轻量化网络

# 使用nn.Module需要实例化后调用lossFunc = nn.CrossEntropyLoss()loss = lossFunc(output, label)# 使用nn.functional则只作为函数即可loss = F.cross_entropy(output, label)3 nn.Module主要方法

nn.Module的主要属性与方法列举如表所示。

序号属性/方法含义1forward()模型前向传播2train()训练模式3eval()评估模式4named_parameters()返回模型各可学习参数的名称和参数组成的列表5parameters()返回模型各可学习参数组成的列表6children()返回一个迭代器,其中每个元素是Sequential序列类型,可以使用下标索引来进一步获取每一个Sequenrial里面的具体层,比如conv层、dense层等7named_children()返回一个迭代器,其中每个元素是一个二元组,第一元是名称,第二元是该名称对应的层或Sequential序列4 自定义网络一般步骤

自定义网络一般步骤总结如下:

自定义一个继承自Module的类实现构造函数_init__,在其中参数化网络层,比如卷积神经网络的卷积核大小、池化层尺寸,全连接网络的输入输出大小等;实现前向传播forward()接口,定义网络的连接情况或其他运算方式(如向量拼接、向量变维、数据处理等)

下面再给出一个卷积神经网络的实例加深理解

class CNN(nn.Module): def __init__(self): super().__init__() self.convPoolLayer_1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU() ) self.convPoolLayer_2 = nn.Sequential( nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU() ) self.fcLayer = nn.Linear(320, 10) def __str__(self) -> str: return "cnn_model" def forward(self, x): batchSize = x.size(0) x = self.convPoolLayer_1(x) x = self.convPoolLayer_2(x) x = x.reshape(batchSize, -1) x = self.fcLayer(x) return x


🔥 更多精彩专栏:

《ROS从入门到精通》《Pytorch深度学习实战》《机器学习强基计划》《运动规划实战精讲》…

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文链接地址:https://www.jiuchutong.com/zhishi/297615.html 转载请保留说明!

上一篇:【Vant Weapp】van-tab 标签页(vant weapp官方文档)

下一篇:Django--基于Python的Web应用框架(django pycharm)

  • 华为荣耀9xpro与荣耀20s对比(华为荣耀9xpro与9x区别)

    华为荣耀9xpro与荣耀20s对比(华为荣耀9xpro与9x区别)

  • 荣耀9x支持多少万快充(荣耀9x支持多少hz刷新率)

    荣耀9x支持多少万快充(荣耀9x支持多少hz刷新率)

  • 手机电话显示未在上网络注册(手机电话显示未知什么意思)

    手机电话显示未在上网络注册(手机电话显示未知什么意思)

  • qq消息免打扰自动取消(qq消息免打扰)

    qq消息免打扰自动取消(qq消息免打扰)

  • 苹果11qq音乐怎么设置桌面歌词(苹果11qq音乐怎么不被中断)

    苹果11qq音乐怎么设置桌面歌词(苹果11qq音乐怎么不被中断)

  • 小米手环4支付宝扫不上(小米手环4支付宝)

    小米手环4支付宝扫不上(小米手环4支付宝)

  • ipodclassic怎么区分几代(辨别ipod机型)

    ipodclassic怎么区分几代(辨别ipod机型)

  • 小米10有没有光学防抖(小米10有没有光感功能)

    小米10有没有光学防抖(小米10有没有光感功能)

  • word正文一般用几号(文档中一般正文用什么字体几号呢)

    word正文一般用几号(文档中一般正文用什么字体几号呢)

  • 怎么把手机内照片传到QQ里(怎么把手机内照片导出)

    怎么把手机内照片传到QQ里(怎么把手机内照片导出)

  • whql支持要不要开

    whql支持要不要开

  • 移动手机号pin码怎么查(移动手机号pin码忘记怎么办)

    移动手机号pin码怎么查(移动手机号pin码忘记怎么办)

  • oppo哪款支持5g网络(oppo支持5g网络的手机有哪些?售价?)

    oppo哪款支持5g网络(oppo支持5g网络的手机有哪些?售价?)

  • htm是什么文件扩展名(htm扩展名)

    htm是什么文件扩展名(htm扩展名)

  • 华为Nova5Pro屏幕产商(华为nova5pro屏幕刷新率是多少)

    华为Nova5Pro屏幕产商(华为nova5pro屏幕刷新率是多少)

  • 手机桌面图标如何放大(手机桌面图标如何设置)

    手机桌面图标如何放大(手机桌面图标如何设置)

  • 小红书被限流需要多久恢复(小红书被限流会怎样)

    小红书被限流需要多久恢复(小红书被限流会怎样)

  • 数据的存储结构包括哪四种(数据的存储结构分为哪两类)

    数据的存储结构包括哪四种(数据的存储结构分为哪两类)

  • 苹果手机拷贝的内容在哪里(苹果手机拷贝的内容怎么粘贴)

    苹果手机拷贝的内容在哪里(苹果手机拷贝的内容怎么粘贴)

  • 收钱码姓名能隐藏吗(微信收钱码如何隐藏姓名)

    收钱码姓名能隐藏吗(微信收钱码如何隐藏姓名)

  • 网易云有听歌识曲吗(网易云听歌识曲记录在哪里)

    网易云有听歌识曲吗(网易云听歌识曲记录在哪里)

  • iphone7怎么打开多任务管理界面(iphone7怎么打开高清通话)

    iphone7怎么打开多任务管理界面(iphone7怎么打开高清通话)

  •  oppo怎么拉黑电话(oppo怎么拉黑电话号码)

    oppo怎么拉黑电话(oppo怎么拉黑电话号码)

  • videoleap怎么调秒数(videoleap怎么调清晰)

    videoleap怎么调秒数(videoleap怎么调清晰)

  • 办公楼出租价格怎么算
  • 减免的企业所得税需要计入应交税费吗
  • 增值税电子普通发票怎么作废
  • 第四季度报表和年度报表一样吗
  • 个人为什么不能寄活鱼
  • 酒销售账务处理
  • 工资3700扣多少社保钱
  • 金税盘开完票后怎么报税一下
  • 企业所得税税率多少
  • 免征增值税企业进项税怎么处理
  • 企业缴纳的社保
  • 合作社 注销
  • 子公司之间可以相互交易吗
  • 医院减免医药费后还可以报保险吗
  • 发票已认证未抵扣怎么办
  • 费用一定计入当期损益吗
  • 营业收入包括其收入吗
  • 药店可以开具专票吗
  • 所有的罚款都不能税前扣除吗
  • 工会开发票有税号吗?
  • 物流公司的保险服务属于什么费用
  • 车辆购置税通过应交税费吗
  • 远程清卡失败怎么办
  • 差旅费中的车票可以抵扣进项税吗
  • 成品油生产企业身份归类管理办法
  • 本期预收的货款属于
  • 现金流量表的编制依据
  • 白银及其制品出自哪里
  • 公司注销登记提交材料规范
  • iphone7如何设置输入法
  • 汇率调整怎么做分录
  • thinkphp框架介绍
  • PHP开发之归档格式
  • 存货发生了减值怎么处理
  • autorun.exe
  • 享受小型微利企业税收优惠的条件
  • batch size 大小
  • php如何删除数组元素
  • 冰川湾国家公园的冰川不止有白色一种
  • 个人所得税通过扣缴义务人申报
  • 公司的内账
  • docker_practice
  • 应交税费是借增还是贷增?
  • linux中搭建web服务器
  • 专项附加扣除中住房租金扣除所指的工作城市范围包括
  • 质保金挂账是否需要发票
  • 增值税无票收入负数预警值
  • 租金收入怎么做分录
  • 委托加工物资的消费税计入成本吗
  • 设备维修费可以抵扣进项税吗
  • 预付款在会计里属于什么
  • 过路费计入差旅费还是车辆
  • 上期留抵税额怎么在账上提现
  • 商业承兑汇票如何开具
  • 抵扣联和发票联的区别
  • mysql清空表内容
  • 图形工具的作用
  • solaris安装教程
  • xp系统的本地连接
  • Ubuntu Eclipse MyEclipse 添加GBK支持 不乱码
  • ubuntu命令行打开火狐浏览器
  • windows用户注册
  • w8系统怎么用
  • software protection延迟启动
  • win8.1中文版下载
  • linux电子邮件
  • win7插u盘电脑没反应怎么回事
  • css优化提高性能的方法有哪些
  • javascript entries
  • 完美解决怠速抖动加油就平稳
  • android零基础入门教程
  • android fragmentation
  • perl常用函数
  • python如何用pi
  • js页面滚动到指定位置
  • 北京亦庄开发区属于哪个区
  • 资源税的税目有7个,其中不包括
  • 2021广东农村医保多少钱一年
  • 个人年收入超过多少不能退税
  • 黄金增值税管理难点
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设