diff --git a/api/tests/open_telemetry_instrumentation_tests.py b/api/tests/open_telemetry_instrumentation_tests.py new file mode 100644 index 0000000..a49139f --- /dev/null +++ b/api/tests/open_telemetry_instrumentation_tests.py @@ -0,0 +1,277 @@ +import signal +from unittest.mock import patch, MagicMock + +from django.test import TestCase +from django.urls import reverse +from rest_framework.test import APITestCase + +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult +from opentelemetry.propagate import extract +from opentelemetry.baggage import get_baggage + +from backend.otel_instrumentation import DjangoTelemetry, SHUTDOWN_TIMEOUT_MILLIS + + +class InMemorySpanExporter(SpanExporter): + """InMemorySpanExporter to validate the instrumentation since we cant pull it from the console""" + def __init__(self): + super().__init__() + self._finished_spans = [] + + def export(self, spans): + self._finished_spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self): + self._finished_spans.clear() + + def clear(self): + self._finished_spans.clear() + + def get_finished_spans(self): + return list(self._finished_spans) + + +class DjangoCarrier: + """Wrap Django headers to behave like a carrier for `extract`.""" + + def __init__(self, headers): + self.headers = headers + + def get(self, key, default=None): + key = key.lower() + for k, v in self.headers.items(): + if k.lower().replace("_", "-").endswith(key): + return v + return default + + +# Global provider setup +memory_exporter = InMemorySpanExporter() +provider = TracerProvider() +provider.add_span_processor(SimpleSpanProcessor(memory_exporter)) +trace.set_tracer_provider(provider) +tracer = trace.get_tracer(__name__) + + +class OpenTelemetryInstrumentationTest(APITestCase): + + def setUp(self): + memory_exporter.clear() + self.url = reverse('config-values-read:index') + + def test_cf_ray_header(self): + """Inject only CF-RAY header → new trace is started, cf.ray_id attribute set.""" + response = self.client.get( + f"{self.url}", **{"HTTP_CF_RAY": "abc123"} + ) + self.assertEqual(response.status_code, 200) + + carrier = DjangoCarrier(response.wsgi_request.META) + ctx = extract(carrier) + span = trace.get_current_span(ctx) + span_ctx = span.get_span_context() + + # No parent span because no traceparent → new trace created + self.assertEqual(span_ctx.is_valid, False) + + # Exported spans should exist + spans = memory_exporter.get_finished_spans() + self.assertEqual(spans[0].resource.attributes.get('service.name'), 'marketing-api') + self.assertEqual(len(spans), 6) + exported_span = spans[5] + # Since no traceparent was injected, parent should be INVALID + self.assertEqual(exported_span.parent, None) + # Our CF-RAY header should be recorded in span attributes + self.assertEqual(exported_span.attributes.get("cf.ray_id"), "abc123") + + def test_traceparent_and_baggage(self): + """Inject TRACEPARENT + BAGGAGE headers → exported span should have parent_id set + baggage propagated.""" + trace_id = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + parent_span_id = "bbbbbbbbbbbbbbbb" + traceparent = f"00-{trace_id}-{parent_span_id}-01" + baggage = "cf.ray_id=xyz" + + response = self.client.get( + f"{self.url}", + **{ + "HTTP_TRACEPARENT": traceparent, + "HTTP_BAGGAGE": baggage, + } + ) + self.assertEqual(response.status_code, 200) + # Extracted context should match + carrier = DjangoCarrier(response.wsgi_request.META) + ctx = extract(carrier) + span = trace.get_current_span(ctx) + span_ctx = span.get_span_context() + self.assertTrue(span_ctx.is_valid) + + # Verify a span was exported + spans = memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 6) + self.assertEqual(spans[0].resource.attributes.get('service.name'), 'marketing-api') + exported_span = spans[5] + # Check that the trace_id is the same as the injected traceparent + self.assertEqual(format(exported_span.context.trace_id, "032x"), trace_id) + + # Check that the parent_id is the injected span_id + self.assertEqual(format(exported_span.parent.span_id, "016x"), parent_span_id) + + # Baggage value should have propagated + baggage_value = get_baggage("cf.ray_id", context=ctx) + self.assertEqual(baggage_value, "xyz") + # And should also show up in span attributes (if your request_hook adds it) + self.assertEqual(exported_span.attributes.get("baggage.cf.ray_id"), "xyz") + + def test_mysql_span_has_db_name(self): + """Simulate a MySQL query and assert db.name attribute is added.""" + with tracer.start_as_current_span("mysql-test") as span: + span.set_attribute("db.name", "test_db") + span.set_attribute("db.statement", "SELECT 1") + + spans = memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + exported = spans[0] + self.assertEqual(exported.resource.attributes.get('service.name'), 'marketing-api') + self.assertEqual(exported.attributes.get("db.name"), "test_db") + self.assertIn("SELECT", exported.attributes.get("db.statement")) + + def test_redis_span_has_key(self): + """Simulate a Redis command and assert db.redis.key is added.""" + with tracer.start_as_current_span("redis-test") as span: + span.set_attribute("db.redis.command", "GET") + span.set_attribute("db.redis.key", "my_key") + + spans = memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + exported = spans[0] + self.assertEqual(exported.resource.attributes.get('service.name'), 'marketing-api') + self.assertEqual(exported.attributes.get("db.redis.command"), "GET") + self.assertEqual(exported.attributes.get("db.redis.key"), "my_key") + + def test_requests_span_has_custom_header(self): + """Simulate a requests span and assert custom header is captured.""" + with tracer.start_as_current_span("requests-test") as span: + span.set_attribute("http.custom_header", "abc123") + span.set_attribute("http.response_length", 42) + + spans = memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + exported = spans[0] + self.assertEqual(exported.resource.attributes.get('service.name'), 'marketing-api') + self.assertEqual(exported.attributes.get("http.custom_header"), "abc123") + self.assertEqual(exported.attributes.get("http.response_length"), 42) + + +class OpenTelemetryShutdownTest(TestCase): + """Tests for OpenTelemetry graceful shutdown behavior.""" + + def setUp(self): + DjangoTelemetry._provider = None + DjangoTelemetry._shutdown_called = False + + def tearDown(self): + DjangoTelemetry._provider = None + DjangoTelemetry._shutdown_called = False + + def test_shutdown_with_no_provider(self): + """shutdown() should be safe when _provider is None (test environments).""" + DjangoTelemetry.shutdown() + self.assertTrue(DjangoTelemetry._shutdown_called) + + def test_shutdown_calls_force_flush_and_shutdown(self): + """shutdown() should call force_flush then shutdown on the provider.""" + mock_provider = MagicMock() + mock_provider.force_flush.return_value = True + DjangoTelemetry._provider = mock_provider + + DjangoTelemetry.shutdown() + + mock_provider.force_flush.assert_called_once_with( + timeout_millis=SHUTDOWN_TIMEOUT_MILLIS + ) + mock_provider.shutdown.assert_called_once() + self.assertTrue(DjangoTelemetry._shutdown_called) + + def test_shutdown_is_idempotent(self): + """Calling shutdown() multiple times should only flush/shutdown once.""" + mock_provider = MagicMock() + mock_provider.force_flush.return_value = True + DjangoTelemetry._provider = mock_provider + + DjangoTelemetry.shutdown() + DjangoTelemetry.shutdown() + DjangoTelemetry.shutdown() + + mock_provider.force_flush.assert_called_once() + mock_provider.shutdown.assert_called_once() + + @patch('backend.otel_instrumentation.logger') + def test_shutdown_handles_force_flush_exception(self, mock_logger): + """shutdown() should not raise even if force_flush throws.""" + mock_provider = MagicMock() + mock_provider.force_flush.side_effect = RuntimeError("network error") + DjangoTelemetry._provider = mock_provider + + DjangoTelemetry.shutdown() + + mock_provider.force_flush.assert_called_once() + mock_provider.shutdown.assert_called_once() + mock_logger.exception.assert_called() + + @patch('backend.otel_instrumentation.logger') + def test_shutdown_handles_provider_shutdown_exception(self, mock_logger): + """shutdown() should not raise even if provider.shutdown() throws.""" + mock_provider = MagicMock() + mock_provider.force_flush.return_value = True + mock_provider.shutdown.side_effect = RuntimeError("shutdown error") + DjangoTelemetry._provider = mock_provider + + DjangoTelemetry.shutdown() + + mock_provider.force_flush.assert_called_once() + mock_provider.shutdown.assert_called_once() + mock_logger.exception.assert_called() + + def test_shutdown_logs_warning_on_flush_timeout(self): + """shutdown() should log a warning when force_flush returns False (timeout).""" + mock_provider = MagicMock() + mock_provider.force_flush.return_value = False + DjangoTelemetry._provider = mock_provider + + with self.assertLogs('backend.otel_instrumentation', level='WARNING') as cm: + DjangoTelemetry.shutdown() + + self.assertTrue(any('timed out' in msg for msg in cm.output)) + + @patch('backend.otel_instrumentation.atexit') + @patch('backend.otel_instrumentation.signal') + def test_register_shutdown_hooks_registers_atexit(self, mock_signal, mock_atexit): + """_register_shutdown_hooks should register atexit handler.""" + mock_signal.getsignal.return_value = signal.SIG_DFL + mock_signal.SIGTERM = signal.SIGTERM + mock_signal.SIG_DFL = signal.SIG_DFL + mock_signal.SIG_IGN = signal.SIG_IGN + + DjangoTelemetry._register_shutdown_hooks() + + mock_atexit.register.assert_called_once_with(DjangoTelemetry.shutdown) + + @patch('backend.otel_instrumentation.atexit') + @patch('backend.otel_instrumentation.signal') + def test_register_shutdown_hooks_registers_sigterm(self, mock_signal, mock_atexit): + """_register_shutdown_hooks should install a SIGTERM handler.""" + mock_signal.getsignal.return_value = signal.SIG_DFL + mock_signal.SIGTERM = signal.SIGTERM + mock_signal.SIG_DFL = signal.SIG_DFL + mock_signal.SIG_IGN = signal.SIG_IGN + + DjangoTelemetry._register_shutdown_hooks() + + mock_signal.signal.assert_called_once() + args = mock_signal.signal.call_args + self.assertEqual(args[0][0], signal.SIGTERM) + self.assertTrue(callable(args[0][1])) \ No newline at end of file diff --git a/backend/.env.example b/backend/.env.example index 64f1a7e..ca4d299 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -43,4 +43,13 @@ OAUTH2_CLIENT_SECRET= OAUTH2_ADD_SCOPE=config-values/add config-values/write OAUTH2_UPDATE_SCOPE=config-values/update config-values/write OAUTH2_DELETE_SCOPE=config-values/delete config-values/write -OAUTH2_CLONE_SCOPE=config-values/clone config-values/write \ No newline at end of file +OAUTH2_CLONE_SCOPE=config-values/clone config-values/write + +# open telemetry integration +OTEL_INSTRUMENTATION_ENABLED=true +OTEL_SERVICE_NAME=marketing-api +OTEL_EXPORTER_OTLP_ENDPOINT=http://collector.fnvirtual.app:4318/v1/traces +OTEL_PROPAGATORS=tracecontext,baggage +OTEL_PYTHON_LOG_CORRELATION=true +# set it to otel_endpoint, console or null (just tu run tests locally) +OTEL_EXPORTER_MODE=otel_endpoint \ No newline at end of file diff --git a/backend/env_var_eval.py b/backend/env_var_eval.py new file mode 100644 index 0000000..23751c3 --- /dev/null +++ b/backend/env_var_eval.py @@ -0,0 +1,7 @@ +import os + +def env_bool(name: str, default: bool = False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in ("1", "true") \ No newline at end of file diff --git a/backend/otel_instrumentation.py b/backend/otel_instrumentation.py new file mode 100644 index 0000000..8e7a9d8 --- /dev/null +++ b/backend/otel_instrumentation.py @@ -0,0 +1,155 @@ +import atexit +import logging +import os +import signal + +from opentelemetry import baggage as baggage_api +from opentelemetry import trace +from opentelemetry.instrumentation.django import DjangoInstrumentor +from opentelemetry.instrumentation.mysqlclient import MySQLClientInstrumentor +from opentelemetry.instrumentation.redis import RedisInstrumentor +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + +OTEL_EXPORTER_OTLP_ENDPOINT = os.getenv('OTEL_EXPORTER_OTLP_ENDPOINT') +OTEL_EXPORTER_MODE = os.getenv('OTEL_EXPORTER_MODE') + +SHUTDOWN_TIMEOUT_MILLIS = 5000 + +logger = logging.getLogger(__name__) + + +class DjangoTelemetry: + + _provider = None + _shutdown_called = False + + @staticmethod + def request_hook(span, request): + if not span.is_recording(): + return + + # Attach CF-Ray header + span.set_attribute("cf.ray_id", request.headers.get("Cf-Ray", "")) + + # Attach baggage if present + baggage_val = baggage_api.get_baggage("cf.ray_id") + if baggage_val: + span.set_attribute("baggage.cf.ray_id", baggage_val) + + @staticmethod + def response_hook(span, request, response): + if span.is_recording() and hasattr(response, "content"): + span.set_attribute("http.response.length", len(response.content)) + + @staticmethod + def mysql_hook(span, instance, cursor, statement, parameters): + """Enrich MySQL spans with DB info""" + if not span.is_recording(): + return + try: + span.set_attribute("db.system", "mysql") + span.set_attribute("db.name", os.getenv('DB_NAME', 'db')) + span.set_attribute("db.statement", statement) + except Exception: + pass + + @staticmethod + def redis_hook(span, instance, args, kwargs): + """Enrich Redis spans with command + keys""" + if not span.is_recording(): + return + try: + cmd = args[0] if args else "" + span.set_attribute("db.system", "redis") + span.set_attribute("redis.command", cmd) + if len(args) > 1: + # Add first key only (avoid leaking big payloads) + span.set_attribute("redis.key", str(args[1])) + except Exception: + pass + + @classmethod + def setup(cls, environment): + if environment != "test": + if OTEL_EXPORTER_MODE: + resource = Resource.create({ + "service.name": os.getenv("OTEL_SERVICE_NAME", "marketing-api") + }) + # Provider with resource + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + cls._provider = provider + if OTEL_EXPORTER_MODE == "otel_endpoint": + exporter = OTLPSpanExporter(endpoint=OTEL_EXPORTER_OTLP_ENDPOINT) + provider.add_span_processor(BatchSpanProcessor(exporter)) + elif OTEL_EXPORTER_MODE == "console": + exporter = ConsoleSpanExporter() + provider.add_span_processor(BatchSpanProcessor(exporter)) + cls._register_shutdown_hooks() + + # Django + DjangoInstrumentor().instrument( + request_hook=cls.request_hook, + response_hook=cls.response_hook, + ) + # MySQL + MySQLClientInstrumentor().instrument( + enable_commenter=True, + cursor_instrumentation_enabled=True, + span_callback=cls.mysql_hook, + ) + # Redis + RedisInstrumentor().instrument( + tracer_provider=trace.get_tracer_provider(), + request_hook=cls.redis_hook, + ) + + @classmethod + def _register_shutdown_hooks(cls): + """Register atexit and SIGTERM handlers for graceful span flushing.""" + atexit.register(cls.shutdown) + + previous_handler = signal.getsignal(signal.SIGTERM) + + def _sigterm_handler(signum, frame): + cls.shutdown() + if callable(previous_handler) and previous_handler not in ( + signal.SIG_DFL, signal.SIG_IGN + ): + previous_handler(signum, frame) + elif previous_handler == signal.SIG_DFL: + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.raise_signal(signal.SIGTERM) + + signal.signal(signal.SIGTERM, _sigterm_handler) + + @classmethod + def shutdown(cls): + """Flush and shut down the TracerProvider.""" + if cls._shutdown_called: + return + cls._shutdown_called = True + + provider = cls._provider + if provider is None: + return + + try: + logger.info("DjangoTelemetry::shutdown - flushing spans before shutdown.") + flushed = provider.force_flush(timeout_millis=SHUTDOWN_TIMEOUT_MILLIS) + if not flushed: + logger.warning( + "DjangoTelemetry::shutdown - force_flush timed out after %d ms.", + SHUTDOWN_TIMEOUT_MILLIS, + ) + except Exception: + logger.exception("DjangoTelemetry::shutdown - error during force_flush.") + + try: + provider.shutdown() + except Exception: + logger.exception("DjangoTelemetry::shutdown - error during provider shutdown.") \ No newline at end of file diff --git a/backend/settings.py b/backend/settings.py index 3d9abe4..0dcefa2 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -362,6 +362,15 @@ os.path.join(BASE_DIR, "backend/media"), ] +DEV_EMAIL = os.getenv('DEV_EMAIL') + +from backend.env_var_eval import env_bool +OTEL_INSTRUMENTATION_ENABLED = env_bool('OTEL_INSTRUMENTATION_ENABLED', True) + +if OTEL_INSTRUMENTATION_ENABLED: + from .otel_instrumentation import DjangoTelemetry + DjangoTelemetry.setup(ENV) + # Import local settings try: from .settings_local import * diff --git a/requirements.txt b/requirements.txt index 632bc0b..1f6ae69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,4 +65,12 @@ tzdata==2025.2 uritemplate==4.2.0 urllib3==2.5.0 Werkzeug==3.1.4 -wrapt==2.0.1 +wrapt==1.17.3 + +#open-telemetry integration +opentelemetry-sdk==1.39.1 +opentelemetry-exporter-otlp==1.39.1 +opentelemetry-instrumentation-django==0.60b1 +opentelemetry-instrumentation-requests==0.60b1 +opentelemetry-instrumentation-mysqlclient==0.60b1 +opentelemetry-instrumentation-redis==0.60b1