diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee0267f..38cc5c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.14 hooks: - - id: ruff + - id: ruff-check args: [--fix, --exit-non-zero-on-fix, --show-fixes] - id: ruff-format diff --git a/loq.toml b/loq.toml index dfee2cd..9086a7d 100644 --- a/loq.toml +++ b/loq.toml @@ -18,7 +18,7 @@ max_lines = 945 [[rules]] path = "src/docket/execution.py" -max_lines = 1020 +max_lines = 1025 [[rules]] path = "src/docket/docket.py" diff --git a/src/docket/dependencies/__init__.py b/src/docket/dependencies/__init__.py index 2a51077..f863c88 100644 --- a/src/docket/dependencies/__init__.py +++ b/src/docket/dependencies/__init__.py @@ -32,6 +32,7 @@ SharedContext, _Depends, _parameter_cache, + annotated_dependency, get_dependency_parameters, ) from ._perpetual import Perpetual @@ -66,6 +67,7 @@ "DependencyFunction", "Shared", "SharedContext", + "annotated_dependency", "get_dependency_parameters", # Retry "ForcedRetry", diff --git a/src/docket/dependencies/_functional.py b/src/docket/dependencies/_functional.py index 2d1f0ad..3a03ffb 100644 --- a/src/docket/dependencies/_functional.py +++ b/src/docket/dependencies/_functional.py @@ -71,6 +71,12 @@ async def _resolve_factory_value( return cast(R, raw_value) +def annotated_dependency(param: inspect.Parameter) -> Dependency | None: + """Return the first Dependency found in an Annotated parameter's metadata, or None.""" + metadata = getattr(param.annotation, "__metadata__", ()) + return next((item for item in metadata if isinstance(item, Dependency)), None) + + _parameter_cache: dict[ TaskFunction | DependencyFunction[Any], dict[str, Dependency], @@ -89,10 +95,10 @@ def get_dependency_parameters( signature = get_signature(function) for parameter, param in signature.parameters.items(): - if not isinstance(param.default, Dependency): - continue - - dependencies[parameter] = param.default + if isinstance(param.default, Dependency): + dependencies[parameter] = param.default + elif (dep := annotated_dependency(param)) is not None: + dependencies[parameter] = dep _parameter_cache[function] = dependencies CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"}) diff --git a/src/docket/execution.py b/src/docket/execution.py index c28d659..7140172 100644 --- a/src/docket/execution.py +++ b/src/docket/execution.py @@ -990,13 +990,16 @@ async def subscribe(self) -> AsyncGenerator[StateEvent | ProgressEvent, None]: def compact_signature(signature: inspect.Signature) -> str: - from .dependencies import Dependency + from .dependencies import Dependency, annotated_dependency parameters: list[str] = [] dependencies: int = 0 for parameter in signature.parameters.values(): - if isinstance(parameter.default, Dependency): + if ( + isinstance(parameter.default, Dependency) + or annotated_dependency(parameter) is not None + ): dependencies += 1 continue diff --git a/tests/fundamentals/test_annotated_dependencies.py b/tests/fundamentals/test_annotated_dependencies.py new file mode 100644 index 0000000..8b11679 --- /dev/null +++ b/tests/fundamentals/test_annotated_dependencies.py @@ -0,0 +1,162 @@ +"""Tests for Annotated[T, Depends(...)] style dependency injection.""" + +from contextlib import asynccontextmanager +from typing import Annotated, AsyncGenerator +from uuid import uuid4 + +from docket import ( + CurrentDocket, + CurrentExecution, + CurrentWorker, + Depends, + Docket, + Execution, + Worker, +) + + +async def test_annotated_function_dependency(docket: Docket, worker: Worker): + """A task can declare dependencies using Annotated[T, Depends(fn)] syntax.""" + + async def get_greeting() -> str: + return f"hello-{uuid4()}" + + called = False + + async def the_task(greeting: Annotated[str, Depends(get_greeting)]): + assert greeting.startswith("hello-") + + nonlocal called + called = True + + await docket.add(the_task)() # pyright: ignore[reportCallIssue] + await worker.run_until_finished() + + assert called + + +async def test_annotated_context_manager_dependency(docket: Docket, worker: Worker): + """Annotated dependencies work with async context manager factories.""" + + stages: list[str] = [] + + @asynccontextmanager + async def get_resource() -> AsyncGenerator[str, None]: + stages.append("setup") + yield "resource-value" + stages.append("teardown") + + called = False + + async def the_task(resource: Annotated[str, Depends(get_resource)]): + assert resource == "resource-value" + + nonlocal called + called = True + + await docket.add(the_task)() # pyright: ignore[reportCallIssue] + await worker.run_until_finished() + + assert called + assert stages == ["setup", "teardown"] + + +async def test_annotated_contextual_dependencies(docket: Docket, worker: Worker): + """Annotated syntax works with contextual dependencies like CurrentDocket.""" + + called = False + + async def the_task( + a: str, + this_docket: Annotated[Docket, CurrentDocket()], + this_worker: Annotated[Worker, CurrentWorker()], + this_execution: Annotated[Execution, CurrentExecution()], + ): + assert a == "hello" + assert this_docket is docket + assert this_worker is worker + assert isinstance(this_execution, Execution) + + nonlocal called + called = True + + await docket.add(the_task)("hello") # pyright: ignore[reportCallIssue] + await worker.run_until_finished() + + assert called + + +async def test_annotated_mixed_with_positional_args(docket: Docket, worker: Worker): + """Annotated dependencies mix freely with regular positional arguments.""" + + called = False + + async def get_config() -> dict[str, int]: + return {"version": 2} + + async def the_task( + name: str, + count: int, + config: Annotated[dict[str, int], Depends(get_config)], + ): + assert name == "test" + assert count == 42 + assert config == {"version": 2} + + nonlocal called + called = True + + await docket.add(the_task)("test", 42) # pyright: ignore[reportCallIssue] + await worker.run_until_finished() + + assert called + + +async def test_annotated_with_default_style_deps(docket: Docket, worker: Worker): + """Annotated and default-style dependencies can coexist on the same task.""" + + async def dep_a() -> str: + return "from-annotated" + + async def dep_b() -> str: + return "from-default" + + called = False + + async def the_task( + a: Annotated[str, Depends(dep_a)], + b: str = Depends(dep_b), + ): + assert a == "from-annotated" + assert b == "from-default" + + nonlocal called + called = True + + await docket.add(the_task)() # pyright: ignore[reportCallIssue] + await worker.run_until_finished() + + assert called + + +async def test_annotated_reusable_type_alias(docket: Docket, worker: Worker): + """Annotated types can be reused as aliases across multiple tasks.""" + + async def get_db() -> str: + return "db-connection" + + DBConn = Annotated[str, Depends(get_db)] + + results: list[str] = [] + + async def task_one(db: DBConn): # pyright: ignore[reportInvalidTypeForm,reportUnknownParameterType] + results.append(f"one:{db}") + + async def task_two(db: DBConn): # pyright: ignore[reportInvalidTypeForm,reportUnknownParameterType] + results.append(f"two:{db}") + + await docket.add(task_one)() # pyright: ignore[reportCallIssue,reportUnknownArgumentType] + await docket.add(task_two)() # pyright: ignore[reportCallIssue,reportUnknownArgumentType] + await worker.run_until_finished() + + assert sorted(results) == ["one:db-connection", "two:db-connection"] diff --git a/tests/test_execution.py b/tests/test_execution.py index 67df7d8..5600bc1 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -47,6 +47,20 @@ async def only_dependencies( ) -> None: ... # pragma: no cover +async def annotated_dependencies( + a: str, + b: Annotated[str, Depends(a_dependency)], + c: Annotated[Docket, CurrentDocket()], +) -> None: ... # pragma: no cover + + +async def only_annotated_dependencies( + a: Annotated[str, Depends(a_dependency)], + b: Annotated[Docket, CurrentDocket()], + c: Annotated[Worker, CurrentWorker()], +) -> None: ... # pragma: no cover + + @pytest.mark.parametrize( "function, expected", [ @@ -57,6 +71,8 @@ async def only_dependencies( (logged_args, "a: str, b: str = 'foo'"), (dependencies, "a: str, b: int = 42, ..."), (only_dependencies, "..."), + (annotated_dependencies, "a: str, ..."), + (only_annotated_dependencies, "..."), ], ) async def test_compact_signature(