diff --git a/src/workflows/recipe/__init__.py b/src/workflows/recipe/__init__.py index 0f1973f..828f583 100644 --- a/src/workflows/recipe/__init__.py +++ b/src/workflows/recipe/__init__.py @@ -5,6 +5,8 @@ from collections.abc import Callable from typing import Any +from opentelemetry import trace + from workflows.recipe.recipe import Recipe from workflows.recipe.validate import validate_recipe from workflows.recipe.wrapper import RecipeWrapper @@ -69,6 +71,68 @@ def unwrap_recipe(header, message): message = mangle_for_receiving(message) if header.get("workflows-recipe") in {True, "True", "true", 1}: rw = RecipeWrapper(message=message, transport=transport_layer) + logger.debug("RecipeWrapper created: %s", rw) + + # Extract and set DCID and recipe_id on the current span + span = trace.get_current_span() + dcid = None + recipe_id = None + + # Extract recipe ID from environment + if isinstance(message, dict): + environment = message.get("environment", {}) + if isinstance(environment, dict): + recipe_id = environment.get("ID") + + # Try multiple locations where DCID might be stored + top_level_params = {} + if isinstance(message, dict): + # Direct parameters (top-level or in recipe) + top_level_params = message.get("parameters", {}) + + # Payload parameters (most common location) + payload = message.get("payload", {}) + payload_params = {} + if isinstance(payload, dict): + payload_params = payload.get("parameters", {}) + + # Try all common locations + dcid = ( + top_level_params.get("ispyb_dcid") + or top_level_params.get("dcid") + or payload_params.get("ispyb_dcid") + or payload_params.get("dcid") + or payload.get("ispyb_dcid") + or payload.get("dcid") + ) + + if dcid: + span.set_attribute("dcid", dcid) + span.add_event("recipe.dcid_extracted", attributes={"dcid": dcid}) + + if recipe_id: + span.set_attribute("recipe_id", recipe_id) + span.add_event( + "recipe.id_extracted", attributes={"recipe_id": recipe_id} + ) + + # Extract span_id and trace_id for logging + span_context = span.get_span_context() + if span_context and span_context.is_valid: + span_id = format(span_context.span_id, "016x") + trace_id = format(span_context.trace_id, "032x") + + log_extra = { + "span_id": span_id, + "trace_id": trace_id, + } + if dcid: + log_extra["dcid"] = dcid + if recipe_id: + log_extra["recipe_id"] = recipe_id + + logger.info("Processing recipe message", extra=log_extra) + if log_extender and rw.environment and rw.environment.get("ID"): with log_extender("recipe_ID", rw.environment["ID"]): return callback(rw, header, message.get("payload")) diff --git a/src/workflows/services/common_service.py b/src/workflows/services/common_service.py index de2ef70..5aa8ee6 100644 --- a/src/workflows/services/common_service.py +++ b/src/workflows/services/common_service.py @@ -9,8 +9,15 @@ import time from typing import Any +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + import workflows import workflows.logging +from workflows.transport.middleware.otel_tracing import OTELTracingMiddleware class Status(enum.Enum): @@ -185,6 +192,37 @@ def start_transport(self): self.transport.subscription_callback_set_intercept( self._transport_interceptor ) + + # Configure OTELTracing + resource = Resource.create( + { + SERVICE_NAME: self._service_name, + } + ) + + self.log.debug("Configuring OTELTracing") + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + + # Configure BatchProcessor and OTLPSpanExporter to point to OTELCollector + otlp_exporter = OTLPSpanExporter( + endpoint="https://otel.tracing.diamond.ac.uk:4318/v1/traces", timeout=10 + ) + span_processor = BatchSpanProcessor(otlp_exporter) + provider.add_span_processor(span_processor) + + # Add OTELTracingMiddleware to the transport layer + tracer = trace.get_tracer(__name__) + otel_middleware = OTELTracingMiddleware( + tracer, service_name=self._service_name + ) + self._transport.add_middleware(otel_middleware) + + self.log.debug( + "OTELTracingMiddleware added to transport layer of %s", + self._service_name, + ) + metrics = self._environment.get("metrics") if metrics: import prometheus_client diff --git a/src/workflows/transport/middleware/otel_tracing.py b/src/workflows/transport/middleware/otel_tracing.py new file mode 100644 index 0000000..453ff6c --- /dev/null +++ b/src/workflows/transport/middleware/otel_tracing.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import functools +from collections.abc import Callable + +from opentelemetry import trace +from opentelemetry.propagate import extract + +from workflows.transport.middleware import BaseTransportMiddleware + + +class OTELTracingMiddleware(BaseTransportMiddleware): + def __init__(self, tracer: trace.Tracer, service_name: str): + """ + Initialize the OpenTelemetry Tracing Middleware. + + :param tracer: An OpenTelemetry tracer instance used to create spans. + """ + self.tracer = tracer + self.service_name = service_name + + def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int: + @functools.wraps(callback) + def wrapped_callback(header, message): + # Extract trace context from message headers + ctx = extract(header) if header else None + + # Start a new span with the extracted context + with self.tracer.start_as_current_span( + "transport.subscribe", context=ctx + ) as span: + span.set_attribute("service_name", self.service_name) + span.set_attribute("channel", channel) + + # Call the original callback + return callback(header, message) + + # Call the next middleware with the wrapped callback + return call_next(channel, wrapped_callback, **kwargs)