From d3eaee44cdeff6b9b8ff684b23529bc363dc18ce Mon Sep 17 00:00:00 2001 From: Rocco Moretti Date: Tue, 30 Sep 2025 15:56:13 -0500 Subject: [PATCH] feat: Checkpoint locating support for running out-of-directory Currently, only the current directory is searched for the model checkpoint. While manual specification of the checkpoint path is possible, this requires that you A) remember to manually specify the path and B) know what the proper path is. To make running outside of the installation directory easier, add a fall-back to search the project path for the checkpoint file. This only applies only if you would otherwise get an error with the current behavior. --- .gitignore | 4 ++++ src/modelhub/inference_engines/rf3.py | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index aceebc1e..4a88d02f 100644 --- a/.gitignore +++ b/.gitignore @@ -200,6 +200,10 @@ wandb/ # Logs logs/ +# Models +rf3_latest.pt +rf3_921.pt + # Other *.sif *.out diff --git a/src/modelhub/inference_engines/rf3.py b/src/modelhub/inference_engines/rf3.py index 28cf344c..bfd93a56 100644 --- a/src/modelhub/inference_engines/rf3.py +++ b/src/modelhub/inference_engines/rf3.py @@ -1,4 +1,5 @@ import logging +import os from os import PathLike from pathlib import Path @@ -133,7 +134,19 @@ def __init__( # Load the training config from the checkpoint # TODO: Load checkpoint only once (instead of twice) - ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...") + try: + resolved_ckpt_path = Path(ckpt_path).resolve(strict=True) + except OSError: + # If path does not exist as-is, attempt to load from the project directory + proj_ckpt_path = os.path.join( os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), ckpt_path ) + try: + resolved_ckpt_path = Path(proj_ckpt_path).resolve(strict=True) + ckpt_path = proj_ckpt_path # Successful + except OSError: + resolved_ckpt_path = Path(ckpt_path).resolve(strict=False) + pass # Just keep the current ckpt_path + + ranked_logger.info(f"Loading checkpoint from {resolved_ckpt_path}...") checkpoint = torch.load( ckpt_path, "cpu", weights_only=False ) # We only extract the `train_cfg` from the checkpoint initially