Skip to content

Conversation

@bhuvanprakash
Copy link

This​‍​‌‍​‍‌​‍​‌‍​‍‌ PR resolves a CUDA error in KTOTrainer during model training with extremely large vocabularies (e.g., Qwen3-VL ~151k vocab). Hence, it opens up the possibility of KTO training on multimodal models having extended vocabularies like Qwen3-VL or alike.

This problem was initially seen downstream in Unsloth while performing KTO training. After a detailed study, the source was found to be Hugging Face TRL implementation. This PR corrects the root cause upstream so that local patches are no longer necessary for downstream projects (e.g., Unsloth). Tested successfully in Unsloth with Qwen3-VL after this fix.

The error resulted from Python list-based fancy indexing of tensors having shapes similar to:

[batch_size, seq_len, vocab_size]

For extremely large vocab sizes, this might lead to:

CUDA error: invalid configuration argument

Downstream issue reference:

unslothai/unsloth#3675

Root Cause

In trl/trainer/kto_trainer.py, Python lists are utilized to split the batch and these lists are simultaneously employed for fancy indexing on big tensors:

chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i]] rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["label"][i]] 
chosen_logps = completion_logps[chosen_idx, ...] rejected_logps = completion_logps[rejected_idx, ...] chosen_logits = completion_logits[chosen_idx, ...] rejected_logits = completion_logits[rejected_idx, ...] 
While dealing with extremely large vocab tensors ([batch, seq, vocab]), this makes a CUDA kernel to get launched with an invalid configuration and thus training crashes.

A similar pattern can be found in the reference_logps branch of get_batch_loss_metrics(...).

Fix

Python list indexing is substituted with device-aware tensor indices and torch.index_select that calls the CUDA kernels that are optimized.

forward(...) device = completion_logits.device labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device) chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1) rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1) chosen_logps = completion_logps.index_select(0, chosen_idx) rejected_logps = completion_logps.index_select(0, rejected_idx) chosen_logits = completion_logits.index_select(0, chosen_idx) rejected_logits = completion_logits.index_select(0, rejected_idx) 
get_batch_loss_metrics(...) (when using reference_logps) device = batch["reference_logps"].device labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device) chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1) rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1) reference_chosen_logps = batch["reference_logps"].index_select(0, chosen_idx) reference_rejected_logps = batch["reference_logps"].index_select(0, rejected_idx) 
This maintains the semantics to be exactly the same but employs safe CUDA kernels for indexing.

Why This Works

  • index_select uses optimized CUDA kernels
  • indices are on the correct device
  • avoids Python list fancy indexing
  • no CPU-GPU sync overhead
  • identical behavior for normal vocab sizes
  • robust for 150k+ vocab models
All existing logic paths remain unchanged for normal vocab scenarios.

Testing Unit Tests

Executed:

python3.10 -m pytest tests/test_kto_trainer.py -v

All KTO tests pass:

  • TestKTOTrainer::test_kto_trainer[...]
  • test_kto_trainer_with_ref_model_is_model
  • test_tokenize_and_process_tokens
  • test_kto_trainer_without_providing_ref_model
  • test_kto_trainer_generate_during_eval_no_wandb
  • test_compute_metrics
(LoRA / liger tests skipped as per upstream config.)

Large-Vocab Simulation

Dummy tensor was created:

[4, 256, 151_936]

Compared:

old: tensor[list, ...]

new: tensor.index_select(0, idx_tensor)

Verification:

torch.allclose(old, new) == True

shapes match

values match

So behavior is unchanged — only the indexing method is safer.

Backward Compatibility

  • no API changes
  • no behavior changes for normal vocab sizes
  • only internal indexing logic updated
  • fixes KTO training for large vocab models (e.g. Qwen family)
Impact

This allows KTO training for multilingual and multimodal models with extended vocabularies without triggering CUDA kernel launch errors. Downstream frameworks like Unsloth will get this fix automatically when they update TRL; hence, they don't need ​‍​‌‍​‍‌​‍​‌‍​‍‌local

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses a CUDA error that occurs when training KTOTrainer with models that have extremely large vocabularies (e.g., Qwen3-VL with ~151k vocab size). The fix converts Python list-based fancy indexing to tensor-based indexing operations to prevent invalid CUDA kernel configurations.

Key Changes

  • Replaced Python list fancy indexing with torch.tensor() conversion + index_select() in the forward method
  • Applied the same fix to the get_batch_loss_metrics method when using pre-computed reference log probabilities
  • Ensures indices are on the correct device for CUDA operations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant