位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 电脑清理流氓软件(电脑清理)(电脑清理流氓软件怎么弄)

    电脑清理流氓软件(电脑清理)(电脑清理流氓软件怎么弄)

  • 荣耀平板v7pro机身尺寸(荣耀平板v7 polo)

    荣耀平板v7pro机身尺寸(荣耀平板v7 polo)

  • 荣耀手机两张卡怎么切换流量(荣耀手机两张卡怎么关闭一张卡)

    荣耀手机两张卡怎么切换流量(荣耀手机两张卡怎么关闭一张卡)

  • 文字转换为表格行数怎么设置(word将文字转换为表格)

    文字转换为表格行数怎么设置(word将文字转换为表格)

  • 为什么qq字符突然不显示了(qq字符为什么会消失)

    为什么qq字符突然不显示了(qq字符为什么会消失)

  • 微店申请退款卖家拒绝(微店申请退款卖家超时)

    微店申请退款卖家拒绝(微店申请退款卖家超时)

  • 相对地址和绝对地址的区别(相对地址和绝对地址的概念)

    相对地址和绝对地址的区别(相对地址和绝对地址的概念)

  • 公众号群发能被谁看到(公众号群发被永久限制怎么解除)

    公众号群发能被谁看到(公众号群发被永久限制怎么解除)

  • 苹果x漏液只能换屏幕吗(iphone x漏液还能用多久)

    苹果x漏液只能换屏幕吗(iphone x漏液还能用多久)

  • ip20防护等级是什么意思(ip20防护等级是12.5 还是12mm)

    ip20防护等级是什么意思(ip20防护等级是12.5 还是12mm)

  • 苹果11怎么改不了微信铃声(苹果11怎么改不了密码)

    苹果11怎么改不了微信铃声(苹果11怎么改不了密码)

  • 5700xt黑屏无信号

    5700xt黑屏无信号

  • 怎么改闲鱼会员名(闲鱼网怎么改会员名)

    怎么改闲鱼会员名(闲鱼网怎么改会员名)

  • 华为手机怎么关闭健康使用手机(华为手机怎么关空调)

    华为手机怎么关闭健康使用手机(华为手机怎么关空调)

  • iphone11突然黑屏关机(iphone11突然黑屏但是能听到声音)

    iphone11突然黑屏关机(iphone11突然黑屏但是能听到声音)

  • win7启动修复无法自动修复此计算机怎么办(win7启动修复无法检测到问题)

    win7启动修复无法自动修复此计算机怎么办(win7启动修复无法检测到问题)

  • wlan常见的拓扑结构有哪些(wlan的网络拓扑)

    wlan常见的拓扑结构有哪些(wlan的网络拓扑)

  • 手机明明有内存为什么内存不足(手机明明有内存为什么下载不了软件)

    手机明明有内存为什么内存不足(手机明明有内存为什么下载不了软件)

  • 华为p30分屏设置(华为p30分屏怎样打开)

    华为p30分屏设置(华为p30分屏怎样打开)

  • 计算机病毒是一种什么(计算机病毒是一个在计算机内部或者在计算机)

    计算机病毒是一种什么(计算机病毒是一个在计算机内部或者在计算机)

  • 红米应用怎么移到sd卡(红米应用怎么移动)

    红米应用怎么移到sd卡(红米应用怎么移动)

  • 荣耀20有面部识别吗(荣耀20有面部解锁)

    荣耀20有面部识别吗(荣耀20有面部解锁)

  • oppo手机为什么不能下载东西怎么办(oppo手机为什么突然开不了机)

    oppo手机为什么不能下载东西怎么办(oppo手机为什么突然开不了机)

  • 绝对音量功能有什么用(绝对音量功能需要开吗)

    绝对音量功能有什么用(绝对音量功能需要开吗)

  • vivoz1和z3x有什么区别(vivo系列z1x和z3哪个好)

    vivoz1和z3x有什么区别(vivo系列z1x和z3哪个好)

  • vivoz3有指纹解锁吗(vivoz3指纹在哪)

    vivoz3有指纹解锁吗(vivoz3指纹在哪)

  • 手机wps分享怎么不变成链接(手机wps分享怎么分享)

    手机wps分享怎么不变成链接(手机wps分享怎么分享)

  • 外籍人员可以在中国工作吗
  • 非税收入专用申报表
  • 返佣账务处理
  • 给退休工人发工资怎么入账
  • 房地产土地增值税优惠政策
  • 企业发票入账冲销流程
  • 现金流量表借款还了流入和流出可以抵消吗
  • 营改增要交增值税吗
  • 预警税负率表
  • 公司向员工借款合法吗
  • 保险合同有啥用
  • 哪些进项税不能加计抵减
  • 电话费发票个人抬头可以税前扣除
  • 员工转入子公司怎么做账
  • 会计报表附表属于会计报表内容吗
  • 境外承包工程出口货物能否办理退税?
  • 生产型企业进出口初申报流程
  • 发放的工资比计提的多怎么办
  • 车辆租赁费交的是什么税
  • 支付固定资产运杂费计入什么科目
  • u盘的内存卡怎么装
  • 设备的折旧率是什么意思
  • thinkphp获取数据库数据
  • win10远程桌面连接不成功
  • 在php中,字符串有哪些表示形式
  • 购入工程物资用于建设厂房,购入后直接领用至工程项目
  • 高温补贴需要缴纳社会保险费吗
  • 不确认收入要结转成本吗
  • php 二维数组
  • 怎么用html做一个收藏夹
  • 46 个非常有用的成语
  • PHP+Mysql+Ajax实现淘宝客服或阿里旺旺聊天功能(前台页面)
  • php遍历对象
  • php扫二维码
  • PHP+MySql+jQuery实现的"顶"和"踩"投票功能
  • php验证系统
  • 小型微利企业减按25%计算应纳税所得额
  • 为什么没缴税
  • python中列表的索引用法
  • 会影响当期损益的科目有
  • 年终奖的个税税率
  • 月收入一万该怎么说
  • 处置长期股权投资其他综合收益结转
  • access日期时间格式怎么修改
  • 出口免抵额需要加交付地方附加税吗
  • 商品流通企业税费按征收对象可分为
  • 职工教育经费产生的差异
  • 企业关联业务往来情况怎么申报
  • 跨年度冲红字发票账务处理
  • 土地出让金抵减增值税
  • 银行承兑汇票进行贴现的会计分录
  • 360天认证期是什么时候发布的
  • 收入的确认条件包括
  • 根据《增值税暂行条例》的规定,适用9
  • 公司现金支票取钱需要带什么资料
  • safari 快捷键
  • win7系统关机很慢什么原因
  • xp系统操作全程图解
  • linux操作系统的安装
  • 查看win8.1版本
  • centos pptpd
  • linux中的vi编辑器一般有哪三个模式
  • cocos2dx怎么用啊
  • javascript函数的作用
  • extjs 为某个事件设置拦截器
  • nodejs示例
  • perl如何使用
  • node.js速成
  • nodejs定义数组
  • shell实现自动ssh
  • JavaScript数据类型
  • javascript下拉列表怎么做
  • 电子发票怎么汇总清卡
  • 保险是不是跟车走
  • 韩国快递关税
  • 出租车开的发票如何查询校验码?
  • 税务机关代收工会经费手续费
  • 无锡市国税
  • 南京税务举报
  • 广东省深圳市地图最新版
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设