位置: IT常识 - 正文

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

编辑:rootadmin
【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

推荐整理分享【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

新年新气象,兄弟们新年快乐。撒花!!!

之前我们的项目已经讲过了常见的4种深度学习任务(当然还有一些没有接触到的,例如GAN和今年大红的Transformer),今天这个blog我们就来谈谈一谈常见的损失函数。损失函数的更新也是非常的快,各位大佬的想法也是层出不穷,我们站在巨人的肩膀上,就可以看的更远,走的更远。

1.BCELoss

BCELoss又叫二分类交叉熵损失,顾名思义,它是用来做二分类的损失函数,我们先来看看BCELoss的公式。

其中pt---模型预测值,target---标签值, w---权重值,一般是1

上面这个公式是单个样本的,当一个batch有N个样本时

这么说是不是显得很苍白无力,所以我们来一个例子吧,我们先创建一个pt和target

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print(pt)print(target)

pt我用的随机数代替的,target一般是0或者1,我们print一下,看看目前的数值是多少

这里的torch.rand(2, 3)中的2代表2个样本,3代表每个样本是一个1*3的向量

好了,我们来挨个计算:

pt的第一行第一列的值是 0.4963,它对应的标签target的第一行第一列的值是0,所以求根据刚才的公式L(pt,target) = -w*(target * ln(pt) + (1-target) * ln(1-pt)),w一般取1

L = -1 * (0*ln(0.4963)+1*ln(1-0.4963)) = -ln(1-0.4963) = 0.685774426230532 ≈ 0.6857

pt的第一行第二列的值是 0.7682,它对应的标签target的第一行第一列的值是0

L = -1 * (0*ln(0.7682)+1*ln(1-0.7682)) = -ln(1-0.7682) = 1.46188034807648  ≈ 1.4620

pt的第一行第三列的值是 0.0885,它对应的标签target的第一行第一列的值是1

L = -1 * (1*ln(0.0885)+0*ln(1-0.0885)) = -ln(0.0885) = -2.424752726968253  ≈ 2.4250

接下去,我就不算了,留个兄弟们来算,我们用代码来验证一下算对了没有吧

def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss)

此时x就是pt,y也就是target,值得注意的是torch.log = ln,它不是真的log,看看计算结果吧

看第一行,和我们刚刚的计算结果完全吻合,确实是这么算的,没跑了

别忘了,同时每一个样本也要求一下平均值

第一个样本的平均值是 (0.6857 + 1.4620 + 2.4250)/ 3 = 1.524233333333333333333

第二个样本的平均值是 (2.0247 + 0.3673 + 1.0053)/ 3 = 1.132433333333333333333

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

根据公式:

 所以loss = (1.524233333333333333333 + 1.132433333333333333333)/ 2 ≈1.328333

 上代码看看是不是这么回事吧

torch.manual_seed(0) pt = torch.rand(2, 3) target = torch.tensor([[0., 0., 1.], [1., 0., 0.]]) print('pt:',pt) print('target:',target) def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss) losss = torch.mean(com(pt, target)) print('总loss:',losss)

看看结果

 不错,一模一样,算对了。但是你肯定有疑问了,你这是你自己手算的,代码也是你自己写的,你只能证明你的计算和你的代码是对上了,怎么证明真正的和BCELoss对上了,那我们请出Pytorch的nn.BCELoss来看看结果吧

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print('pt:',pt)print('target:',target)loss = nn.BCELoss()print('pytorch loss:',loss(pt, target))

怎么样,我是不是算对了。

值得注意的是,在用BCELoss的时候,要记得先经过一个sigmoid或者softmax,以保证pt是0-1之间的。当然了,pytorch不可能想不到这个啊,所以它还提供了一个函数nn.BCEWithLogitsLoss()他会自动进行sigmoid操作。棒棒的!

2.带权重的BCELoss

先看看BCELoss的公式,w就是所谓的权重

 torch.nn.BCELoss()中,其实提供了一个weight的参数

我们要保持weight的形状和维度与target一致就可以了。

于是我手写一个带权重BCELoss,上代码

class BCE_WITH_WEIGHT(torch.nn.Module): def __init__(self, alpha=0.25, reduction='mean'): super(BCE_WITH_WEIGHT, self).__init__() self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5)) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

 核心带代码是

loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5))

alpha就是权重了,一般很多时候,正负样本是不平衡的,如果不加入权重,网络训练的时候,训练的关注的重点就跑到了样本多的那一类样本上去,对样本少的就不公平了,所以为了维护世界和平,贯彻爱与真实的邪恶,可爱又迷人的反派角色,带权重的损失函数就出现了。

大家可以看到,我在有一个地方是torch.log(pt+1e-5),1e-5的意思就是10的-5次方,为什么要加入1e-5,这个跟ln函数有关系,因为ln(0) = -无穷大,这样损失就爆炸了,训练就会出错误,所以默认就把它加上了。

3.BCE版本的Focal_Loss

FocalLoss的公式

此时的pt就是刚刚的那个pt了,此时的pt就是刚刚我们的BCEloss的结果了 

先上代码看看吧

class BCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='mean'): super(BCEFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt++1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5))) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

核心代码:

loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt+1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5)))

Focalloss的目前不仅是为了控制样本不平衡的现象,还有个作用就是,让网络着重训练难样本。

好了,BCE讲的差不多了,讲的不对的地方,欢迎大家指出。

至此,敬礼,salute!!!

老规矩,上咩咩狗

本文链接地址:https://www.jiuchutong.com/zhishi/299220.html 转载请保留说明!

上一篇:element-ui动态表单和验证(elementui动态表单数据回显)

下一篇:【HTML+CSS】实现网页的导航栏和下拉菜单(html cssjs)

  • ppt的占位符在哪里(ppt2016占位符在哪里)

    ppt的占位符在哪里(ppt2016占位符在哪里)

  • 怎么把电脑图标变小(怎么把电脑图标隐藏)

    怎么把电脑图标变小(怎么把电脑图标隐藏)

  • 微信面对面建群无法获取位置信息(微信面对面建群怎么设置管理员)

    微信面对面建群无法获取位置信息(微信面对面建群怎么设置管理员)

  • 7plus更新ios12好用吗(ios 12正式版苹果7p更新)

    7plus更新ios12好用吗(ios 12正式版苹果7p更新)

  • 手机为什么老显示请检查网络设置(手机为什么老显示关机重启)

    手机为什么老显示请检查网络设置(手机为什么老显示关机重启)

  • oppoa5微信视频通话怎么美颜

    oppoa5微信视频通话怎么美颜

  • 苹果ld邮件是什么意思(苹果id的邮件格式怎么弄)

    苹果ld邮件是什么意思(苹果id的邮件格式怎么弄)

  • 拼多多上评价了为什么不显示呢(拼多多做评价)

    拼多多上评价了为什么不显示呢(拼多多做评价)

  • 卡贴机有信号但激活sim无效(卡贴机有信号但没4g)

    卡贴机有信号但激活sim无效(卡贴机有信号但没4g)

  • vivox20微信自带美颜吗(vivo手机微信2怎么弄出来)

    vivox20微信自带美颜吗(vivo手机微信2怎么弄出来)

  • 笔记本机械硬盘和台式一样吗(笔记本机械硬盘突然消失只剩固态)

    笔记本机械硬盘和台式一样吗(笔记本机械硬盘突然消失只剩固态)

  • 字符型常量的表达形式(字符型常量用什么定义)

    字符型常量的表达形式(字符型常量用什么定义)

  • word添加页码怎么设置(word添加页码怎么都是1)

    word添加页码怎么设置(word添加页码怎么都是1)

  • pr添加字幕如何移动(pr添加字幕如何保持一个大小)

    pr添加字幕如何移动(pr添加字幕如何保持一个大小)

  • 哈罗单车能远程开锁吗(哈罗单车能远程解锁吗)

    哈罗单车能远程开锁吗(哈罗单车能远程解锁吗)

  • vivox23有没有空调遥控器(vivox23有空调遥控功能吗)

    vivox23有没有空调遥控器(vivox23有空调遥控功能吗)

  • ios掌阅怎么导入本地(苹果掌阅如何导入)

    ios掌阅怎么导入本地(苹果掌阅如何导入)

  • 清空聊天记录能恢复吗(清空聊天记录能释放内存吗)

    清空聊天记录能恢复吗(清空聊天记录能释放内存吗)

  • 苹果xr摄像头用保护吗(苹果xr摄像头用什么玻璃)

    苹果xr摄像头用保护吗(苹果xr摄像头用什么玻璃)

  • 怎样恢复网络状况不佳(红枣树任妙音)

    怎样恢复网络状况不佳(红枣树任妙音)

  • 挖孔屏是什么意思(挖孔屏啥意思)

    挖孔屏是什么意思(挖孔屏啥意思)

  • 微信自动收红包怎么设置(微信自动收红包功能是怎么弄的)

    微信自动收红包怎么设置(微信自动收红包功能是怎么弄的)

  • 土豆视频如何投影电视(土豆视频如何投屏)

    土豆视频如何投影电视(土豆视频如何投屏)

  • 电脑看视频黑屏但有声音(为什么b站电脑看视频黑屏)

    电脑看视频黑屏但有声音(为什么b站电脑看视频黑屏)

  • 斯科默岛白玉草丛中的海鹦,威尔士彭布罗克郡 (© Ross Hoddinott/Minden Pictures)

    斯科默岛白玉草丛中的海鹦,威尔士彭布罗克郡 (© Ross Hoddinott/Minden Pictures)

  • 小规模纳税人增值税起征点
  • 上月少计提的个税本月怎么调整
  • 递延所得税资产怎么计算
  • 什么是本期应纳税所得额
  • 代垫水电费增值税
  • 定期定额征收如何办理税费认定
  • 报废固定资产产生的净损益属于利得吗
  • 预支差旅费属于什么凭证
  • 远期汇票分为哪几种
  • 商业保险可以抵扣增值税吗
  • 增值税免税和即税的区别
  • 挂账留底税额如何抵扣?
  • 建筑公司收取的管理费如何入账
  • 接受固定资产投资的企业,应该按照投资合同
  • 业务宣传及广告费超比例
  • 工资低于3000要申报吗
  • 增值税抵扣联是什么意思
  • 一般纳税人所得税率是多少
  • 购车时服务费用怎么算
  • 公司代扣的社保怎么做分录
  • 所得税申报表中利润总额是怎样算出来的
  • 固定资产清理的借方
  • 什么是销售利润率和成本利润率
  • 代开专票交的城建税怎么申报附加税
  • 建筑工程账务处理是在哪个阶段
  • 合并报表六大抵消分录通俗理解
  • win11预览版选哪个
  • 生产车间报销费用
  • 以太网默认网关怎么查看
  • 金融企业计提资产减值准备是根据会计核算的
  • php写一个函数,算出两个文件的相对路径
  • 损益类科目怎么结转
  • 资产评估增值是什么意思
  • vue3响应式丢失
  • async/await原理
  • 小规模纳税人出租不动产免征增值税
  • Vite4+Pinia2+vue-router4+ElmentPlus搭建Vue3项目(组件、图标等按需引入)[保姆级]
  • lsmod命令结果详解
  • 公允价值变动损益
  • 公司与公司往来账表格怎么制作
  • 社保退回的款怎么继承
  • 增值税申报系统登录密码
  • 补开上年发票的税务处理要怎么做?
  • mysql 使用索引
  • 建筑工程分包案例
  • 进口关税的计算是以什么为基础
  • 哪些情况需要开具无违法犯罪证明
  • 前期物业管理阶段的工作有哪些
  • 财务风险有什么类别
  • 印花税如何计提缴纳
  • 已开票未收款如何销往来账
  • 股东转公户的钱叫什么
  • 一次性付款的优势
  • ubuntu 16.04.6安装教程
  • centos bz
  • mac系统文件名
  • mac怎么修改图片格式jpg
  • win7安装ubuntu20.10
  • macos使用方法
  • nfs网络安装
  • windows7的开机启动项在哪里
  • Android游戏开发教程
  • 批处理常用命令总结
  • jquery mobile app案例
  • python的基本数值类型
  • 有关javascript的书
  • shell错误日志输出
  • bootstrap需要学多久
  • 工商与税务合并了吗
  • 企业科研经费管理制度
  • 开票信息电子版怎么做
  • 福建省税务局 电子
  • 关于地税代收工会经费工作实施办法
  • 中国的消费税率是多少
  • 税务局文化建设实施方案
  • 普惠性和非普惠的区别
  • 税务迁出需要哪些手续2020年
  • 为什么有的企业在企查查上查不到
  • 江苏电子口岸卡邮寄大概需要多久
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设