在强化学习中,Sarsa和Q-Learning很类似,本次内容将会基于之前所讲的Q-Learning的内容。
Sarsa决策部分和Q-Learning一模一样,都是采用Q表的方式进行决策,所以我们会在Q表中挑选values比较大的动作实施在环境中来换取奖赏。但是Sarsa的更新是不一样的
和上次一样用小学生写作业为例子,我们会经历写作业的状态s1
,然后再挑选一个带来最大潜在奖励的动作a2
,这样我们就到达了继续写作业的状态s2
,而在这一步没如果你用的是Q-Learning,你会观察一下在s2
上选取哪一个动作会带来最大的奖赏reward
来更新,但是在真正要做决定的时候却不一定会选取到那个带来最大reward
的动作,Q-Learning这一步只是估计了接下来的value。而Sarsa在s2这一步估计的动作就是他接下来要做的动作。所以Q(s1,a2)
现实的计算值我们也会改动,去掉了maxQ
,取而代之的是在S2上我们实实在在选取的a2的Q值。最后像Q-Learning一样,求出现实和估计的差距并更新Q表里的Q(s1,a2)
。
上图就是Sarsa更新的公式。我们可以看到和Q-Learning的不同之处:
state
中已经想好了state
对应的action
,而且想好了下一个state_
和下一个action_
(Q-learning还没有想好下一个action_
)Q(s,a)
的时候基于的是下一个Q(s_,a_)
(Q-learning基于的是maxQ(s_)
)这种不同之处使得Sarsa相对于Q-learning显得比较的”胆小“。原因在于
maxQ
最大化,因为这个maxQ
变得贪婪,不考虑其他非maxQ
的结果。我们可以理解成Q-learning是一种贪婪,大胆,勇敢的算法,对于错误,死亡并不在乎。而Sarsa是一种保守的算法,他在乎每一步的决策,对于错误和死亡比较敏感,这可以在可视化部分看出他们的不同。两种算法都有他们的好处,比如在实际中,如果你比较在乎机器的损害,那么用一种保守的算法,在训练中就可以有效地减少损坏的次数。maxQ
,而Sarsa却要看a_
的值,而a_
的值需要看greedy
的脸色,如果greedy=1
那么a_
就是maxQ
,与Q—Learning在greedy=1
无差别。greedy
值越小,Sarsa越不坚决(选择Q表中大的那个),而是会根据np.random.choice随机选择一个方向,同时也正是因Sarsa多了一项探索的概率,所以才是的Sarsa容易偏离终点,从视觉上看Sarsa有时显得很纠结。正因如此,Sarsa其实在某些程度上显得他很勇敢,因为Sarsa比Q-Learning更有探索精神,也正是这份精神使得Sarsa对终点的渴望不那么果决,饥渴成都要看greedy
的脸色,更具多面性。黄色是天堂(reward=1),黑色是地狱(reward=-1)。我们的目标就是让探险者经过自己的多次入“地狱”,最终学会入“天堂”
首先我们先import两个模块,maze_env
是我们游戏虚拟环境模块,是用python自带的GUI模块tkinter
来编写,具体细节不多赘述,完整代码会放在最后。RL_brain
这个模块是RL的大脑部分,稍后会提及。
1from maze_env import Maze
2from RL_brain import SarsaTable
下面就是我们的更新部分代码
1def update():
2 for episode in range(100):
3 # 初始化环境
4 observation = env.reset()
5
6 # Sarsa根据state观测选择行为
7 action = RL.choose_action(str(observation))
8
9 while True:
10 # 刷新环境
11 env.render()
12
13 # 在环境中采取行为,获得下一个state_(observation_),reward,和终止信号
14 observation_, reward, done = env.step(action)
15
16 # 根据下一个state(observation_)选取下一个action_
17 action_ = RL.choose_action(str(observation_))
18
19 #从(s, a, r, s, a)中学习,更新Q_table的参数
20 RL.learn(str(observation), action, reward, str(observation_), action_)
21
22 # 将下一个的observation_和action_当成对应下一步的参数
23 observation = observation_
24 action = action_
25
26 if done:
27 break
28
29 # end of game
30 print('game over')
31 env.destroy()
32
33if __name__ == "__main__":
34 #定义环境enc和RL方式
35 env = Maze()
36 RL = SarsaTable(actions=list(range(env.n_actions)))
37 env.after(100, update)
38 env.mainloop()
我们定义一个父类classRL
,然后SarsaTable
作为父类的衍生。
1import numpy as np
2import pandas as pd
3
4
5class RL:
6 #初始化参数
7 def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
8 self.actions = actions # 行为列表
9 self.lr = learning_rate #学习率
10 self.gamma = reward_decay #奖励衰减度
11 self.epsilon = e_greedy #贪婪度
12 self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) #初始化q_table
13
14 #选择行为
15 def choose_action(self, observation):
16 self.check_state_exist(observation) #检验state是否在q_table中出现
17 # 贪婪模式
18 if np.random.uniform() < self.epsilon:
19 state_action = self.q_table.loc[observation, :]
20 # 同一个state,可能会有多个相同的Q action value,所以我们乱序一下
21 action = np.random.choice(state_action[state_action == np.max(state_action)].index)
22 else:
23 # 非贪婪模式随机选择action
24 action = np.random.choice(self.actions)
25 return action
26
27 #学习更新参数
28 def learn(self, s, a, r, s_):
29 self.check_state_exist(s_)#同样先检验一下q_table中是否存在S_
30 q_predict = self.q_table.loc[s, a]
31 if s_ != 'terminal':
32 #下个状态不是终止
33 q_target = r + self.gamma * self.q_table.loc[s_, :].max()
34 else:
35 q_target = r
36 #更新参数
37 self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
38
39 #检验state是否存在
40 def check_state_exist(self, state):
41 if state not in self.q_table.index:
42 # 如果不存在就插入一组全0数据,当做state的所有action的初始values
43 self.q_table = self.q_table.append(
44 pd.Series(
45 [0]*len(self.actions),
46 index=self.q_table.columns,
47 name=state,
48 )
49 )
然后我们编写SarsaTable
中learn
也就是更新功能就完成了。
1class SarsaTable(RL):
2
3 def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
4 super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
5
6 def learn(self, s, a, r, s_, a_):
7 self.check_state_exist(s_)
8 q_predict = self.q_table.loc[s, a]
9 if s_ != 'terminal':
10 q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target基于选好的a_而不是Q(s_)的最大值
11 else:
12 q_target = r
13 self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新q_table
最后探险者就可以很轻松的上天堂了!