Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ wandb/
# Logs
logs/

# Models
rf3_latest.pt
rf3_921.pt

# Other
*.sif
*.out
Expand Down
15 changes: 14 additions & 1 deletion src/modelhub/inference_engines/rf3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from os import PathLike
from pathlib import Path

Expand Down Expand Up @@ -133,7 +134,19 @@ def __init__(

# Load the training config from the checkpoint
# TODO: Load checkpoint only once (instead of twice)
ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...")
try:
resolved_ckpt_path = Path(ckpt_path).resolve(strict=True)
except OSError:
# If path does not exist as-is, attempt to load from the project directory
proj_ckpt_path = os.path.join( os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), ckpt_path )
try:
resolved_ckpt_path = Path(proj_ckpt_path).resolve(strict=True)
ckpt_path = proj_ckpt_path # Successful
except OSError:
resolved_ckpt_path = Path(ckpt_path).resolve(strict=False)
pass # Just keep the current ckpt_path

ranked_logger.info(f"Loading checkpoint from {resolved_ckpt_path}...")
checkpoint = torch.load(
ckpt_path, "cpu", weights_only=False
) # We only extract the `train_cfg` from the checkpoint initially
Expand Down