Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
14a7a4d
Add multitask diffusion transformer policy
brysonjones Nov 13, 2025
ab97d5c
Merge branch 'main' into feature/add-multitask-dit
brysonjones Nov 13, 2025
8b9fada
expand the observation encoder to support differnt size encoders for …
brysonjones Nov 21, 2025
34499cb
Merge branch 'main' into feature/add-multitask-dit
brysonjones Nov 29, 2025
a0d5a08
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 9, 2025
46ebcc2
add RoPE attention module as this is shown to help training dynamics …
brysonjones Dec 9, 2025
22714af
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 10, 2025
55e19ff
update readme and citations for multitask dit policy
brysonjones Dec 10, 2025
adabb37
remove dino vision encoder and simplify text and vision encoders by r…
brysonjones Dec 10, 2025
6f85601
adjust factory comment
brysonjones Dec 10, 2025
cdacc09
update docstring for multitask dit policy processor file
brysonjones Dec 10, 2025
103230c
simplify config for multitask dit by merging and flattening everythin…
brysonjones Dec 10, 2025
b92dc82
add references to the modeling file comments
brysonjones Dec 10, 2025
3b2a4f5
merge all modules files into the main modeling file
brysonjones Dec 10, 2025
3a16a00
add torch.no_grad decorators
brysonjones Dec 10, 2025
5524a0d
split up select action return statement
brysonjones Dec 10, 2025
10cfc17
remove redundant asserts
brysonjones Dec 10, 2025
f1ac454
add tutorial to training with multi_task_dit
brysonjones Dec 10, 2025
d49d339
Merge branch 'main' into feature/add-multitask-dit
brysonjones Dec 10, 2025
ba968e8
fix bugs when testing on hardware
brysonjones Dec 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 294 additions & 0 deletions docs/source/multitask_dit.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# Multi-Task DiT Policy

Multi-Task Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multi-task robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.

## Model Overview

The model uses:

- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation

This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
VLAs, with only ~450M parameters and significantly less training.

## Installation Requirements

Multi-Task DiT Policy has additional dependencies. Install it with:

```bash
pip install lerobot[multi_task_dit]
```

This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.

## Usage

To use Multi-Task DiT in your LeRobot configuration, specify the policy type as:

```python
policy.type=multi_task_dit
```

## Training

### Basic Training Command

Here's a complete training command for training Multi-Task DiT on your dataset:

```bash
lerobot-train \
--dataset.repo_id={{MY_DATASET_ID}} \
--output_dir={{MY_OUTPUT_DIR}} \
--policy.type=multi_task_dit \
--policy.device=cuda \
--policy.repo_id={{MY_REPO_ID}}
--batch_size=32 \
--steps=5000 \
--save_freq=500 \
--log_freq=100 \
--wandb.enable=true
```

### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)

For reliable performance, start with these suggested default hyperparameters:

```bash
lerobot-train \
--dataset.repo_id={{MY_DATASET_ID}} \
--output_dir={{MY_OUTPUT_DIR}} \
--policy.type=multi_task_dit \
--policy.device=cuda \
--batch_size=320 \
--steps=30000 \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.repo_id={{MY_REPO_ID}} \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_train_timesteps=100 \
--wandb.enable=true
```

**Key Parameters:**

- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
- **Training Steps**: >30k steps recommended for a single task

### Training Configuration Parameters

#### Objective Selection

Choose between diffusion and flow matching:

```bash
# Diffusion objective (default)
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
--policy.num_train_timesteps=100 \
--policy.num_inference_steps=10 \ # For faster inference

# Flow matching objective
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
--policy.num_integration_steps=100 \
--policy.integration_method=euler \ # or "rk4"
```

#### Transformer Architecture

Adjust model capacity based on dataset size:

```bash
# Small datasets (< 100 examples)
--policy.num_layers=4 \
--policy.hidden_dim=512

# Medium datasets (100-5k examples) - default
--policy.num_layers=6 \
--policy.hidden_dim=512

# Large datasets (> 5k examples)
--policy.num_layers=8 \
--policy.hidden_dim=512
```

#### Vision Encoder Configuration

```bash
# Use different CLIP model for more expressivity at the cost of inference time
--policy.vision_encoder_name=openai/clip-vit-large-patch14

# Image preprocessing
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
--policy.image_crop_shape=[224,224] \
--policy.image_crop_is_random=true # Random during training, center at inference
```

#### Learning Rate Configuration

The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:

```bash
--policy.optimizer_lr=2e-5 \
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
```

### Training Tuning Guidelines

#### 1. Flow Matching with Beta Sampling

Consider switching to flow matching with beta sampling distribution for potentially improved performance:

```bash
--policy.objective=flow_matching \
--policy.timestep_sampling_strategy=beta \
--policy.timestep_sampling_alpha=1.5 \
--policy.timestep_sampling_beta=1.0 \
--policy.timestep_sampling_s=0.999
```

This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.

#### 2. Number of Transformer Layers

Match model capacity to your dataset size:

- **Small datasets** (< 100 examples): Reduce to 4 layers
- **Large datasets** (> 5k examples): Increase to 8 layers

#### 3. `horizon` Tuning

The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:

- **30 Hz frequency**: `horizon=30`
- **10 Hz frequency**: `horizon=10`

Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.

#### 4. `n_action_steps` Sensitivity

The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:

- **Lower values**: More reactive but potentially less stable for long-horizon tasks
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery

### Inference Tuning

For faster inference, use DDIM with fewer sampling steps:

```bash
--policy.noise_scheduler_type=DDIM \
--policy.num_inference_steps=10
```

### Resuming Training

To resume training from a checkpoint:

```bash
lerobot-train \
--config_path=$OUTPUT_DIR/checkpoints/00001000/pretrained_model/train_config.json \
--resume=true \
--output_dir=$OUTPUT_DIR
```

The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training).

## Common Failure Modes and Debugging

Training these models can be finicky. Here are common failure modes and debugging approaches:

### Idling / No Motion

The model may "collapse" during inference, resulting in static or no motion. This can occur when:

1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.

2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.

**Debugging tips:**

- Increase dataset size (double until you get to over 300 examples)
- Train for longer, up to 100k steps, even when the loss flatlines
- Check if the model is receiving proper language instructions or increase diversity of instruction

### Executing the Wrong Task

Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.

**Potential causes:**

- Language instruction ambiguity
- Insufficient task-specific training data
- Model confusion between similar tasks in the multitask dataset

**Debugging tips:**

- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
- Check task distribution in your training dataset and add weighting to the failing/ignored task
- Consider task-specific fine-tuning

### Training Instability

If training loss is unstable or diverging:

- Try adjusting learning rate between `1e-5` and `3e-4`
- Increase batch size if possible
- Check that your dataset normalization is correct
- Verify image preprocessing is working correctly

## Performance Considerations

### GPU Requirements

- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc

### Batch Size Recommendations

- **Minimum**: 64 (less than this may result in unstable training)
- **Recommended**: 256-320 (best performance, requires larger GPU)

## Example: Training on Custom Dataset

Here's a complete example training on a custom dataset:

```bash
lerobot-train \
--dataset.repo_id={{MY_DATASET_ID}} \
--output_dir={{MY_OUTPUT_DIR}} \
--policy.type=multi_task_dit \
--policy.device=cuda \
--batch_size=320 \
--steps=30000 \
--save_freq=1000 \
--log_freq=100 \
--eval_freq=1000 \
--policy.horizon=32 \
--policy.n_action_steps=24 \
--policy.objective=diffusion \
--policy.noise_scheduler_type=DDPM \
--policy.num_layers=6 \
--policy.hidden_dim=512 \
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
--policy.image_resize_shape=[320,240] \
--policy.image_crop_shape=[224,224] \
--wandb.enable=true \
--wandb.project=multitask_dit \
--policy.repo_id={{MY_REPO_ID}}
```

## References

For more details on the technical implementation and architecture, see:

- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
Expand Down
2 changes: 1 addition & 1 deletion src/lerobot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
)

# lists all available policies from `lerobot/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"]

# lists all available robots from `lerobot/robots`
available_robots = [
Expand Down
2 changes: 2 additions & 0 deletions src/lerobot/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
Expand All @@ -26,6 +27,7 @@
__all__ = [
"ACTConfig",
"DiffusionConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI05Config",
"SmolVLAConfig",
Expand Down
23 changes: 20 additions & 3 deletions src/lerobot/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
Expand Down Expand Up @@ -61,7 +62,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:

Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".

Returns:
The policy class corresponding to the given name.
Expand All @@ -81,6 +82,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.act.modeling_act import ACTPolicy

return ACTPolicy
elif name == "multi_task_dit":
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy

return MultiTaskDiTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy

Expand Down Expand Up @@ -129,8 +134,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:

Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier".
**kwargs: Keyword arguments to be passed to the configuration class constructor.

Returns:
Expand All @@ -145,6 +150,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "multi_task_dit":
return MultiTaskDiTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
Expand Down Expand Up @@ -289,6 +296,16 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)

elif isinstance(policy_cfg, MultiTaskDiTConfig):
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)

processors = make_multi_task_dit_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors

Expand Down
Loading