diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py index 41155a4f..045566b6 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -16,7 +16,7 @@ from .common import opt_acquisition_timeout, opt_duration_partial, opt_selector from .login import relogin_client from jumpstarter.common.utils import launch_shell -from jumpstarter.config.client import ClientConfigV1Alpha1 +from jumpstarter.config.client import ClientConfigV1Alpha1, raise_expired_token_error from jumpstarter.config.exporter import ExporterConfigV1Alpha1 @@ -98,8 +98,7 @@ async def _shell_with_signal_handling( if token: remaining = get_token_remaining_seconds(token) if remaining is not None and remaining <= 0: - from jumpstarter.common.exceptions import ConnectionError - raise ConnectionError("token is expired") + raise_expired_token_error(config) async with create_task_group() as tg: tg.start_soon(signal_handler, tg.cancel_scope) diff --git a/python/packages/jumpstarter/jumpstarter/config/client.py b/python/packages/jumpstarter/jumpstarter/config/client.py index 448f7d03..56084fb4 100644 --- a/python/packages/jumpstarter/jumpstarter/config/client.py +++ b/python/packages/jumpstarter/jumpstarter/config/client.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta from functools import wraps from pathlib import Path -from typing import Annotated, ClassVar, Literal, Optional, Self +from typing import Annotated, ClassVar, Literal, NoReturn, Optional, Self import grpc import yaml @@ -50,17 +50,27 @@ def wrapper(*args, **kwargs): return wrapper +def _attach_config_if_expired_token(exc: ConnectionError, config: ClientConfigV1Alpha1) -> None: + """Attach config to a ConnectionError so re-auth can use it. No-op if not token-expired.""" + if "token is expired" in str(exc): + exc.set_config(config) + + +def raise_expired_token_error(config: ClientConfigV1Alpha1) -> NoReturn: + """Raise ConnectionError for expired token with config attached so re-auth can run.""" + err = ConnectionError("token is expired") + err.set_config(config) + raise err + + def _handle_connection_error(f): @wraps(f) async def wrapper(*args, **kwargs): try: return await f(*args, **kwargs) except ConnectionError as e: - if "token is expired" in str(e): - # args[0] should be self for instance methods - e.set_config(args[0]) - raise e - except Exception: + # args[0] is self for instance methods + _attach_config_if_expired_token(e, args[0]) raise return wrapper @@ -291,13 +301,9 @@ async def lease_async( ) as lease: yield lease - # this replicates _handle_connection_error, the decorator doesn't work with asynccontextmanager + # decorator doesn't work with asynccontextmanager, so we use the same helper except ConnectionError as e: - if "token is expired" in str(e): - # args[0] should be self for instance methods - e.set_config(self) - raise e - except Exception: + _attach_config_if_expired_token(e, self) raise @classmethod