位置: IT常识 - 正文

pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

编辑:rootadmin
pytorch 笔记:torch.distributions 概率分布相关(更新中) 1 包介绍

推荐整理分享pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch1.5对应的torchvision,pytorch with torch.nograd,pytorch torch,pytorch torch.load,pytorch torchscript,pytorch torch,pytorch torchvision,pytorch torchvision,内容如对您有帮助,希望把文章链接给更多的朋友!

        torch.distributions包包含可参数化的概率分布和采样函数。 这允许构建用于优化的随机计算图和随机梯度估计器。

        不可能通过随机样本直接反向传播。 但是,有两种主要方法可以创建可以反向传播的代理函数。

这些是

评分函数估计量 score function estimato似然比估计量 likelihood ratio estimatorREINFORCE路径导数估计量 pathwise derivative estimator

REINFORCE 通常被视为强化学习中策略梯度方法的基础,

路径导数估计器常见于变分自编码器的重新参数化技巧中。

        虽然评分函数只需要样本 f(x)的值,但路径导数需要导数 f'(x)。、

1.1 REINFORCE

        我们以reinforce 为例:

        当概率密度函数关于其参数可微时,我们只需要 sample() 和 log_prob() 来实现 REINFORCE:

        

        其中θ是参数,α是学习率,r是奖励,是在状态s的时候,根据策略使用动作a的概率

        (这个也就是policy gradient)

强化学习笔记:Policy-based Approach_UQI-LIUWJ的博客-CSDN博客

         在实践中,我们会从网络的输出中采样一个动作,在一个环境中应用这个动作,然后使用 log_prob 构造一个等效的损失函数。

         对于分类策略,实现 REINFORCE 的代码如下:(这只是一个示意代码,跑不起来的)

probs = policy_network(state)#在状态state的时候,各个action的概率m = Categorical(probs)#分类概率action = m.sample()#采样一个actionnext_state, reward = env.step(action)#这里为了简化考虑,一个episode只有一个actionloss = -m.log_prob(action) * reward#m.log_prob(action) 就是 logp#reward就是前面的r#这里用负号是因为强化学习是梯度上升loss.backward()  2 包所涉及的类2.1 伯努利分布

torch.distributions.bernoulli.Bernoulli( probs=None, logits=None, validate_args=None)

        创建由 probs 或 logits(但不是两者同时)参数化的伯努利分布。

        样本是二进制的(0 或 1)。 它们取值 1 的概率为 p,取值 0 的概率为 1 - p。

2.1.1 参数probs (Number,Tensor) 采样概率logits (Number,Tensor) 采样的对数几率2.1.2 函数 & 属性sample()

采样,默认采样一个值

还可以按照shape 采样

entropy()

计算熵

enumerate_support()

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean

均值

probs, logits两个输入的参数param_shape

参数的形状

variance

方差

2.2 贝塔分布torch.distributions.beta.Beta( concentration1, concentration0, validate_args=None)

由concentration 1 (α)和concentration 0 (β)参数化的 Beta 分布。

 2.2.1 函数采样

默认是采样一个值,也可以设置采样的维数

entropy

计算熵

rsample(sample_shape)pytorch 笔记:torch.distributions 概率分布相关(更新中)(pytorch torch)

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

注:生成Beta分布的时候,两个参数必须至少有一个是Tensor,否则rsample效果失效

mean,variance

均值 & 方差

 2.3 Chi2 分布torch.distributions.chi2.Chi2( df, validate_args=None)

 它只有sample一个函数 

2.4 连续伯努利

参数和伯努利很类似

torch.distributions.continuous_bernoulli.ContinuousBernoulli( probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)

请注意,与伯努利不同,这里的“probs”不对应于伯努利的“probs”,这里的“logits”不对应于伯努利的“logits”,但由于与伯努利的相似性,使用了相同的名称。 

2.4.1 函数sample还是采样cdf

返回以 value 计算的累积概率密度函数。

icdf

返回以 value 计算的逆累积密度/质量函数。

entropy

还是计算熵

rsample

如果分布参数是批处理的,则生成一个 sample_shape 形状的重新参数化样本或 sample_shape 形状的重新参数化样本批次。

和前面Beta分布类似,只有创建时参数为Tensor,才会有rsample效果

mean,variance均值 方差 2.5 二项分布torch.distributions.binomial.Binomial( total_count=1, probs=None, logits=None, validate_args=None)

 

         创建由 total_count 和 probs 或 logits(但不是两者)参数化的二项分布。 total_count 必须可以用 probs/logits 广播。

2.5.1 函数&参数sample

采样

 

100被广播到0,0.2,0.8,1 所以每次相当于是四个二项分布

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

mean,variance

均值,方差

2.6  分类分布torch.distributions.categorical.Categorical( probs=None, logits=None, validate_args=None)

 样本是来{0,...,K−1} 的整数,其中 K 是 probs.size(-1)。

2.6.1 函数sample采样entropy

enumerate_support

返回包含离散分布支持的所有值的张量。 结果将在维度 0 上枚举

2.6.2 注意:

创建分类分布时候的Tensor中元素的和可以不是1,最后归一化到1即可

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])m.probs#tensor([0.1429, 0.2857, 0.5714])3 log_probs

很多分类都有这样一个函数log_probs,我们就统一说一下

假设m是一个torch的分类,那么m.log_prob(action)相当于

probs.log()[0][action.item()].unsqueeze(0)

(对这个action的概率添加log操作) 

import torchimport mathm=torch.distributions.Categorical(torch.Tensor([1,2,4]))m.enumerate_support()#tensor([0, 1, 2])a=m.sample()a#tensor(2)m.probs#tensor([0.1429, 0.2857, 0.5714])m.probs.log()#tensor([-1.9459, -1.2528, -0.5596])m.log_prob(a)#tensor(-0.5596)m.probs.log()[a.item()]#tensor(-0.5596)
本文链接地址:https://www.jiuchutong.com/zhishi/297674.html 转载请保留说明!

上一篇:Vue的安装及使用教程【超详细图文教程】(vue的安装步骤)

下一篇:Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了(segment anything model模型 需要的配置)

  • 工商年报已报网上还查不出来
  • 月末结存材料的实际成本例题
  • 材料按实际成本计价时发出成本的计算方法有
  • 银行本票与银行本票存款的区别
  • 计提工资附什么单据
  • 费用先付款后收到发票做账
  • 电子承兑到期怎么操作流程
  • 新政府会计制度下属于负债类科目的是
  • 物流公司主营业务范围
  • 委托加工物资属于在产品吗
  • 制造企业需要设哪些部门
  • 税务机关如何防范关联企业涉税风险问题
  • 跨年销售收入退回增值税处理
  • 工程结算审核程序
  • 所得税汇算清缴分录怎么做
  • 厂区折旧
  • 企业所得税费用税率
  • 小规模纳税人转成一般纳税人条件
  • 企业被列为风险纳税人税控开票会显示什么
  • 填报企业年报
  • 广告业公司成立时的资金如何记账?
  • win7网络无连接
  • 华为mate50耳机孔和充电口一样吗
  • 鸿蒙大文件夹怎么设置透明度
  • 电脑怎么备份系统win7
  • 流动资产和非流动资产占比多少合适
  • 高新企业研发费用占销售收入的比例
  • 应收账款转让分录
  • 公司买了一辆二手汽车,怎么入账
  • linux 引导
  • linux创建一个文件并写入内容
  • 社会保险费缓缴政策
  • php精度丢失
  • 营业利润期末余额怎么算
  • yii框架教程
  • 计提房产税会计分录怎么做账
  • 个人所得税完整证明
  • 成功解决用英语怎么说
  • 企业转让固定资产增值税税率
  • sqlserver阻止保存要求重新
  • MicrosoftSQLserver2014可以卸载吗
  • sql server 2008使用教程
  • 跨年费用账务处理
  • 收到证券公司信息
  • 应付账款暂估会计处理
  • 固定资产的计提折旧方法有哪些
  • 将本月发生的制造费用在甲、乙产品之间
  • 临时账户过期了怎么办
  • 固定资产改建支出的扣除规定
  • 财政登记证取消了吗
  • 金税盘可以申请发票吗
  • 小规模纳税人怎么开增值税专用发票
  • 贴发票要按时间顺序吗
  • 代理记账许可证查询
  • 明细账的作用
  • sql Server 触发器的when的用法
  • mysql8.0无法启动
  • windowsxp电脑开机
  • win8硬盘重装
  • window10节电模式怎么关闭
  • ubuntu server initramfs
  • linux设置用户名和密码
  • mac如何用u盘安装win10
  • win10总是弹窗广告
  • directx?
  • macbookair扫描文件怎么弄
  • 手机游戏开发工具app
  • 又拍云cdn配置
  • javascript如何学
  • 如何用dos修复引导
  • unity游戏开发的技术
  • 安卓绘制图表
  • New AssetBundle build system in Unity 5.0
  • jquery示例
  • 如何理解计算消费税时的(1
  • 上海增值税怎么报税流程
  • 南京税务局几点下班?
  • 个人绩效考核税务局
  • 发票真伪查询国税官网12366
  • 营业外收入缴纳哪些税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设