位置: IT常识 - 正文

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

编辑:rootadmin
Pytorch中的grid_sample算子功能解析

推荐整理分享Pytorch中的grid_sample算子功能解析(pytorch中的数据类型),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch中的loss函数,pytorch中的forward函数,pytorch中的数据类型,pytorch中的tensor,pytorch中的view函数,pytorch中的张量,pytorch中的loss函数,pytorch中的tensor,内容如对您有帮助,希望把文章链接给更多的朋友!

         pytorch中的grid_sample是一种特殊的采样算法。

调用接口为:

torch.nn.functional.grid_sample(input,grid,mode='bilinear',padding_mode='zeros',align_corners=None)。

         input参数是输入特征图tensor,也就是特征图,可以是四维或者五维张量,以四维形式为例(N,C,Hin,Win),N可以理解为Batch_size,C可以理解为通道数,Hin和Win也就是特征图高和宽。

         grid包含输出特征图特征图的格网大小以及每个格网对应到输入特征图的采样点位,对应四维input,其张量形式为(N,Hout,Wout,2),其中最后一维大小必须为2,如果输入为五维张量,那么最后一维大小必须为3。为什么最后一维必须为2或者3?因为grid的最后一个维度实际上代表一个坐标(x,y)或者(xy,z),对应到输入特征图的二维或三维特征图的坐标维度,xy取值范围一般为[-1,1],该范围映射到输入特征图的全图。

         mode为选择采样方法,有三种内插算法可选,分别是'bilinear'双线性差值、'nearest'最邻近插值、'bicubic' 双三次插值。

Pytorch中的grid_sample算子功能解析(pytorch中的数据类型)

         padding_mode为填充模式,即当(x,y)取值超过输入特征图采样范围,返回一个特定值,有'zeros' 、 'border' 、 'reflection'三种可选,一般用zero。

         align_corners为bool类型,指设定特征图坐标与特征值对应方式,设定为TRUE时,特征值位于像素中心。

         要理解grid_sample是如何工作的,最好就是进行简单的复现。假设输入shape为(N,C,H,W),grid的shape设定为(N,H,W,2),以双线性差值为例进行处理。首先根据input和grid设定,输出特征图tensor的shape为(N,C,H,W),输出特征图上每一个cell上的值由grid最后一维(x,y)确定。那么如何计算输出tensor上每一个点的值?首先,通过(x,y)找到输入特征图上的采样位置,由于xy取值范围为[-1,1],为了便于计算,先将xy取值范围调整为[0,1]。通过(w-1)*(x+1)/2、(wh-1)*(y+1)/2将xy映射为输入特征图的具体坐标位置。将xy映射到特征图实际坐标后,取该坐标附近四个角点特征值,通过四个特征值坐标与采样点坐标相对关系进行双线性插值,得到采样点的值。

注意:xy映射后的坐标可能是输入特征图上任意位置。假设输出特征图上(2,2)坐标位置上的值采样位置可能为输入特征图上(3,4)位置,xy越小越靠近输入特征图左上角,越大则越靠近右下角。

         基于上面的思路,可以进行一个简单的自定义实现。根据指定shape生成input和grid,使用pytorch中的grid_sample算子生成output。之后取grid中的第一个位置中的xy,根据xy从input中通过双线性插值计算出output第一个位置的值。

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N,C,H,W)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): param = [0.0, 0.0] param[0] = (W_in - 1) * (grid[i][k][l][0] + 1) / 2 param[1] = (H_in - 1) * (grid[i][k][l][1] + 1) / 2 x0 = int(param[0]) x1 = x0 + 1 y0 = int(param[1]) y1 = y0 + 1 param[0] -= x0 param[1] -= y0 left_top = input[i][j][y0][x0] * (1 - param[0]) * (1 - param[1]) left_bottom = input[i][j][y1][x0] * (1 - param[0]) * param[1] right_top = input[i][j][y0][x1] * param[0] * (1 - param[1]) right_bottom = input[i][j][y1][x1] * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H, W = 1, 1, 4, 4input = np.random.random((N,C,H,W))grid = np.random.random((N,H,W,2))out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input,grid,mode='bilinear', padding_mode='zeros',align_corners=True)print(f'grid_sample输出结果:\n{output}')

运行结果:

         从输出结果上看,与pytorch基本一致,由于仅仅做简单验证,这里没有对超出[-1,1]范围的xy值做处理,只能处理四维input,五维input的实现思路与这里基本一致。

        考虑到(x,y)取值范围可能越界,pytorch中的padding_mode设置就是对(x,y)落在输入特征图外边缘情况进行处理,一般设置'zero',也就是对靠近输入特征图范围以外的采样点进行0填充,如果不进行处理显然会造成索引越界。要解决(x,y)越界问题,可以进行如下修改:

import torchimport numpy as npdef grid_sample(input, grid): N, C, H_in, W_in = input.shape N, H_out, W_out, _ = grid.shape output = np.random.random((N, C, H_out, W_out)) for i in range(N): for j in range(C): for k in range(H_out): for l in range(W_out): x, y = grid[i][k][l][0], grid[i][k][l][1] param = [0.0, 0.0] param[0] = (W_in - 1) * (x + 1) / 2 param[1] = (H_in - 1) * (y + 1) / 2 x1 = int(param[0] + 1) x0 = x1 - 1 y1 = int(param[1] + 1) y0 = y1 - 1 param[0] = abs(param[0] - x0) param[1] = abs(param[1] - y0) left_top_value, left_bottom_value, right_top_value, right_bottom_value = 0, 0, 0, 0 if 0 <= x0 < W_in and 0 <= y0 < H_in: left_top_value = input[i][j][y0][x0] if 0 <= x1 < W_in and 0 <= y0 < H_in: right_top_value = input[i][j][y0][x1] if 0 <= x0 < W_in and 0 <= y1 < H_in: left_bottom_value = input[i][j][y1][x0] if 0 <= x1 < W_in and 0 <= y1 < H_in: right_bottom_value = input[i][j][y1][x1] left_top = left_top_value * (1 - param[0]) * (1 - param[1]) left_bottom = left_bottom_value * (1 - param[0]) * param[1] right_top = right_top_value * param[0] * (1 - param[1]) right_bottom = right_bottom_value * param[0] * param[1] result = left_bottom + left_top + right_bottom + right_top output[i][j][k][l] = result return outputN, C, H_in, W_in = 1, 1, 4, 4H_out, W_out = 4, 4input = np.random.random((N, C, H_in, W_in))grid = np.random.random((N, H_out, W_out, 2))grid[0][0][0] = [-1.2, 1.3]out = grid_sample(input, grid)print(f'自定义实现输出结果:\n{out}')input = torch.from_numpy(input)grid = torch.from_numpy(grid)output = torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)print(f'grid_sample输出结果:\n{output}')

     测试结果:

   

本文链接地址:https://www.jiuchutong.com/zhishi/295209.html 转载请保留说明!

上一篇:成功解决:npm 版本不支持node.js。【 npm v9.1.2 does not support Node.js v16.6.0.】(成功解决冲突的能力英语)

下一篇:Vue 中 forEach() 的使用(vue foreach is not a function)

  • 苹果8跟xs大小对比(苹果8跟xs大小对比图)

    苹果8跟xs大小对比(苹果8跟xs大小对比图)

  • 快手养号一般养多久就可以发作品了(快手养号一般养多久就有金币)

    快手养号一般养多久就可以发作品了(快手养号一般养多久就有金币)

  • 抢票可以两个人同时帮另一个人抢吗(抢票两个人一起抢还是一个人抢好)

    抢票可以两个人同时帮另一个人抢吗(抢票两个人一起抢还是一个人抢好)

  • 滴滴白金会员是什么等级(滴滴白金会员有啥好处)

    滴滴白金会员是什么等级(滴滴白金会员有啥好处)

  • 华为售后贴膜免费吗(华为售后贴膜免费贴几年)

    华为售后贴膜免费吗(华为售后贴膜免费贴几年)

  • 怎样删除发现公众号(怎样删除微信里的发现公众号)

    怎样删除发现公众号(怎样删除微信里的发现公众号)

  • 如何查看网线是几类线(如何查看网线是不是千兆)

    如何查看网线是几类线(如何查看网线是不是千兆)

  • 1660ti和1650ti差距大吗(1660ti 和1650)

    1660ti和1650ti差距大吗(1660ti 和1650)

  • 小米wifi放大器连不上(小米wifi放大器连不上路由器)

    小米wifi放大器连不上(小米wifi放大器连不上路由器)

  • 群管理员是干什么的(群管理员负责管什么)

    群管理员是干什么的(群管理员负责管什么)

  • 主频越高运算速度越快对吗(主频越高计算机运算速度越快对吗)

    主频越高运算速度越快对吗(主频越高计算机运算速度越快对吗)

  • 索尼a7m2是触摸屏吗(索尼a7m3触摸屏怎么使用)

    索尼a7m2是触摸屏吗(索尼a7m3触摸屏怎么使用)

  • iphone7支持多少w快充(iphone 7最高支持多少瓦)

    iphone7支持多少w快充(iphone 7最高支持多少瓦)

  • 网易云文件夹叫什么(网易云文件夹叫什么名字)

    网易云文件夹叫什么(网易云文件夹叫什么名字)

  • 荣耀v20怎么调24时间(荣耀v20怎么调字体大小)

    荣耀v20怎么调24时间(荣耀v20怎么调字体大小)

  • 一加7pro游戏模式怎么设置(一加7pro游戏模式怎么关闭)

    一加7pro游戏模式怎么设置(一加7pro游戏模式怎么关闭)

  • 苹果7p诊断用量在哪(苹果7p测量仪在哪里)

    苹果7p诊断用量在哪(苹果7p测量仪在哪里)

  • 爱奇艺取消自动续费后还是会员吗(爱奇艺取消自动续费支付宝)

    爱奇艺取消自动续费后还是会员吗(爱奇艺取消自动续费支付宝)

  • 教师资格证怎么考?容易吗?(教师资格证怎么补办)

    教师资格证怎么考?容易吗?(教师资格证怎么补办)

  • 手机网络密钥怎么查(手机网络密钥在哪里)

    手机网络密钥怎么查(手机网络密钥在哪里)

  • Linux中文本处理命令sed的使用示例分享(linux常见的文本编辑工具有哪些)

    Linux中文本处理命令sed的使用示例分享(linux常见的文本编辑工具有哪些)

  • [uniapp] 跨页面传值 uni.$emit 和 uni.$on 的使用方法 以及遇到的坑(uniapp跨域解决方案)

    [uniapp] 跨页面传值 uni.$emit 和 uni.$on 的使用方法 以及遇到的坑(uniapp跨域解决方案)

  • 如何在天猫平台抢茅台
  • 支持疫情防控捐赠语言
  • 销售不动产税率9%还是5%
  • 股票印花税怎么交
  • 工业土地摊销年限最新规定
  • 房租押金不退还怎么处理
  • 个人提供建筑安装劳务如何缴纳个人所得税
  • 未使用固定资产计提折旧计入
  • 企业对企业分红要缴纳什么税免税分红
  • 福利费专票进项可以抵扣吗
  • 营业外收支计入哪里
  • 单位买绿植可以报销吗
  • 总公司汇总缴纳所得税升为一般纳税人分公司受影响吗
  • 报表的应付款太大怎么调?
  • 申请国家知识产权的条件
  • 发票已抵扣是什么意思
  • 股权变更印花税双方都要交吗
  • 可以给农村信用社的存折转账吗
  • 税收奖励需要纳税吗
  • 代扣代缴个人所得税现金流计入哪里
  • 建筑挂靠管理费用如何账务处理?
  • 结余资金财政收回如何做账
  • 可转换债券存在的问题
  • 编制会计报表利润表
  • 学校收取食堂管理费
  • 劳务费个税计算方式
  • 最新气象报告
  • award bios设置详解
  • 搜索特定
  • 母公司投资子公司怎么做账
  • 苹果电脑怎么打顿号
  • ghost后分区没有了
  • 固定资产出租需交什么税
  • PHP:class_parents()的用法_spl函数
  • 税务申报逾期罚款不交
  • 虚开发票的管理办法是什么?
  • 财政补贴是解决什么问题的
  • Ubuntu18.04安装cuda10.2
  • php设计模式六大原则
  • 公司对公账户没有流水怎么办
  • 代扣代缴个税手续费返还文件
  • 收据的种类有哪些
  • sqlserver数据库事务
  • 医院执行政府会计制度操作指南 .pdf
  • 劳务公司已开票怎么入账
  • 收到自然人税务申报短信
  • 内账会计的岗位职责
  • 合并报表抵消分录的基本原理
  • 小规模纳税人销售不动产适用税率
  • 厂房出租自用各种费用
  • 私人出租房子发圈文案
  • 过路费怎么抵扣进项税额报表怎么填
  • 广告收入计入哪个科目
  • 应付账款周转率越大越好还是越小越好?
  • 提供劳务方式是什么意思
  • 事业单位开办费与注册资本的关系
  • 补交增值税如何入账
  • 代理记账公司都是假账么
  • 个体工商户的公章丢了怎么办
  • Windows Server 2008下高效域管理体验
  • mac chrome浏览器插件
  • 在unix系统中采用的页面置换
  • soapui安装与配置
  • 文本文件模式
  • win8电脑设置
  • win10打不出字解决办法
  • msdev.exe是什么
  • w10自启
  • win8 休眠
  • shell脚本中执行echo卡住
  • python生成密钥
  • 浏览器兼容的方法
  • css hacks
  • jq cookie
  • android天气预报开发极简
  • android 轮播
  • 增值税已申报但是忘清卡
  • 个人绩效考核税务局
  • 我国现行税率分
  • 按季申报印花税怎么申报
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设