Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.14
hooks:
- id: ruff
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
- id: ruff-format

Expand Down
2 changes: 1 addition & 1 deletion loq.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ max_lines = 945

[[rules]]
path = "src/docket/execution.py"
max_lines = 1020
max_lines = 1025

[[rules]]
path = "src/docket/docket.py"
Expand Down
2 changes: 2 additions & 0 deletions src/docket/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SharedContext,
_Depends,
_parameter_cache,
annotated_dependency,
get_dependency_parameters,
)
from ._perpetual import Perpetual
Expand Down Expand Up @@ -66,6 +67,7 @@
"DependencyFunction",
"Shared",
"SharedContext",
"annotated_dependency",
"get_dependency_parameters",
# Retry
"ForcedRetry",
Expand Down
14 changes: 10 additions & 4 deletions src/docket/dependencies/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ async def _resolve_factory_value(
return cast(R, raw_value)


def annotated_dependency(param: inspect.Parameter) -> Dependency | None:
"""Return the first Dependency found in an Annotated parameter's metadata, or None."""
metadata = getattr(param.annotation, "__metadata__", ())
return next((item for item in metadata if isinstance(item, Dependency)), None)
Comment on lines +76 to +77

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Resolve postponed annotations before scanning metadata

When tasks are defined in modules using from __future__ import annotations (common on Python 3.10), inspect.signature() exposes param.annotation as a string, so __metadata__ is missing and this helper returns None. That means Annotated[..., Depends(...)] parameters are silently skipped during dependency extraction, and workers later invoke the task without required injected args, raising runtime TypeError for missing parameters instead of resolving dependencies.

Useful? React with 👍 / 👎.



_parameter_cache: dict[
TaskFunction | DependencyFunction[Any],
dict[str, Dependency],
Expand All @@ -89,10 +95,10 @@ def get_dependency_parameters(
signature = get_signature(function)

for parameter, param in signature.parameters.items():
if not isinstance(param.default, Dependency):
continue

dependencies[parameter] = param.default
if isinstance(param.default, Dependency):
dependencies[parameter] = param.default
elif (dep := annotated_dependency(param)) is not None:
dependencies[parameter] = dep

_parameter_cache[function] = dependencies
CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
Expand Down
7 changes: 5 additions & 2 deletions src/docket/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,13 +990,16 @@ async def subscribe(self) -> AsyncGenerator[StateEvent | ProgressEvent, None]:


def compact_signature(signature: inspect.Signature) -> str:
from .dependencies import Dependency
from .dependencies import Dependency, annotated_dependency

parameters: list[str] = []
dependencies: int = 0

for parameter in signature.parameters.values():
if isinstance(parameter.default, Dependency):
if (
isinstance(parameter.default, Dependency)
or annotated_dependency(parameter) is not None
):
dependencies += 1
continue

Expand Down
162 changes: 162 additions & 0 deletions tests/fundamentals/test_annotated_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Tests for Annotated[T, Depends(...)] style dependency injection."""

from contextlib import asynccontextmanager
from typing import Annotated, AsyncGenerator
from uuid import uuid4

from docket import (
CurrentDocket,
CurrentExecution,
CurrentWorker,
Depends,
Docket,
Execution,
Worker,
)


async def test_annotated_function_dependency(docket: Docket, worker: Worker):
"""A task can declare dependencies using Annotated[T, Depends(fn)] syntax."""

async def get_greeting() -> str:
return f"hello-{uuid4()}"

called = False

async def the_task(greeting: Annotated[str, Depends(get_greeting)]):
assert greeting.startswith("hello-")

nonlocal called
called = True

await docket.add(the_task)() # pyright: ignore[reportCallIssue]
await worker.run_until_finished()

assert called


async def test_annotated_context_manager_dependency(docket: Docket, worker: Worker):
"""Annotated dependencies work with async context manager factories."""

stages: list[str] = []

@asynccontextmanager
async def get_resource() -> AsyncGenerator[str, None]:
stages.append("setup")
yield "resource-value"
stages.append("teardown")

called = False

async def the_task(resource: Annotated[str, Depends(get_resource)]):
assert resource == "resource-value"

nonlocal called
called = True

await docket.add(the_task)() # pyright: ignore[reportCallIssue]
await worker.run_until_finished()

assert called
assert stages == ["setup", "teardown"]


async def test_annotated_contextual_dependencies(docket: Docket, worker: Worker):
"""Annotated syntax works with contextual dependencies like CurrentDocket."""

called = False

async def the_task(
a: str,
this_docket: Annotated[Docket, CurrentDocket()],
this_worker: Annotated[Worker, CurrentWorker()],
this_execution: Annotated[Execution, CurrentExecution()],
):
assert a == "hello"
assert this_docket is docket
assert this_worker is worker
assert isinstance(this_execution, Execution)

nonlocal called
called = True

await docket.add(the_task)("hello") # pyright: ignore[reportCallIssue]
await worker.run_until_finished()

assert called


async def test_annotated_mixed_with_positional_args(docket: Docket, worker: Worker):
"""Annotated dependencies mix freely with regular positional arguments."""

called = False

async def get_config() -> dict[str, int]:
return {"version": 2}

async def the_task(
name: str,
count: int,
config: Annotated[dict[str, int], Depends(get_config)],
):
assert name == "test"
assert count == 42
assert config == {"version": 2}

nonlocal called
called = True

await docket.add(the_task)("test", 42) # pyright: ignore[reportCallIssue]
await worker.run_until_finished()

assert called


async def test_annotated_with_default_style_deps(docket: Docket, worker: Worker):
"""Annotated and default-style dependencies can coexist on the same task."""

async def dep_a() -> str:
return "from-annotated"

async def dep_b() -> str:
return "from-default"

called = False

async def the_task(
a: Annotated[str, Depends(dep_a)],
b: str = Depends(dep_b),
):
assert a == "from-annotated"
assert b == "from-default"

nonlocal called
called = True

await docket.add(the_task)() # pyright: ignore[reportCallIssue]
await worker.run_until_finished()

assert called


async def test_annotated_reusable_type_alias(docket: Docket, worker: Worker):
"""Annotated types can be reused as aliases across multiple tasks."""

async def get_db() -> str:
return "db-connection"

DBConn = Annotated[str, Depends(get_db)]

results: list[str] = []

async def task_one(db: DBConn): # pyright: ignore[reportInvalidTypeForm,reportUnknownParameterType]
results.append(f"one:{db}")

async def task_two(db: DBConn): # pyright: ignore[reportInvalidTypeForm,reportUnknownParameterType]
results.append(f"two:{db}")

await docket.add(task_one)() # pyright: ignore[reportCallIssue,reportUnknownArgumentType]
await docket.add(task_two)() # pyright: ignore[reportCallIssue,reportUnknownArgumentType]
await worker.run_until_finished()

assert sorted(results) == ["one:db-connection", "two:db-connection"]
16 changes: 16 additions & 0 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ async def only_dependencies(
) -> None: ... # pragma: no cover


async def annotated_dependencies(
a: str,
b: Annotated[str, Depends(a_dependency)],
c: Annotated[Docket, CurrentDocket()],
) -> None: ... # pragma: no cover


async def only_annotated_dependencies(
a: Annotated[str, Depends(a_dependency)],
b: Annotated[Docket, CurrentDocket()],
c: Annotated[Worker, CurrentWorker()],
) -> None: ... # pragma: no cover


@pytest.mark.parametrize(
"function, expected",
[
Expand All @@ -57,6 +71,8 @@ async def only_dependencies(
(logged_args, "a: str, b: str = 'foo'"),
(dependencies, "a: str, b: int = 42, ..."),
(only_dependencies, "..."),
(annotated_dependencies, "a: str, ..."),
(only_annotated_dependencies, "..."),
],
)
async def test_compact_signature(
Expand Down