📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏: 【强化学习】- 【单智能体强化学习】(3)---《基础在线算法:Sarsa算法》
Sarsa算法是一种强化学习(Reinforcement Learning, RL)的经典算法,属于时序差分(Temporal Difference, TD)方法。它是一种基于策略的学习算法,用于解决马尔可夫决策过程(Markov Decision Process, MDP)中的问题。
简单来说,Sarsa的目标是通过不断地交互,学习如何从当前状态选择最优动作,从而获得最大的累积奖励。
Sarsa的核心是估计状态-动作值函数(Q函数),然后根据这个函数选择动作。该值函数
表示在状态
下采取动作
所能获得的期望回报。
Sarsa算法的名字来源于它的更新过程涉及的五元组:
,
,
,
,
Sarsa使用以下公式来更新
值:
:当前状态
:当前动作
:当前奖励
:下一状态
:下一动作
:学习率,控制更新的步幅
:折扣因子,衡量未来奖励的重要性
值为任意值(通常为0)。
、折扣因子
。
根据策略(如
-贪婪策略)选择动作
。
,观察到奖励
和下一个状态
。
中,根据策略选择下一动作
。
:
。
值的更新,逐渐改善选择动作的策略。
特点 | Sarsa | Q-Learning |
---|---|---|
策略类型 | 基于当前策略(on-policy) | 基于最优策略(off-policy) |
更新公式中的动作 | 使用实际选择的动作 | 使用最优动作 |
行为特点 | 更安全、探索性强 | 更快逼近最优,但可能冒险 |
直观理解:
关于on-policy和off-policy的区别,下面这篇文章进行了较为详细的描述: 【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
项目代码我已经放入GitCode里面,可以通过下面链接跳转:🔥 【强化学习】---Sarsa算法 后续相关单智能体强化学习算法也会不断在【强化学习】项目里更新,如果该项目对你有所帮助,请帮我点一个星星✨✨✨✨✨,鼓励分享,十分感谢!!! 若是下面代码复现困难或者有问题,也欢迎评论区留言。
"""《Sarsa算法实现》
时间:2024.12
作者:不去幼儿园
"""
import numpy as np # 导入numpy库,用于数组和矩阵运算
import pandas as pd # 导入pandas库,用于数据处理和创建数据表
import matplotlib.pyplot as plt # 导入matplotlib库,用于绘图
import time # 导入time库,用于控制程序暂停时间
# 定义强化学习的一些超参数
ALPHA = 0.1 # 学习率,控制更新Q值的幅度
GAMMA = 0.95 # 折扣因子,控制未来奖励的重要性
EPSILION = 0.9 # epsilon值,用于ε-贪婪策略,控制探索与利用的权衡
N_STATE = 6 # 状态的数量,表示状态空间的大小
ACTIONS = ['left', 'right'] # 可能的动作列表,表示智能体的可选行为
MAX_EPISODES = 200 # 最大的训练轮次,表示最大实验次数
FRESH_TIME = 0.1 # 控制环境更新的时间间隔,用于显示训练过程
# 定义Q表的构建函数
def build_q_table(n_state, actions):
# 创建一个Q表,行代表状态,列代表动作,初始时所有Q值为0
q_table = pd.DataFrame(
np.zeros((n_state, len(actions))), # 初始化一个全0的表格,大小为状态数x动作数
np.arange(n_state), # 状态的索引
actions # 动作的名称
)
return q_table # 返回初始化后的Q表
# 定义选择动作的函数
def choose_action(state, q_table):
# epslion - greedy策略
state_action = q_table.loc[state, :] # 获取当前状态下所有可能动作的Q值
if np.random.uniform() > EPSILION or (state_action == 0).all(): # 如果随机数大于epsilon或所有动作Q值为0
action_name = np.random.choice(ACTIONS) # 选择一个随机动作(探索)
else:
action_name = state_action.idxmax() # 否则选择Q值最大的动作(利用)
return action_name # 返回选择的动作
# 定义环境反馈的函数
def get_env_feedback(state, action):
# 根据当前状态和动作来返回下一个状态和奖励
if action == 'right': # 如果选择向右移动
if state == N_STATE - 2: # 如果当前状态是倒数第二个状态
next_state = 'terminal' # 到达终止状态
reward = 1 # 奖励为1
else:
next_state = state + 1 # 否则状态加1
reward = -0.5 # 奖励为-0.5
else: # 如果选择向左移动
if state == 0: # 如果当前状态是最左边的状态
next_state = 0 # 保持在原地
else:
next_state = state - 1 # 否则状态减1
reward = -0.5 # 奖励为-0.5
return next_state, reward # 返回下一个状态和奖励
# 定义环境更新的函数
def update_env(state, episode, step_counter):
# 生成一个表示环境的字符串,'-'表示空地,'T'表示终止状态
env = ['-'] * (N_STATE - 1) + ['T']
if state == 'terminal': # 如果到达终止状态
print("Episode {}, the total step is {}".format(episode + 1, step_counter)) # 打印当前回合和步骤
final_env = ['-'] * (N_STATE - 1) + ['T'] # 环境没有变化
return True, step_counter # 终止回合,返回True
else:
env[state] = '*' # 将当前状态位置标记为'*'
env = ''.join(env) # 将环境列表转化为字符串
print(env) # 打印当前环境的状态
time.sleep(FRESH_TIME) # 暂停程序FRESH_TIME秒,模拟环境变化的延迟
return False, step_counter # 没有到达终止状态,返回False
# 定义SARSA学习算法的函数
def sarsa_learning():
q_table = build_q_table(N_STATE, ACTIONS) # 创建一个Q表
step_counter_times = [] # 用于记录每个回合的步骤数
for episode in range(MAX_EPISODES): # 进行最大回合数的学习
state = 0 # 初始状态设为0
is_terminal = False # 初始状态不是终止状态
step_counter = 0 # 初始步骤计数为0
update_env(state, episode, step_counter) # 更新环境并显示
while not is_terminal: # 当未到达终止状态时继续学习
action = choose_action(state, q_table) # 根据当前状态选择动作
next_state, reward = get_env_feedback(state, action) # 获取环境反馈(下一个状态和奖励)
if next_state != 'terminal': # 如果不是终止状态
next_action = choose_action(next_state, q_table) # 选择下一个状态的动作(SARSA更新方法)
else:
next_action = action # 如果是终止状态,动作不再改变
next_q = q_table.loc[state, action] # 获取当前Q值
if next_state == 'terminal': # 如果到达终止状态
is_terminal = True # 设置为终止状态
q_target = reward # 目标Q值为奖励
else:
delta = reward + GAMMA * q_table.loc[next_state, next_action] - q_table.loc[state, action] # SARSA更新公式
q_table.loc[state, action] += ALPHA * delta # 更新Q表中的值
state = next_state # 更新当前状态为下一个状态
is_terminal, steps = update_env(state, episode, step_counter + 1) # 更新环境并检查是否终止
step_counter += 1 # 增加步骤计数
if is_terminal: # 如果到达终止状态,记录步骤数
step_counter_times.append(steps)
# 主函数入口
if __name__ == '__main__':
q_table, step_counter_times = sarsa_learning() # 执行SARSA学习
print("Q table\n{}\n".format(q_table)) # 打印最终的Q表
print('end') # 输出训练结束
plt.plot(step_counter_times, 'g-') # 绘制每回合的步骤数变化曲线
plt.ylabel("steps") # 设置y轴标签
plt.title("Sarsa Algorithm") # 设置图标题
plt.show() # 显示图形
print("The step_counter_times is {}".format(step_counter_times)) # 打印每回合的步骤数
# 环境配置
Python 3.11.5
torch 2.1.0
torchvision 0.16.0
gym 0.26.2
Sarsa算法是强化学习领域的基石之一,其优点在于:
但在实际应用中,Sarsa的收敛速度较慢,需要良好的超参数调整。