diff --git a/.gitignore b/.gitignore index aceebc1e..4a88d02f 100644 --- a/.gitignore +++ b/.gitignore @@ -200,6 +200,10 @@ wandb/ # Logs logs/ +# Models +rf3_latest.pt +rf3_921.pt + # Other *.sif *.out diff --git a/src/modelhub/inference_engines/rf3.py b/src/modelhub/inference_engines/rf3.py index 28cf344c..bfd93a56 100644 --- a/src/modelhub/inference_engines/rf3.py +++ b/src/modelhub/inference_engines/rf3.py @@ -1,4 +1,5 @@ import logging +import os from os import PathLike from pathlib import Path @@ -133,7 +134,19 @@ def __init__( # Load the training config from the checkpoint # TODO: Load checkpoint only once (instead of twice) - ranked_logger.info(f"Loading checkpoint from {Path(ckpt_path).resolve()}...") + try: + resolved_ckpt_path = Path(ckpt_path).resolve(strict=True) + except OSError: + # If path does not exist as-is, attempt to load from the project directory + proj_ckpt_path = os.path.join( os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), ckpt_path ) + try: + resolved_ckpt_path = Path(proj_ckpt_path).resolve(strict=True) + ckpt_path = proj_ckpt_path # Successful + except OSError: + resolved_ckpt_path = Path(ckpt_path).resolve(strict=False) + pass # Just keep the current ckpt_path + + ranked_logger.info(f"Loading checkpoint from {resolved_ckpt_path}...") checkpoint = torch.load( ckpt_path, "cpu", weights_only=False ) # We only extract the `train_cfg` from the checkpoint initially