位置: IT常识 - 正文

网络模型的参数量和FLOPs的计算 Pytorch(网络模型参数方法)

编辑:rootadmin
网络模型的参数量和FLOPs的计算 Pytorch

目录

1、torchstat 

2、thop

3、fvcore 

4、flops_counter

5、自定义统计函数


推荐整理分享网络模型的参数量和FLOPs的计算 Pytorch(网络模型参数方法),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:网络模型的参数量是一层不变的吗,网络模型的参数设置,网络模型的参数是什么,网络模型参数量如何计算,网络模型的参数量,网络模型的参数是什么,网络模型的参数量是一层不变的吗,网络模型的参数有哪些,内容如对您有帮助,希望把文章链接给更多的朋友!

FLOPS和FLOPs的区别:

FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

在介绍torchstat包和thop包之前,先总结一下:

torchstat包可以统计卷积神经网络和全连接神经网络的参数和计算量。thop包可以统计统计卷积神经网络、全连接神经网络以及循环神经网络的参数和计算量,程序示例等详见下文。1、torchstat pip install torchstat -i https://pypi.tuna.tsinghua.edu.cn/simple

在实际操作中,我们可以调用torchstat包,帮助我们统计模型的parameters和FLOPs。如果不修改这个包里面的一些代码,那么这个包只适用于输入为3通道的图像的模型。

import torchimport torch.nn as nnfrom torchstat import statclass Simple(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False) self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return xmodel = Simple()stat(model, (3, 244, 244)) # 统计模型的参数量和FLOPs,(3,244,244)是输入图像的size

 如果把torchstat包中的一行程序进行一点点改动,那么这个包可以用来统计全连接神经网络的参数量和计算量。当然手动计算全连接神经网络的参数量和计算量也很快 =_= 。进入torchstat源代码之后,如下图所示,注释掉圈红的地方,就可以用torchstat包统计全连接神经网络的参数量和计算量了。

网络模型的参数量和FLOPs的计算 Pytorch(网络模型参数方法)

2、thoppip install thop -i https://pypi.tuna.tsinghua.edu.cn/simpleimport torchimport torch.nn as nnfrom thop import profileclass Simple(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 10) def forward(self, x): x = self.fc1(x) return xnet = Simple()input = torch.randn(1, 10) # batchsize=1, 输入向量长度为10macs, params = profile(net, inputs=(input, ))print(' FLOPs: ', macs*2) # 一般来讲,FLOPs是macs的两倍print('params: ', params)3、fvcore pip install fvcore -i https://pypi.tuna.tsinghua.edu.cn/simple

用它比较好

import torchfrom torchvision.models import resnet50from fvcore.nn import FlopCountAnalysis, parameter_count_table# 创建resnet50网络model = resnet50(num_classes=1000)# 创建输入网络的tensortensor = (torch.rand(1, 3, 224, 224),)# 分析FLOPsflops = FlopCountAnalysis(model, tensor)print("FLOPs: ", flops.total())# 分析parametersprint(parameter_count_table(model))

 终端输出结果如下,FLOPs为4089184256,模型参数数量约为25.6M(这里的参数数量和我自己计算的有些出入,主要是在BN模块中,这里只计算了beta和gamma两个训练参数,没有统计moving_mean和moving_var两个参数),具体可以看下我在官方提的issue。 通过终端打印的信息我们可以发现在计算FLOPs时并没有包含BN层,池化层还有普通的add操作(我发现计算FLOPs时并没有统一的规定,在github上看的计算FLOPs项目基本每个都不同,但计算出来的结果大同小异)。

注意:在使用fvcore模块计算模型的flops时,遇到了问题,记录一下解决方案。首先是在jit_analysis.py的589行出错。经过调试发现,op_counts.values()的类型是int32,但是计算要求的类型只能是int、float、np.float64和np.int64,因此需要手动进行强制转换。修改如下:

4、flops_counterpip install ptflops -i https://pypi.tuna.tsinghua.edu.cn/simple

用它也很好,结果和fvcore一样

from ptflops import get_model_complexity_infomacs, params = get_model_complexity_info(model, (112, 9, 9), as_strings=True, print_per_layer_stat=True, verbose=True)print('{:<30} {:<8}'.format('Computational complexity: ', macs))print('{:<30} {:<8}'.format('Number of parameters: ', params))

5、自定义统计函数import torchimport numpy as npdef calc_flops(model, input): def conv_hook(self, input, output): batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * ( 2 if multiply_adds else 1) bias_ops = 1 if self.bias is not None else 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_height * output_width list_conv.append(flops) def linear_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 num_steps = input[0].size(0) weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) bias_ops = self.bias.nelement() if self.bias is not None else 0 flops = batch_size * (weight_ops + bias_ops) flops *= num_steps list_linear.append(flops) def fsmn_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 weight_ops = self.filter.nelement() * (2 if multiply_adds else 1) num_steps = input[0].size(0) flops = num_steps * weight_ops flops *= batch_size list_fsmn.append(flops) def gru_cell(input_size, hidden_size, bias=True): total_ops = 0 # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ state_ops = (hidden_size + input_size) * hidden_size + hidden_size if bias: state_ops += hidden_size * 2 total_ops += state_ops * 2 # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ total_ops += (hidden_size + input_size) * hidden_size + hidden_size if bias: total_ops += hidden_size * 2 # r hadamard : r * (~) total_ops += hidden_size # h' = (1 - z) * n + z * h # hadamard hadamard add total_ops += hidden_size * 3 return total_ops def gru_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 if self.batch_first: batch_size = input[0].size(0) num_steps = input[0].size(1) else: batch_size = input[0].size(1) num_steps = input[0].size(0) total_ops = 0 bias = self.bias input_size = self.input_size hidden_size = self.hidden_size num_layers = self.num_layers total_ops = 0 total_ops += gru_cell(input_size, hidden_size, bias) for i in range(num_layers - 1): total_ops += gru_cell(hidden_size, hidden_size, bias) total_ops *= batch_size total_ops *= num_steps list_lstm.append(total_ops) def lstm_cell(input_size, hidden_size, bias): total_ops = 0 state_ops = (input_size + hidden_size) * hidden_size + hidden_size if bias: state_ops += hidden_size * 2 total_ops += state_ops * 4 total_ops += hidden_size * 3 total_ops += hidden_size return total_ops def lstm_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 if self.batch_first: batch_size = input[0].size(0) num_steps = input[0].size(1) else: batch_size = input[0].size(1) num_steps = input[0].size(0) total_ops = 0 bias = self.bias input_size = self.input_size hidden_size = self.hidden_size num_layers = self.num_layers total_ops = 0 total_ops += lstm_cell(input_size, hidden_size, bias) for i in range(num_layers - 1): total_ops += lstm_cell(hidden_size, hidden_size, bias) total_ops *= batch_size total_ops *= num_steps list_lstm.append(total_ops) def bn_hook(self, input, output): list_bn.append(input[0].nelement()) def relu_hook(self, input, output): list_relu.append(input[0].nelement()) def pooling_hook(self, input, output): batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_ops = self.kernel_size * self.kernel_size bias_ops = 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_height * output_width list_pooling.append(flops) def foo(net): childrens = list(net.children()) if not childrens: print(net) if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): net.register_forward_hook(conv_hook) # print('conv_hook_ready') if isinstance(net, torch.nn.Linear): net.register_forward_hook(linear_hook) # print('linear_hook_ready') if isinstance(net, torch.nn.BatchNorm2d): net.register_forward_hook(bn_hook) # print('batch_norm_hook_ready') if isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): net.register_forward_hook(relu_hook) # print('relu_hook_ready') if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): net.register_forward_hook(pooling_hook) # print('pooling_hook_ready') if isinstance(net, torch.nn.LSTM): net.register_forward_hook(lstm_hook) # print('lstm_hook_ready') if isinstance(net, torch.nn.GRU): net.register_forward_hook(gru_hook) # if isinstance(net, FSMNZQ): # net.register_forward_hook(fsmn_hook) # print('fsmn_hook_ready') return for c in childrens: foo(c) multiply_adds = False list_conv, list_bn, list_relu, list_linear, list_pooling, list_lstm, list_fsmn = [], [], [], [], [], [], [] foo(model) _ = model(input) total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum( list_lstm) + sum(list_fsmn)) fsmn_flops = (sum(list_fsmn) + sum(list_linear)) lstm_flops = sum(list_lstm) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('The network has {} params.'.format(params)) print(total_flops, fsmn_flops, lstm_flops) print(' + Number of FLOPs: %.2f M' % (total_flops / 1000 ** 2)) return total_flopsif __name__ == '__main__': from torchvision.models import resnet18 model = resnet18(num_classes=1000) imput_size = torch.rand((1,3,224,224)) calc_flops(model, imput_size)

本文链接地址:https://www.jiuchutong.com/zhishi/299382.html 转载请保留说明!

上一篇:c++STL急急急(c++stl详解)

下一篇:40个web前端实战项目,练完即可就业,从入门到进阶,基础到框架,html_css【附视频+源码】(web前端视频教程全套)

  • 沃租乐l稻壳怎么解除(沃租乐l稻壳怎么解除它订单)

    沃租乐l稻壳怎么解除(沃租乐l稻壳怎么解除它订单)

  • 微博黑名单在哪里看(微博黑名单在哪里)

    微博黑名单在哪里看(微博黑名单在哪里)

  • 0xc0000034不能开机(0xc0000034不能开机win10)

    0xc0000034不能开机(0xc0000034不能开机win10)

  • 苹果耳机三代充电线怎么用(苹果耳机三代充电仓丢了可以配的到吗)

    苹果耳机三代充电线怎么用(苹果耳机三代充电仓丢了可以配的到吗)

  • 荣耀9x是几g手机(荣耀9x是4g手机还是5g手机)

    荣耀9x是几g手机(荣耀9x是4g手机还是5g手机)

  • 华为基带是自己的吗(华为基带是什么牌子)

    华为基带是自己的吗(华为基带是什么牌子)

  • 拦截电话怎么恢复正常(拦截电话怎么恢复正常安卓)

    拦截电话怎么恢复正常(拦截电话怎么恢复正常安卓)

  • 淘宝宝贝下架后再上架有影响吗(淘宝宝贝下架后销量还在吗)

    淘宝宝贝下架后再上架有影响吗(淘宝宝贝下架后销量还在吗)

  • 拼多多的618是什么意思(拼多多618商品是正品吗)

    拼多多的618是什么意思(拼多多618商品是正品吗)

  • QQ关联对方为什么都是撤回(qq关联为什么对方能看到我的我看不到他的)

    QQ关联对方为什么都是撤回(qq关联为什么对方能看到我的我看不到他的)

  • 苹果手机换了id怎么还是以前的ID?(苹果手机换了id照片怎么恢复)

    苹果手机换了id怎么还是以前的ID?(苹果手机换了id照片怎么恢复)

  • 公众号白名单什么意思(公众号白名单什么时候可以转载)

    公众号白名单什么意思(公众号白名单什么时候可以转载)

  • 高德地图收藏在哪里找(高德地图收藏在哪?)

    高德地图收藏在哪里找(高德地图收藏在哪?)

  • 在word2010中无法实现的操作是什么(在world2010中无法实现的操作是)

    在word2010中无法实现的操作是什么(在world2010中无法实现的操作是)

  • vivo如何将软件移到内存卡(vivo如何将软件分享给微信好友)

    vivo如何将软件移到内存卡(vivo如何将软件分享给微信好友)

  • 如何找回手机wps文档(如何找回手机删除的照片和视频)

    如何找回手机wps文档(如何找回手机删除的照片和视频)

  • 注册qq的邀请码怎么弄(注册qq的邀请码怎么填)

    注册qq的邀请码怎么弄(注册qq的邀请码怎么填)

  • 苹果x怎么省电设置方法(苹果x怎么省电不发热)

    苹果x怎么省电设置方法(苹果x怎么省电不发热)

  • 原深感摄像头出现问题(原深感摄像头出现问题面容id不可用)

    原深感摄像头出现问题(原深感摄像头出现问题面容id不可用)

  • 魅族新智能冻结3.0s干嘛的(魅族智能冻结3.0怎样关闭)

    魅族新智能冻结3.0s干嘛的(魅族智能冻结3.0怎样关闭)

  • win7停留在启动管理器(win7停留在启动管理器进不去,按F8黑屏)

    win7停留在启动管理器(win7停留在启动管理器进不去,按F8黑屏)

  • 存储卡a1和a2的区别是什么(存储卡a2什么意思)

    存储卡a1和a2的区别是什么(存储卡a2什么意思)

  • 苹果自带录屏在哪(苹果自带录屏在录制和平精英是怎么隐藏键位)

    苹果自带录屏在哪(苹果自带录屏在录制和平精英是怎么隐藏键位)

  • 华为畅享9s电池容量(华为畅享9s电池能用多久)

    华为畅享9s电池容量(华为畅享9s电池能用多久)

  • 电脑打开蓝屏英文字母(电脑出现蓝屏英文选项要怎么处理)

    电脑打开蓝屏英文字母(电脑出现蓝屏英文选项要怎么处理)

  • PHPCMS v9 怎么换模板?(phpcms v9安装教程)

    PHPCMS v9 怎么换模板?(phpcms v9安装教程)

  • 个税一般劳务报酬所得如何申报
  • 车船税是每个月交还是每年交
  • 税负率的计算方法有哪些
  • 首套房契税税率是多少?
  • 领的增值税专用发票如何录入电脑
  • 三免三减半如何申报企业所得税
  • 啤酒消费税在那里征收
  • 计提职工非货币福利怎么算
  • 收到服务费发票摘要怎么写
  • 消费取得普通发票怎么开
  • 公司注销固定资产清理需要开票吗
  • 营改增之后还有营业税吗
  • 股东分红个人所得税怎么申报
  • 预缴税款的税率
  • 注册资本越大越有实力
  • 服务型企业管理体系
  • 资产处置损益在企业所得税汇算时如何填列
  • 完税凭证号是几位数
  • 新办企业的开办费用应计入( )
  • 个人买卖二手房增值税
  • 支付境外咨询费代扣代缴增值税
  • 收到的支票背书怎么写
  • 买办公用品花了100元如何做会计分录
  • 进项发票没有收到,销项已开出,成本如何结转
  • 固定资产已入库款项已付次月开发票何时记提折旧
  • 外国企业代表处企业所得税
  • 库存商品转结
  • 汇算清缴补交的所得税会计分录
  • 个人股权转让如何申报个人所得税
  • win10永久激活码神key一周内
  • 8款应用
  • win10高级功能
  • ctl.start
  • 交易性金融资产包括哪些项目
  • 最小的成像传感器
  • 失控发票成本转出怎么做账
  • php高级程序招聘
  • 消费税购置税价格一样
  • dmsetup remove_all 会不会清掉数据
  • tsar命令 收集服务器系统信息
  • mac m1 安装windows
  • 应付职工薪酬如何确认
  • 非金融企业之间借款
  • 什么是关联企业?关联企业之间业务往来
  • 不确认递延所得税资产的特殊情况举例
  • 非应税项目是有哪些项目
  • mysql客户端程序的功能是什么
  • macos添加用户
  • 出现什么情况企业不能持续经营
  • 一般纳税人销售旧货税率
  • 职工薪酬实际发生额忘记填会有风险吗
  • 先前收取的包装费用
  • 个贷系统平账专户怎么做账
  • 增值税加计抵减企业所得税如何处理
  • 计提工资的核算流程
  • 支付货款订金入什么科目
  • 清理固定资产是什么意思
  • 合并报表的内部投资抵消
  • 如何办理公司注册地址变更
  • 营业外支出增加说明了什么
  • 企业清算的顺序
  • 深入浅出意思
  • 备份还原工具怎么用
  • 创建mysql数据库指定字符集
  • 目前默认系统%1
  • windows server 2016 百度网盘下载
  • Linux(CentOS)用split命令分割文件的方法
  • winadserv.exe - winadserv是什么进程
  • win8.1流畅吗
  • unityrpg
  • linux如何启动tomcat
  • importem
  • python%i
  • win7支持快速启动吗
  • 详解金球奖之争
  • ug编程代码意思
  • python traits
  • 电子发票怎么汇总清卡
  • 浙江公务员冬令时上班时间
  • 深圳国税网上申报流程图
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设