位置: IT常识 - 正文

对Transformer中Add&Norm层的理解(transformer中的参数)

编辑:rootadmin
对Transformer中Add&Norm层的理解 对Add&Norm层的理解Add操作Norm操作Add操作

推荐整理分享对Transformer中Add&Norm层的理解(transformer中的参数),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:transform方法会对产生的标量值进行,transform方法会对产生的标量值进行,transformer add norm,transformer中的参数,transformer中的参数,transformer方法,transformer中的参数,transformer中的参数,内容如对您有帮助,希望把文章链接给更多的朋友!

首先我们还是先来回顾一下Transformer的结构:Transformer结构主要分为两大部分,一是Encoder层结构,另一个则是Decoder层结构,Encoder 的输入由 Input Embedding 和 Positional Embedding 求和输入Multi-Head-Attention,再通过Feed Forward进行输出。

由下图可以看出:在Encoder层和Decoder层中都用到了Add&Norm操作,即残差连接和层归一化操作。 什么是残差连接呢?残差连接就是把网络的输入和输出相加,即网络的输出为F(x)+x,在网络结构比较深的时候,网络梯度反向传播更新参数时,容易造成梯度消失的问题,但是如果每层的输出都加上一个x的时候,就变成了F(x)+x,对x求导结果为1,所以就相当于每一层求导时都加上了一个常数项‘1’,有效解决了梯度消失问题。

Norm操作

首先要明白Norm做了一件什么事,从刚开始接触Transformer开始,我认为所谓的Norm就是BatchNorm,但是有一天我看到了这篇文章,才明白了Norm是什么。

假设我们输入的词向量的形状是(2,3,4),2为批次(batch),3为句子长度,4为词向量的维度,生成以下数据:

[[w11, w12, w13, w14], [w21, w22, w23, w24], [w31, w32, w33, w34][w41, w42, w43, w44], [w51, w52, w53, w54], [w61, w62, w63, w64]]

如果是在做BatchNorm(BN)的话,其计算过程如下:BN1=(w11+w12+w13+w14+w41+ w42+w43+w44)/8,同理会得到BN2和BN3,最终得到[BN1,BN2,BN3] 3个mean

对Transformer中Add&Norm层的理解(transformer中的参数)

如果是在做LayerNorm(LN)的话,则会进如下计算:LN1=(w11+w12+w13+w14+w21+ w22+w23+w24+w31+w32+w33+w34)/12,同理会得到LN2,最终得到[LN1,LN2]两个mean

如果是在做InstanceNorm(IN)的话,则会进如下计算:IN1=(w11+w12+w13+w14)/4,同理会得到IN2,IN3,IN4,IN5,IN6,六个mean,[[IN1,IN2,IN3],[IN4,IN5,IN6]] 下图完美的揭示了,这几种Norm 接下来我们来看一下Transformer中的Norm:首先生成[2,3,4]形状的数据,使用原始的编码方式进行编码:

import torchfrom torch.nn import InstanceNorm2drandom_seed = 123torch.manual_seed(random_seed)batch_size, seq_size, dim = 2, 3, 4embedding = torch.randn(batch_size, seq_size, dim)layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)print("y: ", layer_norm(embedding))

输出:

y: tensor([[[ 1.5524, 0.0155, -0.3596, -1.2083], [ 0.5851, 1.3263, -0.7660, -1.1453], [ 0.2864, 0.0185, 1.2388, -1.5437]], [[ 1.1119, -0.3988, 0.7275, -1.4406], [-0.4144, -1.1914, 0.0548, 1.5510], [ 0.3914, -0.5591, 1.4105, -1.2428]]])

接下来手动去进行一下编码:

eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))mean: torch.Size([2, 3, 1])y_custom: tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以发现和LayerNorm的结果是一样的,也就是说明Norm是对d_model进行的Norm,会给我们[batch,sqe_length]形状的平均值。 加下来进行batch_norm,

layer_norm = torch.nn.LayerNorm([seq_size,dim], elementwise_affine = False)eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

输出:

mean: torch.Size([2, 1, 1])y_custom: tensor([[[ 1.1822, 0.4419, -0.3196, -1.9889], [-0.6677, -0.2537, -0.8151, 1.5143], [ 0.7174, 1.2147, -0.0852, -0.9403]], [[-0.0138, 1.5666, -2.1726, 1.0590], [ 0.6646, 0.6852, -0.8706, -0.0442], [-0.1163, 0.1389, 0.4454, -1.3423]]])

可以看到BN的计算的mean形状为[2, 1, 1],并且Norm结果也和上面的两个不一样,这就充分说明了Norm是在对最后一个维度求平均。 那么什么又是Instancenorm呢?接下来再来实现一下instancenorm

instance_norm = InstanceNorm2d(3, affine=False)output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)print(layer_norm(embedding))

输出:

tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以看出无论是layernorm还是instancenorm,还是我们手动去求平均计算其Norm,结果都是一样的,由此我们可以得出一个结论:Layernorm实际上是在做Instancenorm!

如果喜欢文章请点个赞,笔者也是一个刚入门Transformer的小白,一起学习,共同努力。

本文链接地址:https://www.jiuchutong.com/zhishi/296105.html 转载请保留说明!

上一篇:图像融合、Transformer、扩散模型(图像融合名词解释)

下一篇:Vue | Vue.js 全家桶 Pinia状态管理(vue全家桶的app项目代码)

  • xboxgamebar什么意思(xbox game bar是啥)

    xboxgamebar什么意思(xbox game bar是啥)

  • 华为p30微信视频能美颜吗(华为P30微信视频怎么开美颜?)

    华为p30微信视频能美颜吗(华为P30微信视频怎么开美颜?)

  • 无推广中单元是什么意思(无推广中单元被驳回怎么解决)

    无推广中单元是什么意思(无推广中单元被驳回怎么解决)

  • 机械键盘换轴必须焊吗(给机械键盘换轴)

    机械键盘换轴必须焊吗(给机械键盘换轴)

  • 三星手机回收站在哪里打开(三星手机回收站怎么找到)

    三星手机回收站在哪里打开(三星手机回收站怎么找到)

  • 荣耀10x支持wifi6吗(荣耀10x支持多少瓦快充)

    荣耀10x支持wifi6吗(荣耀10x支持多少瓦快充)

  • 在word中打开文档的作用是什么(在word中打开文件菜单的键盘方法是按住)

    在word中打开文档的作用是什么(在word中打开文件菜单的键盘方法是按住)

  • 淘宝预计送达时间准吗(淘宝预计送达时间是怎么算的)

    淘宝预计送达时间准吗(淘宝预计送达时间是怎么算的)

  • ipad刷新率在哪调(ipad刷新率在哪里设置)

    ipad刷新率在哪调(ipad刷新率在哪里设置)

  • 微信怎么不出现在通讯录(微信怎么不出现在别人的通讯录里)

    微信怎么不出现在通讯录(微信怎么不出现在别人的通讯录里)

  • 计算器上ac是开关键吗(计算器的ac是什么)

    计算器上ac是开关键吗(计算器的ac是什么)

  • 企鹅影院和腾讯视频一样吗(企鹅影院和腾讯会员通用吗)

    企鹅影院和腾讯视频一样吗(企鹅影院和腾讯会员通用吗)

  • 定位为什么显示离线(定位为什么显示黑色)

    定位为什么显示离线(定位为什么显示黑色)

  • 怎么查自己名下的手机号(怎么查自己名下有没有房产)

    怎么查自己名下的手机号(怎么查自己名下有没有房产)

  • 美团错误删除订单恢复(美团订单删除不了该订单不可删除)

    美团错误删除订单恢复(美团订单删除不了该订单不可删除)

  • 可以查到别人微信聊天记录吗(可以查到别人微信里的好友吗)

    可以查到别人微信聊天记录吗(可以查到别人微信里的好友吗)

  • 拼多多黑号怎么解除(拼多多黑号怎么黑回来)

    拼多多黑号怎么解除(拼多多黑号怎么黑回来)

  • wps2019智能工具箱在哪(wps2016智能工具箱在哪)

    wps2019智能工具箱在哪(wps2016智能工具箱在哪)

  • 手机tar格式怎么打开(tar后缀文件手机怎么打开)

    手机tar格式怎么打开(tar后缀文件手机怎么打开)

  • 三星s9支持人脸识别支付吗(三星s9人脸识别怎么关闭)

    三星s9支持人脸识别支付吗(三星s9人脸识别怎么关闭)

  • 手机怎么设置定位追踪(手机怎么设置定时打电话)

    手机怎么设置定位追踪(手机怎么设置定时打电话)

  • 最新Win1021H1专业版永久免费激活码推荐 含激活工具(window10专业版2021)

    最新Win1021H1专业版永久免费激活码推荐 含激活工具(window10专业版2021)

  • khooker.exe是什么进程 有什么用 khooker进程查询(kcleaner.exe是什么)

    khooker.exe是什么进程 有什么用 khooker进程查询(kcleaner.exe是什么)

  • WEB网页设计期末作业个人主页——基于HTML+CSS制作个人简介网站(web网页设计期末作业猫眼电影首页)

    WEB网页设计期末作业个人主页——基于HTML+CSS制作个人简介网站(web网页设计期末作业猫眼电影首页)

  • 小规模纳税人代收水电费税率
  • 发票验旧是验旧已开发票还是未开发票
  • 环保科技属于什么行业类别
  • 应付股利一直挂账怎么办
  • 机场工作人员的家属票
  • 资金周转率计算公式期初占用资金
  • 跨年度的银行未入账如何处理
  • 产权交易所怎么赚钱
  • 一般纳税人企业所得税政策最新2023税率
  • 预缴增值税预缴的城建税怎么申报
  • 预付款发票可以入费用吗
  • 不涉及税收
  • 公司购买理财产品的收益计入什么科目
  • 小规模纳税人可以抵扣增值税专用发票吗
  • 收到认证费用计入什么科目
  • 工会筹备金怎么报
  • 房地产企业预缴增值税什么时候结转
  • 个体户注销麻烦还是公司注销麻烦
  • 建筑业预估成本怎么算
  • 代扣税款手续费管理办法
  • 增值税收入和所得税收入不一致怎么办
  • 结转存货跌价准备冲减主营业务成本
  • 高新企业认定 研发委外费用
  • linux命令执行成功后会返回什么
  • 在线上网测试
  • windows 11预览版
  • 对公账户转库存现金对方科目怎么填
  • 税费缴纳比例
  • 转账支票购买办公用品会计
  • 现金折扣发生销售退回
  • 融资性售后回租和融资租赁的区别
  • 专利年费的滞纳金
  • 免税需要什么条件
  • 公司处理固定资产车辆怎么开发票
  • 傅里叶级数狄利克雷判别法
  • 【6G 新技术】6G数据面介绍
  • php7 数组
  • 消费者如何鉴别美的乐享三代风管机
  • js相关知识
  • opengl 帧率
  • 装修公司管理费是什么
  • 单位多缴个人社保证明
  • 已经红冲的发票显示正常
  • python locator
  • 进项未认证但已开票怎么办
  • 专用红字发票如何开具
  • 税前扣除的意思
  • 新成立公司如何报税
  • 差旅费退回怎么做账
  • 政府补助会计处理方法由总额法变为净额法
  • 转账支付水电费
  • 支付技术研究开发费
  • 销售货物的运费的税率怎么算
  • 关停企业的国家规定
  • 银行存款利息的结算方式
  • 次年发放的奖金怎么入账
  • ubuntu系统管理
  • executor进程
  • 清除桌面应用软件
  • windows手动启动服务
  • centos6启动服务的命令
  • win7系统自带网卡吗?
  • fedora23安装
  • win8 更新
  • windows 10预览版
  • win10系统安全中心在哪
  • windows7宽带连接断开怎么办
  • js与css有什么区别
  • unity开发的小游戏
  • nodejs调试指南
  • 3种不同的播种方法
  • 从零开始学什么
  • shell中的-n
  • NGUI学习笔记汇总
  • javascript入门教程
  • 如何用javascript
  • 装饰装修公司需要什么
  • 北京930末班车时间表
  • 西安国家税务局丁雁现任命职务
  • 云南省昆明市官渡区矣六街道
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设