Skip to content

GRPOTrainer: Implement Unbiased KL Estimator from DeepSeek-V3.2 #4637

@jlcanta

Description

@jlcanta

Feature request

I propose implementing the unbiased KL divergence estimator with importance sampling correction as described in the DeepSeek-V3.2 paper (Equation 7).

Currently, GRPOTrainer calculates the per-token KL divergence using the standard approximation (Schulman, 2020):

per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

However, according to the DeepSeek-V3.2 paper, this estimator can be improved by correcting it with the importance sampling ratio between the current policy and the old policy used for data collection $(π_θ / π_{old})$.

The proposed formula is:

$$ \mathbb{D}_{KL} \approx \frac{\pi_\theta}{\pi_{\text{old}}} \left( \frac{\pi_{\text{ref}}}{\pi_\theta} - \log \frac{\pi_{\text{ref}}}{\pi_\theta} - 1 \right) $$

I propose adding a configuration flag to GRPOConfig (e.g., use_bias_correction_kl) to optionally enable this implementation while keeping the current behavior as default.

Motivation

The motivation is to improve training stability and convergence speed, particularly when the current policy deviates significantly from the reference policy.

As stated in the DeepSeek-V3.2 paper (Section: 3.1 Scaling GRPO):

"the original K3 estimator, particularly when the sampled tokens have substantially lower probabilities under the current policy than the reference policy, i.e., $\pi_\theta \ll \pi_{\text{ref}}$. In such cases, the gradient of the K3 estimator assigns disproportionately large, unbounded weights to maximize the likelihood of these tokens, resulting in noisy gradient updates that accumulate to degrade sample quality in subsequent iterations and lead to unstable training dynamics."

By including the importance sampling ratio correction ($π_θ/π_{old}$), the gradient of the KL estimator becomes unbiased. This eliminates systematic estimation errors and facilitates more stable convergence, which is critical for the reasoning tasks (math/code) where GRPO is primarily used.

Your contribution

Yes, I would love to submit a PR for this. I plan to:

  1. Add the boolean flag to GRPOConfig.
  2. Update the GRPOTrainer loss computation to include the importance sampling ratio in the KL term when the flag is enabled.
  3. Ensure backward compatibility is maintained.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions