位置: IT常识 - 正文

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

编辑:rootadmin
Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

推荐整理分享Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch基础,pytorch技巧,pytorch技巧,pytorch基础,pytorch零基础入门,pytorch 快速入门,pytorch基础教程,pytorch基础教程,内容如对您有帮助,希望把文章链接给更多的朋友!

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch APIPytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torchimport torch.nn as nnclass Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return dataif __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = 'input' output_name = 'output' torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型import numpy as npimport onnximport onnxruntimeif __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print('output1.shape: ', np.squeeze(np.array(output1), 0).shape) print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

本文链接地址:https://www.jiuchutong.com/zhishi/295970.html 转载请保留说明!

上一篇:vue中使用wangeditor富文本编辑器(vue中使用require报错)

下一篇:图文详解vue.js devtools插件使用方法(图文详解一本通)

  • 证书挂靠要申报个人所得税年度汇算吗
  • 汽车购置税怎么在手机上缴费
  • 员工买东西自己垫付的钱怎么做账
  • 利润表的上期金额是指全年吗
  • 印花税可以申报以前年度吗
  • 增值税普通发票和专用发票有什么区别
  • 所得税不计提直接缴纳,年末一次性计提
  • 行政运行经费包括项目支出吗
  • 代开专票可以开13个点吗
  • 纳税总额和实际上缴税费总额
  • 小企业会计准则2023电子版
  • 用友怎么设置工龄工资
  • 加工费的计提工资账务处理
  • 带薪缺勤会计处理
  • 车间领用材料应填制什么凭证
  • 外聘人员劳务费入什么科目
  • 上个月没有结账可以做下个月的账吗
  • 挂账留底税额如何抵扣?
  • 投标保证金以现金的形式缴纳,能以现金的形式退回吗
  • 股权转让时其他股东不配合怎么转让
  • 建筑公司收到工程服务发票怎么做会计分录
  • 收了公司的款项不拿回公司属于什么行为
  • 企业间借贷利息规定
  • 进项大于销项的分录怎么写
  • 预缴时弥补的以前年度亏损是会计亏损吗?
  • 停车管理费什么时候交
  • 一般纳税人有进项无销项
  • 审计调账后企业怎么处理
  • 进口报关单保费
  • win10 20h2更新时间久
  • 合并报表的范围
  • 发票的概念
  • pniopcac.exe是什么进程
  • php数组函数大全
  • 捐赠支出汇算清缴需要调增吗
  • 购买加油卡如何开发票
  • 非上市公司股票期权个人所得税
  • win11升级正式版
  • php自定义函数的语法格式
  • 十四届智能车规则
  • php jquery
  • php yii
  • website
  • 商贸公司库存怎么盘点准确一点儿
  • 残次品生产成本计算
  • 帝国cms首页调用显示标题图片代码
  • 织梦如何开启会员功能
  • 研发人员旅游能计入研发费用吗
  • 印花税申报怎么填
  • 个税 收入
  • 进项税转出金额怎么算
  • 汽车租赁费怎么做分录
  • 抵账协议上可以签字吗
  • 学校接受捐赠收入要交企业所得税吗
  • 上级补助收入科目
  • 股东向公司借款多久必须归还
  • 工会经费可以在以后年度扣除吗
  • 企业取得租车发票
  • 客户手续费率
  • 买相机送肩带吗
  • 去年的物业费今年收到了可以确认收入吗
  • 房产税什么时候开始征收2023
  • vmware 错误
  • linux中使用grep命令显示包含特殊字符的行
  • windows 8.1将“计算机”(This PC)更名为“此电脑”
  • player文件怎么打开
  • javascript数组的方法
  • node实战
  • 安卓布局优化
  • unityai寻路
  • web应用程序开源框架
  • javascript的dom
  • 手把手教你把币从交易所提到钱包
  • html折叠
  • 江苏国税电子税局
  • 增值税发票税控开票软件客服
  • 落实落地是什么意思
  • 福州市税务局领导班子成员名单
  • 企业补缴公积金 归集额增加
  • 地方税务机关税率是多少
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设