Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/packages/jumpstarter-cli/jumpstarter_cli/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .common import opt_acquisition_timeout, opt_duration_partial, opt_selector
from .login import relogin_client
from jumpstarter.common.utils import launch_shell
from jumpstarter.config.client import ClientConfigV1Alpha1
from jumpstarter.config.client import ClientConfigV1Alpha1, raise_expired_token_error
from jumpstarter.config.exporter import ExporterConfigV1Alpha1


Expand Down Expand Up @@ -98,8 +98,7 @@ async def _shell_with_signal_handling(
if token:
remaining = get_token_remaining_seconds(token)
if remaining is not None and remaining <= 0:
from jumpstarter.common.exceptions import ConnectionError
raise ConnectionError("token is expired")
raise_expired_token_error(config)

async with create_task_group() as tg:
tg.start_soon(signal_handler, tg.cancel_scope)
Expand Down
30 changes: 18 additions & 12 deletions python/packages/jumpstarter/jumpstarter/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timedelta
from functools import wraps
from pathlib import Path
from typing import Annotated, ClassVar, Literal, Optional, Self
from typing import Annotated, ClassVar, Literal, NoReturn, Optional, Self

import grpc
import yaml
Expand Down Expand Up @@ -50,17 +50,27 @@ def wrapper(*args, **kwargs):
return wrapper


def _attach_config_if_expired_token(exc: ConnectionError, config: ClientConfigV1Alpha1) -> None:
"""Attach config to a ConnectionError so re-auth can use it. No-op if not token-expired."""
if "token is expired" in str(exc):
exc.set_config(config)


def raise_expired_token_error(config: ClientConfigV1Alpha1) -> NoReturn:
"""Raise ConnectionError for expired token with config attached so re-auth can run."""
err = ConnectionError("token is expired")
err.set_config(config)
raise err


def _handle_connection_error(f):
@wraps(f)
async def wrapper(*args, **kwargs):
try:
return await f(*args, **kwargs)
except ConnectionError as e:
if "token is expired" in str(e):
# args[0] should be self for instance methods
e.set_config(args[0])
raise e
except Exception:
# args[0] is self for instance methods
_attach_config_if_expired_token(e, args[0])
raise

return wrapper
Expand Down Expand Up @@ -291,13 +301,9 @@ async def lease_async(
) as lease:
yield lease

# this replicates _handle_connection_error, the decorator doesn't work with asynccontextmanager
# decorator doesn't work with asynccontextmanager, so we use the same helper
except ConnectionError as e:
if "token is expired" in str(e):
# args[0] should be self for instance methods
e.set_config(self)
raise e
except Exception:
_attach_config_if_expired_token(e, self)
raise

@classmethod
Expand Down
Loading