Fix KTOTrainer CUDA error for large-vocab models via tensor indexing #4635
+19
−11
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR resolves a CUDA error in KTOTrainer during model training with extremely large vocabularies (e.g., Qwen3-VL ~151k vocab). Hence, it opens up the possibility of KTO training on multimodal models having extended vocabularies like Qwen3-VL or alike.
This problem was initially seen downstream in Unsloth while performing KTO training. After a detailed study, the source was found to be Hugging Face TRL implementation. This PR corrects the root cause upstream so that local patches are no longer necessary for downstream projects (e.g., Unsloth). Tested successfully in Unsloth with Qwen3-VL after this fix.
The error resulted from Python list-based fancy indexing of tensors having shapes similar to:
[batch_size, seq_len, vocab_size]
For extremely large vocab sizes, this might lead to:
CUDA error: invalid configuration argument
Downstream issue reference:
unslothai/unsloth#3675
Root Cause
In trl/trainer/kto_trainer.py, Python lists are utilized to split the batch and these lists are simultaneously employed for fancy indexing on big tensors:
A similar pattern can be found in the reference_logps branch of get_batch_loss_metrics(...).
Fix
Python list indexing is substituted with device-aware tensor indices and torch.index_select that calls the CUDA kernels that are optimized.
Why This Works
- index_select uses optimized CUDA kernels
- indices are on the correct device
- avoids Python list fancy indexing
- no CPU-GPU sync overhead
- identical behavior for normal vocab sizes
- robust for 150k+ vocab models
All existing logic paths remain unchanged for normal vocab scenarios.Testing Unit Tests
Executed:
python3.10 -m pytest tests/test_kto_trainer.py -v
All KTO tests pass:
- TestKTOTrainer::test_kto_trainer[...]
- test_kto_trainer_with_ref_model_is_model
- test_tokenize_and_process_tokens
- test_kto_trainer_without_providing_ref_model
- test_kto_trainer_generate_during_eval_no_wandb
- test_compute_metrics
(LoRA / liger tests skipped as per upstream config.)Large-Vocab Simulation
Dummy tensor was created:
[4, 256, 151_936]
Compared:
old: tensor[list, ...]
new: tensor.index_select(0, idx_tensor)
Verification:
torch.allclose(old, new) == True
shapes match
values match
So behavior is unchanged — only the indexing method is safer.
Backward Compatibility
- no API changes
- no behavior changes for normal vocab sizes
- only internal indexing logic updated
- fixes KTO training for large vocab models (e.g. Qwen family)
ImpactThis allows KTO training for multilingual and multimodal models with extended vocabularies without triggering CUDA kernel launch errors. Downstream frameworks like Unsloth will get this fix automatically when they update TRL; hence, they don't need local