位置: IT常识 - 正文

深度学习中模型计算量(FLOPs)和参数量(Params)的理解以及四种计算方法总结

编辑:rootadmin
深度学习中模型计算量(FLOPs)和参数量(Params)的理解以及四种计算方法总结

推荐整理分享深度学习中模型计算量(FLOPs)和参数量(Params)的理解以及四种计算方法总结,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

接下来要分别概述以下内容:

1 首先什么是参数量,什么是计算量

2 如何计算 参数量,如何统计 计算量

3 换算参数量,把他换算成我们常用的单位,比如:mb

4 对于各个经典网络,论述他们是计算量大还是参数量,有什么好处

5 计算量,参数量分别对显存,芯片提出什么要求,我们又是怎么权衡   

深度学习中模型参数量和计算量的理解与计算1 首先什么是计算量,什么是参数量2 如何计算:参数量,计算量3 对于换算计算量4 对于各个经典网络:5 计算量与参数量对于硬件要求6 计算量(FLOPs)和参数量(Params)6.1 第一种方法:thop第一步:安装模块第二步:计算6.2 第二种方法:ptflops6.3 第三种方法:pytorch_model_summary6.4 第四种方法:参数总量和可训练参数总量7 输入数据对模型的参数量和计算量的影响参考资料1 首先什么是计算量,什么是参数量

计算量对应我们之前的时间复杂度,参数量对应于我们之前的空间复杂度,这么说就很明显了

也就是计算量要看网络执行时间的长短,参数量要看占用显存的量

2 如何计算:参数量,计算量

(1)针对于卷积层的   其中上面的公式是计算时间复杂度(计算量),而下面的公式是计算空间复杂度(参数量)

对于卷积层:

参数量就是

(kernel*kernel) *channel_input*channel_outputkernel*kernel 就是 weight * weight其中kernel*kernel = 1个feature的参数量

计算量就是

(kernel*kernel*map*map) *channel_input*channel_outputkernel*kernel 就是weight*weightmap*map是下个featuremap的大小,也就是上个weight*weight到底做了多少次运算其中kernel*kernel*map*map= 1个feature的计算量

(2)针对于池化层:

无参数

(3)针对于全连接层:

参数量=计算量=weight_in*weight_out3 对于换算计算量

一般一个参数是值一个float,也就是4个字节

1kb=1024字节

4 对于各个经典网络:

(1)换算

深度学习中模型计算量(FLOPs)和参数量(Params)的理解以及四种计算方法总结

以alexnet为例:

参数量:6000万

设每个参数都是float,也就是一个参数是4字节,

总的字节数是24000万字节

24000万字节= 24000万/1024/1024=228mb

(2)为什么模型之间差距这么大

这个关乎于模型的设计了,其中模型里面最费参数的就是全连接层,这个可以看alex和vgg,

alex,vgg有很多fc(全连接层)

resnet就一个fc

inceptionv1(googlenet)也是就一个fc

(3)计算量

densenet其实这个模型不大,也就是参数量不大,因为就1个fc

但是他的计算量确实很大,因为每一次都把上一个feature加进来,所以计算量真的很大

5 计算量与参数量对于硬件要求

计算量,参数量对于硬件的要求是不同的

计算量的要求是在于芯片的floaps(指的是gpu的运算能力)

参数量取决于显存大小

6 计算量(FLOPs)和参数量(Params)6.1 第一种方法:thop

计算量: FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。

参数量: Params,是指网络模型中需要训练的参数总数。

第一步:安装模块pip install thop第二步:计算# -- coding: utf-8 --import torchimport torchvisionfrom thop import profile# Modelprint('==> Building model..')model = torchvision.models.alexnet(pretrained=False)dummy_input = torch.randn(1, 3, 224, 224)flops, params = profile(model, (dummy_input,))print('flops: ', flops, 'params: ', params)print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))

结果

==> Building model..[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.[WARN] Cannot find rule for <class 'torchvision.models.alexnet.AlexNet'>. Treat it as zero Macs and zero Params.flops: 714691904.0 params: 61100840.0flops: 714.69 M, params: 61.10 M

注意:

输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数profile(net, (inputs,))的 (inputs,)中必须加上逗号,否者会报错6.2 第二种方法:ptflops# -- coding: utf-8 --import torchvisionfrom ptflops import get_model_complexity_infomodel = torchvision.models.alexnet(pretrained=False)flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)print('flops: ', flops, 'params: ', params)

结果

AlexNet( 61.101 M, 100.000% Params, 0.716 GMac, 100.000% MACs, (features): Sequential( 2.47 M, 4.042% Params, 0.657 GMac, 91.804% MACs, (0): Conv2d(0.023 M, 0.038% Params, 0.07 GMac, 9.848% MACs, 3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) (1): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.027% MACs, inplace=True) (2): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.027% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(0.307 M, 0.503% Params, 0.224 GMac, 31.316% MACs, 64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (4): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.020% MACs, inplace=True) (5): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.020% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (6): Conv2d(0.664 M, 1.087% Params, 0.112 GMac, 15.681% MACs, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (7): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.009% MACs, inplace=True) (8): Conv2d(0.885 M, 1.448% Params, 0.15 GMac, 20.902% MACs, 384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.006% MACs, inplace=True) (10): Conv2d(0.59 M, 0.966% Params, 0.1 GMac, 13.936% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.006% MACs, inplace=True) (12): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.006% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) ) (avgpool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.001% MACs, output_size=(6, 6)) (classifier): Sequential( 58.631 M, 95.958% Params, 0.059 GMac, 8.195% MACs, (0): Dropout(0.0 M, 0.000% Params, 0.0 GMac, 0.000% MACs, p=0.5, inplace=False) (1): Linear(37.753 M, 61.788% Params, 0.038 GMac, 5.276% MACs, in_features=9216, out_features=4096, bias=True) (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.001% MACs, inplace=True) (3): Dropout(0.0 M, 0.000% Params, 0.0 GMac, 0.000% MACs, p=0.5, inplace=False) (4): Linear(16.781 M, 27.465% Params, 0.017 GMac, 2.345% MACs, in_features=4096, out_features=4096, bias=True) (5): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.001% MACs, inplace=True) (6): Linear(4.097 M, 6.705% Params, 0.004 GMac, 0.573% MACs, in_features=4096, out_features=1000, bias=True) ))flops: 0.72 GMac params: 61.1 M6.3 第三种方法:pytorch_model_summaryimport torchimport torchvisionfrom pytorch_model_summary import summary# Modelprint('==> Building model..')model = torchvision.models.alexnet(pretrained=False)dummy_input = torch.randn(1, 3, 224, 224)print(summary(model, dummy_input, show_input=False, show_hierarchical=False))

结果

==> Building model..----------------------------------------------------------------------------- Layer (type) Output Shape Param # Tr. Param #============================================================================= Conv2d-1 [1, 64, 55, 55] 23,296 23,296 ReLU-2 [1, 64, 55, 55] 0 0 MaxPool2d-3 [1, 64, 27, 27] 0 0 Conv2d-4 [1, 192, 27, 27] 307,392 307,392 ReLU-5 [1, 192, 27, 27] 0 0 MaxPool2d-6 [1, 192, 13, 13] 0 0 Conv2d-7 [1, 384, 13, 13] 663,936 663,936 ReLU-8 [1, 384, 13, 13] 0 0 Conv2d-9 [1, 256, 13, 13] 884,992 884,992 ReLU-10 [1, 256, 13, 13] 0 0 Conv2d-11 [1, 256, 13, 13] 590,080 590,080 ReLU-12 [1, 256, 13, 13] 0 0 MaxPool2d-13 [1, 256, 6, 6] 0 0 AdaptiveAvgPool2d-14 [1, 256, 6, 6] 0 0 Dropout-15 [1, 9216] 0 0 Linear-16 [1, 4096] 37,752,832 37,752,832 ReLU-17 [1, 4096] 0 0 Dropout-18 [1, 4096] 0 0 Linear-19 [1, 4096] 16,781,312 16,781,312 ReLU-20 [1, 4096] 0 0 Linear-21 [1, 1000] 4,097,000 4,097,000=============================================================================Total params: 61,100,840Trainable params: 61,100,840Non-trainable params: 0-----------------------------------------------------------------------------6.4 第四种方法:参数总量和可训练参数总量import torchimport torchvisionfrom pytorch_model_summary import summary# Modelprint('==> Building model..')model = torchvision.models.alexnet(pretrained=False)pytorch_total_params = sum(p.numel() for p in model.parameters())trainable_pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)print('Total - ', pytorch_total_params)print('Trainable - ', trainable_pytorch_total_params)

结果

==> Building model..Total - 61100840Trainable - 611008407 输入数据对模型的参数量和计算量的影响# -- coding: utf-8 --import torchimport torchvisionfrom thop import profile# Modelprint('==> Building model..')model = torchvision.models.alexnet(pretrained=False)dummy_input = torch.randn(1, 3, 224, 224)flops, params = profile(model, (dummy_input,))print('flops: ', flops, 'params: ', params)print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))输入数据:(1, 3, 224, 224),一张224*224的RGB图像flops: 714691904.0 params: 61100840.0flops: 714.69 M, params: 61.10 M输入数据:(1, 3, 512, 512),一张512*512的RGB图像flops: 3710034752.0 params: 61100840.0flops: 3710.03 M params: 61.10 M输入数据:(8, 3, 224, 224),八张224*224的RGB图像flops: 5717535232.0 params: 61100840.0flops: 5717.54 M params: 61.10 M输入数据计算量(flops)参数量(params)(1, 3, 224, 224)714.69 M61.10 M(1, 3, 512, 512)3710.03 M61.10 M(8, 3, 224, 224)5717.54 M61.10 M参考资料https://www.cnblogs.com/lllcccddd/p/10671879.htmlhttps://blog.csdn.net/Caesar6666/article/details/109842379
本文链接地址:https://www.jiuchutong.com/zhishi/282703.html 转载请保留说明!

上一篇:WIN10电脑C盘满了如何清理空间(win10电脑c盘满了怎么转移到d盘)

下一篇:imonnt.exe进程是什么文件 是病毒程序吗 imonnt进程查询(进程mmc.exe)

  • qq怎么设置好友互动标识(qq怎么设置好友生日提醒)

    qq怎么设置好友互动标识(qq怎么设置好友生日提醒)

  • 抖音可以删掉共同关系吗(抖音共获得赞怎么删)

    抖音可以删掉共同关系吗(抖音共获得赞怎么删)

  • 华为p40pro后盖是什么材质(华为p40pro后盖是否原装)

    华为p40pro后盖是什么材质(华为p40pro后盖是否原装)

  • 微信被限制收款是为什么(微信被限制收款功能怎么解除)

    微信被限制收款是为什么(微信被限制收款功能怎么解除)

  • 可以打电话的平板有哪些(可以打电话的平板多少钱)

    可以打电话的平板有哪些(可以打电话的平板多少钱)

  • 快手怎么删除已购订单(快手怎么删除已发视频)

    快手怎么删除已购订单(快手怎么删除已发视频)

  • 微信视频开美颜怎么开(微信视频开美颜对方看到的是美颜效果吗)

    微信视频开美颜怎么开(微信视频开美颜对方看到的是美颜效果吗)

  • 小米笔记本连wifi不能上网(小米笔记本连wifi总显示不能用)

    小米笔记本连wifi不能上网(小米笔记本连wifi总显示不能用)

  • 华为专用铃声叫什么歌(华为专用铃声叫什么歌中文)

    华为专用铃声叫什么歌(华为专用铃声叫什么歌中文)

  • 华为数据线通用吗(华为的数据线通用吗)

    华为数据线通用吗(华为的数据线通用吗)

  • 哔哩哔哩怎么升级lv1(哔哩哔哩怎么升级lv2)

    哔哩哔哩怎么升级lv1(哔哩哔哩怎么升级lv2)

  • qq达人显示0天是隐身吗(qq达人一直是0天)

    qq达人显示0天是隐身吗(qq达人一直是0天)

  • aqmtl00是什么型号手机

    aqmtl00是什么型号手机

  • 数据库安全包括哪两个方面(数据库安全包括两方面问题,即)

    数据库安全包括哪两个方面(数据库安全包括两方面问题,即)

  • 苹果平板可以用普通耳机吗(苹果平板可以用鼠标吗)

    苹果平板可以用普通耳机吗(苹果平板可以用鼠标吗)

  • 如何解除手机通话限制(如何解除手机通话录音)

    如何解除手机通话限制(如何解除手机通话录音)

  • 华为mate20怎么切后台(华为mate20怎么切换桌面滑动效果)

    华为mate20怎么切后台(华为mate20怎么切换桌面滑动效果)

  • 拼多多实名认证在哪看(拼多多实名认证可以认证几个号)

    拼多多实名认证在哪看(拼多多实名认证可以认证几个号)

  • 苹果13.1.2信号不好(苹果13信号不好?)

    苹果13.1.2信号不好(苹果13信号不好?)

  • 苹果11屏幕敲起来为什么像空的(苹果11屏幕敲起来空空的)

    苹果11屏幕敲起来为什么像空的(苹果11屏幕敲起来空空的)

  • 笔记本电脑睡眠后黑屏打不开(笔记本电脑睡眠不了是什么原因)

    笔记本电脑睡眠后黑屏打不开(笔记本电脑睡眠不了是什么原因)

  • qq怎么扫一扫一百元(qq扫一扫在哪里扫)

    qq怎么扫一扫一百元(qq扫一扫在哪里扫)

  • 手动双面打印怎么放纸(手动双面打印怎么操作)

    手动双面打印怎么放纸(手动双面打印怎么操作)

  • 网易云拉黑能私信么(网易云拉黑对方能发消息吗)

    网易云拉黑能私信么(网易云拉黑对方能发消息吗)

  • qq安全扫描失败禁止下载该文件(qq安全扫描失败无法下载怎么解决2022)

    qq安全扫描失败禁止下载该文件(qq安全扫描失败无法下载怎么解决2022)

  • 红掌的养殖方法和注意事项(图文)(红掌的养殖方法和注意事项)

    红掌的养殖方法和注意事项(图文)(红掌的养殖方法和注意事项)

  • 一般纳税人抵扣小规模期间的专票怎么解决
  • 已认证的专票发票在哪里
  • 固定资产报废电脑
  • 退休职工怎么填写单位吗
  • 收回个人社会保险费是否可以冲红管理费用
  • 公司现金支付管理办法
  • 企业前期开办费没有发票怎么入账
  • 开具发票时如何选择对应的商品分类编码?
  • 价外补贴需要交增值税吗
  • 房地产企业销售额排名
  • 海关免税设备清单
  • 国有独资企业董事会应当在每年
  • 销售产生的增值税
  • 员工领取产假工资怎么算
  • 付款小于发票金额的原因
  • 地税三方协议是什么意思
  • 农产品收购发票使用范围
  • 防伪税控技术维护费是进项还是销项
  • 房租怎么开票
  • 个人可以到税务局来取消办税人员信息吗?
  • 全资子公司合并抵消
  • 作废发票如何管理
  • 金蝶kis迷你版操作手册
  • 小规模纳税人应交增值税科目设置
  • 自产农产品销售怎么做账
  • 两处拿工资的缴税问题
  • 因为质量问题
  • 允许扣除的土地价款怎么计算例题
  • 去税局代开开专用发票需要带什么证件?
  • 应收账款核销如何做账
  • 收到销售折让销售怎么做
  • 住房补贴缴纳比例是多少
  • php常见面试问题
  • typecho插件开发教程
  • php 文件操作
  • php魔术方法的讲解与使用
  • 公司试乘试驾车管理
  • 政策性搬迁的会计处理
  • 关于已开发票收到部分款项风险温馨提示
  • php readdir函数
  • 最高像素的镜头是多少
  • 交纳印花税
  • 灰狼算法的改进
  • 印度泰姬陵建筑
  • 联邦学习攻击与防御综述
  • php第三方支付
  • php对象是值传递还是引用传递
  • 每年结息一次,到期一次还本是单利
  • 如何用织梦在本地搭建网站
  • 项目资金支付
  • mysql5.7压缩包安装配置教程
  • python 添加列表
  • 残疾人就业保障金是什么意思啊
  • 税法中对差旅费的处理
  • 现金日记账里
  • 购入不需要安装的固定资产会计科目
  • 收到对公打款认证怎么入账
  • 怎么理解什么是生命
  • 企业实收资本怎么计算
  • 计提坏账准备需要哪些资料
  • sqlserver2005附加数据库错误1827
  • win8系统之家官网
  • vista win
  • 苹果mac怎样
  • administrator帐户已锁定
  • jQuery实现别踩白块儿网页版小游戏
  • linux tcp keepalive
  • dos命令/s
  • 分享一下什么
  • vue框架写淘宝购物车
  • linux shell 中 2>&1的含义
  • 批处理实现语音报警
  • js按下回车键时提交
  • android xmlns
  • 湖南省电子国税
  • 重庆税筹公司
  • 欧美 房产税
  • 内蒙古国家税务总局电子税务局官网
  • 怎么查询小米手机位置
  • 2021北京餐饮业发展趋势报告
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设