首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >tanh-squash 的 log_prob 未修正、α 自适应错写、没有双 Q 取最小的三连坑

tanh-squash 的 log_prob 未修正、α 自适应错写、没有双 Q 取最小的三连坑

原创
作者头像
九年义务漏网鲨鱼
发布2025-12-20 11:51:16
发布2025-12-20 11:51:16
2580
举报

SAC 学不稳?tanh-squash 的 log_prob 未修正、α 自适应错写、没有双 Q 取最小的三连坑

场景:在连续控制(Pendulum/HalfCheetah/Walker 等)上复现 SAC。训练能学,但回报抖动大、迟迟上不去,α(熵系数)时而爆、时而缩到几乎 0。复盘后常见三件事:

  1. 高斯策略经 tanh 压缩到动作空间后,log_prob 没做雅可比修正;
  2. α 的自适应目标写错或忘记对 log_prob 断梯度,导致 α 发散或塌陷;
  3. 价值分支没有双 Q 取最小(或目标没用 target critic),目标偏乐观,训练极不稳。

下面给出最小复现实验(CPU 可跑小段验证)与端到端修复模板。


Bug 现象

  • 平均回报长时间卡在一个低值(比如 Pendulum 接近 -600~-400),或早期好、后期崩。
  • 策略熵剧烈波动;α 从 0.001 一路暴涨到 10+,或直接衰到 1e-6。
  • 价值损失周期性高峰,Q 值量级逐步抬升;换更小 lr 也只能暂时缓解。

场景复现(tanh-squash log_prob 校验脚本)

保存为 sac_tanh_logprob_check.py,CPU 即可运行,观察两种写法的数值差异。

代码语言:python
复制
# sac_tanh_logprob_check.py
import torch, torch.nn.functional as F, math
torch.manual_seed(0)

def stable_tanh_logdet_jacobian(u):
    # log(1 - tanh(u)^2) 的稳定写法:2*(log(2) - u - softplus(-2u))
    return (2*(math.log(2) - u - F.softplus(-2*u))).sum(dim=-1)

def sample_squashed_gauss(mu, log_std):
    std = log_std.exp().clamp_min(1e-6)
    base = torch.distributions.Normal(mu, std)
    u = base.rsample()
    a = torch.tanh(u)
    # 正确:base.log_prob(u) - log|det d tanh(u)/du|
    logp = base.log_prob(u).sum(dim=-1) - stable_tanh_logdet_jacobian(u)
    return a, logp

def sample_wrong(mu, log_std):
    std = log_std.exp()
    base = torch.distributions.Normal(mu, std)
    u = base.rsample()
    a = torch.tanh(u)
    # 错误:漏掉雅可比修正
    logp = base.log_prob(u).sum(dim=-1)
    return a, logp

D, N = 3, 4096
mu = torch.zeros(N, D)
log_std = torch.zeros(N, D)

a1, lp1 = sample_squashed_gauss(mu, log_std)
a2, lp2 = sample_wrong(mu, log_std)

print("mean|a| correct/ wrong:", a1.abs().mean().item(), a2.abs().mean().item())
print("mean logp correct/ wrong:", lp1.mean().item(), lp2.mean().item())
print("diff logp (should not be ~0):", (lp1 - lp2).abs().mean().item())

你会看到“错误版”与“正确版”的 log_prob 均值差距显著。错写 log_prob 会直接影响 actor loss 中的 α·logπ(a|s),使熵正则的力度完全失真,从而拖垮训练。


端到端修复模板(最小 SAC 更新步)

下面是可直接嵌入项目的关键片段,涵盖三处修复:tanh-log_prob、α 自适应、双 Q 与 target。

代码语言:python
复制
import torch, torch.nn as nn, torch.nn.functional as F, math

LOG_STD_MIN, LOG_STD_MAX = -5.0, 2.0
GAMMA, TAU = 0.99, 0.005

def stable_tanh_logdet_jacobian(u):
    return (2*(math.log(2) - u - F.softplus(-2*u))).sum(dim=-1)

class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
        )
        self.mu = nn.Linear(hidden, act_dim)
        self.log_std = nn.Linear(hidden, act_dim)

    def forward(self, obs, deterministic=False, with_logprob=True):
        h = self.net(obs)
        mu = self.mu(h)
        log_std = self.log_std(h).clamp(LOG_STD_MIN, LOG_STD_MAX)
        std = log_std.exp()
        dist = torch.distributions.Normal(mu, std)

        if deterministic:
            u = mu
        else:
            u = dist.rsample()                      # 重参数化采样
        a = torch.tanh(u)                           # 压到 (-1,1)

        if with_logprob:
            logp = dist.log_prob(u).sum(-1) - stable_tanh_logdet_jacobian(u)
        else:
            logp = None
        return a, logp, mu, log_std

class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden=256):
        super().__init__()
        def qnet():
            return nn.Sequential(
                nn.Linear(obs_dim + act_dim, hidden), nn.ReLU(),
                nn.Linear(hidden, hidden), nn.ReLU(),
                nn.Linear(hidden, 1)
            )
        self.q1 = qnet()
        self.q2 = qnet()
    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        return self.q1(x), self.q2(x)

@torch.no_grad()
def soft_update(target, online, tau=TAU):
    for tp, p in zip(target.parameters(), online.parameters()):
        tp.data.mul_(1-tau).add_(tau*p.data)

def sac_update(batch, actor, critic, target_critic,
               opt_actor, opt_critic, log_alpha, opt_alpha,
               target_entropy, alpha_clip=(1e-6, 1e+2)):
    obs, act, rew, next_obs, done = batch  # 形状:[B, D]
    # 1) critic 更新:双 Q + target,bootstrap 用 (1 - done)
    with torch.no_grad():
        next_a, next_logp, _, _ = actor(next_obs, deterministic=False, with_logprob=True)
        q1_t, q2_t = target_critic(next_obs, next_a)
        q_target = torch.min(q1_t, q2_t) - log_alpha.exp() * next_logp
        y = rew + GAMMA * (1.0 - done) * q_target

    q1, q2 = critic(obs, act)
    critic_loss = F.mse_loss(q1, y) + F.mse_loss(q2, y)

    opt_critic.zero_grad(set_to_none=True)
    critic_loss.backward()
    torch.nn.utils.clip_grad_norm_(critic.parameters(), 5.0)
    opt_critic.step()

    # 2) actor 更新:最小 Q 上的策略梯度 + 熵项
    #   为节省显存/防泄漏,可在此处临时冻结 critic 参数
    for p in critic.parameters(): p.requires_grad_(False)
    a_pi, logp_pi, _, _ = actor(obs, deterministic=False, with_logprob=True)
    q1_pi, q2_pi = critic(obs, a_pi)
    q_pi = torch.min(q1_pi, q2_pi)
    actor_loss = (log_alpha.exp() * logp_pi - q_pi).mean()

    opt_actor.zero_grad(set_to_none=True)
    actor_loss.backward()
    opt_actor.step()
    for p in critic.parameters(): p.requires_grad_(True)

    # 3) α 自适应(断梯度):目标熵一般取 -act_dim
    #   注意对 logp 断梯度;否则 α 与策略互相拉扯,数值极不稳
    alpha_loss = -(log_alpha * (logp_pi.detach() + target_entropy)).mean()
    opt_alpha.zero_grad(set_to_none=True)
    alpha_loss.backward()
    opt_alpha.step()
    with torch.no_grad():
        log_alpha.data.clamp_(math.log(alpha_clip[0]), math.log(alpha_clip[1]))

    # 4) 软更新 target critic
    soft_update(target_critic, critic)

    return {
        "critic_loss": float(critic_loss.item()),
        "actor_loss": float(actor_loss.item()),
        "alpha": float(log_alpha.exp().item()),
        "entropy": float((-logp_pi).mean().item()),
    }

要点回顾

  • tanh-squash 后的 log_prob 必须减去雅可比:log |det ∂tanh/∂u|
  • α 自适应:L(α) = E[-α (logπ(a|s) + H_target)],其中 logπ 需断梯度;
  • 价值分支:双 Q 取最小,目标用 target_critic,且含 - α·logπ(s',a') 的熵校正。

Debug 过程

  1. 打印关键统计 每 N 步记录:
  • α、策略熵、log_std 的均值与范围(建议把 log_std clamp 到 -5, 2);
  • Q 值绝对值均值;当其持续上扬,优先排查“无双 Q”“无 target”“done 掩码错写”。
  1. 校验 log_prob 是否修正 对同一批前向,把“未修正”和“已修正”的 log_prob 打印出来;若两者均值/方差接近,说明你可能错误地在动作空间而非 pre-tanh 空间上求了 base.log_prob。
  2. α 梯度路径 临时把 α 的学习率调大一些,观察 α 单独更新的稳定性;若 α 与 actor 同步抖得很厉害,通常是忘了对 log_prob 断梯度。
  3. target 更新与 done 掩码 断点检查 target 参数是否真的在变化(soft-update 生效),以及 bootstrap 系数是否为 (1 - terminated) 而非同时把 time-limit 截断也当终止。

监控与护栏

代码语言:python
复制
def assert_logprob_sane(logp):
    v = float(logp.abs().mean())
    assert math.isfinite(v) and v < 100, f"log_prob 异常:{v}"

def guard_log_std(log_std):
    if (log_std < LOG_STD_MIN - 1e-3).any() or (log_std > LOG_STD_MAX + 1e-3).any():
        print("[warn] log_std 超界,考虑 clamp 或正则")

def check_alpha_update(log_alpha, logp_pi, target_entropy):
    g = torch.autograd.grad((-log_alpha * (logp_pi + target_entropy)).mean(),
                            [log_alpha], retain_graph=True, allow_unused=True)[0]
    assert g is not None, "α 无梯度,请检查计算图与优化器绑定"

常见问答

  • 目标熵取多少合适 一般取 -act_dim。某些任务更大或更小会更好,但先从该默认值出发再做网格搜索。
  • actor 更新时要不要对 critic 断梯度 梯度不应更新 critic 参数;最省心的方法是在计算 actorloss 前把 critic 参数 `requires_grad(False)`,更新后再恢复。
  • 需要 value 网络吗 SAC v2 移除显式 V 网络,使用双 Q 的最小值替代;大多数实现采用 v2。
  • 为什么我的动作边界不是 (-1,1) 环境动作空间若不是该范围,需线性映射 a_env = low + (a_tanh+1)/2 * (high-low),同时在环境交互处做映射即可;策略与损失仍建议在 tanh 空间处理。

结语

SAC 的“隐形三连坑”——tanh-squash log_prob 未修正、α 自适应断梯度缺失、没有双 Q 最小与 target bootstrapping——足以让曲线长期低迷或随机游走。把上面的校验脚本跑一遍,再把修复模板固化到代码里,并持续监控 α/熵/Q 值与 log_std 的健康区间,这类顽固不稳问题基本可以一次性清零

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • SAC 学不稳?tanh-squash 的 log_prob 未修正、α 自适应错写、没有双 Q 取最小的三连坑
    • Bug 现象
    • 场景复现(tanh-squash log_prob 校验脚本)
    • 端到端修复模板(最小 SAC 更新步)
    • Debug 过程
    • 监控与护栏
    • 常见问答
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档