-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Feature request
I propose implementing the unbiased KL divergence estimator with importance sampling correction as described in the DeepSeek-V3.2 paper (Equation 7).
Currently, GRPOTrainer calculates the per-token KL divergence using the standard approximation (Schulman, 2020):
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
However, according to the DeepSeek-V3.2 paper, this estimator can be improved by correcting it with the importance sampling ratio between the current policy and the old policy used for data collection
The proposed formula is:
I propose adding a configuration flag to GRPOConfig (e.g., use_bias_correction_kl) to optionally enable this implementation while keeping the current behavior as default.
Motivation
The motivation is to improve training stability and convergence speed, particularly when the current policy deviates significantly from the reference policy.
As stated in the DeepSeek-V3.2 paper (Section: 3.1 Scaling GRPO):
"the original K3 estimator, particularly when the sampled tokens have substantially lower probabilities under the current policy than the reference policy, i.e.,
$\pi_\theta \ll \pi_{\text{ref}}$ . In such cases, the gradient of the K3 estimator assigns disproportionately large, unbounded weights to maximize the likelihood of these tokens, resulting in noisy gradient updates that accumulate to degrade sample quality in subsequent iterations and lead to unstable training dynamics."
By including the importance sampling ratio correction (
Your contribution
Yes, I would love to submit a PR for this. I plan to:
- Add the boolean flag to GRPOConfig.
- Update the GRPOTrainer loss computation to include the importance sampling ratio in the KL term when the flag is enabled.
- Ensure backward compatibility is maintained.