PPO 近端策略优化算法

PPO(Proximal Policy Optimization,近端策略优化)是OpenAI于2017年提出的一种策略梯度强化学习算法,以其高效性、稳定性和易实现性成为强化学习领域的主流算法,也是ChatGPT等大模型RLHF训练的核心技术。

核心原理

问题背景

传统策略梯度方法存在两大痛点:

问题描述
更新步长敏感步长过大易导致策略崩溃,步长过小则收敛缓慢
样本利用率低需大量环境交互数据,单次更新后数据即失效

TRPO(Trust Region Policy Optimization)通过约束策略更新幅度解决这些问题,但实现复杂、计算成本高。PPO通过简化约束机制,在保持稳定性的同时大幅降低实现复杂度。

PPO的解决方案

技术手段作用
Clipped Surrogate Objective限制策略更新幅度,确保新策略与旧策略差异可控
重要性采样复用旧策略采集的数据,提升样本效率
自适应KL惩罚替代TRPO的复杂约束优化,降低计算成本

数学推导

策略梯度基础

策略梯度目标函数:

其中 为优势函数,衡量动作的相对价值。

PPO目标函数

PPO引入重要性采样比,构建clipped目标函数:

关键符号说明:

符号含义
概率比率:$\frac{\pi_\theta(a_t
优势函数估计
裁剪参数(通常为0.2)
限制在 区间

Clip机制的核心作用:

  • 限制 区间
  • 取最小值确保优化方向保守,避免过度偏离旧策略
  • 当优势为正时,鼓励该动作但不过度增加概率
  • 当优势为负时,抑制该动作但不过度降低概率

优势估计(GAE)

广义优势估计(Generalized Advantage Estimation):

其中:

  • :折扣因子(通常为0.99)
  • :GAE参数(通常为0.95),控制偏差-方差权衡

PyTorch代码实现

核心算法实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
 
class PPO:
    def __init__(self, actor_critic, clip_param=0.2, lr=3e-4, 
                 ent_coef=0.01, gamma=0.99, gae_lambda=0.95):
        self.actor_critic = actor_critic
        self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr)
        self.clip_param = clip_param
        self.ent_coef = ent_coef  # 熵正则化系数
        self.gamma = gamma
        self.gae_lambda = gae_lambda
 
    def compute_gae(self, rewards, values, next_values, dones):
        """计算广义优势估计"""
        advantages = []
        gae = 0
        
        # 从后往前计算
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
            advantages.append(gae)
        
        advantages = torch.tensor(advantages[::-1], dtype=torch.float32)
        # 归一化优势函数
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        return advantages
 
    def update(self, rollouts):
        """PPO更新步骤"""
        obs, actions, old_log_probs, returns, advantages = rollouts.sample()
 
        # 计算新策略的概率和熵
        dist, values = self.actor_critic(obs)
        new_log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()
 
        # 重要性采样比
        ratio = (new_log_probs - old_log_probs).exp()
        
        # PPO裁剪目标函数
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - self.clip_param, 
                           1 + self.clip_param) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
 
        # 价值函数损失(使用Huber损失)
        value_loss = 0.5 * (returns - values).pow(2).mean()
 
        # 总损失(含熵正则化)
        loss = policy_loss + value_loss - self.ent_coef * entropy
 
        # 梯度更新
        self.optimizer.zero_grad()
        loss.backward()
        # 梯度裁剪防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 0.5)
        self.optimizer.step()
        
        return policy_loss.item(), value_loss.item(), entropy.item()

算法流程伪代码

for epoch in 1, 2, ..., N:
    # 1. 收集数据
    for t in 1, 2, ..., T:
        使用当前策略 π_θ_old 与环境交互
        存储 {s_t, a_t, r_t, log_prob_t, V_t}
    
    # 2. 计算优势与回报
    使用GAE算法计算每个时间步的优势值 A_t
    计算回报 G_t = A_t + V_t
    
    # 3. 优化策略(多轮更新)
    for k in 1, 2, ..., K:
        随机采样一个 batch 数据
        计算重要性采样比 r_t(θ)
        计算 clipped 目标函数 L^CLIP
        更新策略网络参数 θ
        更新价值网络参数 φ

大模型RLHF中的应用

RLHF三阶段流程

PPO是RLHF(Reinforcement Learning from Human Feedback)第三阶段的核心算法:

阶段方法目标
Stage 1SFT(监督微调)让模型学会基本对话能力
Stage 2RM(奖励模型训练)训练一个能判断回答质量的模型
Stage 3PPO(强化学习优化)使用奖励模型优化策略模型

RLHF中的四个关键模型

在PPO阶段,涉及四个核心模型:

模型角色状态作用
Actor Model演员可训练目标语言模型,生成回复
Reference Model参考冻结防止Actor偏离原始SFT模型太远
Critic Model评论家可训练预测总收益
Reward Model奖励冻结计算即时收益

RLHF中的奖励计算

在LLM场景下,奖励函数设计特殊:

非最后token位置:

最后token位置:

即:整个回复的奖励分数只在最后一个token位置给出,其余位置主要由KL散度惩罚构成,防止模型偏离参考模型太远。

RLHF中的Actor Loss

其中优势函数 基于GAE计算:

RLHF中的Critic Loss

其中 (基于GAE计算的实际收益)。

同样会对预测值 进行裁剪,限制其在 的一定范围内。

应用案例

模型应用说明
ChatGPT / GPT-4PPO是RLHF训练的核心算法,确保输出符合人类偏好
Claude使用类似RLHF框架,PPO用于安全性和有用性对齐
Llama 2采用PPO进行RLHF训练,开源完整训练流程
InstructGPTOpenAI早期PPO应用示范

算法优势与局限

优势

特点说明
稳定性限制策略更新幅度,避免训练崩溃
效率支持大规模并行训练,样本利用率高
可控性KL散度惩罚确保输出质量
易实现相比TRPO实现简单,计算成本低

局限性与改进

局限性改进方向
局部最优陷阱PPO-Adaptive:动态调整Clip范围
高维动作空间调整困难POP:解耦策略与价值函数更新频率
探索能力受限结合元学习(如Meta-PPO)

与其他算法对比

PPO vs TRPO

特性PPOTRPO
约束方式裁剪目标函数KL散度硬约束
实现复杂度简单复杂(需二阶优化)
计算成本
性能相当或更好较好

PPO vs DPO

DPO(Direct Preference Optimization)是PPO的简化替代方案:

特性PPODPO
训练流程需奖励模型 + PPO优化直接优化,无需奖励模型
计算成本高(四个模型)低(仅需策略模型)
稳定性需调参更稳定
效果更强(复杂场景)较好(简单场景)

参考资源

原论文

教程与博客

开源实现

相关笔记