位置: IT常识 - 正文

对Transformer中Add&Norm层的理解(transformer中的参数)

编辑:rootadmin
对Transformer中Add&Norm层的理解 对Add&Norm层的理解Add操作Norm操作Add操作

推荐整理分享对Transformer中Add&Norm层的理解(transformer中的参数),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:transform方法会对产生的标量值进行,transform方法会对产生的标量值进行,transformer add norm,transformer中的参数,transformer中的参数,transformer方法,transformer中的参数,transformer中的参数,内容如对您有帮助,希望把文章链接给更多的朋友!

首先我们还是先来回顾一下Transformer的结构:Transformer结构主要分为两大部分,一是Encoder层结构,另一个则是Decoder层结构,Encoder 的输入由 Input Embedding 和 Positional Embedding 求和输入Multi-Head-Attention,再通过Feed Forward进行输出。

由下图可以看出:在Encoder层和Decoder层中都用到了Add&Norm操作,即残差连接和层归一化操作。 什么是残差连接呢?残差连接就是把网络的输入和输出相加,即网络的输出为F(x)+x,在网络结构比较深的时候,网络梯度反向传播更新参数时,容易造成梯度消失的问题,但是如果每层的输出都加上一个x的时候,就变成了F(x)+x,对x求导结果为1,所以就相当于每一层求导时都加上了一个常数项‘1’,有效解决了梯度消失问题。

Norm操作

首先要明白Norm做了一件什么事,从刚开始接触Transformer开始,我认为所谓的Norm就是BatchNorm,但是有一天我看到了这篇文章,才明白了Norm是什么。

假设我们输入的词向量的形状是(2,3,4),2为批次(batch),3为句子长度,4为词向量的维度,生成以下数据:

[[w11, w12, w13, w14], [w21, w22, w23, w24], [w31, w32, w33, w34][w41, w42, w43, w44], [w51, w52, w53, w54], [w61, w62, w63, w64]]

如果是在做BatchNorm(BN)的话,其计算过程如下:BN1=(w11+w12+w13+w14+w41+ w42+w43+w44)/8,同理会得到BN2和BN3,最终得到[BN1,BN2,BN3] 3个mean

对Transformer中Add&Norm层的理解(transformer中的参数)

如果是在做LayerNorm(LN)的话,则会进如下计算:LN1=(w11+w12+w13+w14+w21+ w22+w23+w24+w31+w32+w33+w34)/12,同理会得到LN2,最终得到[LN1,LN2]两个mean

如果是在做InstanceNorm(IN)的话,则会进如下计算:IN1=(w11+w12+w13+w14)/4,同理会得到IN2,IN3,IN4,IN5,IN6,六个mean,[[IN1,IN2,IN3],[IN4,IN5,IN6]] 下图完美的揭示了,这几种Norm 接下来我们来看一下Transformer中的Norm:首先生成[2,3,4]形状的数据,使用原始的编码方式进行编码:

import torchfrom torch.nn import InstanceNorm2drandom_seed = 123torch.manual_seed(random_seed)batch_size, seq_size, dim = 2, 3, 4embedding = torch.randn(batch_size, seq_size, dim)layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)print("y: ", layer_norm(embedding))

输出:

y: tensor([[[ 1.5524, 0.0155, -0.3596, -1.2083], [ 0.5851, 1.3263, -0.7660, -1.1453], [ 0.2864, 0.0185, 1.2388, -1.5437]], [[ 1.1119, -0.3988, 0.7275, -1.4406], [-0.4144, -1.1914, 0.0548, 1.5510], [ 0.3914, -0.5591, 1.4105, -1.2428]]])

接下来手动去进行一下编码:

eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))mean: torch.Size([2, 3, 1])y_custom: tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以发现和LayerNorm的结果是一样的,也就是说明Norm是对d_model进行的Norm,会给我们[batch,sqe_length]形状的平均值。 加下来进行batch_norm,

layer_norm = torch.nn.LayerNorm([seq_size,dim], elementwise_affine = False)eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

输出:

mean: torch.Size([2, 1, 1])y_custom: tensor([[[ 1.1822, 0.4419, -0.3196, -1.9889], [-0.6677, -0.2537, -0.8151, 1.5143], [ 0.7174, 1.2147, -0.0852, -0.9403]], [[-0.0138, 1.5666, -2.1726, 1.0590], [ 0.6646, 0.6852, -0.8706, -0.0442], [-0.1163, 0.1389, 0.4454, -1.3423]]])

可以看到BN的计算的mean形状为[2, 1, 1],并且Norm结果也和上面的两个不一样,这就充分说明了Norm是在对最后一个维度求平均。 那么什么又是Instancenorm呢?接下来再来实现一下instancenorm

instance_norm = InstanceNorm2d(3, affine=False)output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)print(layer_norm(embedding))

输出:

tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以看出无论是layernorm还是instancenorm,还是我们手动去求平均计算其Norm,结果都是一样的,由此我们可以得出一个结论:Layernorm实际上是在做Instancenorm!

如果喜欢文章请点个赞,笔者也是一个刚入门Transformer的小白,一起学习,共同努力。

本文链接地址:https://www.jiuchutong.com/zhishi/296105.html 转载请保留说明!

上一篇:图像融合、Transformer、扩散模型(图像融合名词解释)

下一篇:Vue | Vue.js 全家桶 Pinia状态管理(vue全家桶的app项目代码)

  • vivox70pro有hifi吗(vivox70pro有耳机孔吗)

    vivox70pro有hifi吗(vivox70pro有耳机孔吗)

  • 小红书怎么改头像(小红书怎么改头像和名字图片)

    小红书怎么改头像(小红书怎么改头像和名字图片)

  • 华为畅享9plus电池耐用吗(华为畅享9plus电池多少钱一块)

    华为畅享9plus电池耐用吗(华为畅享9plus电池多少钱一块)

  • 在淘宝网购流程的步骤是什么(淘宝网购流程图片)

    在淘宝网购流程的步骤是什么(淘宝网购流程图片)

  • vivo y9s这款手机是闪充吗(vivo-y9s)

    vivo y9s这款手机是闪充吗(vivo-y9s)

  • 功率1800w一小时耗电多少(额定功率1800w一小时用多少电)

    功率1800w一小时耗电多少(额定功率1800w一小时用多少电)

  • 淘宝买家守则与规则在哪(淘宝买家规则)

    淘宝买家守则与规则在哪(淘宝买家规则)

  • 苹果外音喇叭声音小(苹果手机外音喇叭不响了怎么修)

    苹果外音喇叭声音小(苹果手机外音喇叭不响了怎么修)

  • 闪电转usb连接线什么意思(闪电转usb连接线 苹果连电视)

    闪电转usb连接线什么意思(闪电转usb连接线 苹果连电视)

  • 微信身体传感器要开吗(微信身体传感器权限怎么开?)

    微信身体传感器要开吗(微信身体传感器权限怎么开?)

  • 自己充值的抖币怎么提现(自己充值的抖币能提现吗)

    自己充值的抖币怎么提现(自己充值的抖币能提现吗)

  • 开个人热点耗电吗(开个人热点耗电量大吗)

    开个人热点耗电吗(开个人热点耗电量大吗)

  • 如何在微信中举报他人(如何在微信中举报)

    如何在微信中举报他人(如何在微信中举报)

  • vivox30是三星屏幕吗(vivox30是lcd屏吗)

    vivox30是三星屏幕吗(vivox30是lcd屏吗)

  • 表格电话号码为什么有E(表格上电话号为13888e+10)

    表格电话号码为什么有E(表格上电话号为13888e+10)

  • 微信发出去的图片怎么撤回(微信发出去的图片如何销毁)

    微信发出去的图片怎么撤回(微信发出去的图片如何销毁)

  • 手机网易云音乐下载的歌曲在哪个文件夹(手机网易云音乐怎么在车上播放)

    手机网易云音乐下载的歌曲在哪个文件夹(手机网易云音乐怎么在车上播放)

  • 抖音短视频怎么输入密码(抖音短视频怎么拍)

    抖音短视频怎么输入密码(抖音短视频怎么拍)

  • oppo手机电话号码怎么保存到卡上(oppo手机电话号码导入sim卡)

    oppo手机电话号码怎么保存到卡上(oppo手机电话号码导入sim卡)

  • xs支持多少瓦快充(xs最多支持多少瓦充电)

    xs支持多少瓦快充(xs最多支持多少瓦充电)

  • 美柚到底怎么用的(美柚怎么了)

    美柚到底怎么用的(美柚怎么了)

  • 苹果11有啥功能(苹果11有什么功能是我们不知道的)

    苹果11有啥功能(苹果11有什么功能是我们不知道的)

  • 手机hdb什么意思(手机上hdb是什么意思)

    手机hdb什么意思(手机上hdb是什么意思)

  • 小米8有语音唤醒吗(小米8有语音唤醒功能)

    小米8有语音唤醒吗(小米8有语音唤醒功能)

  • 5g和4g有何不同(5g和4g的区别大吗)

    5g和4g有何不同(5g和4g的区别大吗)

  • mysql表的设计规范(mysql表设计原则)

    mysql表的设计规范(mysql表设计原则)

  • 帝国CMS功能解密之字段处理函数详解(帝国cms破解授权)

    帝国CMS功能解密之字段处理函数详解(帝国cms破解授权)

  • 纳税人期末存货怎么结转
  • 个人所得税速算扣除数表
  • 机动车发票税率怎么算
  • 税务师考试2023年考试时间
  • 小规模定额征收是怎样
  • 报关单保费000/0.1/1
  • 收到小微企业补助会计分录
  • 应税销售行为的购买方为消费者个人的可以开专票吗
  • 兼职业务拿提成合法吗
  • 支付买方佣金
  • 企业合并所得税筹划
  • 拿到发票后如何处理
  • 年中股东红利分录怎么写
  • 地质勘察费用应由谁支付
  • 本年累计应交税费需要加上年初数吗
  • 采购退货退款怎么做账
  • 高速公路通行费发票怎么开
  • 公司委托法人代收款
  • 固定资产加速折旧最新政策2023
  • 非专利技术属于无形资产吗?
  • 统一社会信用代码证
  • 国际货运运费的计算基础
  • 应收账款的内容包括
  • 股东投资如何做账务处理
  • php中的事务使用是什么
  • PHP:Memcached::isPristine()的用法_Memcached类
  • 企业无偿提供劳务
  • phpscanf
  • 企业接受外单位投入的材料一批,应编制()
  • 水资源税收费标准
  • 商品入库进项税额怎么算
  • 前端埋点sdk
  • 收到车险发票含增值税吗
  • 浅谈如何培养孩子的注意力
  • 什么是进项票什么是成本票
  • axure简单教程
  • ChatGPT等大模型的模型量化:平滑量化法
  • php源码封装
  • 古腾堡中文官网
  • 房地产企业开发的已出租的房屋属于投资性房地产吗
  • 企业重组的特殊性税务处理例题
  • 个人劳务费用
  • 甲公司购入一台不需要安装
  • 信用减值损失在贷方表示什么
  • sql server数据库连接端口1434
  • 委托银行贷款利息发票谁提供
  • 还未摊销的房租怎么入账
  • 外币报表折算差额可以转损益吗
  • 以前年度社保计提出错了怎么调整
  • 工资汇算清缴前发
  • 房屋租赁费计入什么会计科目
  • 服装公司的会计怎么做账
  • 业务招待费属于管理费用吗
  • 应收帐款坏帐会计分录怎么处理
  • 咨询服务公司的经营范围
  • 客运服务费发票计入什么科目
  • 产品成本核算方法受那些因素影响
  • 怎么验证触发器的执行
  • SQL中exists的使用方法
  • vmware虚拟机步骤
  • centos6.5设置网络
  • 事件查看器中"TermService" 服务的性能库问题处理
  • windows如何删除本地用户
  • centos cpu 内存
  • mac os固件下载
  • win10预览版好吗
  • c#使用mongodb
  • Cocos2d-x3.3 Physics物理引擎模块解决了刚体穿透问题
  • Extjs中使用extend(js继承) 的代码
  • nodejs阿里云
  • android打包v1v2
  • cmd命令如何进入d盘
  • 嗌中怎么读
  • jquery中判断某个类是否存在的方法
  • python urljoin
  • python面向对象编程心得体会
  • 北京税务跨区迁移不予受理,原因是什么
  • 收到补税点的分录
  • 广东省电子税务局app下载手机版
  • 非盈利org
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设