前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >强化学习之Sarsa

强化学习之Sarsa

作者头像
CristianoC
发布2020-05-31 15:20:03
6760
发布2020-05-31 15:20:03
举报
文章被收录于专栏:计算机视觉漫谈

在强化学习中,Sarsa和Q-Learning很类似,本次内容将会基于之前所讲的Q-Learning的内容。

目录

  • 算法简介
  • 更新准则
  • 探险者上天堂实战

算法简介

Sarsa决策部分和Q-Learning一模一样,都是采用Q表的方式进行决策,所以我们会在Q表中挑选values比较大的动作实施在环境中来换取奖赏。但是Sarsa的更新是不一样的

更新准则

和上次一样用小学生写作业为例子,我们会经历写作业的状态s1,然后再挑选一个带来最大潜在奖励的动作a2,这样我们就到达了继续写作业的状态s2,而在这一步没如果你用的是Q-Learning,你会观察一下在s2上选取哪一个动作会带来最大的奖赏reward来更新,但是在真正要做决定的时候却不一定会选取到那个带来最大reward的动作,Q-Learning这一步只是估计了接下来的value。而Sarsa在s2这一步估计的动作就是他接下来要做的动作。所以Q(s1,a2)现实的计算值我们也会改动,去掉了maxQ,取而代之的是在S2上我们实实在在选取的a2的Q值。最后像Q-Learning一样,求出现实和估计的差距并更新Q表里的Q(s1,a2)

上图就是Sarsa更新的公式。我们可以看到和Q-Learning的不同之处:

  • 他在当前的state中已经想好了state对应的action,而且想好了下一个state_和下一个action_(Q-learning还没有想好下一个action_
  • 更新Q(s,a)的时候基于的是下一个Q(s_,a_)(Q-learning基于的是maxQ(s_)

这种不同之处使得Sarsa相对于Q-learning显得比较的”胆小“。原因在于

  • Q-learning在更新的时候始终都是选择maxQ最大化,因为这个maxQ变得贪婪,不考虑其他非maxQ的结果。我们可以理解成Q-learning是一种贪婪,大胆,勇敢的算法,对于错误,死亡并不在乎。而Sarsa是一种保守的算法,他在乎每一步的决策,对于错误和死亡比较敏感,这可以在可视化部分看出他们的不同。两种算法都有他们的好处,比如在实际中,如果你比较在乎机器的损害,那么用一种保守的算法,在训练中就可以有效地减少损坏的次数。
  • 从另一个角度想,Q-learning更新使用maxQ,而Sarsa却要看a_的值,而a_的值需要看greedy的脸色,如果greedy=1那么a_就是maxQ,与Q—Learning在greedy=1无差别。greedy值越小,Sarsa越不坚决(选择Q表中大的那个),而是会根据np.random.choice随机选择一个方向,同时也正是因Sarsa多了一项探索的概率,所以才是的Sarsa容易偏离终点,从视觉上看Sarsa有时显得很纠结。正因如此,Sarsa其实在某些程度上显得他很勇敢,因为Sarsa比Q-Learning更有探索精神,也正是这份精神使得Sarsa对终点的渴望不那么果决,饥渴成都要看greedy的脸色,更具多面性。

探险者上天堂实战

背景

黄色是天堂(reward=1),黑色是地狱(reward=-1)。我们的目标就是让探险者经过自己的多次入“地狱”,最终学会入“天堂”

主模块

首先我们先import两个模块,maze_env是我们游戏虚拟环境模块,是用python自带的GUI模块tkinter来编写,具体细节不多赘述,完整代码会放在最后。RL_brain这个模块是RL的大脑部分,稍后会提及。

代码语言:javascript
复制
1from maze_env import Maze
2from RL_brain import SarsaTable

下面就是我们的更新部分代码

代码语言:javascript
复制
 1def update():
 2    for episode in range(100):
 3        # 初始化环境
 4        observation = env.reset()
 5
 6        # Sarsa根据state观测选择行为
 7        action = RL.choose_action(str(observation))
 8
 9        while True:
10            # 刷新环境
11            env.render()
12
13            # 在环境中采取行为,获得下一个state_(observation_),reward,和终止信号
14            observation_, reward, done = env.step(action)
15
16            # 根据下一个state(observation_)选取下一个action_
17            action_ = RL.choose_action(str(observation_))
18
19            #从(s, a, r, s, a)中学习,更新Q_table的参数
20            RL.learn(str(observation), action, reward, str(observation_), action_)
21
22            # 将下一个的observation_和action_当成对应下一步的参数
23            observation = observation_
24            action = action_
25
26            if done:
27                break
28
29    # end of game
30    print('game over')
31    env.destroy()
32
33if __name__ == "__main__":
34    #定义环境enc和RL方式
35    env = Maze()
36    RL = SarsaTable(actions=list(range(env.n_actions)))
37    env.after(100, update)
38    env.mainloop()

RL_brain模块

我们定义一个父类classRL,然后SarsaTable作为父类的衍生。

代码语言:javascript
复制
 1import numpy as np
 2import pandas as pd
 3
 4
 5class RL:
 6    #初始化参数
 7    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
 8        self.actions = actions  # 行为列表
 9        self.lr = learning_rate #学习率
10        self.gamma = reward_decay  #奖励衰减度
11        self.epsilon = e_greedy #贪婪度
12        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) #初始化q_table
13
14    #选择行为
15    def choose_action(self, observation):
16        self.check_state_exist(observation) #检验state是否在q_table中出现
17        # 贪婪模式
18        if np.random.uniform() < self.epsilon:
19            state_action = self.q_table.loc[observation, :]
20            # 同一个state,可能会有多个相同的Q action value,所以我们乱序一下
21            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
22        else:
23            # 非贪婪模式随机选择action
24            action = np.random.choice(self.actions)
25        return action
26
27    #学习更新参数
28    def learn(self, s, a, r, s_):
29        self.check_state_exist(s_)#同样先检验一下q_table中是否存在S_
30        q_predict = self.q_table.loc[s, a]
31        if s_ != 'terminal':
32            #下个状态不是终止
33            q_target = r + self.gamma * self.q_table.loc[s_, :].max()
34        else:
35            q_target = r
36        #更新参数
37        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
38
39    #检验state是否存在
40    def check_state_exist(self, state):
41        if state not in self.q_table.index:
42            # 如果不存在就插入一组全0数据,当做state的所有action的初始values
43            self.q_table = self.q_table.append(
44                pd.Series(
45                    [0]*len(self.actions),
46                    index=self.q_table.columns,
47                    name=state,
48                )
49            )

然后我们编写SarsaTablelearn也就是更新功能就完成了。

代码语言:javascript
复制
 1class SarsaTable(RL):
 2
 3    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
 4        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
 5
 6    def learn(self, s, a, r, s_, a_):
 7        self.check_state_exist(s_)
 8        q_predict = self.q_table.loc[s, a]
 9        if s_ != 'terminal':
10            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # q_target基于选好的a_而不是Q(s_)的最大值
11        else:
12            q_target = r
13        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新q_table

最后探险者就可以很轻松的上天堂了!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-07-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 计算机视觉漫谈 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 目录
  • 算法简介
  • 更新准则
  • 探险者上天堂实战
  • 背景
  • 主模块
  • RL_brain模块
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档