Skip to content
Draft
119 changes: 85 additions & 34 deletions src/labthings_fastapi/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,26 @@
TypeVar,
overload,
)
from weakref import WeakSet
import weakref
from fastapi import FastAPI, HTTPException, Request, Body, BackgroundTasks
from pydantic import BaseModel, create_model
from anyio.abc import ObjectSendStream

from .base_descriptor import BaseDescriptor
from .base_descriptor import (
BaseDescriptor,
BaseDescriptorInfo,
DescriptorInfoCollection,
)
from .logs import add_thing_log_destination
from .utilities import model_to_dict, wrap_plain_types_in_rootmodel
from .invocations import InvocationModel, InvocationStatus, LogRecordModel
from .dependencies.invocation import NonWarningInvocationID
from .events import Message
from .exceptions import (
InvocationCancelledError,
InvocationError,
NoBlobManagerError,
NotBoundToInstanceError,
NotConnectedToServerError,
)
from .outputs.blob import BlobIOContextDep, blobdata_to_url_ctx
Expand All @@ -61,7 +67,6 @@
)
from .thing_description import type_to_dataschema
from .thing_description._model import ActionAffordance, ActionOp, Form, LinkElement
from .utilities import labthings_data


if TYPE_CHECKING:
Expand Down Expand Up @@ -247,10 +252,10 @@ def response(self, request: Optional[Request] = None) -> InvocationModel:
]
# The line below confuses MyPy because self.action **evaluates to** a Descriptor
# object (i.e. we don't call __get__ on the descriptor).
return self.action.invocation_model( # type: ignore[call-overload]
return self.action.invocation_model( # type: ignore[attr-defined]
status=self.status,
id=self.id,
action=self.thing.path + self.action.name, # type: ignore[call-overload]
action=self.thing.path + self.action.name, # type: ignore[attr-defined]
href=href,
timeStarted=self._start_time,
timeCompleted=self._end_time,
Expand Down Expand Up @@ -290,7 +295,7 @@ def run(self) -> None:
"""
# self.action evaluates to an ActionDescriptor. This confuses mypy,
# which thinks we are calling ActionDescriptor.__get__.
action: ActionDescriptor = self.action # type: ignore[call-overload]
action: ActionDescriptor = self.action # type: ignore[assignment]
logger = self.thing.logger
# The line below saves records matching our ID to ``self._log``
add_thing_log_destination(self.id, self._log)
Expand Down Expand Up @@ -445,10 +450,7 @@ def list_invocations(
i.response(request=request)
for i in self.invocations
if thing is None or i.thing == thing
if action is None or i.action == action # type: ignore[call-overload]
# i.action evaluates to an ActionDescriptor, which confuses mypy - it
# thinks we are calling ActionDescriptor.__get__ but this isn't ever
# called.
if action is None or i.action == action
]

def expire_invocations(self) -> None:
Expand Down Expand Up @@ -625,8 +627,68 @@ def delete_invocation(id: uuid.UUID) -> None:
OwnerT = TypeVar("OwnerT", bound="Thing")


class ActionInfo(
BaseDescriptorInfo[
"ActionDescriptor", OwnerT, Callable[ActionParams, ActionReturn]
],
Generic[OwnerT, ActionParams, ActionReturn],
):
"""Convenient access to the metadata of an action."""

@property
def response_timeout(self) -> float:
"""The time to wait before replying to the HTTP request initiating an action."""
return self.get_descriptor().response_timeout

@property
def retention_time(self) -> float:
"""How long to retain the action's output for, in seconds."""
return self.get_descriptor().retention_time

@property
def input_model(self) -> type[BaseModel]:
"""A Pydantic model for the input parameters of an Action."""
return self.get_descriptor().input_model

@property
def output_model(self) -> type[BaseModel]:
"""A Pydantic model for the output parameters of an Action."""
return self.get_descriptor().output_model

@property
def invocation_model(self) -> type[BaseModel]:
"""A Pydantic model for an invocation of this action."""
return self.get_descriptor().invocation_model

@property
def func(self) -> Callable[Concatenate[OwnerT, ActionParams], ActionReturn]:
"""The function that runs the action."""
return self.get_descriptor().func

def observe(self, stream: ObjectSendStream[Message]) -> None:
"""Observe changes to this property.

Changes to this property will be sent to the supplied stream.

:param stream: The stream to which updated values should be sent.
"""
if self.owning_object is None:
msg = "Can't observe action status from an unbound ActionInfo."
raise NotBoundToInstanceError(msg)
self.owning_object._thing_server_interface.subscribe(self.name, stream)


class ActionCollection(
DescriptorInfoCollection[OwnerT, ActionInfo],
Generic[OwnerT],
):
"""Access to the metadata of each Action."""

_descriptorinfo_class = ActionInfo


class ActionDescriptor(
BaseDescriptor[Callable[ActionParams, ActionReturn]],
BaseDescriptor[OwnerT, Callable[ActionParams, ActionReturn]],
Generic[ActionParams, ActionReturn, OwnerT],
):
"""Wrap actions to enable them to be run over HTTP.
Expand Down Expand Up @@ -691,7 +753,7 @@ def __init__(
)
self.invocation_model.__name__ = f"{name}_invocation"

def __set_name__(self, owner: type[Thing], name: str) -> None:
def __set_name__(self, owner: type[OwnerT], name: str) -> None:
"""Ensure the action name matches the function name.

It's assumed in a few places that the function name and the
Expand All @@ -709,7 +771,7 @@ def __set_name__(self, owner: type[Thing], name: str) -> None:
f"'{self.func.__name__}'",
)

def instance_get(self, obj: Thing) -> Callable[ActionParams, ActionReturn]:
def instance_get(self, obj: OwnerT) -> Callable[ActionParams, ActionReturn]:
"""Return the function, bound to an object as for a normal method.

This currently doesn't validate the arguments, though it may do so
Expand All @@ -721,27 +783,7 @@ def instance_get(self, obj: Thing) -> Callable[ActionParams, ActionReturn]:
descriptor.
:return: the action function, bound to ``obj``.
"""
# `obj` should be of type `OwnerT`, but `BaseDescriptor` currently
# isn't generic in the type of the owning Thing, so we can't express
# that here.
return partial(self.func, obj) # type: ignore[arg-type]

def _observers_set(self, obj: Thing) -> WeakSet:
"""Return a set used to notify changes.

Note that we need to supply the `.Thing` we are looking at, as in
general there may be more than one object of the same type, and
descriptor instances are shared between all instances of their class.

:param obj: The `.Thing` on which the action is being observed.

:return: a weak set of callables to notify on changes to the action.
This is used by websocket endpoints.
"""
ld = labthings_data(obj)
if self.name not in ld.action_observers:
ld.action_observers[self.name] = WeakSet()
return ld.action_observers[self.name]
return partial(self.func, obj)

def emit_changed_event(self, obj: Thing, status: str) -> None:
"""Notify subscribers that the action status has changed.
Expand Down Expand Up @@ -920,6 +962,15 @@ def action_affordance(
output=type_to_dataschema(self.output_model, title=f"{self.name}_output"),
)

def descriptor_info(self, owner: OwnerT | None = None) -> ActionInfo:
"""Return an `.ActionInfo` object describing this action.

The returned object will either refer to the class, or be bound to a particular
instance. If it is bound, more properties will be available - e.g. we will be
able to get the bound function.
"""
return self._descriptor_info(ActionInfo, owner)


@overload
def action(
Expand Down
Loading
Loading