Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lerobot/configs/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __post_init__(self) -> None:
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
# Keep as string to preserve forward slashes for HuggingFace repo IDs
self.policy.pretrained_path = policy_path

else:
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion src/lerobot/configs/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: Path | None = None
pretrained_path: str | None = None

def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
Expand Down
42 changes: 37 additions & 5 deletions src/lerobot/configs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def validate(self) -> None:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
# Keep as string to preserve forward slashes for HuggingFace repo IDs
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
Expand All @@ -92,7 +93,8 @@ def validate(self) -> None:

policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
# Convert to string with forward slashes for consistency
self.policy.pretrained_path = policy_dir.as_posix()
self.checkpoint_path = policy_dir.parent

if self.policy is None:
Expand Down Expand Up @@ -136,11 +138,41 @@ def __get_path_fields__(cls) -> list[str]:
return ["policy"]

def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
# Convert Path objects to POSIX-style strings to avoid Windows backslashes
if isinstance(self.output_dir, Path):
output_dir_backup = self.output_dir
self.output_dir = self.output_dir.as_posix() # type: ignore[assignment]
if isinstance(self.checkpoint_path, Path):
checkpoint_path_backup = self.checkpoint_path
self.checkpoint_path = self.checkpoint_path.as_posix() # type: ignore[assignment]

try:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
finally:
# Restore Path objects
if isinstance(self.output_dir, str) and "output_dir_backup" in locals():
self.output_dir = output_dir_backup
if isinstance(self.checkpoint_path, str) and "checkpoint_path_backup" in locals():
self.checkpoint_path = checkpoint_path_backup

def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
# Convert Path objects to POSIX-style strings to avoid Windows backslashes in saved configs
if isinstance(self.output_dir, Path):
output_dir_backup = self.output_dir
self.output_dir = self.output_dir.as_posix() # type: ignore[assignment]
if isinstance(self.checkpoint_path, Path):
checkpoint_path_backup = self.checkpoint_path
self.checkpoint_path = self.checkpoint_path.as_posix() # type: ignore[assignment]

try:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
finally:
# Restore Path objects
if isinstance(self.output_dir, str) and "output_dir_backup" in locals():
self.output_dir = output_dir_backup
if isinstance(self.checkpoint_path, str) and "checkpoint_path_backup" in locals():
self.checkpoint_path = checkpoint_path_backup

@classmethod
def from_pretrained(
Expand Down