位置: IT常识 - 正文

深度强化学习-DQN算法原理与代码

编辑:rootadmin
深度强化学习-DQN算法原理与代码

推荐整理分享深度强化学习-DQN算法原理与代码,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

DQN算法是DeepMind团队提出的一种深度强化学习算法,在许多电动游戏中达到人类玩家甚至超越人类玩家的水准,本文就带领大家了解一下这个算法,论文和代码的链接见下方。

论文:Human-level control through deep reinforcement learning | Nature

代码:https://github.com/indigoLovee/DQN

喜欢的话可以点个star呢。

1 DQN算法简介

Q-learning算法采用一个Q-tabel来记录每个状态下的动作值,当状态空间或动作空间较大时,需要的存储空间也会较大。如果状态空间或动作空间连续,则该算法无法使用。因此,Q-learning算法只能用于解决离散低维状态空间和动作空间类问题。DQN算法的核心就是用一个人工神经网络来代替Q-tabel,即动作价值函数。网络的输入为状态信息,输出为每个动作的价值,因此DQN算法可以用来解决连续状态空间和离散动作空间问题,无法解决连续动作空间类问题。针对连续动作空间类问题,后面blog会慢慢介绍。

2 DQN算法原理

DQN算法是一种off-policy算法,当同时出现异策、自益和函数近似时,无法保证收敛性,容易出现训练不稳定或训练困难等问题。针对这些问题,研究人员主要从以下两个方面进行了改进。

(1)经验回放:将经验(当前状态、动作、即时奖励、下个状态、回合状态)存放在经验池中,并按照一定的规则采样。

(2)目标网络:修改网络的更新方式,例如不把刚学习到的网络权重马上用于后续的自益过程。

2.1 经验回放

经验回放就是一种让经验概率分布变得稳定的技术,可以提高训练的稳定性。经验回放主要有“存储”和“回放”两大关键步骤:

存储:将经验以形式存储在经验池中。

回放:按照某种规则从经验池中采样一条或多条经验数据。

从存储的角度来看,经验回放可以分为集中式回放和分布式回放:

集中式回放:智能体在一个环境中运行,把经验统一存储在经验池中。

分布式回放:多个智能体同时在多个环境中运行,并将经验统一存储在经验池中。由于多个智能体同时生成经验,所以能够使用更多资源的同时更快地收集经验。

从采样的角度来看,经验回放可以分为均匀回放和优先回放:

均匀回放:等概率从经验池中采样经验。

优先回放:为经验池中每条经验指定一个优先级,在采样经验时更倾向于选择优先级更高的经验。一般的做法是,如果某条经验(例如经验)的优先级为,那么选取该经验的概率为:

优先回放可以具体参照这篇论文:优先经验回放

深度强化学习-DQN算法原理与代码

经验回放的优点:

1.在训练Q网络时,可以打破数据之间的相关性,使得数据满足独立同分布,从而减小参数更新的方差,提高收敛速度。

2.能够重复使用经验,数据利用率高,对于数据获取困难的情况尤其有用。

经验回放的缺点:

无法应用于回合更新和多步学习算法。但是将经验回放应用于Q学习,就规避了这个缺点。

代码中采用集中式均匀回放,具体如下:

import numpy as npclass ReplayBuffer: def __init__(self, state_dim, action_dim, max_size, batch_size): self.mem_size = max_size self.batch_size = batch_size self.mem_cnt = 0 self.state_memory = np.zeros((self.mem_size, state_dim)) self.action_memory = np.zeros((self.mem_size, )) self.reward_memory = np.zeros((self.mem_size, )) self.next_state_memory = np.zeros((self.mem_size, state_dim)) self.terminal_memory = np.zeros((self.mem_size, ), dtype=np.bool) def store_transition(self, state, action, reward, state_, done): mem_idx = self.mem_cnt % self.mem_size self.state_memory[mem_idx] = state self.action_memory[mem_idx] = action self.reward_memory[mem_idx] = reward self.next_state_memory[mem_idx] = state_ self.terminal_memory[mem_idx] = done self.mem_cnt += 1 def sample_buffer(self): mem_len = min(self.mem_size, self.mem_cnt) batch = np.random.choice(mem_len, self.batch_size, replace=True) states = self.state_memory[batch] actions = self.action_memory[batch] rewards = self.reward_memory[batch] states_ = self.next_state_memory[batch] terminals = self.terminal_memory[batch] return states, actions, rewards, states_, terminals def ready(self): return self.mem_cnt > self.batch_size2.2 目标网络

对于基于自益的Q学习,动作价值估计和权重有关。当权重变化时,动作价值的估计也会发生变化。在学习的过程中,动作价值试图追逐一个变化的回报,容易出现不稳定的情况。

目标网络是在原有的神经网络之外重新搭建一个结构完全相同的网络。原先的网络称为评估网络,新构建的网络称为目标网络。在学习过程中,使用目标网络进行自益得到回报的评估值,作为学习目标。在更新过程中,只更新评估网络的权重,而不更新目标网络的权重。这样,更新权重时针对的目标不会在每次迭代都发生变化,是一个固定的目标。在更新一定次数后,再将评估网络的权重复制给目标网络,进而进行下一批更新,这样目标网络也能得到更新。由于在目标网络没有变化的一段时间内回报的估计是相对固定的,因此目标网络的引入增加了学习的稳定性。

目标网络的更新方式:

上述在一段时间内固定目标网络,一定次数后将评估网络权重复制给目标网络的更新方式为硬更新(hard update),即

其中表示目标网络权重,表示评估网络权重。

另外一种常用的更新方式为软更新(soft update),即引入一个学习率,将旧的目标网络参数和新的评估网络参数直接做加权平均后的值赋值给目标网络

学习率

3 DQN算法伪代码

DQN算法的实现代码为:

import torch as Timport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Fimport numpy as npfrom buffer import ReplayBufferdevice = T.device("cuda:0" if T.cuda.is_available() else "cpu")class DeepQNetwork(nn.Module): def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim): super(DeepQNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, fc1_dim) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.q = nn.Linear(fc2_dim, action_dim) self.optimizer = optim.Adam(self.parameters(), lr=alpha) self.to(device) def forward(self, state): x = T.relu(self.fc1(state)) x = T.relu(self.fc2(x)) q = self.q(x) return q def save_checkpoint(self, checkpoint_file): T.save(self.state_dict(), checkpoint_file, _use_new_zipfile_serialization=False) def load_checkpoint(self, checkpoint_file): self.load_state_dict(T.load(checkpoint_file))class DQN: def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim, ckpt_dir, gamma=0.99, tau=0.005, epsilon=1.0, eps_end=0.01, eps_dec=5e-4, max_size=1000000, batch_size=256): self.tau = tau self.gamma = gamma self.epsilon = epsilon self.eps_min = eps_end self.eps_dec = eps_dec self.batch_size = batch_size self.action_space = [i for i in range(action_dim)] self.checkpoint_dir = ckpt_dir self.q_eval = DeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim, fc1_dim=fc1_dim, fc2_dim=fc2_dim) self.q_target = DeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim, fc1_dim=fc1_dim, fc2_dim=fc2_dim) self.memory = ReplayBuffer(state_dim=state_dim, action_dim=action_dim, max_size=max_size, batch_size=batch_size) self.update_network_parameters(tau=1.0) def update_network_parameters(self, tau=None): if tau is None: tau = self.tau for q_target_params, q_eval_params in zip(self.q_target.parameters(), self.q_eval.parameters()): q_target_params.data.copy_(tau * q_eval_params + (1 - tau) * q_target_params) def remember(self, state, action, reward, state_, done): self.memory.store_transition(state, action, reward, state_, done) def choose_action(self, observation, isTrain=True): state = T.tensor([observation], dtype=T.float).to(device) actions = self.q_eval.forward(state) action = T.argmax(actions).item() if (np.random.random() < self.epsilon) and isTrain: action = np.random.choice(self.action_space) return action def learn(self): if not self.memory.ready(): return states, actions, rewards, next_states, terminals = self.memory.sample_buffer() batch_idx = np.arange(self.batch_size) states_tensor = T.tensor(states, dtype=T.float).to(device) rewards_tensor = T.tensor(rewards, dtype=T.float).to(device) next_states_tensor = T.tensor(next_states, dtype=T.float).to(device) terminals_tensor = T.tensor(terminals).to(device) with T.no_grad(): q_ = self.q_target.forward(next_states_tensor) q_[terminals_tensor] = 0.0 target = rewards_tensor + self.gamma * T.max(q_, dim=-1)[0] q = self.q_eval.forward(states_tensor)[batch_idx, actions] loss = F.mse_loss(q, target.detach()) self.q_eval.optimizer.zero_grad() loss.backward() self.q_eval.optimizer.step() self.update_network_parameters() self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min def save_models(self, episode): self.q_eval.save_checkpoint(self.checkpoint_dir + 'Q_eval/DQN_q_eval_{}.pth'.format(episode)) print('Saving Q_eval network successfully!') self.q_target.save_checkpoint(self.checkpoint_dir + 'Q_target/DQN_Q_target_{}.pth'.format(episode)) print('Saving Q_target network successfully!') def load_models(self, episode): self.q_eval.load_checkpoint(self.checkpoint_dir + 'Q_eval/DQN_q_eval_{}.pth'.format(episode)) print('Loading Q_eval network successfully!') self.q_target.load_checkpoint(self.checkpoint_dir + 'Q_target/DQN_Q_target_{}.pth'.format(episode)) print('Loading Q_target network successfully!')

算法仿真环境是在gym库中的LunarLander-v2环境,因此需要先配置好gym库。进入Aanconda中对应的Python环境中,执行下面的指令

pip install gym

但是,这样安装的gym库只包括少量的内置环境,如算法环境、简单文字游戏环境和经典控制环境,无法使用LunarLander-v2。因此还要安装一些其他依赖项,具体可以参照这篇blog:AttributeError: module ‘gym.envs.box2d‘ has no attribute ‘LunarLander‘ 解决办法

训练脚本如下:

import gymimport numpy as npimport argparsefrom DQN import DQNfrom utils import plot_learning_curve, create_directoryparser = argparse.ArgumentParser()parser.add_argument('--max_episodes', type=int, default=500)parser.add_argument('--ckpt_dir', type=str, default='./checkpoints/DQN/')parser.add_argument('--reward_path', type=str, default='./output_images/avg_reward.png')parser.add_argument('--epsilon_path', type=str, default='./output_images/epsilon.png')args = parser.parse_args()def main(): env = gym.make('LunarLander-v2') agent = DQN(alpha=0.0003, state_dim=env.observation_space.shape[0], action_dim=env.action_space.n, fc1_dim=256, fc2_dim=256, ckpt_dir=args.ckpt_dir, gamma=0.99, tau=0.005, epsilon=1.0, eps_end=0.05, eps_dec=5e-4, max_size=1000000, batch_size=256) create_directory(args.ckpt_dir, sub_dirs=['Q_eval', 'Q_target']) total_rewards, avg_rewards, eps_history = [], [], [] for episode in range(args.max_episodes): total_reward = 0 done = False observation = env.reset() while not done: action = agent.choose_action(observation, isTrain=True) observation_, reward, done, info = env.step(action) agent.remember(observation, action, reward, observation_, done) agent.learn() total_reward += reward observation = observation_ total_rewards.append(total_reward) avg_reward = np.mean(total_rewards[-100:]) avg_rewards.append(avg_reward) eps_history.append(agent.epsilon) print('EP:{} reward:{} avg_reward:{} epsilon:{}'. format(episode + 1, total_reward, avg_reward, agent.epsilon)) if (episode + 1) % 50 == 0: agent.save_models(episode + 1) episodes = [i for i in range(args.max_episodes)] plot_learning_curve(episodes, avg_rewards, 'Reward', 'reward', args.reward_path) plot_learning_curve(episodes, eps_history, 'Epsilon', 'epsilon', args.epsilon_path)if __name__ == '__main__': main()

训练时还会用到画图函数和创建文件夹函数,我将他们另外放在一个utils.py脚本中,具体代码如下:

import osimport matplotlib.pyplot as pltdef plot_learning_curve(episodes, records, title, ylabel, figure_file): plt.figure() plt.plot(episodes, records, linestyle='-', color='r') plt.title(title) plt.xlabel('episode') plt.ylabel(ylabel) plt.show() plt.savefig(figure_file)def create_directory(path: str, sub_dirs: list): for sub_dir in sub_dirs: if os.path.exists(path + sub_dir): print(path + sub_dir + ' is already exist!') else: os.makedirs(path + sub_dir, exist_ok=True) print(path + sub_dir + ' create successfully!')

仿真结果如下图所示:

通过平均奖励曲线可以看出,大概迭代到400步左右时算法趋于收敛。

本文链接地址:https://www.jiuchutong.com/zhishi/299923.html 转载请保留说明!

上一篇:作用域和作用域链(作用域和作用域链的理解)

下一篇:21世纪20年代的ConvNet——ConvNeXt(21世纪20年代的中国)

  • 抖音怎么置顶关注(抖音怎么置顶关注的人评论)

    抖音怎么置顶关注(抖音怎么置顶关注的人评论)

  • 快手上怎么找到关注的人(快手上怎么找到微信好友)

    快手上怎么找到关注的人(快手上怎么找到微信好友)

  • iPhone11镜头要不要贴膜(苹果11镜头需要保护吗)

    iPhone11镜头要不要贴膜(苹果11镜头需要保护吗)

  • meat30和meat30pro的区别(华为meat30pro和meat30epro有什么区别)

    meat30和meat30pro的区别(华为meat30pro和meat30epro有什么区别)

  • 手机视频压缩怎么弄(手机视频压缩怎么压缩)

    手机视频压缩怎么弄(手机视频压缩怎么压缩)

  • oppo和vivo是一家吗(oppo和vivo是一家公司吗步步高)

    oppo和vivo是一家吗(oppo和vivo是一家公司吗步步高)

  • 怎样把几张照片合成一个文件(怎样把几张照片拼在一张)

    怎样把几张照片合成一个文件(怎样把几张照片拼在一张)

  • 微信红包能发多少(微信红包能发多少钱最多)

    微信红包能发多少(微信红包能发多少钱最多)

  • qq看点图片加载不出来怎么办(qq看点图片显示不全)

    qq看点图片加载不出来怎么办(qq看点图片显示不全)

  • 手机看快手卡顿怎么解决(手机看快手卡顿怎么解决华为)

    手机看快手卡顿怎么解决(手机看快手卡顿怎么解决华为)

  • 删除温控会有什么后果(删除温控会有什么影响)

    删除温控会有什么后果(删除温控会有什么影响)

  • 看qq精选照片有记录吗(qq精选照片有记录吗)

    看qq精选照片有记录吗(qq精选照片有记录吗)

  • 小米8是安卓几的版本(小米8是安卓吗)

    小米8是安卓几的版本(小米8是安卓吗)

  • 华为p40pro是3d人脸识别吗(p40pro有3d)

    华为p40pro是3d人脸识别吗(p40pro有3d)

  • 微机是什么意思(微机课是学什么的)

    微机是什么意思(微机课是学什么的)

  • 爱奇艺会员可以在银河奇异果用吗(爱奇艺会员可以投屏到电视上么)

    爱奇艺会员可以在银河奇异果用吗(爱奇艺会员可以投屏到电视上么)

  • 华为p9有红外线遥控吗(华为p9有红外线遥控器吗)

    华为p9有红外线遥控吗(华为p9有红外线遥控器吗)

  • 宾馆订后砍五折啥意思(酒店5折砍价)

    宾馆订后砍五折啥意思(酒店5折砍价)

  • 淘宝红包会自动退回吗(淘宝红包会自动到支付宝吗)

    淘宝红包会自动退回吗(淘宝红包会自动到支付宝吗)

  • 天眼查的数据哪来的(天眼查的数据哪里找)

    天眼查的数据哪来的(天眼查的数据哪里找)

  • iphone11分辨率(iphone11分辨率怎么调)

    iphone11分辨率(iphone11分辨率怎么调)

  • 64g手机2年后够用吗

    64g手机2年后够用吗

  • p30pro防水级别(p30pro防水性能怎么样)

    p30pro防水级别(p30pro防水性能怎么样)

  • airpods一代二代区分(airpods一代二代套子一样吗)

    airpods一代二代区分(airpods一代二代套子一样吗)

  • 大叶绣球花上的一对日本树蛙,日本滋贺 (© Mitsuhiko Imamori/Minden)(绣球花的叶子出现了斑点,这是怎么了?)

    大叶绣球花上的一对日本树蛙,日本滋贺 (© Mitsuhiko Imamori/Minden)(绣球花的叶子出现了斑点,这是怎么了?)

  • pwdhash命令  密码哈希生成器(passwd -s命令)

    pwdhash命令 密码哈希生成器(passwd -s命令)

  • Nginx环境搭建及前端部署教程(Windows版)(nginx怎么搭建)

    Nginx环境搭建及前端部署教程(Windows版)(nginx怎么搭建)

  • 产权转移书据印花税政策
  • 减免税款的会计分录在什么时候处理
  • 税务师考试考几门几年考完
  • 车辆购置税会计核算
  • 银行对账单放前面还是放后面
  • 交强险怎么报销流程
  • 工程施工科目核算内容
  • 红字发票需要认证吗之前的发票还有用吗
  • 物业公司收款一般多久
  • 水电费的进项税额能抵扣吗
  • 固定资产评估减值后如何入账
  • 企业国有资产无偿划转办法
  • 联营返点收入账务处理
  • 异地经营需要办什么税务手续?
  • 建筑企业如何申请高新技术企业
  • 增值税留抵税额抵减欠税
  • 可以给农村信用社的存折转账吗
  • 个体经营户如何开电子发票
  • 税控系统技术维护费抵扣如何填报
  • 离职补偿金个税计算器2022
  • 职工因公出差伙食补助标准
  • 合伙人退伙怎么处理
  • 其他应付款坏账怎么处理
  • 垫付的医药费怎么理赔
  • 公司的资本成本取决于投资人的必要报酬率
  • 制造业增值税加计抵减
  • 筹建期的餐饮费会计分录
  • iphone6按键功能介绍
  • 国家统计局一套表平台网址
  • 车辆维修费可以抵扣进项吗
  • 进项税额认证了也就是抵扣了吗?
  • 从劳务市场雇人受伤了怎么办?
  • 鸿蒙2.0 更新
  • 广告费和业务宣传费15%还是30%
  • 跨年销货退回账务处理
  • linux中常用的文件类型有哪些如何区分
  • zmweb.exe是什么进程
  • PHP:imagesetstyle()的用法_GD库图像处理函数
  • laravel框架关键技术解析
  • php内置数组
  • 实际退税能退多少
  • 业务招待费的账务处理金额
  • web网页设计期末作业猫眼电影首页
  • 现金日记账每月都做本年累计数吗?
  • 冲减预提成本分录
  • css 入门
  • 租金算营业成本还是管理费用
  • 企业的对公账户怎么办理
  • mysql表设计原则
  • 增值税退税流程怎么操作
  • 个体工商户能享受4050政策吗
  • 营业收入的构成分析包括
  • 总分类账户余额表怎么做账
  • 未分配利润转增股本要交税吗怎么交
  • 其他应付款怎么处理
  • 残保金需要计提吗怎样做分录
  • 新会计准则印花税需要计提吗
  • 环境检测费账务处理
  • 在建工程的
  • 网银费用及回单怎么查
  • 车辆违章有几种处理方法
  • xp系统好怪啊
  • centos.repo
  • centos怎么调出终端
  • 把winpe安装至系统盘
  • window10的cmd命令
  • centos发送http请求
  • Win RT 8.1 Update 3怎么提前更新安装使用?
  • win8系统怎么样
  • win8.1使用
  • javascript语言基础
  • Android自定义对话框
  • angular1
  • 批处理之家官网
  • 怎么检测python
  • javascript选项
  • js domcontentloaded
  • spring mvc jsp
  • 国家税务局涉税信息公开
  • 企业所得税报错了税款扣了可以改吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设