
场景:在连续控制(Pendulum/HalfCheetah/Walker 等)上复现 SAC。训练能学,但回报抖动大、迟迟上不去,α(熵系数)时而爆、时而缩到几乎 0。复盘后常见三件事:
下面给出最小复现实验(CPU 可跑小段验证)与端到端修复模板。
保存为 sac_tanh_logprob_check.py,CPU 即可运行,观察两种写法的数值差异。
# sac_tanh_logprob_check.py
import torch, torch.nn.functional as F, math
torch.manual_seed(0)
def stable_tanh_logdet_jacobian(u):
# log(1 - tanh(u)^2) 的稳定写法:2*(log(2) - u - softplus(-2u))
return (2*(math.log(2) - u - F.softplus(-2*u))).sum(dim=-1)
def sample_squashed_gauss(mu, log_std):
std = log_std.exp().clamp_min(1e-6)
base = torch.distributions.Normal(mu, std)
u = base.rsample()
a = torch.tanh(u)
# 正确:base.log_prob(u) - log|det d tanh(u)/du|
logp = base.log_prob(u).sum(dim=-1) - stable_tanh_logdet_jacobian(u)
return a, logp
def sample_wrong(mu, log_std):
std = log_std.exp()
base = torch.distributions.Normal(mu, std)
u = base.rsample()
a = torch.tanh(u)
# 错误:漏掉雅可比修正
logp = base.log_prob(u).sum(dim=-1)
return a, logp
D, N = 3, 4096
mu = torch.zeros(N, D)
log_std = torch.zeros(N, D)
a1, lp1 = sample_squashed_gauss(mu, log_std)
a2, lp2 = sample_wrong(mu, log_std)
print("mean|a| correct/ wrong:", a1.abs().mean().item(), a2.abs().mean().item())
print("mean logp correct/ wrong:", lp1.mean().item(), lp2.mean().item())
print("diff logp (should not be ~0):", (lp1 - lp2).abs().mean().item())你会看到“错误版”与“正确版”的 log_prob 均值差距显著。错写 log_prob 会直接影响 actor loss 中的 α·logπ(a|s),使熵正则的力度完全失真,从而拖垮训练。
下面是可直接嵌入项目的关键片段,涵盖三处修复:tanh-log_prob、α 自适应、双 Q 与 target。
import torch, torch.nn as nn, torch.nn.functional as F, math
LOG_STD_MIN, LOG_STD_MAX = -5.0, 2.0
GAMMA, TAU = 0.99, 0.005
def stable_tanh_logdet_jacobian(u):
return (2*(math.log(2) - u - F.softplus(-2*u))).sum(dim=-1)
class Actor(nn.Module):
def __init__(self, obs_dim, act_dim, hidden=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
)
self.mu = nn.Linear(hidden, act_dim)
self.log_std = nn.Linear(hidden, act_dim)
def forward(self, obs, deterministic=False, with_logprob=True):
h = self.net(obs)
mu = self.mu(h)
log_std = self.log_std(h).clamp(LOG_STD_MIN, LOG_STD_MAX)
std = log_std.exp()
dist = torch.distributions.Normal(mu, std)
if deterministic:
u = mu
else:
u = dist.rsample() # 重参数化采样
a = torch.tanh(u) # 压到 (-1,1)
if with_logprob:
logp = dist.log_prob(u).sum(-1) - stable_tanh_logdet_jacobian(u)
else:
logp = None
return a, logp, mu, log_std
class Critic(nn.Module):
def __init__(self, obs_dim, act_dim, hidden=256):
super().__init__()
def qnet():
return nn.Sequential(
nn.Linear(obs_dim + act_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, 1)
)
self.q1 = qnet()
self.q2 = qnet()
def forward(self, obs, act):
x = torch.cat([obs, act], dim=-1)
return self.q1(x), self.q2(x)
@torch.no_grad()
def soft_update(target, online, tau=TAU):
for tp, p in zip(target.parameters(), online.parameters()):
tp.data.mul_(1-tau).add_(tau*p.data)
def sac_update(batch, actor, critic, target_critic,
opt_actor, opt_critic, log_alpha, opt_alpha,
target_entropy, alpha_clip=(1e-6, 1e+2)):
obs, act, rew, next_obs, done = batch # 形状:[B, D]
# 1) critic 更新:双 Q + target,bootstrap 用 (1 - done)
with torch.no_grad():
next_a, next_logp, _, _ = actor(next_obs, deterministic=False, with_logprob=True)
q1_t, q2_t = target_critic(next_obs, next_a)
q_target = torch.min(q1_t, q2_t) - log_alpha.exp() * next_logp
y = rew + GAMMA * (1.0 - done) * q_target
q1, q2 = critic(obs, act)
critic_loss = F.mse_loss(q1, y) + F.mse_loss(q2, y)
opt_critic.zero_grad(set_to_none=True)
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(critic.parameters(), 5.0)
opt_critic.step()
# 2) actor 更新:最小 Q 上的策略梯度 + 熵项
# 为节省显存/防泄漏,可在此处临时冻结 critic 参数
for p in critic.parameters(): p.requires_grad_(False)
a_pi, logp_pi, _, _ = actor(obs, deterministic=False, with_logprob=True)
q1_pi, q2_pi = critic(obs, a_pi)
q_pi = torch.min(q1_pi, q2_pi)
actor_loss = (log_alpha.exp() * logp_pi - q_pi).mean()
opt_actor.zero_grad(set_to_none=True)
actor_loss.backward()
opt_actor.step()
for p in critic.parameters(): p.requires_grad_(True)
# 3) α 自适应(断梯度):目标熵一般取 -act_dim
# 注意对 logp 断梯度;否则 α 与策略互相拉扯,数值极不稳
alpha_loss = -(log_alpha * (logp_pi.detach() + target_entropy)).mean()
opt_alpha.zero_grad(set_to_none=True)
alpha_loss.backward()
opt_alpha.step()
with torch.no_grad():
log_alpha.data.clamp_(math.log(alpha_clip[0]), math.log(alpha_clip[1]))
# 4) 软更新 target critic
soft_update(target_critic, critic)
return {
"critic_loss": float(critic_loss.item()),
"actor_loss": float(actor_loss.item()),
"alpha": float(log_alpha.exp().item()),
"entropy": float((-logp_pi).mean().item()),
}要点回顾
log |det ∂tanh/∂u|;L(α) = E[-α (logπ(a|s) + H_target)],其中 logπ 需断梯度;- α·logπ(s',a') 的熵校正。(1 - terminated) 而非同时把 time-limit 截断也当终止。def assert_logprob_sane(logp):
v = float(logp.abs().mean())
assert math.isfinite(v) and v < 100, f"log_prob 异常:{v}"
def guard_log_std(log_std):
if (log_std < LOG_STD_MIN - 1e-3).any() or (log_std > LOG_STD_MAX + 1e-3).any():
print("[warn] log_std 超界,考虑 clamp 或正则")
def check_alpha_update(log_alpha, logp_pi, target_entropy):
g = torch.autograd.grad((-log_alpha * (logp_pi + target_entropy)).mean(),
[log_alpha], retain_graph=True, allow_unused=True)[0]
assert g is not None, "α 无梯度,请检查计算图与优化器绑定"-act_dim。某些任务更大或更小会更好,但先从该默认值出发再做网格搜索。a_env = low + (a_tanh+1)/2 * (high-low),同时在环境交互处做映射即可;策略与损失仍建议在 tanh 空间处理。SAC 的“隐形三连坑”——tanh-squash log_prob 未修正、α 自适应断梯度缺失、没有双 Q 最小与 target bootstrapping——足以让曲线长期低迷或随机游走。把上面的校验脚本跑一遍,再把修复模板固化到代码里,并持续监控 α/熵/Q 值与 log_std 的健康区间,这类顽固不稳问题基本可以一次性清零
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。