位置: IT常识 - 正文

网络模型的参数量和FLOPs的计算 Pytorch(网络模型参数方法)

编辑:rootadmin
网络模型的参数量和FLOPs的计算 Pytorch

目录

1、torchstat 

2、thop

3、fvcore 

4、flops_counter

5、自定义统计函数


推荐整理分享网络模型的参数量和FLOPs的计算 Pytorch(网络模型参数方法),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:网络模型的参数量是一层不变的吗,网络模型的参数设置,网络模型的参数是什么,网络模型参数量如何计算,网络模型的参数量,网络模型的参数是什么,网络模型的参数量是一层不变的吗,网络模型的参数有哪些,内容如对您有帮助,希望把文章链接给更多的朋友!

FLOPS和FLOPs的区别:

FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

在介绍torchstat包和thop包之前,先总结一下:

torchstat包可以统计卷积神经网络和全连接神经网络的参数和计算量。thop包可以统计统计卷积神经网络、全连接神经网络以及循环神经网络的参数和计算量,程序示例等详见下文。1、torchstat pip install torchstat -i https://pypi.tuna.tsinghua.edu.cn/simple

在实际操作中,我们可以调用torchstat包,帮助我们统计模型的parameters和FLOPs。如果不修改这个包里面的一些代码,那么这个包只适用于输入为3通道的图像的模型。

import torchimport torch.nn as nnfrom torchstat import statclass Simple(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False) self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return xmodel = Simple()stat(model, (3, 244, 244)) # 统计模型的参数量和FLOPs,(3,244,244)是输入图像的size

 如果把torchstat包中的一行程序进行一点点改动,那么这个包可以用来统计全连接神经网络的参数量和计算量。当然手动计算全连接神经网络的参数量和计算量也很快 =_= 。进入torchstat源代码之后,如下图所示,注释掉圈红的地方,就可以用torchstat包统计全连接神经网络的参数量和计算量了。

网络模型的参数量和FLOPs的计算 Pytorch(网络模型参数方法)

2、thoppip install thop -i https://pypi.tuna.tsinghua.edu.cn/simpleimport torchimport torch.nn as nnfrom thop import profileclass Simple(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 10) def forward(self, x): x = self.fc1(x) return xnet = Simple()input = torch.randn(1, 10) # batchsize=1, 输入向量长度为10macs, params = profile(net, inputs=(input, ))print(' FLOPs: ', macs*2) # 一般来讲,FLOPs是macs的两倍print('params: ', params)3、fvcore pip install fvcore -i https://pypi.tuna.tsinghua.edu.cn/simple

用它比较好

import torchfrom torchvision.models import resnet50from fvcore.nn import FlopCountAnalysis, parameter_count_table# 创建resnet50网络model = resnet50(num_classes=1000)# 创建输入网络的tensortensor = (torch.rand(1, 3, 224, 224),)# 分析FLOPsflops = FlopCountAnalysis(model, tensor)print("FLOPs: ", flops.total())# 分析parametersprint(parameter_count_table(model))

 终端输出结果如下,FLOPs为4089184256,模型参数数量约为25.6M(这里的参数数量和我自己计算的有些出入,主要是在BN模块中,这里只计算了beta和gamma两个训练参数,没有统计moving_mean和moving_var两个参数),具体可以看下我在官方提的issue。 通过终端打印的信息我们可以发现在计算FLOPs时并没有包含BN层,池化层还有普通的add操作(我发现计算FLOPs时并没有统一的规定,在github上看的计算FLOPs项目基本每个都不同,但计算出来的结果大同小异)。

注意:在使用fvcore模块计算模型的flops时,遇到了问题,记录一下解决方案。首先是在jit_analysis.py的589行出错。经过调试发现,op_counts.values()的类型是int32,但是计算要求的类型只能是int、float、np.float64和np.int64,因此需要手动进行强制转换。修改如下:

4、flops_counterpip install ptflops -i https://pypi.tuna.tsinghua.edu.cn/simple

用它也很好,结果和fvcore一样

from ptflops import get_model_complexity_infomacs, params = get_model_complexity_info(model, (112, 9, 9), as_strings=True, print_per_layer_stat=True, verbose=True)print('{:<30} {:<8}'.format('Computational complexity: ', macs))print('{:<30} {:<8}'.format('Number of parameters: ', params))

5、自定义统计函数import torchimport numpy as npdef calc_flops(model, input): def conv_hook(self, input, output): batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * ( 2 if multiply_adds else 1) bias_ops = 1 if self.bias is not None else 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_height * output_width list_conv.append(flops) def linear_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 num_steps = input[0].size(0) weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) bias_ops = self.bias.nelement() if self.bias is not None else 0 flops = batch_size * (weight_ops + bias_ops) flops *= num_steps list_linear.append(flops) def fsmn_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 weight_ops = self.filter.nelement() * (2 if multiply_adds else 1) num_steps = input[0].size(0) flops = num_steps * weight_ops flops *= batch_size list_fsmn.append(flops) def gru_cell(input_size, hidden_size, bias=True): total_ops = 0 # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ state_ops = (hidden_size + input_size) * hidden_size + hidden_size if bias: state_ops += hidden_size * 2 total_ops += state_ops * 2 # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ total_ops += (hidden_size + input_size) * hidden_size + hidden_size if bias: total_ops += hidden_size * 2 # r hadamard : r * (~) total_ops += hidden_size # h' = (1 - z) * n + z * h # hadamard hadamard add total_ops += hidden_size * 3 return total_ops def gru_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 if self.batch_first: batch_size = input[0].size(0) num_steps = input[0].size(1) else: batch_size = input[0].size(1) num_steps = input[0].size(0) total_ops = 0 bias = self.bias input_size = self.input_size hidden_size = self.hidden_size num_layers = self.num_layers total_ops = 0 total_ops += gru_cell(input_size, hidden_size, bias) for i in range(num_layers - 1): total_ops += gru_cell(hidden_size, hidden_size, bias) total_ops *= batch_size total_ops *= num_steps list_lstm.append(total_ops) def lstm_cell(input_size, hidden_size, bias): total_ops = 0 state_ops = (input_size + hidden_size) * hidden_size + hidden_size if bias: state_ops += hidden_size * 2 total_ops += state_ops * 4 total_ops += hidden_size * 3 total_ops += hidden_size return total_ops def lstm_hook(self, input, output): batch_size = input[0].size(0) if input[0].dim() == 2 else 1 if self.batch_first: batch_size = input[0].size(0) num_steps = input[0].size(1) else: batch_size = input[0].size(1) num_steps = input[0].size(0) total_ops = 0 bias = self.bias input_size = self.input_size hidden_size = self.hidden_size num_layers = self.num_layers total_ops = 0 total_ops += lstm_cell(input_size, hidden_size, bias) for i in range(num_layers - 1): total_ops += lstm_cell(hidden_size, hidden_size, bias) total_ops *= batch_size total_ops *= num_steps list_lstm.append(total_ops) def bn_hook(self, input, output): list_bn.append(input[0].nelement()) def relu_hook(self, input, output): list_relu.append(input[0].nelement()) def pooling_hook(self, input, output): batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_ops = self.kernel_size * self.kernel_size bias_ops = 0 params = output_channels * (kernel_ops + bias_ops) flops = batch_size * params * output_height * output_width list_pooling.append(flops) def foo(net): childrens = list(net.children()) if not childrens: print(net) if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): net.register_forward_hook(conv_hook) # print('conv_hook_ready') if isinstance(net, torch.nn.Linear): net.register_forward_hook(linear_hook) # print('linear_hook_ready') if isinstance(net, torch.nn.BatchNorm2d): net.register_forward_hook(bn_hook) # print('batch_norm_hook_ready') if isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): net.register_forward_hook(relu_hook) # print('relu_hook_ready') if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): net.register_forward_hook(pooling_hook) # print('pooling_hook_ready') if isinstance(net, torch.nn.LSTM): net.register_forward_hook(lstm_hook) # print('lstm_hook_ready') if isinstance(net, torch.nn.GRU): net.register_forward_hook(gru_hook) # if isinstance(net, FSMNZQ): # net.register_forward_hook(fsmn_hook) # print('fsmn_hook_ready') return for c in childrens: foo(c) multiply_adds = False list_conv, list_bn, list_relu, list_linear, list_pooling, list_lstm, list_fsmn = [], [], [], [], [], [], [] foo(model) _ = model(input) total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum( list_lstm) + sum(list_fsmn)) fsmn_flops = (sum(list_fsmn) + sum(list_linear)) lstm_flops = sum(list_lstm) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('The network has {} params.'.format(params)) print(total_flops, fsmn_flops, lstm_flops) print(' + Number of FLOPs: %.2f M' % (total_flops / 1000 ** 2)) return total_flopsif __name__ == '__main__': from torchvision.models import resnet18 model = resnet18(num_classes=1000) imput_size = torch.rand((1,3,224,224)) calc_flops(model, imput_size)

本文链接地址:https://www.jiuchutong.com/zhishi/299382.html 转载请保留说明!

上一篇:c++STL急急急(c++stl详解)

下一篇:40个web前端实战项目,练完即可就业,从入门到进阶,基础到框架,html_css【附视频+源码】(web前端视频教程全套)

  • 口罩面容id怎么设置(口罩face id)

    口罩面容id怎么设置(口罩face id)

  • 网易云一起听会显示在播放列表嘛(网易云一起听会被别人看见吗)

    网易云一起听会显示在播放列表嘛(网易云一起听会被别人看见吗)

  • 红米root权限怎么开启(红米root权限怎么开启MIUI13)

    红米root权限怎么开启(红米root权限怎么开启MIUI13)

  • 美团外卖怎么开会员(美团外卖怎么开通先吃后付)

    美团外卖怎么开会员(美团外卖怎么开通先吃后付)

  • 阵列cad快捷键(阵列cad快捷键2014)

    阵列cad快捷键(阵列cad快捷键2014)

  • qq语音麦克风白色是打开吗(qq麦克风是白色状态静音吗)

    qq语音麦克风白色是打开吗(qq麦克风是白色状态静音吗)

  • 耳机插在充电口无反应(耳机插在充电口没声音)

    耳机插在充电口无反应(耳机插在充电口没声音)

  • 路由器上光纤显示红色是什么意思(路由器上光纤显示红色怎么办)

    路由器上光纤显示红色是什么意思(路由器上光纤显示红色怎么办)

  • 电脑打不上数字怎么回事(为什么电脑打不上数字)

    电脑打不上数字怎么回事(为什么电脑打不上数字)

  • 联通物联卡lte怎么改成4g(联通物联卡lte怎么开热点)

    联通物联卡lte怎么改成4g(联通物联卡lte怎么开热点)

  • 内部服务器错误怎么办(内部服务器错误是网站问题还是电脑问题)

    内部服务器错误怎么办(内部服务器错误是网站问题还是电脑问题)

  • v1936a是什么手机型号(v2048a是什么手机)

    v1936a是什么手机型号(v2048a是什么手机)

  • iphone11合约机是什么意思(苹果11合约机能买吗)

    iphone11合约机是什么意思(苹果11合约机能买吗)

  • oppo手机录屏怎么录声音(oppo手机录屏怎么结束)

    oppo手机录屏怎么录声音(oppo手机录屏怎么结束)

  • 苹果手机怎样关闭浮点(苹果手机怎样关闭后运行程序功能)

    苹果手机怎样关闭浮点(苹果手机怎样关闭后运行程序功能)

  • 熊猫烧香是勒索病毒吗

    熊猫烧香是勒索病毒吗

  • ios13出了吗(苹果出13了吗)

    ios13出了吗(苹果出13了吗)

  • vivox27息屏时钟怎么关闭

    vivox27息屏时钟怎么关闭

  • 微博支付在哪里(微博的微博支付在哪)

    微博支付在哪里(微博的微博支付在哪)

  • 韩版s10和国行的区别(韩版三星s10跟国行区别很大么)

    韩版s10和国行的区别(韩版三星s10跟国行区别很大么)

  • 蚂蚁森林怎么加好友(蚂蚁森林怎么加不了好友)

    蚂蚁森林怎么加好友(蚂蚁森林怎么加不了好友)

  • 水印相机如何录制视频(水印相机如何录音)

    水印相机如何录制视频(水印相机如何录音)

  • jetcar.exe - jetcar是什么进程 有什么作用

    jetcar.exe - jetcar是什么进程 有什么作用

  • 瑞吉外卖项目:编辑员工信息与公共字段自动填充(瑞吉外卖项目简历)

    瑞吉外卖项目:编辑员工信息与公共字段自动填充(瑞吉外卖项目简历)

  • IP协议+以太网协议(ip和以太网的区别)

    IP协议+以太网协议(ip和以太网的区别)

  • 其他权益工具账务处理内容
  • 增值税发票是真发票,但是平台查验不到
  • 分派现金股利需要缴税吗
  • 所得税汇算清缴后发现有误怎么办
  • 现金股利和现金利润的区别
  • 政府补贴营业外收入所得税汇算清缴需要调增吗
  • 工程项目结算方式有哪几种
  • 销售折扣怎么开
  • 可供出售金融资产和交易性金融资产
  • 我国现行资源税的课税范围不包括
  • 增值税申报和开票不一致怎么做账
  • 税收返还怎么做会计分录
  • 汽车固定资产清理账务处理
  • 企业为职工社保补缴怎么办理
  • 一般纳税人进项税额转出会计分录
  • 财税[2012]15
  • 申报文件解密失败怎么办?
  • 2018年终奖个人所得税计算器公式
  • 单位代扣代缴个人社保
  • 保险公司赔偿计入营业外收入
  • 广告公司怎样
  • 个体户进项发票多开出发票少怎么办
  • 资产负债率70%说明长期偿债能力
  • 土地增值税清算后补缴税款如何帐务处理
  • 工程质保金扣除
  • 政府奖励如何记账
  • 应收账款坏账收回会计处理
  • 小规模增值税免税额
  • 10万以下免征增值税 文件
  • 车辆保险返点计算器
  • 个体工商户申报流程图
  • 机器用油怎么做成的
  • 美团代收是什么意思
  • 电脑游戏没法玩怎么办
  • 如何使用php
  • 成本核算的基本程序是什么
  • win11怎么清理电脑垃圾
  • PHP:pg_fetch_result()的用法_PostgreSQL函数
  • 处置抵债资产的增值税计入
  • 企业购入软件会计分录
  • 无法支付的货款如何处理
  • php 无限级分类
  • 圣米歇尔山法语介
  • 国内旅客运输服务发票
  • vue property decorator
  • 李牧其人
  • 模型参数是什么意思
  • 买配件组装成产品算生产吗
  • 餐饮店库存盘点表
  • 股东出资不足需要赔偿吗
  • 增值税申报系统登录密码
  • sql实例命名规则
  • 刷pos机的如何记会计分录
  • 收回已冲销的应收账款会计分录
  • 房产税中出租房产原值怎么算
  • sql server干嘛的
  • SQL server 2008中的数据库能否只包含数据文件
  • 新租赁准则承租人租金用什么科目
  • 购买方收到的违约金
  • 一般纳税人金税盘280怎么做账
  • 税控盘有什么作用
  • 原材料不足
  • 民间非营利组织包括哪些单位
  • 应收账款和预收账款都是企业的流动资产
  • 福利部门的福利有哪些
  • 商业企业购进商品的分录
  • win9怎么升级win10
  • win10预览体验三个选项
  • url是什么格式的文件怎么打开
  • winxp不能正常启动
  • jquery左滑切换
  • nodejs文件上传服务器
  • unity3D游戏开发
  • windows运行bat文件命令
  • expressjs中文
  • jquery div innerhtml
  • jquery教程chm
  • 成都网上税务局
  • 税务局追缴社保流程及办理期限
  • 从哪里可以免费听歌
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设