From 8fe9da3e8b57f07bb67ed2d7a1c675a77ea8e378 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Wed, 28 Jan 2026 17:08:00 -0800 Subject: [PATCH 01/26] fix: update test mocks for create_openrouter_llm import location Tests were patching src.knowledge.faq_summarizer.create_openrouter_llm but the import is now inside the function from src.core.services.litellm_llm. Fixed both test instances to patch the correct module. --- tests/test_integration/test_mailman_e2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_integration/test_mailman_e2e.py b/tests/test_integration/test_mailman_e2e.py index 4a8f7aa..c5f548b 100644 --- a/tests/test_integration/test_mailman_e2e.py +++ b/tests/test_integration/test_mailman_e2e.py @@ -175,7 +175,7 @@ def mock_fetch(url: str, **kwargs): # noqa: ARG001 patch("src.knowledge.db.get_db_path", return_value=e2e_test_db), patch("src.knowledge.mailman_sync._fetch_page", side_effect=mock_fetch), patch( - "src.knowledge.faq_summarizer.create_openrouter_llm", + "src.core.services.litellm_llm.create_openrouter_llm", side_effect=[mock_scoring_model, mock_summary_model], ), ): @@ -311,7 +311,7 @@ def mock_fetch(url: str, **kwargs): # noqa: ARG001 patch("src.knowledge.db.get_db_path", return_value=e2e_test_db), patch("src.knowledge.mailman_sync._fetch_page", side_effect=mock_fetch), patch( - "src.knowledge.faq_summarizer.create_openrouter_llm", + "src.core.services.litellm_llm.create_openrouter_llm", side_effect=[mock_scoring_model, mock_summary_model], ), ): From ad3d25297bce6bb5f9e7b0ed3ee21659ad92bfd9 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Wed, 28 Jan 2026 23:39:07 -0800 Subject: [PATCH 02/26] Implement hybrid rate limiting: built-in API + KV (#130) Implement hybrid rate limiting approach: - Per-minute (bot protection): Built-in Rate Limiting API - Per-hour (human abuse): Workers KV - Rate limits: 10/min, 20/hour (prod), 60/min, 100/hour (dev) - 50% reduction in KV writes - Auto-deploy on worker file changes Closes #129 --- .github/workflows/sync-worker-cors.yml | 55 +++++++++++++++++-- workers/osa-worker/README.md | 46 ++++++++-------- workers/osa-worker/index.js | 76 +++++++++++++++++++------- workers/osa-worker/wrangler.toml | 35 ++++++++---- 4 files changed, 152 insertions(+), 60 deletions(-) diff --git a/.github/workflows/sync-worker-cors.yml b/.github/workflows/sync-worker-cors.yml index 3d364aa..d9758f0 100644 --- a/.github/workflows/sync-worker-cors.yml +++ b/.github/workflows/sync-worker-cors.yml @@ -7,10 +7,12 @@ on: - 'src/assistants/*/config.yaml' - 'scripts/sync_worker_cors.py' - '.github/workflows/sync-worker-cors.yml' + - 'workers/osa-worker/**' pull_request: branches: [main, develop] paths: - 'src/assistants/*/config.yaml' + - 'workers/osa-worker/**' permissions: contents: write @@ -40,16 +42,50 @@ jobs: - name: Check for changes id: check_changes run: | + # Check if CORS sync made changes if git diff --quiet workers/osa-worker/index.js; then - echo "changed=false" >> $GITHUB_OUTPUT echo "No CORS changes detected" + CORS_CHANGED=false else - echo "changed=true" >> $GITHUB_OUTPUT echo "CORS changes detected" + CORS_CHANGED=true + fi + + # Check if worker files were modified in the push that triggered this workflow + # Use the range from the push event (before -> after) + if [ "${{ github.event.before }}" != "0000000000000000000000000000000000000000" ]; then + # Normal push (not first commit) + if git diff --name-only ${{ github.event.before }} ${{ github.sha }} | grep -q "^workers/osa-worker/"; then + echo "Worker files changed in push" + WORKER_CHANGED=true + else + echo "No worker files in push" + WORKER_CHANGED=false + fi + else + # First commit to branch, check current files + if git ls-files workers/osa-worker/ | grep -q .; then + echo "Worker files exist (first commit)" + WORKER_CHANGED=true + else + WORKER_CHANGED=false + fi + fi + + # Set outputs + echo "cors_changed=$CORS_CHANGED" >> $GITHUB_OUTPUT + + # Deploy if either CORS sync changed files OR worker files were pushed + if [ "$CORS_CHANGED" = true ] || [ "$WORKER_CHANGED" = true ]; then + echo "changed=true" >> $GITHUB_OUTPUT + echo "✅ Deployment needed" + else + echo "changed=false" >> $GITHUB_OUTPUT + echo "No deployment needed" fi - - name: Commit changes (main/develop only) - if: steps.check_changes.outputs.changed == 'true' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop') && github.event_name == 'push' + - name: Commit CORS changes (main/develop only) + if: steps.check_changes.outputs.cors_changed == 'true' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop') && github.event_name == 'push' run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" @@ -64,7 +100,7 @@ jobs: run: | npm install -g wrangler cd workers/osa-worker - wrangler deploy --env="" + wrangler deploy echo "✅ Deployed to production worker" - name: Deploy to Cloudflare Workers (dev) @@ -82,9 +118,16 @@ jobs: uses: actions/github-script@v7 with: script: | + let message = '⚠️ **Worker Deployment Required**\n\n'; + if ('${{ steps.check_changes.outputs.cors_changed }}' === 'true') { + message += 'This PR modifies community CORS origins. '; + } + message += 'Worker changes detected. After merging to `main` or `develop`, the workflow will automatically deploy the worker.\n\n'; + message += '**Manual deployment (if needed):**\n```bash\ncd workers/osa-worker\nwrangler deploy --env dev # for develop branch\nwrangler deploy # for main branch\n```'; + github.rest.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, - body: '⚠️ **Worker CORS Update Required**\n\nThis PR modifies community CORS origins. After merging, run:\n\n```bash\npython scripts/sync_worker_cors.py\ngit add workers/osa-worker/index.js\ngit commit -m "chore: sync worker CORS"\ncd workers/osa-worker && wrangler deploy\n```\n\nOr merge to main and the sync workflow will auto-commit the changes.' + body: message }) diff --git a/workers/osa-worker/README.md b/workers/osa-worker/README.md index a847f83..9b6b7f7 100644 --- a/workers/osa-worker/README.md +++ b/workers/osa-worker/README.md @@ -2,7 +2,9 @@ Security proxy for the Open Science Assistant backend. Provides: - **Turnstile verification** (visible widget) for bot protection -- **Rate limiting** (IP-based, per-minute and per-hour) +- **Hybrid rate limiting** (IP-based, per-minute and per-hour) + - Per-minute: Built-in API (fast bot protection, <1ms) + - Per-hour: KV (global human abuse prevention) - **CORS validation** for allowed origins - **API key injection** for backend authentication - **BYOK mode** for CLI/programmatic access @@ -39,23 +41,11 @@ npm install -g wrangler wrangler login ``` -### 2. Create KV namespaces for rate limiting +### 2. KV namespaces -```bash -# Production -wrangler kv:namespace create "RATE_LIMITER" -# Copy the ID and update wrangler.toml - -# Development -wrangler kv:namespace create "RATE_LIMITER" --env dev -# Copy the ID and update wrangler.toml [env.dev.kv_namespaces] -``` +KV namespaces are already configured in `wrangler.toml` for per-hour rate limiting. No additional setup needed. -### 3. Update wrangler.toml - -Replace `REPLACE_WITH_KV_ID` and `REPLACE_WITH_DEV_KV_ID` with the IDs from step 2. - -### 4. Set up Turnstile +### 3. Set up Turnstile 1. Go to Cloudflare Dashboard > Turnstile 2. Create a new widget with **Visible** mode @@ -67,7 +57,7 @@ Replace `REPLACE_WITH_KV_ID` and `REPLACE_WITH_DEV_KV_ID` with the IDs from step 4. Copy the Site Key (for frontend integration) 5. Copy the Secret Key (for this worker) -### 5. Set secrets +### 4. Set secrets ```bash # Backend API key (generate with: python -c "import secrets; print(secrets.token_urlsafe(32))") @@ -81,7 +71,7 @@ wrangler secret put BACKEND_API_KEY --env dev wrangler secret put TURNSTILE_SECRET_KEY --env dev ``` -### 6. Deploy +### 5. Deploy ```bash # Production @@ -104,10 +94,22 @@ wrangler deploy --env dev ## Rate Limits -| Environment | Per Minute | Per Hour | -|-------------|------------|----------| -| Production | 10 | 100 | -| Development | 60 | 600 | +Hybrid approach for optimal performance and protection: + +| Environment | Per Minute (Bot Protection) | Per Hour (Human Abuse) | +|-------------|----------------------------|----------------------| +| Production | 10 (built-in API) | 20 (KV, global) | +| Development | 60 (built-in API) | 100 (KV, global) | + +**Why hybrid?** +- **Per-minute**: Needs to be fast (<1ms), catches bots immediately → Built-in API +- **Per-hour**: Needs global consistency across edge locations → KV +- **Result**: 50% fewer KV writes (1 vs 2 per request), faster bot protection + +**Rate limit scope:** +- Limits are **per IP address**, not per session +- 20/hour in production = ~1 question every 3 minutes (reasonable for research) +- Prevents abuse while allowing legitimate use ## BYOK Mode diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index e067007..dbe3119 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -14,7 +14,7 @@ function getConfig(env) { const isDev = env.ENVIRONMENT === 'development'; return { RATE_LIMIT_PER_MINUTE: isDev ? 60 : 10, - RATE_LIMIT_PER_HOUR: isDev ? 600 : 100, + RATE_LIMIT_PER_HOUR: isDev ? 100 : 20, REQUEST_TIMEOUT: 120000, // 2 minutes for LLM responses IS_DEV: isDev, }; @@ -66,33 +66,67 @@ async function verifyTurnstileToken(token, secretKey, ip) { } /** - * Check rate limit using KV storage + * Hybrid rate limiting approach: + * - Per-minute (bot protection): Built-in API (fast, <1ms, in-memory) + * - Per-hour (human abuse): KV (global consistency, 1 write per request) + * + * Benefits: + * - 50% reduction in KV writes (1 vs 2 per request) + * - Faster bot protection (<1ms vs ~10-50ms for critical first check) + * - Global hourly limits across all edge locations + * + * Known limitation: + * - KV read-then-write is not atomic; concurrent requests from same IP + * may slightly exceed hourly limit. Per-minute guard constrains this. */ async function checkRateLimit(request, env, CONFIG) { - if (!env.RATE_LIMITER) return { allowed: true }; - const ip = request.headers.get('CF-Connecting-IP') || 'unknown'; - const now = Math.floor(Date.now() / 1000); - const minuteKey = `rl:min:${ip}:${Math.floor(now / 60)}`; - const hourKey = `rl:hour:${ip}:${Math.floor(now / 3600)}`; - - // Check per-minute limit - const minuteCount = parseInt(await env.RATE_LIMITER.get(minuteKey) || '0'); - if (minuteCount >= CONFIG.RATE_LIMIT_PER_MINUTE) { - return { allowed: false, reason: 'Too many requests per minute' }; + + // Check hourly limit first (KV, read-only, no token consumed) + // This prevents wasting per-minute tokens on already-rejected requests + if (env.RATE_LIMITER_KV) { + try { + const now = Math.floor(Date.now() / 1000); + const hourKey = `rl:hour:${ip}:${Math.floor(now / 3600)}`; + + // Check current count + const hourCount = parseInt(await env.RATE_LIMITER_KV.get(hourKey) || '0', 10); + if (hourCount >= CONFIG.RATE_LIMIT_PER_HOUR) { + return { allowed: false, reason: 'Too many requests per hour' }; + } + } catch (error) { + console.error('Per-hour rate limit check error:', error); + // Fail open for KV errors + } } - // Check per-hour limit - const hourCount = parseInt(await env.RATE_LIMITER.get(hourKey) || '0'); - if (hourCount >= CONFIG.RATE_LIMIT_PER_HOUR) { - return { allowed: false, reason: 'Too many requests per hour' }; + // Check per-minute limit (built-in API, fast, consumes token) + // Only check this AFTER hourly passes to avoid wasting tokens + if (env.RATE_LIMITER_MINUTE) { + try { + const { success } = await env.RATE_LIMITER_MINUTE.limit({ key: ip }); + if (!success) { + return { allowed: false, reason: 'Too many requests per minute' }; + } + } catch (error) { + console.error('Per-minute rate limit check error:', error); + // Fail open for built-in API errors + } } - // Increment counters - await Promise.all([ - env.RATE_LIMITER.put(minuteKey, (minuteCount + 1).toString(), { expirationTtl: 120 }), - env.RATE_LIMITER.put(hourKey, (hourCount + 1).toString(), { expirationTtl: 7200 }), - ]); + // Increment hourly counter (1 write per request instead of 2) + // Done last, after both checks pass + if (env.RATE_LIMITER_KV) { + try { + const now = Math.floor(Date.now() / 1000); + const hourKey = `rl:hour:${ip}:${Math.floor(now / 3600)}`; + const hourCount = parseInt(await env.RATE_LIMITER_KV.get(hourKey) || '0', 10); + await env.RATE_LIMITER_KV.put(hourKey, (hourCount + 1).toString(), { expirationTtl: 7200 }); + } catch (error) { + console.error('Per-hour rate limit increment error:', error); + // Already allowed, so don't fail the request + } + } return { allowed: true }; } diff --git a/workers/osa-worker/wrangler.toml b/workers/osa-worker/wrangler.toml index 32afe2f..5cdcfa1 100644 --- a/workers/osa-worker/wrangler.toml +++ b/workers/osa-worker/wrangler.toml @@ -10,9 +10,19 @@ compatibility_date = "2024-01-01" BACKEND_URL = "https://api.osc.earth/osa" ENVIRONMENT = "production" -# KV namespace for rate limiting +# Hybrid rate limiting approach: +# - Per-minute (bot protection): Built-in API (fast, in-memory) +# - Per-hour (human abuse): KV (global consistency) + +# Per-minute rate limiter (built-in API, free, <1ms latency) +[[ratelimits]] +name = "RATE_LIMITER_MINUTE" +namespace_id = "1001" +simple = { limit = 10, period = 60 } + +# Per-hour rate limiter (KV, global, 1 write per request) [[kv_namespaces]] -binding = "RATE_LIMITER" +binding = "RATE_LIMITER_KV" id = "8f8506b1a7fb400680a00014312c124d" # Development environment @@ -20,8 +30,15 @@ id = "8f8506b1a7fb400680a00014312c124d" name = "osa-worker-dev" vars = { BACKEND_URL = "https://api.osc.earth/osa-dev", ENVIRONMENT = "development" } +# Per-minute rate limiter (dev has higher limit: 60/min) +[[env.dev.ratelimits]] +name = "RATE_LIMITER_MINUTE" +namespace_id = "1002" +simple = { limit = 60, period = 60 } + +# Per-hour rate limiter (KV) [[env.dev.kv_namespaces]] -binding = "RATE_LIMITER" +binding = "RATE_LIMITER_KV" id = "6d46ef72877b4129b38ef2ca1e1cd5ea" # ============================================================================= @@ -34,23 +51,19 @@ id = "6d46ef72877b4129b38ef2ca1e1cd5ea" # 2. Login to Cloudflare: # wrangler login # -# 3. Create KV namespaces: -# wrangler kv:namespace create "RATE_LIMITER" -# wrangler kv:namespace create "RATE_LIMITER" --env dev -# -# 4. Update this file with the KV namespace IDs from step 3 +# 3. KV namespaces already exist (IDs in this file) # -# 5. Set secrets: +# 4. Set secrets: # wrangler secret put BACKEND_API_KEY # wrangler secret put TURNSTILE_SECRET_KEY # wrangler secret put BACKEND_API_KEY --env dev # wrangler secret put TURNSTILE_SECRET_KEY --env dev # -# 6. Deploy: +# 5. Deploy: # wrangler deploy # Production # wrangler deploy --env dev # Development # -# 7. Verify: +# 6. Verify: # curl https://osa-worker..workers.dev/health # # ============================================================================= From 02c1dbe97d8456481bfa093c0838dc8f88bbc9f4 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Wed, 28 Jan 2026 23:43:06 -0800 Subject: [PATCH 03/26] Fix CI: fetch full history for git diff to work Problem: Git diff fails with 'fatal: bad object' because GitHub Actions uses shallow checkout (depth=1) and github.event.before doesn't exist. Solution: Set fetch-depth=0 to fetch full history, allowing git diff to work properly when detecting worker file changes. This will trigger deployment on next push. --- .github/workflows/sync-worker-cors.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/sync-worker-cors.yml b/.github/workflows/sync-worker-cors.yml index d9758f0..08f56cf 100644 --- a/.github/workflows/sync-worker-cors.yml +++ b/.github/workflows/sync-worker-cors.yml @@ -25,6 +25,7 @@ jobs: - uses: actions/checkout@v4 with: token: ${{ secrets.CI_ADMIN_TOKEN }} + fetch-depth: 0 # Fetch full history so git diff works - name: Set up Python uses: actions/setup-python@v5 From 3c2a37dcc8bd171ebb5e07f0ee1f1bd5d40512fe Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 2 Feb 2026 08:59:54 -0800 Subject: [PATCH 04/26] CI: run lint and tests on all PRs Remove branch filter from pull_request trigger in tests.yml so lint and unit tests run on PRs targeting any branch, not just main/develop. This ensures feature branch PRs to epic branches still get CI coverage. --- .github/workflows/tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 371565f..7e8e946 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,7 +4,6 @@ on: push: branches: [main, develop] pull_request: - branches: [main, develop] jobs: lint: From bf26109898477a8654b9926e0c6db398a395ff14 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 2 Feb 2026 08:59:54 -0800 Subject: [PATCH 05/26] CI: run lint and tests on all PRs Remove branch filter from pull_request trigger in tests.yml so lint and unit tests run on PRs targeting any branch, not just main/develop. This ensures feature branch PRs to epic branches still get CI coverage. --- .github/workflows/tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 371565f..7e8e946 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,7 +4,6 @@ on: push: branches: [main, develop] pull_request: - branches: [main, develop] jobs: lint: From 6c9ef5de11a0db0688945ce1ea49d63ab5b7bc42 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 2 Feb 2026 09:09:06 -0800 Subject: [PATCH 06/26] Disable broken URL test until upstream fix Skip test_documentation_urls_accessible; HED docs URL returns 404 due to upstream repo change. See #139. --- tests/test_assistants/test_community_yaml_generic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_assistants/test_community_yaml_generic.py b/tests/test_assistants/test_community_yaml_generic.py index 1bc0d54..c796774 100644 --- a/tests/test_assistants/test_community_yaml_generic.py +++ b/tests/test_assistants/test_community_yaml_generic.py @@ -96,6 +96,7 @@ def test_documentation_urls_valid_format(self, community_id): ) @pytest.mark.slow + @pytest.mark.skip(reason="Disabled: upstream HED URL broken (404). See #139") def test_documentation_urls_accessible(self, community_id): """All documentation source URLs should return HTTP 200. From 062293ce2795fcee2dc673e66e0ddf35528557b9 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 2 Feb 2026 09:09:06 -0800 Subject: [PATCH 07/26] Disable broken URL test until upstream fix Skip test_documentation_urls_accessible; HED docs URL returns 404 due to upstream repo change. See #139. --- tests/test_assistants/test_community_yaml_generic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_assistants/test_community_yaml_generic.py b/tests/test_assistants/test_community_yaml_generic.py index 1bc0d54..c796774 100644 --- a/tests/test_assistants/test_community_yaml_generic.py +++ b/tests/test_assistants/test_community_yaml_generic.py @@ -96,6 +96,7 @@ def test_documentation_urls_valid_format(self, community_id): ) @pytest.mark.slow + @pytest.mark.skip(reason="Disabled: upstream HED URL broken (404). See #139") def test_documentation_urls_accessible(self, community_id): """All documentation source URLs should return HTTP 200. From 652ba1dc2a9af855359274fcd27014f16966c9b3 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Mon, 2 Feb 2026 09:11:11 -0800 Subject: [PATCH 08/26] feat: backend metrics collection and request logging (#138) * feat: add backend metrics collection and request logging - Add src/metrics/ package with SQLite storage (WAL mode), aggregation queries, and request timing middleware - Add global /metrics/overview and /metrics/tokens endpoints - Add per-community /{id}/metrics and /{id}/metrics/usage - Log token usage, model, key_source, tools for ask/chat - Streaming handlers log metrics at end of generator - Middleware captures timing for all requests - All metrics endpoints require admin auth Closes #134 * Address PR review: error handling and type safety fixes - Wrap middleware dispatch in try/except so metrics never crash requests - Wrap init_metrics_db() in try/except for graceful degradation - Always return AssistantWithMetrics (remove return_metrics flag) - Narrow log_request except to sqlite3.Error - Add try/except to _log_streaming_metrics - Log metrics on streaming error paths (400/500) - Add sqlite3 error handling to all metrics endpoints (503) - Add logger + warning for malformed JSON in queries.py - Fix middleware ordering comment - Remove redundant inline import uuid - Move get_metrics_connection to top-level import * CI: run lint and tests on all PRs Remove branch filter from pull_request trigger in tests.yml so lint and unit tests run on PRs targeting any branch, not just main/develop. This ensures feature branch PRs to epic branches still get CI coverage. * Disable broken URL test until upstream fix Skip test_documentation_urls_accessible; HED docs URL returns 404 due to upstream repo change. See #139. --- src/api/main.py | 22 +- src/api/routers/__init__.py | 3 +- src/api/routers/community.py | 265 ++++++++++++++++++-- src/api/routers/metrics.py | 61 +++++ src/metrics/__init__.py | 1 + src/metrics/db.py | 209 ++++++++++++++++ src/metrics/middleware.py | 92 +++++++ src/metrics/queries.py | 290 ++++++++++++++++++++++ tests/test_api/test_metrics_endpoints.py | 211 ++++++++++++++++ tests/test_metrics/__init__.py | 1 + tests/test_metrics/test_db.py | 235 ++++++++++++++++++ tests/test_metrics/test_middleware.py | 102 ++++++++ tests/test_metrics/test_queries.py | 303 +++++++++++++++++++++++ 13 files changed, 1773 insertions(+), 22 deletions(-) create mode 100644 src/api/routers/metrics.py create mode 100644 src/metrics/__init__.py create mode 100644 src/metrics/db.py create mode 100644 src/metrics/middleware.py create mode 100644 src/metrics/queries.py create mode 100644 tests/test_api/test_metrics_endpoints.py create mode 100644 tests/test_metrics/__init__.py create mode 100644 tests/test_metrics/test_db.py create mode 100644 tests/test_metrics/test_middleware.py create mode 100644 tests/test_metrics/test_queries.py diff --git a/src/api/main.py b/src/api/main.py index 04c7dc4..2dcc139 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -14,11 +14,13 @@ from pydantic import BaseModel from src.api.config import get_settings -from src.api.routers import create_community_router, sync_router +from src.api.routers import create_community_router, metrics_router, sync_router from src.api.routers.health import router as health_router from src.api.routers.widget_test import router as widget_test_router from src.api.scheduler import start_scheduler, stop_scheduler from src.assistants import discover_assistants, registry +from src.metrics.db import init_metrics_db +from src.metrics.middleware import MetricsMiddleware logger = logging.getLogger(__name__) @@ -53,6 +55,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.state.settings = settings app.state.start_time = datetime.now(UTC) + # Initialize metrics database (non-critical; degrade gracefully if unavailable) + try: + init_metrics_db() + except Exception: + logger.error( + "Failed to initialize metrics database. Metrics collection will be unavailable. " + "Check DATA_DIR permissions and disk space.", + exc_info=True, + ) + # Start background scheduler for knowledge sync scheduler = start_scheduler() app.state.scheduler = scheduler @@ -153,6 +165,9 @@ def create_app() -> FastAPI: cors_kwargs["allow_origin_regex"] = origin_regex app.add_middleware(CORSMiddleware, **cors_kwargs) + # Metrics middleware - captures request timing and logs to metrics DB + app.add_middleware(MetricsMiddleware) + # Register routes register_routes(app) @@ -179,6 +194,9 @@ def register_routes(app: FastAPI) -> None: # Sync router (not community-specific) app.include_router(sync_router) + # Metrics router (global metrics endpoints) + app.include_router(metrics_router) + # Health check router app.include_router(health_router) @@ -226,6 +244,8 @@ async def root() -> dict[str, Any]: endpoints["GET /sync/status"] = "Knowledge sync status" endpoints["GET /sync/health"] = "Sync health check" endpoints["POST /sync/trigger"] = "Trigger sync (requires API key)" + endpoints["GET /metrics/overview"] = "Metrics overview (requires admin key)" + endpoints["GET /metrics/tokens"] = "Token breakdown (requires admin key)" endpoints["GET /health"] = "Health check" return { diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py index 76dffb5..c2a7692 100644 --- a/src/api/routers/__init__.py +++ b/src/api/routers/__init__.py @@ -1,6 +1,7 @@ """API routers for Open Science Assistant.""" from src.api.routers.community import create_community_router +from src.api.routers.metrics import router as metrics_router from src.api.routers.sync import router as sync_router -__all__ = ["create_community_router", "sync_router"] +__all__ = ["create_community_router", "metrics_router", "sync_router"] diff --git a/src/api/routers/community.py b/src/api/routers/community.py index 479d655..4062ffd 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -7,23 +7,34 @@ import hashlib import json import logging +import time import uuid from collections.abc import AsyncGenerator +from dataclasses import dataclass from datetime import UTC, datetime -from typing import Annotated, Literal +from typing import Annotated, Any, Literal -from fastapi import APIRouter, Header, HTTPException, Request +from fastapi import APIRouter, Header, HTTPException, Query, Request from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessage, HumanMessage from pydantic import BaseModel, Field, field_validator from src.api.config import get_settings -from src.api.security import RequireAuth +from src.api.security import RequireAdminAuth, RequireAuth from src.assistants import registry from src.assistants.community import CommunityAssistant from src.assistants.community import PageContext as AgentPageContext from src.assistants.registry import AssistantInfo from src.core.services.litellm_llm import create_openrouter_llm +from src.metrics.db import ( + RequestLogEntry, + extract_token_usage, + extract_tool_names, + get_metrics_connection, + log_request, + now_iso, +) +from src.metrics.queries import get_community_summary, get_usage_stats logger = logging.getLogger(__name__) @@ -604,6 +615,15 @@ def _get_cache_user_id(community_id: str, api_key: str | None, user_id: str | No return f"{community_id}_widget" +@dataclass +class AssistantWithMetrics: + """Community assistant bundled with metadata for metrics logging.""" + + assistant: CommunityAssistant + model: str + key_source: str + + def create_community_assistant( community_id: str, byok: str | None = None, @@ -612,13 +632,13 @@ def create_community_assistant( requested_model: str | None = None, preload_docs: bool = True, page_context: PageContext | None = None, -) -> CommunityAssistant: +) -> AssistantWithMetrics: """Create a community assistant instance with authorization checks. **Authorization:** - - If BYOK provided → always allowed - - If origin matches community CORS → can use community/platform keys - - Otherwise → rejects with 403 (CLI/unauthorized must provide BYOK) + - If BYOK provided -> always allowed + - If origin matches community CORS -> can use community/platform keys + - Otherwise -> rejects with 403 (CLI/unauthorized must provide BYOK) **Model Selection:** - Custom model requests require BYOK @@ -634,7 +654,8 @@ def create_community_assistant( page_context: Optional context about the page where the widget is embedded Returns: - Configured CommunityAssistant instance + AssistantWithMetrics containing the assistant, resolved model, and key source. + Access the assistant via .assistant attribute. Raises: ValueError: If community_id is not registered @@ -683,13 +704,19 @@ def create_community_assistant( title=page_context.title, ) - return registry.create_assistant( + assistant = registry.create_assistant( community_id, model=model, preload_docs=preload_docs, page_context=agent_page_context, ) + return AssistantWithMetrics( + assistant=assistant, + model=selected_model, + key_source=key_source, + ) + # --------------------------------------------------------------------------- # Router Factory @@ -765,6 +792,7 @@ async def ask( x_user_id, body.page_context, body.model, + http_request=http_request, ), media_type="text/event-stream", headers={ @@ -774,7 +802,7 @@ async def ask( ) try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=x_openrouter_key, origin=origin, @@ -782,6 +810,7 @@ async def ask( requested_model=body.model, page_context=body.page_context, ) + assistant = awm.assistant messages = [HumanMessage(content=body.question)] result = await assistant.ainvoke(messages) @@ -798,6 +827,18 @@ async def ask( for tc in result.get("tool_calls", []) ] + # Store agent metrics on request.state for middleware to log + inp, out, total = extract_token_usage(result) + http_request.state.metrics_agent_data = { + "model": awm.model, + "key_source": awm.key_source, + "input_tokens": inp, + "output_tokens": out, + "total_tokens": total, + "tools_called": extract_tool_names(result), + "stream": False, + } + return AskResponse(answer=response_content, tool_calls=tool_calls_info) except Exception as e: @@ -858,7 +899,13 @@ async def chat( if body.stream: return StreamingResponse( _stream_chat_response( - community_id, session, x_openrouter_key, origin, user_id, body.model + community_id, + session, + x_openrouter_key, + origin, + user_id, + body.model, + http_request=http_request, ), media_type="text/event-stream", headers={ @@ -869,13 +916,14 @@ async def chat( ) try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=x_openrouter_key, origin=origin, user_id=user_id, requested_model=body.model, ) + assistant = awm.assistant result = await assistant.ainvoke(session.messages) response_content = "" @@ -891,6 +939,18 @@ async def chat( for tc in result.get("tool_calls", []) ] + # Store agent metrics on request.state for middleware to log + inp, out, total = extract_token_usage(result) + http_request.state.metrics_agent_data = { + "model": awm.model, + "key_source": awm.key_source, + "input_tokens": inp, + "output_tokens": out, + "total_tokens": total, + "tools_called": extract_tool_names(result), + "stream": False, + } + # Add assistant message with constraint validation try: session.add_assistant_message(response_content) @@ -987,9 +1047,102 @@ async def get_community_config() -> CommunityConfigResponse: default_model_provider=default_provider, ) + # ----------------------------------------------------------------------- + # Per-community Metrics Endpoints + # ----------------------------------------------------------------------- + + @router.get("/metrics") + async def community_metrics(_auth: RequireAdminAuth) -> dict[str, Any]: + """Get metrics summary for this community. Requires admin auth.""" + import sqlite3 + + conn = get_metrics_connection() + try: + return get_community_summary(community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query metrics for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + finally: + conn.close() + + @router.get("/metrics/usage") + async def community_usage( + _auth: RequireAdminAuth, + period: str = Query( + default="daily", + description="Time bucket period", + pattern="^(daily|weekly|monthly)$", + ), + ) -> dict[str, Any]: + """Get time-bucketed usage stats for this community. Requires admin auth.""" + import sqlite3 + + conn = get_metrics_connection() + try: + return get_usage_stats(community_id, period, conn) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except sqlite3.Error: + logger.exception("Failed to query usage stats for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + finally: + conn.close() + return router +# --------------------------------------------------------------------------- +# Metrics Helpers +# --------------------------------------------------------------------------- + + +def _log_streaming_metrics( + http_request: Request | None, + community_id: str, + endpoint: str, + awm: AssistantWithMetrics | None, + tools_called: list[str], + start_time: float, + status_code: int, +) -> None: + """Log metrics at the end of a streaming response. + + Called directly from streaming generators since middleware fires + before streaming completes. Wrapped in try/except to never disrupt + the SSE stream on failure. + """ + try: + duration_ms = (time.monotonic() - start_time) * 1000 + request_id = str(uuid.uuid4()) + if http_request: + request_id = getattr(http_request.state, "request_id", request_id) + # Mark as logged so middleware doesn't double-log + http_request.state.metrics_logged = True + + entry = RequestLogEntry( + request_id=request_id, + timestamp=now_iso(), + endpoint=endpoint, + method="POST", + community_id=community_id, + duration_ms=round(duration_ms, 1), + status_code=status_code, + model=awm.model if awm else None, + key_source=awm.key_source if awm else None, + tools_called=tools_called, + stream=True, + ) + log_request(entry) + except Exception: + logger.exception("Failed to log streaming metrics for %s", endpoint) + + # --------------------------------------------------------------------------- # Streaming Helpers # --------------------------------------------------------------------------- @@ -1003,6 +1156,7 @@ async def _stream_ask_response( user_id: str | None, page_context: PageContext | None = None, requested_model: str | None = None, + http_request: Request | None = None, ) -> AsyncGenerator[str, None]: """Stream response for ask endpoint with JSON-encoded SSE events. @@ -1013,8 +1167,12 @@ async def _stream_ask_response( data: {"event": "done"} data: {"event": "error", "message": "error text"} """ + start_time = time.monotonic() + tools_called: list[str] = [] + awm: AssistantWithMetrics | None = None + try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=byok, origin=origin, @@ -1023,7 +1181,7 @@ async def _stream_ask_response( preload_docs=True, page_context=page_context, ) - graph = assistant.build_graph() + graph = awm.assistant.build_graph() state = { "messages": [HumanMessage(content=question)], @@ -1042,9 +1200,12 @@ async def _stream_ask_response( elif kind == "on_tool_start": tool_input = event.get("data", {}).get("input", {}) + tool_name = event.get("name", "") + if tool_name: + tools_called.append(tool_name) sse_event = { "event": "tool_start", - "name": event.get("name", ""), + "name": tool_name, "input": tool_input if isinstance(tool_input, dict) else {}, } yield f"data: {json.dumps(sse_event)}\n\n" @@ -1061,6 +1222,17 @@ async def _stream_ask_response( sse_event = {"event": "done"} yield f"data: {json.dumps(sse_event)}\n\n" + # Log metrics at end of streaming + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=200, + ) + except HTTPException: # Don't catch our own HTTP exceptions - let them propagate raise @@ -1073,10 +1245,17 @@ async def _stream_ask_response( "retryable": False, } yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=400, + ) except Exception as e: # Unexpected errors - log with full context - import uuid - error_id = str(uuid.uuid4()) logger.error( "Unexpected streaming error (ID: %s) in ask endpoint for community %s: %s", @@ -1096,6 +1275,15 @@ async def _stream_ask_response( "error_id": error_id, } yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=500, + ) async def _stream_chat_response( @@ -1105,6 +1293,7 @@ async def _stream_chat_response( origin: str | None, user_id: str | None, requested_model: str | None = None, + http_request: Request | None = None, ) -> AsyncGenerator[str, None]: """Stream assistant response as JSON-encoded Server-Sent Events. @@ -1115,8 +1304,12 @@ async def _stream_chat_response( data: {"event": "done", "session_id": "..."} data: {"event": "error", "message": "error text"} """ + start_time = time.monotonic() + tools_called: list[str] = [] + awm: AssistantWithMetrics | None = None + try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=byok, origin=origin, @@ -1124,7 +1317,7 @@ async def _stream_chat_response( requested_model=requested_model, preload_docs=True, ) - graph = assistant.build_graph() + graph = awm.assistant.build_graph() state = { "messages": session.messages.copy(), @@ -1147,9 +1340,12 @@ async def _stream_chat_response( elif kind == "on_tool_start": tool_input = event.get("data", {}).get("input", {}) + tool_name = event.get("name", "") + if tool_name: + tools_called.append(tool_name) sse_event = { "event": "tool_start", - "name": event.get("name", ""), + "name": tool_name, "input": tool_input if isinstance(tool_input, dict) else {}, } yield f"data: {json.dumps(sse_event)}\n\n" @@ -1176,11 +1372,31 @@ async def _stream_chat_response( sse_event = {"event": "done", "session_id": session.session_id} yield f"data: {json.dumps(sse_event)}\n\n" + # Log metrics at end of streaming + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=200, + ) + except ValueError as e: # Session limit errors logger.error("Session limit error: %s", e) sse_event = {"event": "error", "message": str(e)} yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=400, + ) except Exception as e: logger.error( "Streaming error in chat endpoint for session %s (community: %s): %s", @@ -1194,3 +1410,12 @@ async def _stream_chat_response( "message": "An error occurred while processing your request.", } yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=500, + ) diff --git a/src/api/routers/metrics.py b/src/api/routers/metrics.py new file mode 100644 index 0000000..752be95 --- /dev/null +++ b/src/api/routers/metrics.py @@ -0,0 +1,61 @@ +"""Global metrics API endpoints. + +Provides cross-community metrics overview and token breakdowns. +All endpoints require admin authentication. +""" + +import logging +import sqlite3 +from typing import Any + +from fastapi import APIRouter, HTTPException, Query + +from src.api.security import RequireAdminAuth +from src.metrics.db import get_metrics_connection +from src.metrics.queries import get_overview, get_token_breakdown + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/metrics", tags=["Metrics"]) + + +@router.get("/overview") +async def metrics_overview(_auth: RequireAdminAuth) -> dict[str, Any]: + """Get cross-community metrics overview. + + Returns total requests, tokens, average duration, error rate, + and per-community breakdown. + """ + conn = get_metrics_connection() + try: + return get_overview(conn) + except sqlite3.Error: + logger.exception("Failed to query metrics database for overview") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + finally: + conn.close() + + +@router.get("/tokens") +async def token_breakdown( + _auth: RequireAdminAuth, + community_id: str | None = Query(default=None, description="Filter by community"), +) -> dict[str, Any]: + """Get token usage breakdown by model and key source. + + Optionally filter by community_id. + """ + conn = get_metrics_connection() + try: + return get_token_breakdown(conn, community_id=community_id) + except sqlite3.Error: + logger.exception("Failed to query metrics database for token breakdown") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + finally: + conn.close() diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 0000000..d05dfb5 --- /dev/null +++ b/src/metrics/__init__.py @@ -0,0 +1 @@ +"""Metrics collection and request logging for OSA.""" diff --git a/src/metrics/db.py b/src/metrics/db.py new file mode 100644 index 0000000..f716046 --- /dev/null +++ b/src/metrics/db.py @@ -0,0 +1,209 @@ +"""Metrics storage layer using SQLite with WAL mode. + +Single SQLite database at {data_dir}/metrics.db stores all request logs. +WAL mode enables concurrent reads during writes. +""" + +import json +import logging +import sqlite3 +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path + +from langchain_core.messages import AIMessage, BaseMessage + +logger = logging.getLogger(__name__) + +METRICS_SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS request_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + request_id TEXT NOT NULL, + timestamp TEXT NOT NULL, + community_id TEXT, + endpoint TEXT NOT NULL, + method TEXT NOT NULL, + duration_ms REAL, + status_code INTEGER, + model TEXT, + input_tokens INTEGER, + output_tokens INTEGER, + total_tokens INTEGER, + estimated_cost REAL, + tools_called TEXT, + key_source TEXT, + stream INTEGER DEFAULT 0 +); + +CREATE INDEX IF NOT EXISTS idx_request_log_community + ON request_log(community_id); +CREATE INDEX IF NOT EXISTS idx_request_log_timestamp + ON request_log(timestamp); +CREATE INDEX IF NOT EXISTS idx_request_log_community_timestamp + ON request_log(community_id, timestamp); +""" + + +@dataclass +class RequestLogEntry: + """A single request log entry for the metrics database.""" + + request_id: str + timestamp: str + endpoint: str + method: str + community_id: str | None = None + duration_ms: float | None = None + status_code: int | None = None + model: str | None = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + estimated_cost: float | None = None + tools_called: list[str] = field(default_factory=list) + key_source: str | None = None + stream: bool = False + + +def get_metrics_db_path() -> Path: + """Return path to the metrics SQLite database. + + Uses DATA_DIR environment variable if set (Docker deployments), + otherwise falls back to platform-specific user data directory. + """ + import os + + from platformdirs import user_data_dir + + data_dir_env = os.environ.get("DATA_DIR") + base = Path(data_dir_env) if data_dir_env else Path(user_data_dir("osa", "osc")) + base.mkdir(parents=True, exist_ok=True) + return base / "metrics.db" + + +def get_metrics_connection(db_path: Path | None = None) -> sqlite3.Connection: + """Get a connection to the metrics database. + + Args: + db_path: Optional path override (for testing). + """ + path = db_path or get_metrics_db_path() + conn = sqlite3.connect(str(path)) + conn.row_factory = sqlite3.Row + return conn + + +def init_metrics_db(db_path: Path | None = None) -> None: + """Initialize the metrics database schema. Idempotent. + + Creates the request_log table and indexes if they don't exist. + Enables WAL mode for concurrent read/write access. + + Args: + db_path: Optional path override (for testing). + """ + conn = get_metrics_connection(db_path) + try: + conn.execute("PRAGMA journal_mode=WAL") + conn.executescript(METRICS_SCHEMA_SQL) + conn.commit() + logger.info("Metrics database initialized at %s", db_path or get_metrics_db_path()) + finally: + conn.close() + + +def log_request(entry: RequestLogEntry, db_path: Path | None = None) -> None: + """Insert a request log entry into the database. + + Args: + entry: The log entry to insert. + db_path: Optional path override (for testing). + """ + conn = get_metrics_connection(db_path) + try: + conn.execute( + """ + INSERT INTO request_log ( + request_id, timestamp, community_id, endpoint, method, + duration_ms, status_code, model, input_tokens, output_tokens, + total_tokens, estimated_cost, tools_called, key_source, stream + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.request_id, + entry.timestamp, + entry.community_id, + entry.endpoint, + entry.method, + entry.duration_ms, + entry.status_code, + entry.model, + entry.input_tokens, + entry.output_tokens, + entry.total_tokens, + entry.estimated_cost, + json.dumps(entry.tools_called) if entry.tools_called else None, + entry.key_source, + 1 if entry.stream else 0, + ), + ) + conn.commit() + except sqlite3.Error: + logger.exception( + "Failed to log metrics request %s (endpoint=%s, community=%s)", + entry.request_id, + entry.endpoint, + entry.community_id, + ) + finally: + conn.close() + + +def extract_token_usage(result: dict) -> tuple[int, int, int]: + """Extract token usage from agent result messages. + + Sums usage_metadata from all AIMessages in result["messages"]. + + Args: + result: Agent result dict containing "messages" list. + + Returns: + Tuple of (input_tokens, output_tokens, total_tokens). + Returns (0, 0, 0) if no usage data is available. + """ + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + + messages: list[BaseMessage] = result.get("messages", []) + for msg in messages: + if not isinstance(msg, AIMessage): + continue + usage = getattr(msg, "usage_metadata", None) + if usage is None: + continue + # usage_metadata is a dict with input_tokens, output_tokens, total_tokens + if isinstance(usage, dict): + input_tokens += usage.get("input_tokens", 0) + output_tokens += usage.get("output_tokens", 0) + total_tokens += usage.get("total_tokens", 0) + + return input_tokens, output_tokens, total_tokens + + +def extract_tool_names(result: dict) -> list[str]: + """Extract tool names from agent result. + + Args: + result: Agent result dict containing "tool_calls" list. + + Returns: + List of tool names called during the request. + """ + tool_calls = result.get("tool_calls", []) + return [tc.get("name", "") for tc in tool_calls if tc.get("name")] + + +def now_iso() -> str: + """Return current UTC time as ISO string.""" + return datetime.now(UTC).isoformat() diff --git a/src/metrics/middleware.py b/src/metrics/middleware.py new file mode 100644 index 0000000..7b94d46 --- /dev/null +++ b/src/metrics/middleware.py @@ -0,0 +1,92 @@ +"""Request timing and metrics middleware. + +Captures request-scoped data (endpoint, duration, status_code, timestamp). +For agent requests, the handler sets metrics on request.state which the +middleware reads after the response completes. + +Streaming caveat: For streaming responses, the middleware fires before +streaming completes. Streaming handlers log metrics directly at the end +of the generator instead. +""" + +import logging +import time +import uuid + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +from src.metrics.db import RequestLogEntry, log_request, now_iso + +logger = logging.getLogger(__name__) + +# Path segments that indicate a community route +_COMMUNITY_ENDPOINTS = {"/ask", "/chat"} + + +def _extract_community_id(path: str) -> str | None: + """Extract community_id from URL path like /{community_id}/ask.""" + parts = path.strip("/").split("/") + if len(parts) >= 2 and f"/{parts[1]}" in _COMMUNITY_ENDPOINTS: + return parts[0] + return None + + +class MetricsMiddleware(BaseHTTPMiddleware): + """Middleware that logs request timing and metrics. + + Sets request.state.request_id and request.state.start_time for + downstream handlers to use. After the response, logs a basic + request entry unless the handler has set request.state.metrics_logged + (indicating the handler logged its own detailed entry, e.g. streaming). + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + request_id = str(uuid.uuid4()) + request.state.request_id = request_id + request.state.start_time = time.monotonic() + request.state.metrics_logged = False + + response = await call_next(request) + + try: + # If handler already logged metrics (streaming), skip + if getattr(request.state, "metrics_logged", False): + return response + + duration_ms = (time.monotonic() - request.state.start_time) * 1000 + community_id = _extract_community_id(request.url.path) + + # Check if handler set agent metrics on request.state + agent_data = getattr(request.state, "metrics_agent_data", None) + + agent_kwargs = {} + if agent_data and isinstance(agent_data, dict): + agent_kwargs = { + "model": agent_data.get("model"), + "input_tokens": agent_data.get("input_tokens"), + "output_tokens": agent_data.get("output_tokens"), + "total_tokens": agent_data.get("total_tokens"), + "estimated_cost": agent_data.get("estimated_cost"), + "tools_called": agent_data.get("tools_called", []), + "key_source": agent_data.get("key_source"), + "stream": agent_data.get("stream", False), + } + + entry = RequestLogEntry( + request_id=request_id, + timestamp=now_iso(), + endpoint=request.url.path, + method=request.method, + community_id=community_id, + duration_ms=round(duration_ms, 1), + status_code=response.status_code, + **agent_kwargs, + ) + + log_request(entry) + except Exception: + logger.exception("Metrics middleware failed for request %s", request_id) + + return response diff --git a/src/metrics/queries.py b/src/metrics/queries.py new file mode 100644 index 0000000..ec67e0a --- /dev/null +++ b/src/metrics/queries.py @@ -0,0 +1,290 @@ +"""Aggregation queries for the metrics database. + +Provides summary statistics, usage breakdowns, and overview queries +for both per-community and cross-community metrics. +""" + +import json +import logging +import sqlite3 +from typing import Any + +logger = logging.getLogger(__name__) + + +def get_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: + """Get summary statistics for a single community. + + Args: + community_id: The community identifier. + conn: SQLite connection (with row_factory=sqlite3.Row). + + Returns: + Dict with total_requests, total_tokens, avg_duration_ms, + error_rate, top_models, top_tools. + """ + row = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as total_estimated_cost, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_count + FROM request_log + WHERE community_id = ? + """, + (community_id,), + ).fetchone() + + total = row["total_requests"] + error_rate = row["error_count"] / total if total > 0 else 0.0 + + # Top models + model_rows = conn.execute( + """ + SELECT model, COUNT(*) as count + FROM request_log + WHERE community_id = ? AND model IS NOT NULL + GROUP BY model + ORDER BY count DESC + LIMIT 5 + """, + (community_id,), + ).fetchall() + + # Top tools (from JSON array in tools_called column) + tool_rows = conn.execute( + """ + SELECT tools_called + FROM request_log + WHERE community_id = ? AND tools_called IS NOT NULL + """, + (community_id,), + ).fetchall() + tool_counts: dict[str, int] = {} + for tr in tool_rows: + try: + tools = json.loads(tr["tools_called"]) + for tool in tools: + tool_counts[tool] = tool_counts.get(tool, 0) + 1 + except (json.JSONDecodeError, TypeError): + logger.warning( + "Malformed tools_called data in request_log for community %s: %r", + community_id, + tr["tools_called"], + ) + top_tools = sorted(tool_counts.items(), key=lambda x: x[1], reverse=True)[:5] + + return { + "community_id": community_id, + "total_requests": total, + "total_input_tokens": row["total_input_tokens"], + "total_output_tokens": row["total_output_tokens"], + "total_tokens": row["total_tokens"], + "avg_duration_ms": round(row["avg_duration_ms"], 1), + "total_estimated_cost": round(row["total_estimated_cost"], 4), + "error_rate": round(error_rate, 4), + "top_models": [{"model": r["model"], "count": r["count"]} for r in model_rows], + "top_tools": [{"tool": t[0], "count": t[1]} for t in top_tools], + } + + +def get_usage_stats( + community_id: str, + period: str, + conn: sqlite3.Connection, +) -> dict[str, Any]: + """Get time-bucketed usage statistics for a community. + + Args: + community_id: The community identifier. + period: One of "daily", "weekly", "monthly". + conn: SQLite connection. + + Returns: + Dict with period, community_id, and buckets list. + """ + # SQLite strftime patterns for bucketing + format_map = { + "daily": "%Y-%m-%d", + "weekly": "%Y-W%W", + "monthly": "%Y-%m", + } + if period not in format_map: + raise ValueError(f"Invalid period: {period}. Must be one of: daily, weekly, monthly") + + fmt = format_map[period] + + # Safe to use f-string: fmt is from a hardcoded whitelist, not user input + rows = conn.execute( + f""" + SELECT + strftime('{fmt}', timestamp) as bucket, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as estimated_cost, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors + FROM request_log + WHERE community_id = ? + GROUP BY bucket + ORDER BY bucket + """, + (community_id,), + ).fetchall() + + return { + "community_id": community_id, + "period": period, + "buckets": [ + { + "bucket": r["bucket"], + "requests": r["requests"], + "tokens": r["tokens"], + "avg_duration_ms": round(r["avg_duration_ms"], 1), + "estimated_cost": round(r["estimated_cost"], 4), + "errors": r["errors"], + } + for r in rows + ], + } + + +def get_overview(conn: sqlite3.Connection) -> dict[str, Any]: + """Get cross-community metrics overview. + + Args: + conn: SQLite connection. + + Returns: + Dict with total stats and per-community breakdown. + """ + # Global totals + totals = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as total_estimated_cost, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as total_errors + FROM request_log + """ + ).fetchone() + + total_req = totals["total_requests"] + + # Per-community breakdown + community_rows = conn.execute( + """ + SELECT + community_id, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as estimated_cost + FROM request_log + WHERE community_id IS NOT NULL + GROUP BY community_id + ORDER BY requests DESC + """ + ).fetchall() + + return { + "total_requests": total_req, + "total_tokens": totals["total_tokens"], + "avg_duration_ms": round(totals["avg_duration_ms"], 1), + "total_estimated_cost": round(totals["total_estimated_cost"], 4), + "error_rate": round(totals["total_errors"] / total_req, 4) if total_req > 0 else 0.0, + "communities": [ + { + "community_id": r["community_id"], + "requests": r["requests"], + "tokens": r["tokens"], + "avg_duration_ms": round(r["avg_duration_ms"], 1), + "estimated_cost": round(r["estimated_cost"], 4), + } + for r in community_rows + ], + } + + +def get_token_breakdown( + conn: sqlite3.Connection, + community_id: str | None = None, +) -> dict[str, Any]: + """Get token usage breakdown by model and key_source. + + Args: + conn: SQLite connection. + community_id: Optional filter by community. + + Returns: + Dict with by_model and by_key_source breakdowns. + """ + where = "" + params: tuple = () + if community_id: + where = "WHERE community_id = ?" + params = (community_id,) + + by_model = conn.execute( + f""" + SELECT + model, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(estimated_cost), 0) as estimated_cost + FROM request_log + {where} + {"AND" if where else "WHERE"} model IS NOT NULL + GROUP BY model + ORDER BY total_tokens DESC + """, + params, + ).fetchall() + + by_key_source = conn.execute( + f""" + SELECT + key_source, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(estimated_cost), 0) as estimated_cost + FROM request_log + {where} + {"AND" if where else "WHERE"} key_source IS NOT NULL + GROUP BY key_source + ORDER BY requests DESC + """, + params, + ).fetchall() + + return { + "community_id": community_id, + "by_model": [ + { + "model": r["model"], + "requests": r["requests"], + "input_tokens": r["input_tokens"], + "output_tokens": r["output_tokens"], + "total_tokens": r["total_tokens"], + "estimated_cost": round(r["estimated_cost"], 4), + } + for r in by_model + ], + "by_key_source": [ + { + "key_source": r["key_source"], + "requests": r["requests"], + "total_tokens": r["total_tokens"], + "estimated_cost": round(r["estimated_cost"], 4), + } + for r in by_key_source + ], + } diff --git a/tests/test_api/test_metrics_endpoints.py b/tests/test_api/test_metrics_endpoints.py new file mode 100644 index 0000000..685dd24 --- /dev/null +++ b/tests/test_api/test_metrics_endpoints.py @@ -0,0 +1,211 @@ +"""Tests for metrics API endpoints.""" + +import os +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from src.api.main import app +from src.metrics.db import ( + RequestLogEntry, + init_metrics_db, + log_request, +) + +ADMIN_KEY = "test-metrics-admin-key" + + +@pytest.fixture +def metrics_db(tmp_path): + """Create isolated metrics database with sample data.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="r1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + estimated_cost=0.001, + tools_called=["search_docs"], + key_source="platform", + ), + RequestLogEntry( + request_id="r2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/chat", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=200, + output_tokens=100, + total_tokens=300, + key_source="byok", + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +@pytest.fixture +def isolated_metrics(metrics_db): + """Patch metrics DB path for all metrics code.""" + with patch("src.metrics.db.get_metrics_db_path", return_value=metrics_db): + yield metrics_db + + +@pytest.fixture +def auth_env(): + """Set up environment for admin auth and clear settings cache.""" + from src.api.config import get_settings + + os.environ["API_KEYS"] = ADMIN_KEY + os.environ["REQUIRE_API_AUTH"] = "true" + get_settings.cache_clear() + + yield + + del os.environ["API_KEYS"] + del os.environ["REQUIRE_API_AUTH"] + get_settings.cache_clear() + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +class TestMetricsOverview: + """Tests for GET /metrics/overview.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200_with_admin_key(self, client): + response = client.get("/metrics/overview", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/metrics/overview", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert "total_requests" in data + assert "total_tokens" in data + assert "avg_duration_ms" in data + assert "error_rate" in data + assert "communities" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_401_without_key(self, client): + response = client.get("/metrics/overview") + assert response.status_code == 401 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_403_with_invalid_key(self, client): + response = client.get("/metrics/overview", headers={"X-API-Key": "wrong-key"}) + assert response.status_code == 403 + + +class TestTokenBreakdown: + """Tests for GET /metrics/tokens.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/metrics/tokens", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/metrics/tokens", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert "by_model" in data + assert "by_key_source" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_filter_by_community(self, client): + response = client.get( + "/metrics/tokens", + params={"community_id": "hed"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + data = response.json() + assert data["community_id"] == "hed" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_admin_auth(self, client): + response = client.get("/metrics/tokens") + assert response.status_code == 401 + + +class TestCommunityMetrics: + """Tests for GET /{community_id}/metrics.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/hed/metrics", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["community_id"] == "hed" + assert "total_requests" in data + assert "total_tokens" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_admin_auth(self, client): + response = client.get("/hed/metrics") + assert response.status_code == 401 + + +class TestCommunityUsage: + """Tests for GET /{community_id}/metrics/usage.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_daily_usage(self, client): + response = client.get( + "/hed/metrics/usage", + params={"period": "daily"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 200 + data = response.json() + assert data["period"] == "daily" + assert "buckets" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_monthly_usage(self, client): + response = client.get( + "/hed/metrics/usage", + params={"period": "monthly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_invalid_period_returns_422(self, client): + """Invalid period rejected by Query pattern validation.""" + response = client.get( + "/hed/metrics/usage", + params={"period": "hourly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 422 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_admin_auth(self, client): + response = client.get("/hed/metrics/usage", params={"period": "daily"}) + assert response.status_code == 401 diff --git a/tests/test_metrics/__init__.py b/tests/test_metrics/__init__.py new file mode 100644 index 0000000..97451ef --- /dev/null +++ b/tests/test_metrics/__init__.py @@ -0,0 +1 @@ +"""Tests for the metrics module.""" diff --git a/tests/test_metrics/test_db.py b/tests/test_metrics/test_db.py new file mode 100644 index 0000000..c25d702 --- /dev/null +++ b/tests/test_metrics/test_db.py @@ -0,0 +1,235 @@ +"""Tests for metrics database storage layer.""" + +import json + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from src.metrics.db import ( + RequestLogEntry, + extract_token_usage, + extract_tool_names, + get_metrics_connection, + init_metrics_db, + log_request, + now_iso, +) + + +@pytest.fixture +def metrics_db(tmp_path): + """Create a temporary metrics database.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + return db_path + + +class TestInitMetricsDb: + """Tests for init_metrics_db().""" + + def test_creates_table(self, tmp_path): + """init_metrics_db creates the request_log table.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='request_log'" + ).fetchall() + assert len(tables) == 1 + finally: + conn.close() + + def test_idempotent(self, tmp_path): + """Calling init_metrics_db twice does not error.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='request_log'" + ).fetchall() + assert len(tables) == 1 + finally: + conn.close() + + def test_wal_mode(self, tmp_path): + """init_metrics_db sets WAL journal mode.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode == "wal" + finally: + conn.close() + + def test_creates_indexes(self, tmp_path): + """init_metrics_db creates the expected indexes.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + indexes = conn.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='request_log'" + ).fetchall() + index_names = {r[0] for r in indexes} + assert "idx_request_log_community" in index_names + assert "idx_request_log_timestamp" in index_names + assert "idx_request_log_community_timestamp" in index_names + finally: + conn.close() + + +class TestLogRequest: + """Tests for log_request().""" + + def test_inserts_entry(self, metrics_db): + """log_request inserts a row that can be read back.""" + entry = RequestLogEntry( + request_id="test-123", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=150.5, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + tools_called=["search_docs", "validate_hed"], + key_source="platform", + stream=False, + ) + log_request(entry, db_path=metrics_db) + + conn = get_metrics_connection(metrics_db) + try: + row = conn.execute( + "SELECT * FROM request_log WHERE request_id = ?", ("test-123",) + ).fetchone() + assert row is not None + assert row["community_id"] == "hed" + assert row["endpoint"] == "/hed/ask" + assert row["method"] == "POST" + assert row["duration_ms"] == 150.5 + assert row["status_code"] == 200 + assert row["model"] == "qwen/qwen3-235b" + assert row["input_tokens"] == 100 + assert row["output_tokens"] == 50 + assert row["total_tokens"] == 150 + assert row["key_source"] == "platform" + assert row["stream"] == 0 + tools = json.loads(row["tools_called"]) + assert tools == ["search_docs", "validate_hed"] + finally: + conn.close() + + def test_null_agent_fields(self, metrics_db): + """Non-agent requests have NULL agent fields.""" + entry = RequestLogEntry( + request_id="basic-req", + timestamp=now_iso(), + endpoint="/health", + method="GET", + status_code=200, + duration_ms=5.0, + ) + log_request(entry, db_path=metrics_db) + + conn = get_metrics_connection(metrics_db) + try: + row = conn.execute( + "SELECT * FROM request_log WHERE request_id = ?", ("basic-req",) + ).fetchone() + assert row is not None + assert row["model"] is None + assert row["input_tokens"] is None + assert row["output_tokens"] is None + assert row["total_tokens"] is None + assert row["tools_called"] is None + assert row["key_source"] is None + assert row["community_id"] is None + finally: + conn.close() + + def test_stream_flag(self, metrics_db): + """stream=True is stored as 1.""" + entry = RequestLogEntry( + request_id="stream-req", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + stream=True, + ) + log_request(entry, db_path=metrics_db) + + conn = get_metrics_connection(metrics_db) + try: + row = conn.execute( + "SELECT stream FROM request_log WHERE request_id = ?", ("stream-req",) + ).fetchone() + assert row["stream"] == 1 + finally: + conn.close() + + +class TestExtractTokenUsage: + """Tests for extract_token_usage().""" + + def test_extracts_from_ai_messages(self): + """Extracts token usage from AIMessages with usage_metadata.""" + msg = AIMessage(content="hello") + msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + result = {"messages": [HumanMessage(content="hi"), msg]} + + inp, out, total = extract_token_usage(result) + assert inp == 10 + assert out == 5 + assert total == 15 + + def test_sums_multiple_messages(self): + """Sums usage across multiple AIMessages.""" + msg1 = AIMessage(content="first") + msg1.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + msg2 = AIMessage(content="second") + msg2.usage_metadata = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + result = {"messages": [msg1, msg2]} + + inp, out, total = extract_token_usage(result) + assert inp == 30 + assert out == 15 + assert total == 45 + + def test_returns_zeros_when_no_usage(self): + """Returns (0, 0, 0) when no usage_metadata.""" + result = {"messages": [AIMessage(content="hello")]} + inp, out, total = extract_token_usage(result) + assert (inp, out, total) == (0, 0, 0) + + def test_returns_zeros_for_empty_result(self): + """Returns (0, 0, 0) for empty result.""" + assert extract_token_usage({}) == (0, 0, 0) + assert extract_token_usage({"messages": []}) == (0, 0, 0) + + +class TestExtractToolNames: + """Tests for extract_tool_names().""" + + def test_extracts_names(self): + result = {"tool_calls": [{"name": "search"}, {"name": "validate"}]} + assert extract_tool_names(result) == ["search", "validate"] + + def test_empty_tool_calls(self): + assert extract_tool_names({"tool_calls": []}) == [] + assert extract_tool_names({}) == [] + + def test_skips_empty_names(self): + result = {"tool_calls": [{"name": ""}, {"name": "search"}]} + assert extract_tool_names(result) == ["search"] diff --git a/tests/test_metrics/test_middleware.py b/tests/test_metrics/test_middleware.py new file mode 100644 index 0000000..5ddbc68 --- /dev/null +++ b/tests/test_metrics/test_middleware.py @@ -0,0 +1,102 @@ +"""Tests for metrics middleware.""" + +from unittest.mock import patch + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from src.metrics.middleware import MetricsMiddleware, _extract_community_id + + +class TestExtractCommunityId: + """Tests for _extract_community_id helper.""" + + def test_extracts_from_ask(self): + assert _extract_community_id("/hed/ask") == "hed" + + def test_extracts_from_chat(self): + assert _extract_community_id("/bids/chat") == "bids" + + def test_returns_none_for_non_community(self): + assert _extract_community_id("/health") is None + assert _extract_community_id("/sync/status") is None + assert _extract_community_id("/metrics/overview") is None + + def test_returns_none_for_root(self): + assert _extract_community_id("/") is None + + +class TestMetricsMiddleware: + """Tests for MetricsMiddleware integration.""" + + @pytest.fixture + def test_app(self): + """Create a minimal FastAPI app with MetricsMiddleware.""" + app = FastAPI() + app.add_middleware(MetricsMiddleware) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.post("/hed/ask") + async def ask(request: Request): + # Simulate agent metrics + request.state.metrics_agent_data = { + "model": "test-model", + "key_source": "platform", + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "tools_called": ["search"], + "stream": False, + } + return {"answer": "test"} + + @app.post("/hed/streaming") + async def streaming(request: Request): + # Simulate handler that logs its own metrics + request.state.metrics_logged = True + return {"answer": "streamed"} + + with patch("src.metrics.middleware.log_request") as mock_log: + yield app, mock_log + + def test_logs_basic_request(self, test_app): + """Middleware logs basic request data for non-agent endpoints.""" + app, mock_log = test_app + client = TestClient(app) + response = client.get("/health") + assert response.status_code == 200 + + assert mock_log.called + entry = mock_log.call_args[0][0] + assert entry.endpoint == "/health" + assert entry.method == "GET" + assert entry.status_code == 200 + assert entry.duration_ms > 0 + assert entry.model is None # non-agent + + def test_picks_up_agent_metrics(self, test_app): + """Middleware reads agent metrics from request.state.""" + app, mock_log = test_app + client = TestClient(app) + response = client.post("/hed/ask", json={}) + assert response.status_code == 200 + + assert mock_log.called + entry = mock_log.call_args[0][0] + assert entry.model == "test-model" + assert entry.key_source == "platform" + assert entry.input_tokens == 10 + assert entry.tools_called == ["search"] + assert entry.community_id == "hed" + + def test_skips_when_handler_logged(self, test_app): + """Middleware skips logging when handler set metrics_logged=True.""" + app, mock_log = test_app + client = TestClient(app) + response = client.post("/hed/streaming", json={}) + assert response.status_code == 200 + assert not mock_log.called diff --git a/tests/test_metrics/test_queries.py b/tests/test_metrics/test_queries.py new file mode 100644 index 0000000..ae71c40 --- /dev/null +++ b/tests/test_metrics/test_queries.py @@ -0,0 +1,303 @@ +"""Tests for metrics aggregation queries.""" + +import pytest + +from src.metrics.db import RequestLogEntry, get_metrics_connection, init_metrics_db, log_request + + +@pytest.fixture +def populated_db(tmp_path): + """Create a metrics DB with sample data for query testing.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="r1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + estimated_cost=0.001, + tools_called=["search_docs"], + key_source="platform", + ), + RequestLogEntry( + request_id="r2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/chat", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=200, + output_tokens=100, + total_tokens=300, + estimated_cost=0.002, + tools_called=["search_docs", "validate_hed"], + key_source="byok", + ), + RequestLogEntry( + request_id="r3", + timestamp="2025-01-16T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=150.0, + status_code=500, + model="openai/gpt-4o", + input_tokens=50, + output_tokens=0, + total_tokens=50, + key_source="byok", + ), + RequestLogEntry( + request_id="r4", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/bids/ask", + method="POST", + community_id="bids", + duration_ms=250.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=80, + output_tokens=40, + total_tokens=120, + estimated_cost=0.0008, + tools_called=["search_docs"], + key_source="platform", + ), + RequestLogEntry( + request_id="r5", + timestamp="2025-01-15T09:00:00+00:00", + endpoint="/health", + method="GET", + duration_ms=5.0, + status_code=200, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +class TestGetCommunitySummary: + """Tests for get_community_summary().""" + + def test_returns_correct_totals(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + assert result["community_id"] == "hed" + assert result["total_requests"] == 3 + assert result["total_input_tokens"] == 350 + assert result["total_output_tokens"] == 150 + assert result["total_tokens"] == 500 + finally: + conn.close() + + def test_error_rate(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + # 1 error out of 3 requests + assert abs(result["error_rate"] - 0.3333) < 0.01 + finally: + conn.close() + + def test_top_models(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + models = {m["model"]: m["count"] for m in result["top_models"]} + assert models["qwen/qwen3-235b"] == 2 + assert models["openai/gpt-4o"] == 1 + finally: + conn.close() + + def test_top_tools(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + tools = {t["tool"]: t["count"] for t in result["top_tools"]} + assert tools["search_docs"] == 2 + assert tools["validate_hed"] == 1 + finally: + conn.close() + + def test_empty_community(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("nonexistent", conn) + assert result["total_requests"] == 0 + assert result["total_tokens"] == 0 + assert result["error_rate"] == 0.0 + finally: + conn.close() + + +class TestGetUsageStats: + """Tests for get_usage_stats().""" + + def test_daily_bucketing(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + result = get_usage_stats("hed", "daily", conn) + assert result["period"] == "daily" + assert result["community_id"] == "hed" + buckets = {b["bucket"]: b for b in result["buckets"]} + assert "2025-01-15" in buckets + assert "2025-01-16" in buckets + assert buckets["2025-01-15"]["requests"] == 2 + assert buckets["2025-01-16"]["requests"] == 1 + finally: + conn.close() + + def test_monthly_bucketing(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + result = get_usage_stats("hed", "monthly", conn) + buckets = result["buckets"] + assert len(buckets) == 1 + assert buckets[0]["bucket"] == "2025-01" + assert buckets[0]["requests"] == 3 + finally: + conn.close() + + def test_invalid_period_raises(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + with pytest.raises(ValueError, match="Invalid period"): + get_usage_stats("hed", "hourly", conn) + finally: + conn.close() + + def test_empty_community_returns_empty_buckets(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + result = get_usage_stats("nonexistent", "daily", conn) + assert result["buckets"] == [] + finally: + conn.close() + + +class TestGetOverview: + """Tests for get_overview().""" + + def test_overview_totals(self, populated_db): + from src.metrics.queries import get_overview + + conn = get_metrics_connection(populated_db) + try: + result = get_overview(conn) + # 5 total entries + assert result["total_requests"] == 5 + assert result["total_tokens"] == 620 # 150+300+50+120+0 + finally: + conn.close() + + def test_overview_communities(self, populated_db): + from src.metrics.queries import get_overview + + conn = get_metrics_connection(populated_db) + try: + result = get_overview(conn) + communities = {c["community_id"]: c for c in result["communities"]} + assert "hed" in communities + assert "bids" in communities + assert communities["hed"]["requests"] == 3 + assert communities["bids"]["requests"] == 1 + finally: + conn.close() + + def test_error_rate_in_overview(self, populated_db): + from src.metrics.queries import get_overview + + conn = get_metrics_connection(populated_db) + try: + result = get_overview(conn) + # 1 error out of 5 requests + assert abs(result["error_rate"] - 0.2) < 0.01 + finally: + conn.close() + + +class TestGetTokenBreakdown: + """Tests for get_token_breakdown().""" + + def test_by_model(self, populated_db): + from src.metrics.queries import get_token_breakdown + + conn = get_metrics_connection(populated_db) + try: + result = get_token_breakdown(conn) + models = {m["model"]: m for m in result["by_model"]} + assert "qwen/qwen3-235b" in models + assert models["qwen/qwen3-235b"]["requests"] == 3 + finally: + conn.close() + + def test_by_key_source(self, populated_db): + from src.metrics.queries import get_token_breakdown + + conn = get_metrics_connection(populated_db) + try: + result = get_token_breakdown(conn) + sources = {s["key_source"]: s for s in result["by_key_source"]} + assert "platform" in sources + assert "byok" in sources + assert sources["platform"]["requests"] == 2 + assert sources["byok"]["requests"] == 2 + finally: + conn.close() + + def test_filter_by_community(self, populated_db): + from src.metrics.queries import get_token_breakdown + + conn = get_metrics_connection(populated_db) + try: + result = get_token_breakdown(conn, community_id="bids") + assert result["community_id"] == "bids" + assert len(result["by_model"]) == 1 + assert result["by_model"][0]["model"] == "qwen/qwen3-235b" + finally: + conn.close() + + def test_empty_db_returns_empty(self, tmp_path): + from src.metrics.queries import get_token_breakdown + + db_path = tmp_path / "empty.db" + init_metrics_db(db_path) + conn = get_metrics_connection(db_path) + try: + result = get_token_breakdown(conn) + assert result["by_model"] == [] + assert result["by_key_source"] == [] + finally: + conn.close() From 85bc5f38d8ecb8dde978f5ee100d52b24e80f4b7 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 2 Feb 2026 21:36:41 -0800 Subject: [PATCH 09/26] Add backend database inspection docs to CLAUDE.md Knowledge databases live inside Docker containers, not locally. Added SSH + docker exec examples for listing tables, querying docstrings, and searching symbols to avoid wasted debugging time. --- CLAUDE.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 5e043a5..8d679b4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -239,6 +239,38 @@ docker exec osa-dev python -m src.cli.main sync github --full - Dev: https://develop.osa-demo.pages.dev - Prod: https://osa-demo.pages.dev +### Inspecting Knowledge Databases + +Knowledge databases (SQLite) live **inside the Docker containers**, not locally. +Do NOT look for `.db` files in the local repo; they won't be there. + +```bash +# List databases in a container +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa find /app/data/knowledge -name '*.db'" + +# Containers: osa (prod), osa-dev (dev) +# Database paths: /app/data/knowledge/{community_id}.db +# e.g., /app/data/knowledge/eeglab.db, /app/data/knowledge/hed.db + +# List tables (no sqlite3 binary; use python) +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa python3 -c 'import sqlite3; conn = sqlite3.connect(\"/app/data/knowledge/eeglab.db\"); print([r[0] for r in conn.execute(\"SELECT name FROM sqlite_master WHERE type=\\\"table\\\"\")]); conn.close()'" + +# Query example: count docstrings +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa python3 -c 'import sqlite3; conn = sqlite3.connect(\"/app/data/knowledge/eeglab.db\"); print(conn.execute(\"SELECT COUNT(*) FROM docstrings\").fetchone()[0]); conn.close()'" + +# Query example: search for a symbol +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa python3 -c 'import sqlite3; conn = sqlite3.connect(\"/app/data/knowledge/eeglab.db\"); [print(r) for r in conn.execute(\"SELECT symbol_name, file_path FROM docstrings WHERE symbol_name LIKE \\\"%erpimage%\\\"\").fetchall()]; conn.close()'" +``` + +**Important notes:** +- `sqlite3` CLI is not installed in containers; use `python3 -c` with the `sqlite3` module +- Use `ssh -o "RequestTTY=no"` to avoid interactive shell banners +- Dev and prod databases may differ; always check the right container + ## References - **API structure**: `.context/api-structure.md` (read first for API work) From 0253e54e5e3dbb49eecd2639f76f75c75b2351fc Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Mon, 2 Feb 2026 21:45:51 -0800 Subject: [PATCH 10/26] Phase 2: Dashboard frontend with public metrics and community tabs (#140) * Add dashboard frontend with public metrics endpoints - Add public query functions (no tokens/costs/models exposed) - Create /metrics/public/* endpoints (no auth required) - Build /dashboard page with Chart.js, community tabs, admin unlock - Register new routers in main.py - Add tests for public endpoints and dashboard page (28 new tests) * Restructure dashboard as standalone static site - Move per-community public metrics to community router (/{community_id}/metrics/public, /{community_id}/metrics/public/usage) - Keep only global /metrics/public/overview in metrics_public router - Remove FastAPI dashboard router - Add dashboard/ as standalone static site for Cloudflare Pages: / = aggregate overview, /{community} = community detail - Client-side routing with configurable API base URL - Add _redirects for Cloudflare Pages SPA routing - Update tests for new route structure * Add CI workflow for dashboard Cloudflare Pages deploy Deploys dashboard/ to osa-dash.pages.dev via wrangler. Same pattern as existing deploy-pages.yml for the demo widget: - main -> osa-dash.pages.dev (production) - develop -> develop.osa-dash.pages.dev - PRs -> {branch}.osa-dash.pages.dev with preview URL comment * Add dynamic community tab bar to dashboard Tabs are populated from /metrics/public/overview API so new communities appear automatically. Navigation uses simple links (All -> /, community -> /{id}) with active tab highlighting. * Address PR review findings: XSS, error handling, tests - Fix XSS: add escapeHtml() helper, sanitize all innerHTML interpolations, use encodeURIComponent() for URL path segments - Move get_metrics_connection() inside try blocks in all metrics endpoints - Add console.error/warn to all JavaScript catch blocks (no silent failures) - Improve admin section UX: defer visibility until data loads successfully - Extract shared helpers in queries.py (_count_tools, _validate_period) - Add test classes: TestPublicAdminBoundary, TestEmptyDatabase, TestCommunityMetricsValues with dynamic community cross-checks - Fix admin boundary tests to use auth_env fixture with test API key * Address round-2 review: XSS, error logging, auth tests, simplify - Fix single-quote XSS in onclick handlers: use encodeURIComponent for communityId in changePeriod calls, decode in changePeriod - Validate health status against known values instead of escapeHtml - Add console.warn to sync/health .catch() blocks - Add console.error to loadCommunityView catch block - Add auth-enabled tests proving public endpoints stay accessible when REQUIRE_API_AUTH=true (core security contract) - Add metrics_connection() context manager in db.py; simplify all endpoint handlers from nested try/try/finally to with-statement - Use tuple unpacking in _count_tools for clarity --- .github/workflows/deploy-dashboard.yml | 104 ++++ dashboard/_redirects | 1 + dashboard/index.html | 762 +++++++++++++++++++++++++ src/api/main.py | 13 +- src/api/routers/__init__.py | 8 +- src/api/routers/community.py | 72 ++- src/api/routers/metrics_public.py | 39 ++ src/metrics/db.py | 18 + src/metrics/queries.py | 229 ++++++-- tests/test_api/test_dashboard.py | 87 +++ tests/test_api/test_metrics_public.py | 395 +++++++++++++ 11 files changed, 1674 insertions(+), 54 deletions(-) create mode 100644 .github/workflows/deploy-dashboard.yml create mode 100644 dashboard/_redirects create mode 100644 dashboard/index.html create mode 100644 src/api/routers/metrics_public.py create mode 100644 tests/test_api/test_dashboard.py create mode 100644 tests/test_api/test_metrics_public.py diff --git a/.github/workflows/deploy-dashboard.yml b/.github/workflows/deploy-dashboard.yml new file mode 100644 index 0000000..cc4ad16 --- /dev/null +++ b/.github/workflows/deploy-dashboard.yml @@ -0,0 +1,104 @@ +name: Deploy Dashboard to Cloudflare Pages + +on: + push: + branches: + - main + - develop + paths: + - 'dashboard/**' + - '.github/workflows/deploy-dashboard.yml' + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'dashboard/**' + - '.github/workflows/deploy-dashboard.yml' + workflow_dispatch: + +jobs: + deploy: + runs-on: ubuntu-latest + permissions: + contents: read + deployments: write + pull-requests: write + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Get branch name + id: branch + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + BRANCH_NAME="${{ github.head_ref }}" + else + BRANCH_NAME="${{ github.ref_name }}" + fi + # Sanitize branch name for URL (replace / with -, lowercase) + SANITIZED=$(echo "$BRANCH_NAME" | tr '/' '-' | tr '[:upper:]' '[:lower:]') + # Cloudflare Pages truncates branch names to ~28 chars for preview URLs + TRUNCATED=$(echo "$SANITIZED" | cut -c1-28) + echo "name=$TRUNCATED" >> $GITHUB_OUTPUT + echo "original=$BRANCH_NAME" >> $GITHUB_OUTPUT + + - name: Deploy to Cloudflare Pages + id: deploy + uses: cloudflare/wrangler-action@v3 + with: + apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} + accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }} + # Deploy with branch name so Cloudflare Pages routes correctly: + # - main branch -> osa-dash.pages.dev (production) + # - develop branch -> develop.osa-dash.pages.dev (preview) + # - PR branches -> {branch}.osa-dash.pages.dev (preview) + command: pages deploy dashboard --project-name=osa-dash --branch=${{ steps.branch.outputs.name }} + + - name: Comment on PR with preview URL + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + const branch = '${{ steps.branch.outputs.name }}'; + const previewUrl = `https://${branch}.osa-dash.pages.dev`; + + // Find existing comment from this workflow + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes('') + ); + + const sha = '${{ github.sha }}'.substring(0, 7); + const body = [ + '', + '## Dashboard Preview', + '', + '| Name | Link |', + '|------|------|', + `| **Preview URL** | ${previewUrl} |`, + '| **Branch** | \`${{ steps.branch.outputs.original }}\` |', + `| **Commit** | \`${sha}\` |`, + '', + 'This preview will be updated automatically when you push new commits.' + ].join('\n'); + + if (botComment) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: body + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: body + }); + } diff --git a/dashboard/_redirects b/dashboard/_redirects new file mode 100644 index 0000000..bbb3e7a --- /dev/null +++ b/dashboard/_redirects @@ -0,0 +1 @@ +/* /index.html 200 diff --git a/dashboard/index.html b/dashboard/index.html new file mode 100644 index 0000000..7147083 --- /dev/null +++ b/dashboard/index.html @@ -0,0 +1,762 @@ + + + + + + Open Science Assistant - Dashboard + + + + +
+ +
+
Loading...
+ + + +
+ + + + diff --git a/src/api/main.py b/src/api/main.py index 2dcc139..9d6566a 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -14,7 +14,12 @@ from pydantic import BaseModel from src.api.config import get_settings -from src.api.routers import create_community_router, metrics_router, sync_router +from src.api.routers import ( + create_community_router, + metrics_public_router, + metrics_router, + sync_router, +) from src.api.routers.health import router as health_router from src.api.routers.widget_test import router as widget_test_router from src.api.scheduler import start_scheduler, stop_scheduler @@ -194,8 +199,9 @@ def register_routes(app: FastAPI) -> None: # Sync router (not community-specific) app.include_router(sync_router) - # Metrics router (global metrics endpoints) + # Metrics routers (admin + public) app.include_router(metrics_router) + app.include_router(metrics_public_router) # Health check router app.include_router(health_router) @@ -239,6 +245,8 @@ async def root() -> dict[str, Any]: endpoints[f"GET /{community_id}/sessions"] = f"List active {name} sessions" endpoints[f"GET /{community_id}/sessions/{{session_id}}"] = "Get session info" endpoints[f"DELETE /{community_id}/sessions/{{session_id}}"] = "Delete a session" + endpoints[f"GET /{community_id}/metrics/public"] = f"Public {name} metrics" + endpoints[f"GET /{community_id}/metrics/public/usage"] = f"Public {name} usage stats" # Add non-community endpoints endpoints["GET /sync/status"] = "Knowledge sync status" @@ -246,6 +254,7 @@ async def root() -> dict[str, Any]: endpoints["POST /sync/trigger"] = "Trigger sync (requires API key)" endpoints["GET /metrics/overview"] = "Metrics overview (requires admin key)" endpoints["GET /metrics/tokens"] = "Token breakdown (requires admin key)" + endpoints["GET /metrics/public/overview"] = "Public metrics overview" endpoints["GET /health"] = "Health check" return { diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py index c2a7692..39f85ab 100644 --- a/src/api/routers/__init__.py +++ b/src/api/routers/__init__.py @@ -2,6 +2,12 @@ from src.api.routers.community import create_community_router from src.api.routers.metrics import router as metrics_router +from src.api.routers.metrics_public import router as metrics_public_router from src.api.routers.sync import router as sync_router -__all__ = ["create_community_router", "metrics_router", "sync_router"] +__all__ = [ + "create_community_router", + "metrics_public_router", + "metrics_router", + "sync_router", +] diff --git a/src/api/routers/community.py b/src/api/routers/community.py index 4062ffd..fadf738 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -7,6 +7,7 @@ import hashlib import json import logging +import sqlite3 import time import uuid from collections.abc import AsyncGenerator @@ -30,11 +31,16 @@ RequestLogEntry, extract_token_usage, extract_tool_names, - get_metrics_connection, log_request, + metrics_connection, now_iso, ) -from src.metrics.queries import get_community_summary, get_usage_stats +from src.metrics.queries import ( + get_community_summary, + get_public_community_summary, + get_public_usage_stats, + get_usage_stats, +) logger = logging.getLogger(__name__) @@ -1054,19 +1060,15 @@ async def get_community_config() -> CommunityConfigResponse: @router.get("/metrics") async def community_metrics(_auth: RequireAdminAuth) -> dict[str, Any]: """Get metrics summary for this community. Requires admin auth.""" - import sqlite3 - - conn = get_metrics_connection() try: - return get_community_summary(community_id, conn) + with metrics_connection() as conn: + return get_community_summary(community_id, conn) except sqlite3.Error: logger.exception("Failed to query metrics for community %s", community_id) raise HTTPException( status_code=503, detail="Metrics database is temporarily unavailable.", ) - finally: - conn.close() @router.get("/metrics/usage") async def community_usage( @@ -1078,11 +1080,9 @@ async def community_usage( ), ) -> dict[str, Any]: """Get time-bucketed usage stats for this community. Requires admin auth.""" - import sqlite3 - - conn = get_metrics_connection() try: - return get_usage_stats(community_id, period, conn) + with metrics_connection() as conn: + return get_usage_stats(community_id, period, conn) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) from e except sqlite3.Error: @@ -1091,8 +1091,52 @@ async def community_usage( status_code=503, detail="Metrics database is temporarily unavailable.", ) - finally: - conn.close() + + # ----------------------------------------------------------------------- + # Per-community Public Metrics Endpoints (no auth required) + # ----------------------------------------------------------------------- + + @router.get("/metrics/public") + async def community_metrics_public() -> dict[str, Any]: + """Get public metrics summary for this community. + + Returns request counts, error rate, and top tools. + No tokens, costs, or model information exposed. + """ + try: + with metrics_connection() as conn: + return get_public_community_summary(community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query public metrics for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + @router.get("/metrics/public/usage") + async def community_usage_public( + period: str = Query( + default="daily", + description="Time bucket period", + pattern="^(daily|weekly|monthly)$", + ), + ) -> dict[str, Any]: + """Get public time-bucketed usage stats for this community. + + Returns request counts and errors per time bucket. + No tokens or costs exposed. + """ + try: + with metrics_connection() as conn: + return get_public_usage_stats(community_id, period, conn) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except sqlite3.Error: + logger.exception("Failed to query public usage stats for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) return router diff --git a/src/api/routers/metrics_public.py b/src/api/routers/metrics_public.py new file mode 100644 index 0000000..2aaf1f9 --- /dev/null +++ b/src/api/routers/metrics_public.py @@ -0,0 +1,39 @@ +"""Public metrics API endpoints. + +Exposes non-sensitive aggregate metrics (request counts, error rates) +without authentication. No tokens, costs, or model information. + +Per-community public metrics are served from the community router +at /{community_id}/metrics/public. +""" + +import logging +import sqlite3 +from typing import Any + +from fastapi import APIRouter, HTTPException + +from src.metrics.db import metrics_connection +from src.metrics.queries import get_public_overview + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/metrics/public", tags=["Public Metrics"]) + + +@router.get("/overview") +async def public_overview() -> dict[str, Any]: + """Get public metrics overview across all communities. + + Returns total requests, error rate, active community count, + and per-community request counts. No tokens, costs, or model info. + """ + try: + with metrics_connection() as conn: + return get_public_overview(conn) + except sqlite3.Error: + logger.exception("Failed to query metrics database for public overview") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) diff --git a/src/metrics/db.py b/src/metrics/db.py index f716046..c72842f 100644 --- a/src/metrics/db.py +++ b/src/metrics/db.py @@ -7,6 +7,8 @@ import json import logging import sqlite3 +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path @@ -93,6 +95,22 @@ def get_metrics_connection(db_path: Path | None = None) -> sqlite3.Connection: return conn +@contextmanager +def metrics_connection(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]: + """Context manager for a metrics database connection. + + Ensures the connection is closed after use, even if an exception occurs. + + Args: + db_path: Optional path override (for testing). + """ + conn = get_metrics_connection(db_path) + try: + yield conn + finally: + conn.close() + + def init_metrics_db(db_path: Path | None = None) -> None: """Initialize the metrics database schema. Idempotent. diff --git a/src/metrics/queries.py b/src/metrics/queries.py index ec67e0a..6281639 100644 --- a/src/metrics/queries.py +++ b/src/metrics/queries.py @@ -11,6 +11,54 @@ logger = logging.getLogger(__name__) +# SQLite strftime patterns for time-bucketed queries +_PERIOD_FORMAT_MAP = { + "daily": "%Y-%m-%d", + "weekly": "%Y-W%W", + "monthly": "%Y-%m", +} + + +def _validate_period(period: str) -> str: + """Validate and return the strftime format for a period. + + Raises ValueError if period is not one of: daily, weekly, monthly. + """ + if period not in _PERIOD_FORMAT_MAP: + raise ValueError(f"Invalid period: {period}. Must be one of: daily, weekly, monthly") + return _PERIOD_FORMAT_MAP[period] + + +def _count_tools( + community_id: str, conn: sqlite3.Connection, limit: int = 5 +) -> list[dict[str, Any]]: + """Count tool usage from JSON arrays in the tools_called column. + + Returns top tools sorted by count descending. + """ + tool_rows = conn.execute( + """ + SELECT tools_called + FROM request_log + WHERE community_id = ? AND tools_called IS NOT NULL + """, + (community_id,), + ).fetchall() + tool_counts: dict[str, int] = {} + for tr in tool_rows: + try: + tools = json.loads(tr["tools_called"]) + for tool in tools: + tool_counts[tool] = tool_counts.get(tool, 0) + 1 + except (json.JSONDecodeError, TypeError): + logger.warning( + "Malformed tools_called data in request_log for community %s: %r", + community_id, + tr["tools_called"], + ) + top = sorted(tool_counts.items(), key=lambda x: x[1], reverse=True)[:limit] + return [{"tool": name, "count": count} for name, count in top] + def get_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: """Get summary statistics for a single community. @@ -20,8 +68,9 @@ def get_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[s conn: SQLite connection (with row_factory=sqlite3.Row). Returns: - Dict with total_requests, total_tokens, avg_duration_ms, - error_rate, top_models, top_tools. + Dict with total_requests, total_input_tokens, total_output_tokens, + total_tokens, avg_duration_ms, total_estimated_cost, error_rate, + top_models, top_tools. """ row = conn.execute( """ @@ -55,29 +104,6 @@ def get_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[s (community_id,), ).fetchall() - # Top tools (from JSON array in tools_called column) - tool_rows = conn.execute( - """ - SELECT tools_called - FROM request_log - WHERE community_id = ? AND tools_called IS NOT NULL - """, - (community_id,), - ).fetchall() - tool_counts: dict[str, int] = {} - for tr in tool_rows: - try: - tools = json.loads(tr["tools_called"]) - for tool in tools: - tool_counts[tool] = tool_counts.get(tool, 0) + 1 - except (json.JSONDecodeError, TypeError): - logger.warning( - "Malformed tools_called data in request_log for community %s: %r", - community_id, - tr["tools_called"], - ) - top_tools = sorted(tool_counts.items(), key=lambda x: x[1], reverse=True)[:5] - return { "community_id": community_id, "total_requests": total, @@ -88,7 +114,7 @@ def get_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[s "total_estimated_cost": round(row["total_estimated_cost"], 4), "error_rate": round(error_rate, 4), "top_models": [{"model": r["model"], "count": r["count"]} for r in model_rows], - "top_tools": [{"tool": t[0], "count": t[1]} for t in top_tools], + "top_tools": _count_tools(community_id, conn), } @@ -107,18 +133,9 @@ def get_usage_stats( Returns: Dict with period, community_id, and buckets list. """ - # SQLite strftime patterns for bucketing - format_map = { - "daily": "%Y-%m-%d", - "weekly": "%Y-W%W", - "monthly": "%Y-%m", - } - if period not in format_map: - raise ValueError(f"Invalid period: {period}. Must be one of: daily, weekly, monthly") - - fmt = format_map[period] + fmt = _validate_period(period) - # Safe to use f-string: fmt is from a hardcoded whitelist, not user input + # Safe to use f-string: fmt is from _PERIOD_FORMAT_MAP whitelist, not user input rows = conn.execute( f""" SELECT @@ -288,3 +305,141 @@ def get_token_breakdown( for r in by_key_source ], } + + +# --------------------------------------------------------------------------- +# Public query functions (no tokens, costs, or model info) +# --------------------------------------------------------------------------- + + +def get_public_overview(conn: sqlite3.Connection) -> dict[str, Any]: + """Get public metrics overview with only non-sensitive data. + + Returns request counts and error rates; no tokens, costs, or model info. + + Args: + conn: SQLite connection. + + Returns: + Dict with total_requests, error_rate, communities_active, + and per-community request counts. + """ + totals = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as total_errors + FROM request_log + """ + ).fetchone() + + total_req = totals["total_requests"] + + community_rows = conn.execute( + """ + SELECT + community_id, + COUNT(*) as requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors + FROM request_log + WHERE community_id IS NOT NULL + GROUP BY community_id + ORDER BY requests DESC + """ + ).fetchall() + + return { + "total_requests": total_req, + "error_rate": round(totals["total_errors"] / total_req, 4) if total_req > 0 else 0.0, + "communities_active": len(community_rows), + "communities": [ + { + "community_id": r["community_id"], + "requests": r["requests"], + "error_rate": round(r["errors"] / r["requests"], 4) if r["requests"] > 0 else 0.0, + } + for r in community_rows + ], + } + + +def get_public_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: + """Get public summary for a single community. + + Returns request counts and top tools; no tokens, costs, or model info. + + Args: + community_id: The community identifier. + conn: SQLite connection. + + Returns: + Dict with community_id, total_requests, error_rate, top_tools. + """ + row = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_count + FROM request_log + WHERE community_id = ? + """, + (community_id,), + ).fetchone() + + total = row["total_requests"] + error_rate = row["error_count"] / total if total > 0 else 0.0 + + return { + "community_id": community_id, + "total_requests": total, + "error_rate": round(error_rate, 4), + "top_tools": _count_tools(community_id, conn), + } + + +def get_public_usage_stats( + community_id: str, + period: str, + conn: sqlite3.Connection, +) -> dict[str, Any]: + """Get public time-bucketed usage statistics. + + Returns request counts and errors per bucket; no tokens or costs. + + Args: + community_id: The community identifier. + period: One of "daily", "weekly", "monthly". + conn: SQLite connection. + + Returns: + Dict with period, community_id, and buckets list. + """ + fmt = _validate_period(period) + + # Safe to use f-string: fmt is from _PERIOD_FORMAT_MAP whitelist, not user input + rows = conn.execute( + f""" + SELECT + strftime('{fmt}', timestamp) as bucket, + COUNT(*) as requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors + FROM request_log + WHERE community_id = ? + GROUP BY bucket + ORDER BY bucket + """, + (community_id,), + ).fetchall() + + return { + "community_id": community_id, + "period": period, + "buckets": [ + { + "bucket": r["bucket"], + "requests": r["requests"], + "errors": r["errors"], + } + for r in rows + ], + } diff --git a/tests/test_api/test_dashboard.py b/tests/test_api/test_dashboard.py new file mode 100644 index 0000000..d15b00d --- /dev/null +++ b/tests/test_api/test_dashboard.py @@ -0,0 +1,87 @@ +"""Tests for the dashboard static HTML page. + +The dashboard is a standalone static site in dashboard/index.html, +deployed separately to Cloudflare Pages. These tests verify the HTML +contains the expected structure and API references. +""" + +from pathlib import Path + +DASHBOARD_HTML_PATH = Path(__file__).parent.parent.parent / "dashboard" / "index.html" + + +class TestDashboardHTML: + """Tests for dashboard/index.html static file.""" + + def test_file_exists(self) -> None: + assert DASHBOARD_HTML_PATH.exists(), "dashboard/index.html must exist" + + def test_is_valid_html(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "" in content + assert "" in content + + def test_contains_page_title(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "Open Science Assistant" in content + + def test_contains_chart_js_cdn(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "chart.js" in content + + def test_references_public_overview_api(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "/metrics/public/overview" in content + + def test_references_community_public_metrics_api(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + # Should use /{community}/metrics/public pattern + assert "/metrics/public" in content + assert "/metrics/public/usage" in content + + def test_has_client_side_routing(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "getRoute" in content + assert "window.location.pathname" in content + + def test_has_aggregate_view(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "renderAggregateView" in content + assert "Questions Answered" in content + + def test_has_community_view(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "loadCommunityView" in content + + def test_has_tab_bar(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "tabBar" in content + assert "tab-link" in content + assert "renderTabs" in content + + def test_has_admin_key_input(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "adminKeyInput" in content + assert "Admin Access" in content + + def test_admin_section_hidden_by_default(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "admin-section" in content + assert "display: none" in content or "display:none" in content + + def test_has_period_toggle(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "changePeriod" in content + assert "daily" in content + assert "weekly" in content + assert "monthly" in content + + def test_api_base_configurable(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + # Should support ?api= query param or window.OSA_API_BASE override + assert "OSA_API_BASE" in content + + def test_cloudflare_redirects_file_exists(self) -> None: + redirects_path = DASHBOARD_HTML_PATH.parent / "_redirects" + assert redirects_path.exists(), "_redirects needed for Cloudflare Pages SPA routing" diff --git a/tests/test_api/test_metrics_public.py b/tests/test_api/test_metrics_public.py new file mode 100644 index 0000000..8f8f427 --- /dev/null +++ b/tests/test_api/test_metrics_public.py @@ -0,0 +1,395 @@ +"""Tests for public metrics API endpoints. + +Global overview: GET /metrics/public/overview (no auth) +Per-community: GET /{community_id}/metrics/public (no auth) +Per-community usage: GET /{community_id}/metrics/public/usage (no auth) +""" + +import os +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from src.api.main import app +from src.metrics.db import ( + RequestLogEntry, + init_metrics_db, + log_request, +) + + +@pytest.fixture +def metrics_db(tmp_path): + """Create isolated metrics database with sample data.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="r1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + estimated_cost=0.001, + tools_called=["search_docs", "validate_hed"], + key_source="platform", + ), + RequestLogEntry( + request_id="r2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/chat", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=200, + output_tokens=100, + total_tokens=300, + tools_called=["search_docs"], + key_source="byok", + ), + RequestLogEntry( + request_id="r3", + timestamp="2025-01-16T09:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=100.0, + status_code=500, + ), + RequestLogEntry( + request_id="r4", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/bids/ask", + method="POST", + community_id="bids", + duration_ms=250.0, + status_code=200, + model="anthropic/claude-sonnet", + input_tokens=150, + output_tokens=75, + total_tokens=225, + key_source="platform", + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +@pytest.fixture +def isolated_metrics(metrics_db): + """Patch metrics DB path for all metrics code.""" + with patch("src.metrics.db.get_metrics_db_path", return_value=metrics_db): + yield metrics_db + + +@pytest.fixture +def noauth_env(): + """Disable auth requirement.""" + from src.api.config import get_settings + + os.environ["REQUIRE_API_AUTH"] = "false" + get_settings.cache_clear() + yield + del os.environ["REQUIRE_API_AUTH"] + get_settings.cache_clear() + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +class TestPublicOverview: + """Tests for GET /metrics/public/overview.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_returns_200_without_auth(self, client): + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_response_structure(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + assert "total_requests" in data + assert "error_rate" in data + assert "communities_active" in data + assert "communities" in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_no_sensitive_fields(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + assert "total_tokens" not in data + assert "total_estimated_cost" not in data + assert "avg_duration_ms" not in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_request_counts_correct(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + assert data["total_requests"] == 4 + assert data["communities_active"] == 2 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_error_rate_includes_errors(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + # 1 error out of 4 requests = 0.25 + assert data["error_rate"] == 0.25 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_community_breakdown_no_sensitive_fields(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + for community in data["communities"]: + assert "community_id" in community + assert "requests" in community + assert "error_rate" in community + assert "tokens" not in community + assert "estimated_cost" not in community + + +class TestCommunityPublicMetrics: + """Tests for GET /{community_id}/metrics/public.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_returns_200_without_auth(self, client): + response = client.get("/hed/metrics/public") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics/public") + data = response.json() + assert data["community_id"] == "hed" + assert "total_requests" in data + assert "error_rate" in data + assert "top_tools" in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_no_sensitive_fields(self, client): + response = client.get("/hed/metrics/public") + data = response.json() + assert "total_tokens" not in data + assert "total_estimated_cost" not in data + assert "top_models" not in data + assert "avg_duration_ms" not in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_top_tools_populated(self, client): + response = client.get("/hed/metrics/public") + data = response.json() + tools = {t["tool"]: t["count"] for t in data["top_tools"]} + assert "search_docs" in tools + assert tools["search_docs"] == 2 + + +class TestCommunityPublicUsage: + """Tests for GET /{community_id}/metrics/public/usage.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_daily_usage_returns_200(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_daily_usage_structure(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + data = response.json() + assert data["community_id"] == "hed" + assert data["period"] == "daily" + assert "buckets" in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_buckets_no_sensitive_fields(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + data = response.json() + for bucket in data["buckets"]: + assert "bucket" in bucket + assert "requests" in bucket + assert "errors" in bucket + assert "tokens" not in bucket + assert "estimated_cost" not in bucket + assert "avg_duration_ms" not in bucket + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_monthly_usage(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "monthly"}) + assert response.status_code == 200 + data = response.json() + assert data["period"] == "monthly" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_weekly_usage(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "weekly"}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_invalid_period_returns_422(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "hourly"}) + assert response.status_code == 422 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_default_period_is_daily(self, client): + response = client.get("/hed/metrics/public/usage") + assert response.status_code == 200 + data = response.json() + assert data["period"] == "daily" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_daily_buckets_count_and_errors(self, client): + """Verify bucket count and error values match fixture data.""" + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + data = response.json() + buckets = data["buckets"] + # Fixture has HED requests on 2025-01-15 and 2025-01-16 + assert len(buckets) == 2 + bucket_map = {b["bucket"]: b for b in buckets} + # 2025-01-16 has one request with status_code=500 + assert bucket_map["2025-01-16"]["errors"] == 1 + + +class TestPublicAdminBoundary: + """Verify public endpoints work without auth while admin endpoints require it.""" + + @pytest.fixture + def auth_env(self): + """Enable auth with a test API key so admin endpoints reject anonymous requests.""" + from src.api.config import get_settings + + os.environ["REQUIRE_API_AUTH"] = "true" + os.environ["API_KEYS"] = "test-secret-key" + get_settings.cache_clear() + yield + del os.environ["REQUIRE_API_AUTH"] + del os.environ["API_KEYS"] + get_settings.cache_clear() + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_public_overview_no_auth_200(self, client): + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_admin_overview_no_auth_rejected(self, client): + """Admin endpoint must reject unauthenticated requests.""" + response = client.get("/metrics/overview") + assert response.status_code in (401, 403) + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_admin_tokens_no_auth_rejected(self, client): + """Admin token endpoint must reject unauthenticated requests.""" + response = client.get("/metrics/tokens") + assert response.status_code in (401, 403) + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_public_overview_accessible_with_auth_enabled(self, client): + """Public overview must return 200 even when auth is required.""" + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_community_public_metrics_accessible_with_auth_enabled(self, client): + """Per-community public metrics must return 200 even when auth is required.""" + response = client.get("/hed/metrics/public") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_community_public_usage_accessible_with_auth_enabled(self, client): + """Per-community public usage must return 200 even when auth is required.""" + response = client.get("/hed/metrics/public/usage") + assert response.status_code == 200 + + +class TestEmptyDatabase: + """Verify public endpoints handle empty databases gracefully.""" + + @pytest.fixture + def empty_metrics_db(self, tmp_path): + db_path = tmp_path / "empty_metrics.db" + init_metrics_db(db_path) + return db_path + + @pytest.fixture + def isolated_empty_metrics(self, empty_metrics_db): + with patch("src.metrics.db.get_metrics_db_path", return_value=empty_metrics_db): + yield + + @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") + def test_overview_empty_db(self, client): + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + data = response.json() + assert data["total_requests"] == 0 + assert data["error_rate"] == 0.0 + assert data["communities_active"] == 0 + assert data["communities"] == [] + + @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") + def test_community_metrics_empty_db(self, client): + response = client.get("/hed/metrics/public") + assert response.status_code == 200 + data = response.json() + assert data["total_requests"] == 0 + assert data["error_rate"] == 0.0 + assert data["top_tools"] == [] + + @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") + def test_community_usage_empty_db(self, client): + response = client.get("/hed/metrics/public/usage") + assert response.status_code == 200 + data = response.json() + assert data["buckets"] == [] + + +class TestCommunityMetricsValues: + """Verify computed values per community match fixture data dynamically.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_community_values_from_overview(self, client): + """Check each community's request count and error rate from overview.""" + response = client.get("/metrics/public/overview") + data = response.json() + checked = 0 + for community in data["communities"]: + cid = community["community_id"] + resp = client.get(f"/{cid}/metrics/public") + if resp.status_code != 200: + continue # community route not registered in test app + detail = resp.json() + assert detail["total_requests"] == community["requests"] + assert detail["total_requests"] > 0 + assert detail["error_rate"] == community["error_rate"] + checked += 1 + assert checked > 0, "Expected at least one community with a registered route" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_per_community_tool_counts_consistent(self, client): + """Each tool count should be a positive integer.""" + response = client.get("/metrics/public/overview") + checked = 0 + for community in response.json()["communities"]: + cid = community["community_id"] + resp = client.get(f"/{cid}/metrics/public") + if resp.status_code != 200: + continue # community route not registered in test app + detail = resp.json() + for tool_entry in detail["top_tools"]: + assert isinstance(tool_entry["tool"], str) + assert tool_entry["count"] > 0 + checked += 1 + assert checked > 0, "Expected at least one community with a registered route" From 94bf3bbba0bf9c2eea973bcc6a36d6580b90b235 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Tue, 3 Feb 2026 09:12:15 -0800 Subject: [PATCH 11/26] Phase 3: Serve dashboard from /osa/ base path (#143) * Serve dashboard from /osa/ base path for status.osc.earth Move dashboard/index.html to dashboard/osa/index.html and add BASE_PATH constant to strip /osa prefix in client-side router. Update all internal links (tabs, community cards) to use the /osa/ prefix. Update _redirects for SPA routing under /osa/. Update dashboard tests for new file location. * Handle /osa without trailing slash in _redirects Add explicit /osa rule alongside /osa/* to ensure the path without trailing slash also serves the SPA index. --- .github/workflows/deploy-dashboard.yml | 3 ++- dashboard/_redirects | 4 +++- dashboard/{ => osa}/index.html | 14 +++++++++----- tests/test_api/test_dashboard.py | 15 ++++++++++----- 4 files changed, 24 insertions(+), 12 deletions(-) rename dashboard/{ => osa}/index.html (98%) diff --git a/.github/workflows/deploy-dashboard.yml b/.github/workflows/deploy-dashboard.yml index cc4ad16..0f62b97 100644 --- a/.github/workflows/deploy-dashboard.yml +++ b/.github/workflows/deploy-dashboard.yml @@ -48,9 +48,10 @@ jobs: apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }} # Deploy with branch name so Cloudflare Pages routes correctly: - # - main branch -> osa-dash.pages.dev (production) + # - main branch -> osa-dash.pages.dev (production), custom domain: status.osc.earth # - develop branch -> develop.osa-dash.pages.dev (preview) # - PR branches -> {branch}.osa-dash.pages.dev (preview) + # Dashboard is served at /osa/ path (e.g., status.osc.earth/osa/) command: pages deploy dashboard --project-name=osa-dash --branch=${{ steps.branch.outputs.name }} - name: Comment on PR with preview URL diff --git a/dashboard/_redirects b/dashboard/_redirects index bbb3e7a..57ede61 100644 --- a/dashboard/_redirects +++ b/dashboard/_redirects @@ -1 +1,3 @@ -/* /index.html 200 +/ /osa/ 301 +/osa /osa/index.html 200 +/osa/* /osa/index.html 200 diff --git a/dashboard/index.html b/dashboard/osa/index.html similarity index 98% rename from dashboard/index.html rename to dashboard/osa/index.html index 7147083..c08122b 100644 --- a/dashboard/index.html +++ b/dashboard/osa/index.html @@ -322,10 +322,14 @@

Admin Access

// ----------------------------------------------------------------------- // Router - read path to decide view // ----------------------------------------------------------------------- + const BASE_PATH = '/osa'; + function getRoute() { - const path = window.location.pathname.replace(/^\/+|\/+$/g, ''); + let path = window.location.pathname; + if (path.startsWith(BASE_PATH)) path = path.slice(BASE_PATH.length); + path = path.replace(/^\/+|\/+$/g, ''); if (!path) return { view: 'aggregate', community: null }; - return { view: 'community', community: path.split('/')[0] }; + return { view: 'community', community: decodeURIComponent(path.split('/')[0]) }; } document.addEventListener('DOMContentLoaded', async () => { @@ -364,11 +368,11 @@

Admin Access

const tabBar = document.getElementById('tabBar'); if (!communities || communities.length === 0) return; - let html = `All`; + let html = `All`; for (const c of communities) { const isActive = c.community_id === activeCommunity; const safe = escapeHtml(c.community_id); - html += `${safe.toUpperCase()}`; + html += `${safe.toUpperCase()}`; } tabBar.innerHTML = html; } @@ -389,7 +393,7 @@

Admin Access

? ((1 - c.error_rate) * 100).toFixed(1) : '0.0'; const safe = escapeHtml(c.community_id); return ` - +

${safe.toUpperCase()}

${c.requests.toLocaleString()} requests diff --git a/tests/test_api/test_dashboard.py b/tests/test_api/test_dashboard.py index d15b00d..cc14100 100644 --- a/tests/test_api/test_dashboard.py +++ b/tests/test_api/test_dashboard.py @@ -1,20 +1,20 @@ """Tests for the dashboard static HTML page. -The dashboard is a standalone static site in dashboard/index.html, +The dashboard is a standalone static site in dashboard/osa/index.html, deployed separately to Cloudflare Pages. These tests verify the HTML contains the expected structure and API references. """ from pathlib import Path -DASHBOARD_HTML_PATH = Path(__file__).parent.parent.parent / "dashboard" / "index.html" +DASHBOARD_HTML_PATH = Path(__file__).parent.parent.parent / "dashboard" / "osa" / "index.html" class TestDashboardHTML: - """Tests for dashboard/index.html static file.""" + """Tests for dashboard/osa/index.html static file.""" def test_file_exists(self) -> None: - assert DASHBOARD_HTML_PATH.exists(), "dashboard/index.html must exist" + assert DASHBOARD_HTML_PATH.exists(), "dashboard/osa/index.html must exist" def test_is_valid_html(self) -> None: content = DASHBOARD_HTML_PATH.read_text() @@ -82,6 +82,11 @@ def test_api_base_configurable(self) -> None: # Should support ?api= query param or window.OSA_API_BASE override assert "OSA_API_BASE" in content + def test_has_base_path_constant(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "BASE_PATH" in content + assert "const BASE_PATH = '/osa'" in content + def test_cloudflare_redirects_file_exists(self) -> None: - redirects_path = DASHBOARD_HTML_PATH.parent / "_redirects" + redirects_path = DASHBOARD_HTML_PATH.parent.parent / "_redirects" assert redirects_path.exists(), "_redirects needed for Cloudflare Pages SPA routing" From fb3fcf7b3cbabf10b131eaf28f831bc03ca9e432 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Tue, 3 Feb 2026 14:29:21 -0800 Subject: [PATCH 12/26] Phase 4: Auth, quality metrics, cost/budget alerting (#147) * Add per-community auth, quality metrics, and budget alerting Phase 4 of the community dashboard: per-community scoped authentication (AuthScope + community admin keys), LangFuse observability wiring, quality metrics (error rates, latency percentiles, tool call tracking), cost estimation with model pricing table, budget checking with configurable limits, and automated GitHub issue alerting when spend thresholds are exceeded. Includes scheduled budget check job (every 15 min) and sample budget configs for HED and EEGLAB communities. Tested: 1152 passed, 68% coverage. * Address PR review: simplify code and fix error handling - Fix _issue_exists to return True on error (prevent duplicate spam) - Simplify redundant exception tuples in alerts.py - Extract _require_community_access helper (4 duplicated blocks) - Extract parse_admin_keys method on Settings (3 duplicated parsers) - Strengthen AuthScope with Literal type, frozen, validation - Make BudgetStatus frozen (immutable snapshot) - Add BudgetConfig cross-field validation (daily <= monthly) - Share single DB connection in budget check loop - Add budget check failure escalation (matching sync pattern) - Split LangFuse except into ImportError vs Exception - Improve _migrate_columns to re-raise unexpected errors - Log warnings for malformed community_admin_keys entries - Bump unknown model fallback logging from debug to warning * Fix streaming metrics fields and add quality endpoint tests Add tool_call_count and langfuse_trace_id to streaming metrics logging so streaming requests capture the same quality data as non-streaming. Add 14 endpoint tests covering community quality, quality summary, and global quality API routes. * Address round 2 review: fix alerts, docstrings, tests, simplify - Fix _issue_exists to return None on failure instead of True, with warning-level logging when dedup check fails - Fix stale pricing date and deduplicate fallback branches - Extract shared _fetch_latency_percentiles helper in queries - Make BudgetConfig frozen (immutable after parsing) - Fix inaccurate docstrings: maintainers usage, _percentile method name, _migrate_columns idempotency, regex claims - Fix get_quality_summary docstring key name mismatch - Track per-community scheduler failures for critical alerting - Upgrade malformed config entry log from WARNING to ERROR - Add tests: AuthScope validation, BudgetConfig daily>monthly, community-scoped keys on global endpoints, dedup failure --- src/api/config.py | 46 +++ src/api/routers/community.py | 110 +++++++- src/api/routers/metrics.py | 63 ++++- src/api/scheduler.py | 94 +++++++ src/api/security.py | 86 +++++- src/assistants/eeglab/config.yaml | 11 + src/assistants/hed/config.yaml | 12 + src/core/config/community.py | 79 ++++++ src/metrics/alerts.py | 163 +++++++++++ src/metrics/budget.py | 115 ++++++++ src/metrics/cost.py | 67 +++++ src/metrics/db.py | 44 ++- src/metrics/middleware.py | 3 + src/metrics/queries.py | 170 ++++++++++++ tests/test_api/test_metrics_endpoints.py | 209 ++++++++++++++ tests/test_api/test_scoped_auth.py | 200 +++++++++++++ tests/test_core/test_config/test_community.py | 190 +++++++++++++ tests/test_metrics/test_alerts.py | 178 ++++++++++++ tests/test_metrics/test_budget.py | 243 ++++++++++++++++ tests/test_metrics/test_cost.py | 64 +++++ tests/test_metrics/test_quality_queries.py | 262 ++++++++++++++++++ 21 files changed, 2380 insertions(+), 29 deletions(-) create mode 100644 src/metrics/alerts.py create mode 100644 src/metrics/budget.py create mode 100644 src/metrics/cost.py create mode 100644 tests/test_api/test_scoped_auth.py create mode 100644 tests/test_metrics/test_alerts.py create mode 100644 tests/test_metrics/test_budget.py create mode 100644 tests/test_metrics/test_cost.py create mode 100644 tests/test_metrics/test_quality_queries.py diff --git a/src/api/config.py b/src/api/config.py index 77df6be..6de9bf3 100644 --- a/src/api/config.py +++ b/src/api/config.py @@ -1,5 +1,6 @@ """Configuration management for the OSA API.""" +import logging from functools import lru_cache from pydantic import Field @@ -7,6 +8,8 @@ from src.version import __version__ +logger = logging.getLogger(__name__) + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -56,6 +59,13 @@ class Settings(BaseSettings): ) require_api_auth: bool = Field(default=True, description="Require API key authentication") + # Per-community admin keys for scoped dashboard access + # Format: "community_id:key1,community_id:key2" (e.g., "hed:abc123,eeglab:xyz789") + community_admin_keys: str | None = Field( + default=None, + description="Per-community admin API keys (format: community_id:key,...)", + ) + # LLM Provider Settings (server defaults, can be overridden by BYOK) openrouter_api_key: str | None = Field(default=None, description="OpenRouter API key") openai_api_key: str | None = Field(default=None, description="OpenAI API key") @@ -127,6 +137,42 @@ class Settings(BaseSettings): description="Cron schedule for papers sync (default: weekly Sunday at 3am UTC)", ) + def parse_admin_keys(self) -> set[str]: + """Parse API_KEYS into a set of valid admin keys. + + Returns: + Set of valid API key strings. + """ + if not self.api_keys: + return set() + return {k.strip() for k in self.api_keys.split(",") if k.strip()} + + def parse_community_admin_keys(self) -> dict[str, set[str]]: + """Parse COMMUNITY_ADMIN_KEYS into {community_id: {keys}} mapping. + + Format: "community_id:key1,community_id:key2" + Multiple keys per community are supported. + + Returns: + Dict mapping community_id to set of valid API keys. + """ + if not self.community_admin_keys: + return {} + result: dict[str, set[str]] = {} + for entry in self.community_admin_keys.split(","): + entry = entry.strip() + if not entry: + continue + if ":" not in entry: + logger.error("Skipping malformed community_admin_keys entry (no ':'): %r", entry) + continue + community_id, key = entry.split(":", 1) + community_id = community_id.strip() + key = key.strip() + if community_id and key: + result.setdefault(community_id, set()).add(key) + return result + @lru_cache def get_settings() -> Settings: diff --git a/src/api/routers/community.py b/src/api/routers/community.py index fadf738..75ddc94 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -21,12 +21,13 @@ from pydantic import BaseModel, Field, field_validator from src.api.config import get_settings -from src.api.security import RequireAdminAuth, RequireAuth +from src.api.security import AuthScope, RequireAuth, RequireScopedAuth from src.assistants import registry from src.assistants.community import CommunityAssistant from src.assistants.community import PageContext as AgentPageContext from src.assistants.registry import AssistantInfo from src.core.services.litellm_llm import create_openrouter_llm +from src.metrics.cost import estimate_cost from src.metrics.db import ( RequestLogEntry, extract_token_usage, @@ -39,6 +40,8 @@ get_community_summary, get_public_community_summary, get_public_usage_stats, + get_quality_metrics, + get_quality_summary, get_usage_stats, ) @@ -628,6 +631,8 @@ class AssistantWithMetrics: assistant: CommunityAssistant model: str key_source: str + langfuse_config: dict | None = None + langfuse_trace_id: str | None = None def create_community_assistant( @@ -717,10 +722,34 @@ def create_community_assistant( page_context=agent_page_context, ) + # Wire LangFuse tracing if configured + langfuse_config = None + langfuse_trace_id = None + try: + from src.core.services.llm import get_llm_service + except ImportError: + logger.debug("LangFuse tracing not available (module not installed)") + else: + try: + llm_service = get_llm_service(settings) + trace_id = f"{community_id}-{uuid.uuid4().hex[:12]}" + config = llm_service.get_config_with_tracing(trace_id=trace_id) + if config.get("callbacks"): + langfuse_config = config + langfuse_trace_id = trace_id + except Exception: + logger.warning( + "LangFuse tracing setup failed for %s, continuing without it", + community_id, + exc_info=True, + ) + return AssistantWithMetrics( assistant=assistant, model=selected_model, key_source=key_source, + langfuse_config=langfuse_config, + langfuse_trace_id=langfuse_trace_id, ) @@ -818,7 +847,7 @@ async def ask( ) assistant = awm.assistant messages = [HumanMessage(content=body.question)] - result = await assistant.ainvoke(messages) + result = await assistant.ainvoke(messages, config=awm.langfuse_config) response_content = "" if result.get("messages"): @@ -828,6 +857,7 @@ async def ask( content = last_msg.content response_content = content if isinstance(content, str) else str(content) + tools_called = extract_tool_names(result) tool_calls_info = [ ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) for tc in result.get("tool_calls", []) @@ -841,7 +871,10 @@ async def ask( "input_tokens": inp, "output_tokens": out, "total_tokens": total, - "tools_called": extract_tool_names(result), + "estimated_cost": estimate_cost(awm.model, inp, out), + "tools_called": tools_called, + "tool_call_count": len(tools_called), + "langfuse_trace_id": awm.langfuse_trace_id, "stream": False, } @@ -930,7 +963,7 @@ async def chat( requested_model=body.model, ) assistant = awm.assistant - result = await assistant.ainvoke(session.messages) + result = await assistant.ainvoke(session.messages, config=awm.langfuse_config) response_content = "" if result.get("messages"): @@ -940,6 +973,7 @@ async def chat( content = last_msg.content response_content = content if isinstance(content, str) else str(content) + tools_called = extract_tool_names(result) tool_calls_info = [ ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) for tc in result.get("tool_calls", []) @@ -953,7 +987,10 @@ async def chat( "input_tokens": inp, "output_tokens": out, "total_tokens": total, - "tools_called": extract_tool_names(result), + "estimated_cost": estimate_cost(awm.model, inp, out), + "tools_called": tools_called, + "tool_call_count": len(tools_called), + "langfuse_trace_id": awm.langfuse_trace_id, "stream": False, } @@ -1057,9 +1094,18 @@ async def get_community_config() -> CommunityConfigResponse: # Per-community Metrics Endpoints # ----------------------------------------------------------------------- + def _require_community_access(auth: AuthScope) -> None: + """Raise 403 if the scoped key cannot access this community.""" + if not auth.can_access_community(community_id): + raise HTTPException( + status_code=403, + detail=f"Your API key does not have access to {community_id} metrics", + ) + @router.get("/metrics") - async def community_metrics(_auth: RequireAdminAuth) -> dict[str, Any]: - """Get metrics summary for this community. Requires admin auth.""" + async def community_metrics(auth: RequireScopedAuth) -> dict[str, Any]: + """Get metrics summary for this community. Requires admin or community key.""" + _require_community_access(auth) try: with metrics_connection() as conn: return get_community_summary(community_id, conn) @@ -1072,14 +1118,15 @@ async def community_metrics(_auth: RequireAdminAuth) -> dict[str, Any]: @router.get("/metrics/usage") async def community_usage( - _auth: RequireAdminAuth, + auth: RequireScopedAuth, period: str = Query( default="daily", description="Time bucket period", pattern="^(daily|weekly|monthly)$", ), ) -> dict[str, Any]: - """Get time-bucketed usage stats for this community. Requires admin auth.""" + """Get time-bucketed usage stats for this community. Requires admin or community key.""" + _require_community_access(auth) try: with metrics_connection() as conn: return get_usage_stats(community_id, period, conn) @@ -1092,6 +1139,43 @@ async def community_usage( detail="Metrics database is temporarily unavailable.", ) + @router.get("/metrics/quality") + async def community_quality( + auth: RequireScopedAuth, + period: str = Query( + default="daily", + description="Time bucket period", + pattern="^(daily|weekly|monthly)$", + ), + ) -> dict[str, Any]: + """Get quality metrics for this community. Requires admin or community key.""" + _require_community_access(auth) + try: + with metrics_connection() as conn: + return get_quality_metrics(community_id, conn, period) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except sqlite3.Error: + logger.exception("Failed to query quality metrics for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + @router.get("/metrics/quality/summary") + async def community_quality_summary(auth: RequireScopedAuth) -> dict[str, Any]: + """Get overall quality summary for this community. Requires admin or community key.""" + _require_community_access(auth) + try: + with metrics_connection() as conn: + return get_quality_summary(community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query quality summary for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + # ----------------------------------------------------------------------- # Per-community Public Metrics Endpoints (no auth required) # ----------------------------------------------------------------------- @@ -1181,6 +1265,8 @@ def _log_streaming_metrics( key_source=awm.key_source if awm else None, tools_called=tools_called, stream=True, + tool_call_count=len(tools_called), + langfuse_trace_id=awm.langfuse_trace_id if awm else None, ) log_request(entry) except Exception: @@ -1233,7 +1319,8 @@ async def _stream_ask_response( "tool_calls": [], } - async for event in graph.astream_events(state, version="v2"): + stream_config = awm.langfuse_config or {} + async for event in graph.astream_events(state, version="v2", config=stream_config): kind = event.get("event") if kind == "on_chat_model_stream": @@ -1369,9 +1456,10 @@ async def _stream_chat_response( "tool_calls": [], } + stream_config = awm.langfuse_config or {} full_response = "" - async for event in graph.astream_events(state, version="v2"): + async for event in graph.astream_events(state, version="v2", config=stream_config): kind = event.get("event") if kind == "on_chat_model_stream": diff --git a/src/api/routers/metrics.py b/src/api/routers/metrics.py index 752be95..6dc4f4a 100644 --- a/src/api/routers/metrics.py +++ b/src/api/routers/metrics.py @@ -1,7 +1,7 @@ """Global metrics API endpoints. Provides cross-community metrics overview and token breakdowns. -All endpoints require admin authentication. +Supports both global admin keys (see all) and per-community keys (filtered view). """ import logging @@ -10,9 +10,14 @@ from fastapi import APIRouter, HTTPException, Query -from src.api.security import RequireAdminAuth +from src.api.security import RequireScopedAuth from src.metrics.db import get_metrics_connection -from src.metrics.queries import get_overview, get_token_breakdown +from src.metrics.queries import ( + get_community_summary, + get_overview, + get_quality_summary, + get_token_breakdown, +) logger = logging.getLogger(__name__) @@ -20,15 +25,18 @@ @router.get("/overview") -async def metrics_overview(_auth: RequireAdminAuth) -> dict[str, Any]: +async def metrics_overview(auth: RequireScopedAuth) -> dict[str, Any]: """Get cross-community metrics overview. - Returns total requests, tokens, average duration, error rate, - and per-community breakdown. + Global admin keys see all communities. Per-community keys see only + their community's data wrapped in the same response format. """ conn = get_metrics_connection() try: - return get_overview(conn) + if auth.role == "admin": + return get_overview(conn) + # Community-scoped: return summary for just their community + return get_community_summary(auth.community_id, conn) except sqlite3.Error: logger.exception("Failed to query metrics database for overview") raise HTTPException( @@ -41,16 +49,22 @@ async def metrics_overview(_auth: RequireAdminAuth) -> dict[str, Any]: @router.get("/tokens") async def token_breakdown( - _auth: RequireAdminAuth, + auth: RequireScopedAuth, community_id: str | None = Query(default=None, description="Filter by community"), ) -> dict[str, Any]: """Get token usage breakdown by model and key source. - Optionally filter by community_id. + Global admin keys can filter by any community. Per-community keys + are automatically scoped to their community (community_id parameter ignored). """ + # Community-scoped keys always filter to their own community + effective_community = community_id + if auth.role == "community": + effective_community = auth.community_id + conn = get_metrics_connection() try: - return get_token_breakdown(conn, community_id=community_id) + return get_token_breakdown(conn, community_id=effective_community) except sqlite3.Error: logger.exception("Failed to query metrics database for token breakdown") raise HTTPException( @@ -59,3 +73,32 @@ async def token_breakdown( ) finally: conn.close() + + +@router.get("/quality") +async def quality_overview(auth: RequireScopedAuth) -> dict[str, Any]: + """Get quality metrics overview. + + Global admin keys see quality for all communities. + Per-community keys see quality summary for their community only. + """ + conn = get_metrics_connection() + try: + if auth.role == "community": + return get_quality_summary(auth.community_id, conn) + # Admin: aggregate quality across all communities + overview = get_overview(conn) + communities_data = overview.get("communities", []) + summaries = [] + for c in communities_data: + cid = c["community_id"] + summaries.append(get_quality_summary(cid, conn)) + return {"communities": summaries} + except sqlite3.Error: + logger.exception("Failed to query quality metrics") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + finally: + conn.close() diff --git a/src/api/scheduler.py b/src/api/scheduler.py index 9629b95..0eb7c47 100644 --- a/src/api/scheduler.py +++ b/src/api/scheduler.py @@ -16,6 +16,9 @@ from src.knowledge.db import init_db from src.knowledge.github_sync import sync_repos from src.knowledge.papers_sync import sync_all_papers, sync_citing_papers +from src.metrics.alerts import create_budget_alert_issue +from src.metrics.budget import check_budget +from src.metrics.db import metrics_connection logger = logging.getLogger(__name__) @@ -63,6 +66,7 @@ def _get_community_paper_dois(community_id: str) -> list[str]: # Failure tracking for alerting _github_sync_failures = 0 _papers_sync_failures = 0 +_budget_check_failures = 0 MAX_CONSECUTIVE_FAILURES = 3 @@ -166,6 +170,82 @@ def _run_papers_sync() -> None: ) +def _check_community_budgets() -> None: + """Check budget limits for all communities and create alert issues if exceeded.""" + global _budget_check_failures + logger.info("Starting scheduled budget check for all communities") + try: + communities_checked = 0 + communities_failed = 0 + alerts_created = 0 + + with metrics_connection() as conn: + for info in registry.list_all(): + if not info.community_config or not info.community_config.budget: + continue + + budget_cfg = info.community_config.budget + maintainers = info.community_config.maintainers + + try: + budget_status = check_budget( + community_id=info.id, + daily_limit_usd=budget_cfg.daily_limit_usd, + monthly_limit_usd=budget_cfg.monthly_limit_usd, + alert_threshold_pct=budget_cfg.alert_threshold_pct, + conn=conn, + ) + communities_checked += 1 + + if budget_status.needs_alert: + issue_url = create_budget_alert_issue( + budget_status=budget_status, + maintainers=maintainers, + ) + if issue_url: + alerts_created += 1 + logger.warning( + "Budget alert created for %s: %s", + info.id, + issue_url, + ) + except Exception: + communities_failed += 1 + logger.exception("Failed to check budget for community %s", info.id) + + log_level = logging.WARNING if communities_failed else logging.INFO + logger.log( + log_level, + "Budget check complete: %d checked, %d failed, %d alerts created", + communities_checked, + communities_failed, + alerts_created, + ) + if communities_failed: + _budget_check_failures += 1 + if _budget_check_failures >= MAX_CONSECUTIVE_FAILURES: + logger.critical( + "Budget check has had community failures for %d consecutive runs. " + "Manual intervention required.", + _budget_check_failures, + ) + else: + _budget_check_failures = 0 + except Exception: + _budget_check_failures += 1 + logger.error( + "Budget check job failed (attempt %d/%d)", + _budget_check_failures, + MAX_CONSECUTIVE_FAILURES, + exc_info=True, + ) + if _budget_check_failures >= MAX_CONSECUTIVE_FAILURES: + logger.critical( + "Budget check has failed %d times consecutively. Manual intervention required.", + _budget_check_failures, + ) + + def start_scheduler() -> BackgroundScheduler | None: """Start the background scheduler with configured sync jobs. @@ -224,6 +304,20 @@ def start_scheduler() -> BackgroundScheduler | None: except ValueError as e: logger.error("Invalid papers sync cron expression: %s", e) + # Add budget check job (every 15 minutes) + try: + budget_trigger = CronTrigger(minute="*/15") + _scheduler.add_job( + _check_community_budgets, + trigger=budget_trigger, + id="budget_check", + name="Community Budget Check", + replace_existing=True, + ) + logger.info("Budget check scheduled: every 15 minutes") + except ValueError as e: + logger.error("Failed to schedule budget check: %s", e) + # Start the scheduler _scheduler.start() logger.info("Background scheduler started") diff --git a/src/api/security.py b/src/api/security.py index bed7366..5d92036 100644 --- a/src/api/security.py +++ b/src/api/security.py @@ -1,6 +1,7 @@ """Security and authentication for the OSA API.""" -from typing import Annotated +from dataclasses import dataclass +from typing import Annotated, Literal from fastapi import Depends, HTTPException, Security, status from fastapi.security import APIKeyHeader @@ -45,8 +46,7 @@ async def verify_api_key( if openai_key or anthropic_key or openrouter_key: return None - # Parse comma-separated API keys - valid_keys = {k.strip() for k in settings.api_keys.split(",") if k.strip()} + valid_keys = settings.parse_admin_keys() # If auth is enabled, require valid API key if not api_key: @@ -86,8 +86,7 @@ async def verify_admin_api_key( if not settings.api_keys: return None - # Parse comma-separated API keys - valid_keys = {k.strip() for k in settings.api_keys.split(",") if k.strip()} + valid_keys = settings.parse_admin_keys() # Admin endpoints require valid server API key (no BYOK bypass) if not api_key: @@ -110,6 +109,83 @@ async def verify_admin_api_key( RequireAdminAuth = Annotated[str | None, Depends(verify_admin_api_key)] +@dataclass(frozen=True) +class AuthScope: + """Scoped authentication result for community-level access control. + + Attributes: + role: "admin" for global admin keys, "community" for per-community keys. + community_id: The community this key is scoped to, or None for global admin. + + Use ``can_access_community(id)`` to check whether this scope permits + access to a given community's data. + """ + + role: Literal["admin", "community"] + community_id: str | None = None + + def __post_init__(self) -> None: + if self.role == "community" and not self.community_id: + raise ValueError("community role requires a community_id") + if self.role == "admin" and self.community_id is not None: + raise ValueError("admin role must not have a community_id") + + def can_access_community(self, community_id: str) -> bool: + """Check if this auth scope permits access to a given community.""" + if self.role == "admin": + return True + return self.community_id == community_id + + +async def verify_scoped_admin_key( + api_key: Annotated[str | None, Security(api_key_header)], + settings: Annotated[Settings, Depends(get_settings)], +) -> AuthScope: + """Verify API key and return scoped auth context. + + Checks in order: + 1. Global admin keys (api_keys setting) -> AuthScope(role="admin") + 2. Per-community keys (community_admin_keys) -> AuthScope(role="community", community_id=X) + 3. No match -> 401/403 + + When auth is disabled (require_api_auth=False or no keys configured), + returns admin scope for backward compatibility. + """ + # If auth is not required, grant full admin access + if not settings.require_api_auth: + return AuthScope(role="admin") + + # If no API keys configured at all, auth is effectively disabled + if not settings.api_keys and not settings.community_admin_keys: + return AuthScope(role="admin") + + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key required for metrics access", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Check global admin keys first + if settings.api_keys and api_key in settings.parse_admin_keys(): + return AuthScope(role="admin") + + # Check per-community keys + community_keys = settings.parse_community_admin_keys() + for community_id, keys in community_keys.items(): + if api_key in keys: + return AuthScope(role="community", community_id=community_id) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + +# Dependency for scoped admin routes (supports per-community keys) +RequireScopedAuth = Annotated[AuthScope, Depends(verify_scoped_admin_key)] + + class BYOKHeaders: """BYOK (Bring Your Own Key) headers for LLM providers. diff --git a/src/assistants/eeglab/config.yaml b/src/assistants/eeglab/config.yaml index a9f9a9b..3a1baa6 100644 --- a/src/assistants/eeglab/config.yaml +++ b/src/assistants/eeglab/config.yaml @@ -13,6 +13,17 @@ cors_origins: - https://www.eeglab.org - https://sccn.github.io +# Community maintainers (GitHub usernames) +# Used for scoped admin access and budget alert @mentions +maintainers: + - arnodelorme + +# Budget limits for cost management +budget: + daily_limit_usd: 5.00 + monthly_limit_usd: 50.00 + alert_threshold_pct: 80.0 + # System prompt template with runtime-substituted placeholders # Placeholders (in curly braces) are replaced by CommunityAssistant at runtime: # {preloaded_docs_section} - Embedded documentation content diff --git a/src/assistants/hed/config.yaml b/src/assistants/hed/config.yaml index f81da10..7c7c5c9 100644 --- a/src/assistants/hed/config.yaml +++ b/src/assistants/hed/config.yaml @@ -19,6 +19,18 @@ cors_origins: # The backend must have this environment variable set openrouter_api_key_env_var: "OPENROUTER_API_KEY_HED" +# Community maintainers (GitHub usernames) +# Used for scoped admin access and budget alert @mentions +maintainers: + - VisLab + - yarikoptic + +# Budget limits for cost management +budget: + daily_limit_usd: 5.00 + monthly_limit_usd: 50.00 + alert_threshold_pct: 80.0 + # Default model for this community (optional) # If specified, overrides the platform-level default model # Format: creator/model-name (OpenRouter format) diff --git a/src/core/config/community.py b/src/core/config/community.py index dbcff55..daf9604 100644 --- a/src/core/config/community.py +++ b/src/core/config/community.py @@ -596,6 +596,35 @@ def validate_agent_roles(self) -> "FAQGenerationConfig": return self +class BudgetConfig(BaseModel): + """Budget limits and alert thresholds for a community. + + When configured, the scheduler periodically checks spend against + these limits and creates GitHub issues when thresholds are exceeded. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + daily_limit_usd: float = Field(..., gt=0, description="Maximum daily spend in USD") + monthly_limit_usd: float = Field(..., gt=0, description="Maximum monthly spend in USD") + alert_threshold_pct: float = Field( + default=80.0, + ge=0, + le=100, + description="Percentage of limit at which to trigger alert (default: 80%)", + ) + + @model_validator(mode="after") + def validate_limits(self) -> "BudgetConfig": + """Ensure daily limit does not exceed monthly limit.""" + if self.daily_limit_usd > self.monthly_limit_usd: + raise ValueError( + f"daily_limit_usd ({self.daily_limit_usd}) cannot exceed " + f"monthly_limit_usd ({self.monthly_limit_usd})" + ) + return self + + class CommunityConfig(BaseModel): """Configuration for a single research community assistant. @@ -730,6 +759,32 @@ def validate_id(cls, v: str) -> str: If not specified, uses default routing for the model. """ + maintainers: list[str] = Field(default_factory=list) + """GitHub usernames of community maintainers. + + Used for: + - @mentioning in automated alert issues (budget alerts, etc.) + - Documenting who is responsible for the community + + Example: + maintainers: + - octocat + - janedoe + """ + + budget: BudgetConfig | None = None + """Budget limits and alert thresholds for cost management. + + When configured, the scheduler checks spend against these limits + and creates GitHub issues when thresholds are exceeded. + + Example: + budget: + daily_limit_usd: 5.0 + monthly_limit_usd: 50.0 + alert_threshold_pct: 80 + """ + @field_validator("cors_origins") @classmethod def validate_cors_origins(cls, v: list[str]) -> list[str]: @@ -756,6 +811,30 @@ def validate_cors_origins(cls, v: list[str]) -> list[str]: validated.append(origin) return validated + @field_validator("maintainers") + @classmethod + def validate_maintainers(cls, v: list[str]) -> list[str]: + """Validate GitHub usernames in maintainers list. + + GitHub usernames: 1-39 chars, alphanumeric or hyphens, + cannot start or end with hyphen. + """ + gh_username_pattern = re.compile(r"^[a-zA-Z0-9]([a-zA-Z0-9-]{0,37}[a-zA-Z0-9])?$") + validated = [] + for username in v: + username = username.strip() + if not username: + continue + if not gh_username_pattern.match(username): + raise ValueError( + f"Invalid GitHub username: '{username}'. " + "Must be 1-39 alphanumeric characters or hyphens, " + "cannot start/end with hyphen." + ) + if username not in validated: + validated.append(username) + return validated + @field_validator("openrouter_api_key_env_var") @classmethod def validate_openrouter_api_key_env_var(cls, v: str | None) -> str | None: diff --git a/src/metrics/alerts.py b/src/metrics/alerts.py new file mode 100644 index 0000000..ee29693 --- /dev/null +++ b/src/metrics/alerts.py @@ -0,0 +1,163 @@ +"""GitHub issue alerting for budget and operational alerts. + +Creates GitHub issues when budget thresholds are exceeded, +with deduplication to avoid spamming. +""" + +import json +import logging +import subprocess + +from src.metrics.budget import BudgetStatus + +logger = logging.getLogger(__name__) + +# GitHub repo for alert issues (org/repo format) +ALERT_REPO = "OpenScience-Collective/osa" + + +def _issue_exists(title: str, repo: str = ALERT_REPO) -> bool | None: + """Check if an open issue with this title already exists. + + Uses gh CLI to search for existing issues to prevent duplicates. + + Returns: + True if a matching issue exists, False if not, None if the check failed. + """ + try: + result = subprocess.run( + [ + "gh", + "issue", + "list", + "--repo", + repo, + "--state", + "open", + "--search", + title, + "--json", + "title", + "--limit", + "5", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + logger.error("gh issue list failed: %s", result.stderr) + return None + + issues = json.loads(result.stdout) + return any(issue.get("title") == title for issue in issues) + except Exception: + logger.exception("Failed to check existing issues") + return None + + +def create_budget_alert_issue( + budget_status: BudgetStatus, + maintainers: list[str], + repo: str = ALERT_REPO, +) -> str | None: + """Create a GitHub issue for a budget alert. + + Includes deduplication: checks for existing open issues with the same + title before creating a new one. + + Args: + budget_status: The budget check result with spend/limit data. + maintainers: GitHub usernames to @mention in the issue body. + repo: GitHub repository in org/repo format. + + Returns: + The issue URL if created, None if skipped (duplicate or error). + """ + # Determine alert type + alert_parts = [] + if budget_status.daily_exceeded: + alert_parts.append("daily limit exceeded") + elif budget_status.daily_alert: + alert_parts.append(f"daily spend at {budget_status.daily_pct:.0f}%") + if budget_status.monthly_exceeded: + alert_parts.append("monthly limit exceeded") + elif budget_status.monthly_alert: + alert_parts.append(f"monthly spend at {budget_status.monthly_pct:.0f}%") + + if not alert_parts: + return None + + alert_type = ", ".join(alert_parts) + title = f"[Budget Alert] {budget_status.community_id}: {alert_type}" + + # Check for existing open issue (three-state: True/False/None) + exists = _issue_exists(title, repo) + if exists is None: + logger.warning( + "Could not verify deduplication for %s; suppressing alert to prevent spam. " + "Check gh CLI and GitHub token configuration.", + budget_status.community_id, + ) + return None + if exists: + logger.info( + "Budget alert issue already exists for %s, skipping", budget_status.community_id + ) + return None + + # Build issue body + mentions = ( + " ".join(f"@{m}" for m in maintainers) if maintainers else "No maintainers configured" + ) + + body = f"""## Budget Alert for `{budget_status.community_id}` + +**Alert:** {alert_type} + +### Current Spend + +| Metric | Spend | Limit | Usage | +|--------|-------|-------|-------| +| Daily | ${budget_status.daily_spend_usd:.4f} | ${budget_status.daily_limit_usd:.2f} | {budget_status.daily_pct:.1f}% | +| Monthly | ${budget_status.monthly_spend_usd:.4f} | ${budget_status.monthly_limit_usd:.2f} | {budget_status.monthly_pct:.1f}% | + +### Alert Threshold +Configured at {budget_status.alert_threshold_pct:.0f}% of limits. + +### Maintainers +{mentions} + +--- +*This issue was created automatically by the OSA budget monitoring system.* +""" + + try: + result = subprocess.run( + [ + "gh", + "issue", + "create", + "--repo", + repo, + "--title", + title, + "--body", + body, + "--label", + "cost-management,operations", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + logger.error("Failed to create budget alert issue: %s", result.stderr) + return None + + issue_url = result.stdout.strip() + logger.info("Created budget alert issue: %s", issue_url) + return issue_url + except Exception: + logger.exception("Failed to create budget alert issue for %s", budget_status.community_id) + return None diff --git a/src/metrics/budget.py b/src/metrics/budget.py new file mode 100644 index 0000000..b0d0e9c --- /dev/null +++ b/src/metrics/budget.py @@ -0,0 +1,115 @@ +"""Budget checking for community cost management. + +Queries the metrics database for current spend and compares against +configured budget limits. +""" + +import logging +import sqlite3 +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BudgetStatus: + """Result of a budget check for a community.""" + + community_id: str + daily_spend_usd: float + monthly_spend_usd: float + daily_limit_usd: float + monthly_limit_usd: float + alert_threshold_pct: float + + @property + def daily_pct(self) -> float: + """Daily spend as percentage of limit.""" + if self.daily_limit_usd <= 0: + return 0.0 + return (self.daily_spend_usd / self.daily_limit_usd) * 100 + + @property + def monthly_pct(self) -> float: + """Monthly spend as percentage of limit.""" + if self.monthly_limit_usd <= 0: + return 0.0 + return (self.monthly_spend_usd / self.monthly_limit_usd) * 100 + + @property + def daily_exceeded(self) -> bool: + """Whether daily spend has reached or exceeded the daily limit.""" + return self.daily_spend_usd >= self.daily_limit_usd + + @property + def monthly_exceeded(self) -> bool: + """Whether monthly spend has reached or exceeded the monthly limit.""" + return self.monthly_spend_usd >= self.monthly_limit_usd + + @property + def daily_alert(self) -> bool: + """Whether daily spend crossed the alert threshold.""" + return self.daily_pct >= self.alert_threshold_pct + + @property + def monthly_alert(self) -> bool: + """Whether monthly spend crossed the alert threshold.""" + return self.monthly_pct >= self.alert_threshold_pct + + @property + def needs_alert(self) -> bool: + """Whether any alert threshold has been crossed.""" + return self.daily_alert or self.monthly_alert + + +def check_budget( + community_id: str, + daily_limit_usd: float, + monthly_limit_usd: float, + alert_threshold_pct: float, + conn: sqlite3.Connection, +) -> BudgetStatus: + """Check current spend against budget limits. + + Queries estimated_cost from request_log for today and current month. + + Args: + community_id: The community identifier. + daily_limit_usd: Maximum daily spend. + monthly_limit_usd: Maximum monthly spend. + alert_threshold_pct: Alert threshold percentage. + conn: SQLite connection. + + Returns: + BudgetStatus with current spend and limit info. + """ + # Daily spend (today UTC) + daily_row = conn.execute( + """ + SELECT COALESCE(SUM(estimated_cost), 0) as spend + FROM request_log + WHERE community_id = ? + AND date(timestamp) = date('now') + """, + (community_id,), + ).fetchone() + + # Monthly spend (current month UTC) + monthly_row = conn.execute( + """ + SELECT COALESCE(SUM(estimated_cost), 0) as spend + FROM request_log + WHERE community_id = ? + AND strftime('%Y-%m', timestamp) = strftime('%Y-%m', 'now') + """, + (community_id,), + ).fetchone() + + return BudgetStatus( + community_id=community_id, + daily_spend_usd=round(daily_row["spend"], 6), + monthly_spend_usd=round(monthly_row["spend"], 6), + daily_limit_usd=daily_limit_usd, + monthly_limit_usd=monthly_limit_usd, + alert_threshold_pct=alert_threshold_pct, + ) diff --git a/src/metrics/cost.py b/src/metrics/cost.py new file mode 100644 index 0000000..1fd5036 --- /dev/null +++ b/src/metrics/cost.py @@ -0,0 +1,67 @@ +"""Cost estimation for LLM requests. + +Model pricing table with per-token costs (USD per million tokens). +Pricing is from OpenRouter as of 2025-07; update regularly. +""" + +import logging + +logger = logging.getLogger(__name__) + +# Pricing: USD per 1M tokens (input, output) +# Source: https://openrouter.ai/models +# Last updated: 2025-07 +MODEL_PRICING: dict[str, tuple[float, float]] = { + # Qwen models + "qwen/qwen3-235b-a22b-2507": (0.14, 0.34), + "qwen/qwen3-30b-a3b-2507": (0.07, 0.15), + # OpenAI models + "openai/gpt-4o": (2.50, 10.00), + "openai/gpt-4o-mini": (0.15, 0.60), + "openai/gpt-oss-120b": (0.00, 0.00), # Free tier + "openai/o1": (15.00, 60.00), + "openai/o3-mini": (1.10, 4.40), + # Anthropic models + "anthropic/claude-opus-4": (15.00, 75.00), + "anthropic/claude-sonnet-4": (3.00, 15.00), + "anthropic/claude-haiku-4.5": (0.80, 4.00), + "anthropic/claude-3.5-sonnet": (3.00, 15.00), + # Google models + "google/gemini-2.5-pro-preview": (1.25, 10.00), + "google/gemini-2.5-flash-preview": (0.15, 0.60), + # DeepSeek models + "deepseek/deepseek-chat-v3": (0.14, 0.28), + "deepseek/deepseek-r1": (0.55, 2.19), + # Meta models + "meta-llama/llama-4-maverick": (0.16, 0.40), +} + +# Fallback rate for models not in the pricing table +_FALLBACK_INPUT_RATE = 1.00 # USD per 1M tokens +_FALLBACK_OUTPUT_RATE = 3.00 # USD per 1M tokens + + +def estimate_cost( + model: str | None, + input_tokens: int, + output_tokens: int, +) -> float: + """Estimate the USD cost for a request. + + Args: + model: Model name in OpenRouter format (e.g., "qwen/qwen3-235b-a22b-2507"). + input_tokens: Number of input tokens. + output_tokens: Number of output tokens. + + Returns: + Estimated cost in USD, rounded to 6 decimal places. + """ + if model and model in MODEL_PRICING: + input_rate, output_rate = MODEL_PRICING[model] + else: + if model: + logger.warning("No pricing data for model %s, using fallback rates", model) + input_rate, output_rate = _FALLBACK_INPUT_RATE, _FALLBACK_OUTPUT_RATE + + cost = (input_tokens * input_rate + output_tokens * output_rate) / 1_000_000 + return round(cost, 6) diff --git a/src/metrics/db.py b/src/metrics/db.py index c72842f..900747c 100644 --- a/src/metrics/db.py +++ b/src/metrics/db.py @@ -34,7 +34,10 @@ estimated_cost REAL, tools_called TEXT, key_source TEXT, - stream INTEGER DEFAULT 0 + stream INTEGER DEFAULT 0, + tool_call_count INTEGER DEFAULT 0, + error_message TEXT, + langfuse_trace_id TEXT ); CREATE INDEX IF NOT EXISTS idx_request_log_community @@ -45,6 +48,13 @@ ON request_log(community_id, timestamp); """ +# Columns added after initial schema; ALTER TABLE for existing databases +_MIGRATION_COLUMNS = [ + ("tool_call_count", "INTEGER DEFAULT 0"), + ("error_message", "TEXT"), + ("langfuse_trace_id", "TEXT"), +] + @dataclass class RequestLogEntry: @@ -65,6 +75,9 @@ class RequestLogEntry: tools_called: list[str] = field(default_factory=list) key_source: str | None = None stream: bool = False + tool_call_count: int = 0 + error_message: str | None = None + langfuse_trace_id: str | None = None def get_metrics_db_path() -> Path: @@ -111,11 +124,31 @@ def metrics_connection(db_path: Path | None = None) -> Generator[sqlite3.Connect conn.close() +def _migrate_columns(conn: sqlite3.Connection) -> None: + """Add new columns to existing databases (backward-compatible migration). + + Attempts ALTER TABLE ADD COLUMN for each new column. If the column + already exists, SQLite raises OperationalError with 'duplicate column', + which we catch and ignore, making this function idempotent. + """ + for col_name, col_def in _MIGRATION_COLUMNS: + try: + conn.execute(f"ALTER TABLE request_log ADD COLUMN {col_name} {col_def}") + logger.info("Added column %s to request_log", col_name) + except sqlite3.OperationalError as e: + if "duplicate column" in str(e).lower(): + pass # Expected on subsequent runs + else: + logger.error("Failed to add column %s to request_log: %s", col_name, e) + raise + + def init_metrics_db(db_path: Path | None = None) -> None: """Initialize the metrics database schema. Idempotent. Creates the request_log table and indexes if they don't exist. Enables WAL mode for concurrent read/write access. + Runs migrations to add new columns to existing databases. Args: db_path: Optional path override (for testing). @@ -124,6 +157,7 @@ def init_metrics_db(db_path: Path | None = None) -> None: try: conn.execute("PRAGMA journal_mode=WAL") conn.executescript(METRICS_SCHEMA_SQL) + _migrate_columns(conn) conn.commit() logger.info("Metrics database initialized at %s", db_path or get_metrics_db_path()) finally: @@ -144,8 +178,9 @@ def log_request(entry: RequestLogEntry, db_path: Path | None = None) -> None: INSERT INTO request_log ( request_id, timestamp, community_id, endpoint, method, duration_ms, status_code, model, input_tokens, output_tokens, - total_tokens, estimated_cost, tools_called, key_source, stream - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + total_tokens, estimated_cost, tools_called, key_source, stream, + tool_call_count, error_message, langfuse_trace_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( entry.request_id, @@ -163,6 +198,9 @@ def log_request(entry: RequestLogEntry, db_path: Path | None = None) -> None: json.dumps(entry.tools_called) if entry.tools_called else None, entry.key_source, 1 if entry.stream else 0, + entry.tool_call_count, + entry.error_message, + entry.langfuse_trace_id, ), ) conn.commit() diff --git a/src/metrics/middleware.py b/src/metrics/middleware.py index 7b94d46..92fbf17 100644 --- a/src/metrics/middleware.py +++ b/src/metrics/middleware.py @@ -72,6 +72,9 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - "tools_called": agent_data.get("tools_called", []), "key_source": agent_data.get("key_source"), "stream": agent_data.get("stream", False), + "tool_call_count": agent_data.get("tool_call_count", 0), + "error_message": agent_data.get("error_message"), + "langfuse_trace_id": agent_data.get("langfuse_trace_id"), } entry = RequestLogEntry( diff --git a/src/metrics/queries.py b/src/metrics/queries.py index 6281639..aa5073e 100644 --- a/src/metrics/queries.py +++ b/src/metrics/queries.py @@ -443,3 +443,173 @@ def get_public_usage_stats( for r in rows ], } + + +# --------------------------------------------------------------------------- +# Quality metrics queries +# --------------------------------------------------------------------------- + + +def get_quality_metrics( + community_id: str, + conn: sqlite3.Connection, + period: str = "daily", +) -> dict[str, Any]: + """Get quality metrics for a community, bucketed by time period. + + Returns error rates, average tool call counts, and latency percentiles + (p50, p95) per time bucket. + + Args: + community_id: The community identifier. + conn: SQLite connection. + period: One of "daily", "weekly", "monthly". + + Returns: + Dict with period, community_id, and buckets with quality data. + """ + fmt = _validate_period(period) + + rows = conn.execute( + f""" + SELECT + strftime('{fmt}', timestamp) as bucket, + COUNT(*) as requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors, + COALESCE(AVG(tool_call_count), 0) as avg_tool_calls, + COUNT(CASE WHEN error_message IS NOT NULL THEN 1 END) as agent_errors, + COUNT(CASE WHEN langfuse_trace_id IS NOT NULL THEN 1 END) as traced_requests + FROM request_log + WHERE community_id = ? + GROUP BY bucket + ORDER BY bucket + """, + (community_id,), + ).fetchall() + + # Compute latency percentiles per bucket + buckets = [] + for r in rows: + bucket_name = r["bucket"] + total = r["requests"] + error_rate = r["errors"] / total if total > 0 else 0.0 + + # Latency percentiles for this bucket + p50, p95 = _fetch_latency_percentiles( + conn, + community_id, + extra_where=f"AND strftime('{fmt}', timestamp) = ?", + extra_params=(bucket_name,), + ) + + buckets.append( + { + "bucket": bucket_name, + "requests": total, + "error_rate": round(error_rate, 4), + "avg_tool_calls": round(r["avg_tool_calls"], 2), + "agent_errors": r["agent_errors"], + "traced_requests": r["traced_requests"], + "p50_duration_ms": round(p50, 1) if p50 is not None else None, + "p95_duration_ms": round(p95, 1) if p95 is not None else None, + } + ) + + return { + "community_id": community_id, + "period": period, + "buckets": buckets, + } + + +def get_quality_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: + """Get overall quality overview for a community. + + Args: + community_id: The community identifier. + conn: SQLite connection. + + Returns: + Dict with overall quality stats: error_rate, avg_tool_calls, + agent_errors, p50/p95 latency, traced percentage. + """ + row = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_count, + COALESCE(AVG(tool_call_count), 0) as avg_tool_calls, + COUNT(CASE WHEN error_message IS NOT NULL THEN 1 END) as agent_errors, + COUNT(CASE WHEN langfuse_trace_id IS NOT NULL THEN 1 END) as traced + FROM request_log + WHERE community_id = ? + """, + (community_id,), + ).fetchone() + + total = row["total_requests"] + error_rate = row["error_count"] / total if total > 0 else 0.0 + traced_pct = row["traced"] / total if total > 0 else 0.0 + + # Latency percentiles + p50, p95 = _fetch_latency_percentiles(conn, community_id) + + return { + "community_id": community_id, + "total_requests": total, + "error_rate": round(error_rate, 4), + "avg_tool_calls": round(row["avg_tool_calls"], 2), + "agent_errors": row["agent_errors"], + "traced_pct": round(traced_pct, 4), + "p50_duration_ms": round(p50, 1) if p50 is not None else None, + "p95_duration_ms": round(p95, 1) if p95 is not None else None, + } + + +def _percentile(sorted_values: list[float], pct: float) -> float | None: + """Compute percentile from a sorted list of values. + + Uses floor-rank method: index = floor(pct * len), clamped to valid range. + + Args: + sorted_values: Pre-sorted list of numeric values. + pct: Percentile as fraction (e.g., 0.5 for p50, 0.95 for p95). + + Returns: + The percentile value, or None if list is empty. + """ + if not sorted_values: + return None + idx = int(pct * len(sorted_values)) + idx = min(idx, len(sorted_values) - 1) + return sorted_values[idx] + + +def _fetch_latency_percentiles( + conn: sqlite3.Connection, + community_id: str, + extra_where: str = "", + extra_params: tuple[str, ...] = (), +) -> tuple[float | None, float | None]: + """Fetch p50 and p95 latency for a community with optional filters. + + Args: + conn: SQLite connection with row_factory set. + community_id: The community identifier. + extra_where: Additional SQL WHERE clause (e.g., "AND strftime(...) = ?"). + extra_params: Parameters for the extra WHERE clause. + + Returns: + Tuple of (p50, p95) duration in ms, or None if no data. + """ + durations = conn.execute( + f""" + SELECT duration_ms + FROM request_log + WHERE community_id = ? AND duration_ms IS NOT NULL {extra_where} + ORDER BY duration_ms + """, + (community_id, *extra_params), + ).fetchall() + values = [d["duration_ms"] for d in durations] + return _percentile(values, 0.5), _percentile(values, 0.95) diff --git a/tests/test_api/test_metrics_endpoints.py b/tests/test_api/test_metrics_endpoints.py index 685dd24..9dffc60 100644 --- a/tests/test_api/test_metrics_endpoints.py +++ b/tests/test_api/test_metrics_endpoints.py @@ -14,6 +14,7 @@ ) ADMIN_KEY = "test-metrics-admin-key" +COMMUNITY_KEY = "hed-community-key" @pytest.fixture @@ -38,6 +39,8 @@ def metrics_db(tmp_path): estimated_cost=0.001, tools_called=["search_docs"], key_source="platform", + tool_call_count=1, + langfuse_trace_id="trace-001", ), RequestLogEntry( request_id="r2", @@ -52,6 +55,19 @@ def metrics_db(tmp_path): output_tokens=100, total_tokens=300, key_source="byok", + tool_call_count=0, + ), + RequestLogEntry( + request_id="r3", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=500.0, + status_code=500, + model="qwen/qwen3-235b", + error_message="LLM timeout", + tool_call_count=0, ), ] for e in entries: @@ -83,6 +99,24 @@ def auth_env(): get_settings.cache_clear() +@pytest.fixture +def scoped_auth_env(): + """Set up environment with both admin and community keys.""" + from src.api.config import get_settings + + os.environ["API_KEYS"] = ADMIN_KEY + os.environ["REQUIRE_API_AUTH"] = "true" + os.environ["COMMUNITY_ADMIN_KEYS"] = f"hed:{COMMUNITY_KEY}" + get_settings.cache_clear() + + yield + + del os.environ["API_KEYS"] + del os.environ["REQUIRE_API_AUTH"] + del os.environ["COMMUNITY_ADMIN_KEYS"] + get_settings.cache_clear() + + @pytest.fixture def client(): """Create test client.""" @@ -209,3 +243,178 @@ def test_invalid_period_returns_422(self, client): def test_requires_admin_auth(self, client): response = client.get("/hed/metrics/usage", params={"period": "daily"}) assert response.status_code == 401 + + +class TestCommunityQuality: + """Tests for GET /{community_id}/metrics/quality.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/hed/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["community_id"] == "hed" + assert data["period"] == "daily" + assert "buckets" in data + assert len(data["buckets"]) > 0 + bucket = data["buckets"][0] + assert "requests" in bucket + assert "error_rate" in bucket + assert "avg_tool_calls" in bucket + assert "agent_errors" in bucket + assert "traced_requests" in bucket + assert "p50_duration_ms" in bucket + assert "p95_duration_ms" in bucket + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_weekly_period(self, client): + response = client.get( + "/hed/metrics/quality", + params={"period": "weekly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 200 + assert response.json()["period"] == "weekly" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_invalid_period_returns_422(self, client): + response = client.get( + "/hed/metrics/quality", + params={"period": "hourly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 422 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_auth(self, client): + response = client.get("/hed/metrics/quality") + assert response.status_code == 401 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_quality_data_reflects_errors(self, client): + """Verify error data from fixture is reflected in quality metrics.""" + response = client.get("/hed/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + # All 3 entries are on 2025-01-15, so one daily bucket + assert len(data["buckets"]) == 1 + bucket = data["buckets"][0] + assert bucket["requests"] == 3 + # 1 out of 3 requests has status_code >= 400 + assert bucket["error_rate"] > 0 + assert bucket["agent_errors"] == 1 + assert bucket["traced_requests"] == 1 + + +class TestCommunityQualitySummary: + """Tests for GET /{community_id}/metrics/quality/summary.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/hed/metrics/quality/summary", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics/quality/summary", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["community_id"] == "hed" + assert "total_requests" in data + assert "error_rate" in data + assert "avg_tool_calls" in data + assert "agent_errors" in data + assert "traced_pct" in data + assert "p50_duration_ms" in data + assert "p95_duration_ms" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_summary_values(self, client): + """Verify summary aggregates match fixture data.""" + response = client.get("/hed/metrics/quality/summary", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["total_requests"] == 3 + assert data["agent_errors"] == 1 + # 1 traced out of 3 + assert 0 < data["traced_pct"] < 1 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_auth(self, client): + response = client.get("/hed/metrics/quality/summary") + assert response.status_code == 401 + + +class TestGlobalQuality: + """Tests for GET /metrics/quality.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200_for_admin(self, client): + response = client.get("/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_admin_sees_communities_list(self, client): + response = client.get("/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert "communities" in data + assert len(data["communities"]) > 0 + community = data["communities"][0] + assert community["community_id"] == "hed" + assert "error_rate" in community + assert "avg_tool_calls" in community + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_auth(self, client): + response = client.get("/metrics/quality") + assert response.status_code == 401 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_403_with_invalid_key(self, client): + response = client.get("/metrics/quality", headers={"X-API-Key": "wrong-key"}) + assert response.status_code == 403 + + +class TestScopedKeyOnGlobalEndpoints: + """Tests for community-scoped keys on global metrics endpoints.""" + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_community_key_on_overview_returns_scoped_data(self, client): + """Community key on /metrics/overview should return only their community summary.""" + response = client.get("/metrics/overview", headers={"X-API-Key": COMMUNITY_KEY}) + assert response.status_code == 200 + data = response.json() + # Scoped key gets community summary, not the full overview with communities list + assert data["community_id"] == "hed" + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_community_key_on_tokens_is_auto_scoped(self, client): + """Community key on /metrics/tokens should be auto-scoped to own community.""" + response = client.get( + "/metrics/tokens", + params={"community_id": "eeglab"}, # Try to view another community + headers={"X-API-Key": COMMUNITY_KEY}, + ) + assert response.status_code == 200 + data = response.json() + # Should be forced to own community, ignoring the query param + assert data["community_id"] == "hed" + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_community_key_on_quality_returns_own_summary(self, client): + """Community key on /metrics/quality should return only own community quality.""" + response = client.get("/metrics/quality", headers={"X-API-Key": COMMUNITY_KEY}) + assert response.status_code == 200 + data = response.json() + # Scoped key gets single community summary, not communities list + assert data["community_id"] == "hed" + assert "error_rate" in data + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_admin_key_still_sees_all(self, client): + """Admin key should still see full data on all global endpoints.""" + response = client.get("/metrics/overview", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + data = response.json() + assert "communities" in data diff --git a/tests/test_api/test_scoped_auth.py b/tests/test_api/test_scoped_auth.py new file mode 100644 index 0000000..e1c7ce7 --- /dev/null +++ b/tests/test_api/test_scoped_auth.py @@ -0,0 +1,200 @@ +"""Tests for per-community scoped authentication. + +Tests AuthScope, verify_scoped_admin_key, and per-community metrics access control. +Uses real HTTP requests against FastAPI test apps. +""" + +import os + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api.config import Settings, get_settings +from src.api.security import AuthScope, RequireScopedAuth + + +@pytest.fixture +def app_scoped_auth() -> FastAPI: + """Create a test app with scoped auth endpoints.""" + os.environ["API_KEYS"] = "global-admin-key" + os.environ["REQUIRE_API_AUTH"] = "true" + os.environ["COMMUNITY_ADMIN_KEYS"] = "hed:hed-key-1,eeglab:eeglab-key-1,hed:hed-key-2" + + get_settings.cache_clear() + + app = FastAPI() + + @app.get("/scoped") + async def scoped_route(auth: RequireScopedAuth) -> dict: + return { + "role": auth.role, + "community_id": auth.community_id, + } + + @app.get("/metrics/{community_id}") + async def community_metrics(community_id: str, auth: RequireScopedAuth) -> dict: + if not auth.can_access_community(community_id): + from fastapi import HTTPException + + raise HTTPException(status_code=403, detail="Access denied") + return {"community_id": community_id, "role": auth.role} + + yield app + + for key in ["API_KEYS", "REQUIRE_API_AUTH", "COMMUNITY_ADMIN_KEYS"]: + os.environ.pop(key, None) + get_settings.cache_clear() + + +@pytest.fixture +def client(app_scoped_auth: FastAPI) -> TestClient: + return TestClient(app_scoped_auth) + + +class TestAuthScope: + """Tests for AuthScope dataclass.""" + + def test_admin_can_access_any_community(self): + scope = AuthScope(role="admin") + assert scope.can_access_community("hed") is True + assert scope.can_access_community("eeglab") is True + assert scope.can_access_community("anything") is True + + def test_community_scope_can_access_own(self): + scope = AuthScope(role="community", community_id="hed") + assert scope.can_access_community("hed") is True + + def test_community_scope_cannot_access_other(self): + scope = AuthScope(role="community", community_id="hed") + assert scope.can_access_community("eeglab") is False + + def test_community_role_requires_community_id(self): + with pytest.raises(ValueError, match="community role requires a community_id"): + AuthScope(role="community") + + def test_admin_role_rejects_community_id(self): + with pytest.raises(ValueError, match="admin role must not have a community_id"): + AuthScope(role="admin", community_id="hed") + + +class TestVerifyScopedAdminKey: + """Tests for verify_scoped_admin_key dependency.""" + + def test_global_admin_key_returns_admin_role(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "global-admin-key"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "admin" + assert data["community_id"] is None + + def test_community_key_returns_community_role(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "hed-key-1"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "community" + assert data["community_id"] == "hed" + + def test_second_community_key_works(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "hed-key-2"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "community" + assert data["community_id"] == "hed" + + def test_eeglab_community_key(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "eeglab-key-1"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "community" + assert data["community_id"] == "eeglab" + + def test_no_key_returns_401(self, client: TestClient): + resp = client.get("/scoped") + assert resp.status_code == 401 + + def test_invalid_key_returns_403(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "wrong-key"}) + assert resp.status_code == 403 + + +class TestScopedCommunityAccess: + """Tests for per-community access control via scoped auth.""" + + def test_admin_can_access_any_community_metrics(self, client: TestClient): + resp = client.get("/metrics/hed", headers={"X-API-Key": "global-admin-key"}) + assert resp.status_code == 200 + assert resp.json()["community_id"] == "hed" + + resp = client.get("/metrics/eeglab", headers={"X-API-Key": "global-admin-key"}) + assert resp.status_code == 200 + assert resp.json()["community_id"] == "eeglab" + + def test_community_key_can_access_own_metrics(self, client: TestClient): + resp = client.get("/metrics/hed", headers={"X-API-Key": "hed-key-1"}) + assert resp.status_code == 200 + assert resp.json()["community_id"] == "hed" + + def test_community_key_cannot_access_other_metrics(self, client: TestClient): + resp = client.get("/metrics/eeglab", headers={"X-API-Key": "hed-key-1"}) + assert resp.status_code == 403 + + +class TestScopedAuthDisabled: + """Tests for scoped auth when authentication is disabled.""" + + def test_no_auth_required_returns_admin_scope(self): + os.environ["REQUIRE_API_AUTH"] = "false" + os.environ.pop("API_KEYS", None) + os.environ.pop("COMMUNITY_ADMIN_KEYS", None) + get_settings.cache_clear() + + app = FastAPI() + + @app.get("/scoped") + async def scoped_route(auth: RequireScopedAuth) -> dict: + return {"role": auth.role, "community_id": auth.community_id} + + client = TestClient(app) + resp = client.get("/scoped") + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "admin" + assert data["community_id"] is None + + os.environ.pop("REQUIRE_API_AUTH", None) + get_settings.cache_clear() + + +class TestParseAdminKeys: + """Tests for Settings.parse_community_admin_keys().""" + + def test_parse_single_key(self): + s = Settings(community_admin_keys="hed:abc123") + result = s.parse_community_admin_keys() + assert result == {"hed": {"abc123"}} + + def test_parse_multiple_communities(self): + s = Settings(community_admin_keys="hed:key1,eeglab:key2") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1"}, "eeglab": {"key2"}} + + def test_parse_multiple_keys_same_community(self): + s = Settings(community_admin_keys="hed:key1,hed:key2") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1", "key2"}} + + def test_parse_empty(self): + s = Settings(community_admin_keys=None) + result = s.parse_community_admin_keys() + assert result == {} + + def test_parse_with_spaces(self): + s = Settings(community_admin_keys=" hed : key1 , eeglab : key2 ") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1"}, "eeglab": {"key2"}} + + def test_parse_ignores_malformed_entries(self): + s = Settings(community_admin_keys="hed:key1,badentry,eeglab:key2") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1"}, "eeglab": {"key2"}} diff --git a/tests/test_core/test_config/test_community.py b/tests/test_core/test_config/test_community.py index 4659982..27dcbdb 100644 --- a/tests/test_core/test_config/test_community.py +++ b/tests/test_core/test_config/test_community.py @@ -13,6 +13,7 @@ from pydantic import ValidationError from src.core.config.community import ( + BudgetConfig, CitationConfig, CommunitiesConfig, CommunityConfig, @@ -320,6 +321,116 @@ def test_allows_unique_extensions(self) -> None: assert len(config.mcp_servers) == 2 +class TestBudgetConfig: + """Tests for BudgetConfig model.""" + + def test_valid_budget_config(self) -> None: + """Should create BudgetConfig with valid inputs.""" + config = BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert config.daily_limit_usd == 5.0 + assert config.monthly_limit_usd == 50.0 + assert config.alert_threshold_pct == 80.0 + + def test_default_alert_threshold(self) -> None: + """Should default alert_threshold_pct to 80.0.""" + config = BudgetConfig(daily_limit_usd=5.0, monthly_limit_usd=50.0) + assert config.alert_threshold_pct == 80.0 + + def test_rejects_zero_daily_limit(self) -> None: + """Should reject zero daily limit.""" + with pytest.raises(ValidationError): + BudgetConfig(daily_limit_usd=0.0, monthly_limit_usd=50.0) + + def test_rejects_negative_daily_limit(self) -> None: + """Should reject negative daily limit.""" + with pytest.raises(ValidationError): + BudgetConfig(daily_limit_usd=-1.0, monthly_limit_usd=50.0) + + def test_rejects_zero_monthly_limit(self) -> None: + """Should reject zero monthly limit.""" + with pytest.raises(ValidationError): + BudgetConfig(daily_limit_usd=5.0, monthly_limit_usd=0.0) + + def test_rejects_negative_threshold(self) -> None: + """Should reject negative alert threshold.""" + with pytest.raises(ValidationError): + BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=-1.0, + ) + + def test_rejects_threshold_over_100(self) -> None: + """Should reject alert threshold over 100.""" + with pytest.raises(ValidationError): + BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=101.0, + ) + + def test_rejects_extra_fields(self) -> None: + """Should reject extra fields (strict schema).""" + with pytest.raises(ValidationError): + BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + unknown_field="value", # type: ignore + ) + + def test_rejects_daily_exceeding_monthly(self) -> None: + """Should reject daily limit greater than monthly limit.""" + with pytest.raises(ValidationError, match="cannot exceed"): + BudgetConfig(daily_limit_usd=100.0, monthly_limit_usd=50.0) + + def test_accepts_equal_daily_and_monthly(self) -> None: + """Should accept daily limit equal to monthly limit.""" + config = BudgetConfig(daily_limit_usd=50.0, monthly_limit_usd=50.0) + assert config.daily_limit_usd == config.monthly_limit_usd + + +class TestCommunityConfigBudget: + """Tests for CommunityConfig.budget field.""" + + def test_budget_none_by_default(self) -> None: + """Should default budget to None.""" + config = CommunityConfig(id="test", name="Test", description="Test") + assert config.budget is None + + def test_budget_config_set(self) -> None: + """Should accept valid budget config.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + budget=BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + ), + ) + assert config.budget is not None + assert config.budget.daily_limit_usd == 5.0 + + def test_budget_from_yaml_dict(self) -> None: + """Should parse budget from YAML-like dict.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + budget={ + "daily_limit_usd": 10.0, + "monthly_limit_usd": 100.0, + "alert_threshold_pct": 90.0, + }, + ) + assert config.budget.daily_limit_usd == 10.0 + assert config.budget.alert_threshold_pct == 90.0 + + class TestCommunityConfig: """Tests for CommunityConfig model.""" @@ -584,6 +695,85 @@ def test_accepts_numeric_only_domain(self) -> None: assert len(config.cors_origins) == 1 +class TestCommunityConfigMaintainers: + """Tests for CommunityConfig.maintainers validation.""" + + def test_valid_maintainers(self) -> None: + """Should accept valid GitHub usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["octocat", "jane-doe", "user123"], + ) + assert config.maintainers == ["octocat", "jane-doe", "user123"] + + def test_defaults_to_empty(self) -> None: + """Should default to empty list.""" + config = CommunityConfig(id="test", name="Test", description="Test") + assert config.maintainers == [] + + def test_rejects_invalid_username_with_special_chars(self) -> None: + """Should reject usernames with special characters.""" + with pytest.raises(ValidationError, match="Invalid GitHub username"): + CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["bad@user"], + ) + + def test_rejects_username_starting_with_hyphen(self) -> None: + """Should reject usernames starting with hyphen.""" + with pytest.raises(ValidationError, match="Invalid GitHub username"): + CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["-badstart"], + ) + + def test_rejects_username_ending_with_hyphen(self) -> None: + """Should reject usernames ending with hyphen.""" + with pytest.raises(ValidationError, match="Invalid GitHub username"): + CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["badend-"], + ) + + def test_deduplicates_maintainers(self) -> None: + """Should remove duplicate usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["octocat", "octocat", "jane"], + ) + assert config.maintainers == ["octocat", "jane"] + + def test_strips_whitespace(self) -> None: + """Should strip whitespace from usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=[" octocat ", " jane "], + ) + assert config.maintainers == ["octocat", "jane"] + + def test_single_char_username(self) -> None: + """Should accept single-character usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["a"], + ) + assert config.maintainers == ["a"] + + class TestCommunitiesConfig: """Tests for CommunitiesConfig model.""" diff --git a/tests/test_metrics/test_alerts.py b/tests/test_metrics/test_alerts.py new file mode 100644 index 0000000..464b560 --- /dev/null +++ b/tests/test_metrics/test_alerts.py @@ -0,0 +1,178 @@ +"""Tests for budget alert issue creation.""" + +from unittest.mock import patch + +from src.metrics.alerts import create_budget_alert_issue +from src.metrics.budget import BudgetStatus + + +def _make_budget_status( + community_id: str = "hed", + daily_spend: float = 4.5, + monthly_spend: float = 45.0, + daily_limit: float = 5.0, + monthly_limit: float = 50.0, + alert_pct: float = 80.0, +) -> BudgetStatus: + return BudgetStatus( + community_id=community_id, + daily_spend_usd=daily_spend, + monthly_spend_usd=monthly_spend, + daily_limit_usd=daily_limit, + monthly_limit_usd=monthly_limit, + alert_threshold_pct=alert_pct, + ) + + +class TestCreateBudgetAlertIssue: + """Tests for create_budget_alert_issue().""" + + def test_no_alert_when_no_threshold_crossed(self): + """Returns None when no alert condition is met.""" + status = _make_budget_status(daily_spend=1.0, monthly_spend=10.0) + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result is None + + @patch("src.metrics.alerts._issue_exists", return_value=True) + @patch("src.metrics.alerts.subprocess") + def test_skips_duplicate_issue(self, mock_subprocess, _mock_exists): + """Returns None when issue already exists.""" + status = _make_budget_status() + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result is None + mock_subprocess.run.assert_not_called() + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_creates_issue_for_daily_alert(self, mock_subprocess, _mock_exists): + """Creates issue when daily alert threshold crossed.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/1\n", "stderr": ""}, + )() + + status = _make_budget_status(daily_spend=4.5, monthly_spend=10.0) + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result == "https://github.com/test/issues/1" + + # Verify subprocess.run was called with gh issue create + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + assert cmd[0] == "gh" + assert cmd[1] == "issue" + assert cmd[2] == "create" + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_creates_issue_for_monthly_exceeded(self, mock_subprocess, _mock_exists): + """Creates issue when monthly limit exceeded.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/2\n", "stderr": ""}, + )() + + status = _make_budget_status(daily_spend=1.0, monthly_spend=50.0) + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result == "https://github.com/test/issues/2" + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_issue_body_contains_maintainer_mentions(self, mock_subprocess, _mock_exists): + """Issue body includes @mentions for maintainers.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/3\n", "stderr": ""}, + )() + + status = _make_budget_status() + create_budget_alert_issue(status, maintainers=["VisLab", "yarikoptic"]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + # Find the --body argument + body_idx = cmd.index("--body") + 1 + body = cmd[body_idx] + assert "@VisLab" in body + assert "@yarikoptic" in body + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_issue_title_format(self, mock_subprocess, _mock_exists): + """Issue title follows expected format.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/4\n", "stderr": ""}, + )() + + status = _make_budget_status(daily_spend=5.0) + create_budget_alert_issue(status, maintainers=[]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + title_idx = cmd.index("--title") + 1 + title = cmd[title_idx] + assert title.startswith("[Budget Alert] hed:") + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_issue_has_labels(self, mock_subprocess, _mock_exists): + """Issue is created with cost-management and operations labels.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/5\n", "stderr": ""}, + )() + + status = _make_budget_status() + create_budget_alert_issue(status, maintainers=[]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + label_idx = cmd.index("--label") + 1 + labels = cmd[label_idx] + assert "cost-management" in labels + assert "operations" in labels + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_returns_none_on_gh_failure(self, mock_subprocess, _mock_exists): + """Returns None when gh CLI fails.""" + mock_subprocess.run.return_value = type( + "Result", (), {"returncode": 1, "stdout": "", "stderr": "auth required"} + )() + + status = _make_budget_status() + result = create_budget_alert_issue(status, maintainers=[]) + assert result is None + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_no_maintainers_message(self, mock_subprocess, _mock_exists): + """Body shows 'No maintainers configured' when list is empty.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/6\n", "stderr": ""}, + )() + + status = _make_budget_status() + create_budget_alert_issue(status, maintainers=[]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + body_idx = cmd.index("--body") + 1 + body = cmd[body_idx] + assert "No maintainers configured" in body + + @patch("src.metrics.alerts._issue_exists", return_value=None) + @patch("src.metrics.alerts.subprocess") + def test_suppresses_alert_when_dedup_check_fails(self, mock_subprocess, _mock_exists): + """Returns None and does not create issue when dedup check fails.""" + status = _make_budget_status() + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result is None + mock_subprocess.run.assert_not_called() diff --git a/tests/test_metrics/test_budget.py b/tests/test_metrics/test_budget.py new file mode 100644 index 0000000..dad462c --- /dev/null +++ b/tests/test_metrics/test_budget.py @@ -0,0 +1,243 @@ +"""Tests for budget checking.""" + +import pytest + +from src.metrics.budget import BudgetStatus, check_budget +from src.metrics.db import RequestLogEntry, get_metrics_connection, init_metrics_db, log_request + + +@pytest.fixture +def budget_db(tmp_path): + """Create a metrics DB with cost data for budget testing.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="b1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + estimated_cost=0.50, + ), + RequestLogEntry( + request_id="b2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + estimated_cost=0.30, + ), + RequestLogEntry( + request_id="b3", + timestamp="2025-01-14T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=250.0, + status_code=200, + estimated_cost=2.00, + ), + RequestLogEntry( + request_id="b4", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/eeglab/ask", + method="POST", + community_id="eeglab", + duration_ms=150.0, + status_code=200, + estimated_cost=0.10, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +class TestBudgetStatus: + """Tests for BudgetStatus dataclass.""" + + def test_daily_pct(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_pct == 80.0 + + def test_monthly_pct(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=40.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.monthly_pct == 80.0 + + def test_zero_limit_returns_zero_pct(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=1.0, + monthly_spend_usd=1.0, + daily_limit_usd=0.0, + monthly_limit_usd=0.0, + alert_threshold_pct=80.0, + ) + assert status.daily_pct == 0.0 + assert status.monthly_pct == 0.0 + + def test_daily_exceeded(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=5.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_exceeded is True + + def test_daily_not_exceeded(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.99, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_exceeded is False + + def test_monthly_exceeded(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=50.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.monthly_exceeded is True + + def test_daily_alert_at_threshold(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_alert is True + + def test_daily_alert_below_threshold(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=3.99, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_alert is False + + def test_needs_alert_daily(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.needs_alert is True + + def test_needs_alert_monthly(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=40.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.needs_alert is True + + def test_no_alert_needed(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=1.0, + monthly_spend_usd=10.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.needs_alert is False + + +class TestCheckBudget: + """Tests for check_budget() function. + + Note: check_budget queries today's date with date('now'), so the test + data from 2025-01-15 won't appear as "today". We test that the query + runs without error and returns the expected structure. + """ + + def test_returns_budget_status(self, budget_db): + conn = get_metrics_connection(budget_db) + try: + status = check_budget( + community_id="hed", + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + conn=conn, + ) + assert isinstance(status, BudgetStatus) + assert status.community_id == "hed" + assert status.daily_limit_usd == 5.0 + assert status.monthly_limit_usd == 50.0 + assert status.alert_threshold_pct == 80.0 + finally: + conn.close() + + def test_empty_community_zero_spend(self, budget_db): + conn = get_metrics_connection(budget_db) + try: + status = check_budget( + community_id="nonexistent", + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + conn=conn, + ) + assert status.daily_spend_usd == 0.0 + assert status.monthly_spend_usd == 0.0 + finally: + conn.close() + + def test_spend_values_are_non_negative(self, budget_db): + conn = get_metrics_connection(budget_db) + try: + status = check_budget( + community_id="hed", + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + conn=conn, + ) + assert status.daily_spend_usd >= 0.0 + assert status.monthly_spend_usd >= 0.0 + finally: + conn.close() diff --git a/tests/test_metrics/test_cost.py b/tests/test_metrics/test_cost.py new file mode 100644 index 0000000..6286fd0 --- /dev/null +++ b/tests/test_metrics/test_cost.py @@ -0,0 +1,64 @@ +"""Tests for cost estimation.""" + +from src.metrics.cost import MODEL_PRICING, estimate_cost + + +class TestEstimateCost: + """Tests for estimate_cost().""" + + def test_known_model(self): + """Cost for a known model uses its pricing.""" + cost = estimate_cost("openai/gpt-4o", input_tokens=1000, output_tokens=500) + # input: 1000 * 2.50 / 1M = 0.0025, output: 500 * 10.00 / 1M = 0.005 + assert cost == 0.0075 + + def test_unknown_model_uses_fallback(self): + """Unknown model uses fallback rates.""" + cost = estimate_cost("unknown/model", input_tokens=1_000_000, output_tokens=1_000_000) + # fallback: 1.00 input + 3.00 output = 4.00 + assert cost == 4.0 + + def test_none_model_uses_fallback(self): + """None model uses fallback rates.""" + cost = estimate_cost(None, input_tokens=1_000_000, output_tokens=0) + assert cost == 1.0 + + def test_zero_tokens(self): + """Zero tokens should return zero cost.""" + cost = estimate_cost("openai/gpt-4o", input_tokens=0, output_tokens=0) + assert cost == 0.0 + + def test_rounding(self): + """Cost should be rounded to 6 decimal places.""" + cost = estimate_cost("openai/gpt-4o-mini", input_tokens=1, output_tokens=1) + # input: 1 * 0.15 / 1M = 0.00000015, output: 1 * 0.60 / 1M = 0.0000006 + # total: 0.00000075 -> rounds to 0.000001 + assert cost == 0.000001 + + def test_all_models_have_pricing(self): + """All models in the pricing table should have valid (input, output) tuples.""" + for model, (input_rate, output_rate) in MODEL_PRICING.items(): + assert isinstance(input_rate, (int, float)), f"{model} has invalid input rate" + assert isinstance(output_rate, (int, float)), f"{model} has invalid output rate" + assert input_rate >= 0, f"{model} has negative input rate" + assert output_rate >= 0, f"{model} has negative output rate" + + def test_qwen_model_cost(self): + """Verify cost for a Qwen model.""" + cost = estimate_cost( + "qwen/qwen3-235b-a22b-2507", + input_tokens=1_000_000, + output_tokens=1_000_000, + ) + # input: 0.14, output: 0.34, total: 0.48 + assert cost == 0.48 + + def test_expensive_model(self): + """Verify cost for an expensive model (Claude Opus 4).""" + cost = estimate_cost( + "anthropic/claude-opus-4", + input_tokens=1_000_000, + output_tokens=1_000_000, + ) + # input: 15.00, output: 75.00, total: 90.00 + assert cost == 90.0 diff --git a/tests/test_metrics/test_quality_queries.py b/tests/test_metrics/test_quality_queries.py new file mode 100644 index 0000000..29a0c2e --- /dev/null +++ b/tests/test_metrics/test_quality_queries.py @@ -0,0 +1,262 @@ +"""Tests for quality metrics queries.""" + +import pytest + +from src.metrics.db import RequestLogEntry, get_metrics_connection, init_metrics_db, log_request + + +@pytest.fixture +def quality_db(tmp_path): + """Create a metrics DB with quality-related data.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="q1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + tool_call_count=2, + langfuse_trace_id="trace-abc", + ), + RequestLogEntry( + request_id="q2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=500.0, + status_code=200, + model="qwen/qwen3-235b", + tool_call_count=1, + langfuse_trace_id="trace-def", + ), + RequestLogEntry( + request_id="q3", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=1500.0, + status_code=500, + model="qwen/qwen3-235b", + tool_call_count=0, + error_message="LLM timeout", + ), + RequestLogEntry( + request_id="q4", + timestamp="2025-01-16T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + tool_call_count=3, + langfuse_trace_id="trace-ghi", + ), + RequestLogEntry( + request_id="q5", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/eeglab/ask", + method="POST", + community_id="eeglab", + duration_ms=250.0, + status_code=200, + tool_call_count=1, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +class TestGetQualityMetrics: + """Tests for get_quality_metrics().""" + + def test_daily_buckets(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + assert result["community_id"] == "hed" + assert result["period"] == "daily" + buckets = {b["bucket"]: b for b in result["buckets"]} + assert "2025-01-15" in buckets + assert "2025-01-16" in buckets + finally: + conn.close() + + def test_error_rate_per_bucket(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + # Jan 15: 1 error out of 3 requests + assert abs(buckets["2025-01-15"]["error_rate"] - 0.3333) < 0.01 + # Jan 16: 0 errors + assert buckets["2025-01-16"]["error_rate"] == 0.0 + finally: + conn.close() + + def test_avg_tool_calls(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + # Jan 15: (2+1+0)/3 = 1.0 + assert buckets["2025-01-15"]["avg_tool_calls"] == 1.0 + # Jan 16: 3/1 = 3.0 + assert buckets["2025-01-16"]["avg_tool_calls"] == 3.0 + finally: + conn.close() + + def test_agent_errors(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + assert buckets["2025-01-15"]["agent_errors"] == 1 # "LLM timeout" + assert buckets["2025-01-16"]["agent_errors"] == 0 + finally: + conn.close() + + def test_traced_requests(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + assert buckets["2025-01-15"]["traced_requests"] == 2 + assert buckets["2025-01-16"]["traced_requests"] == 1 + finally: + conn.close() + + def test_latency_percentiles(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + # Jan 15: durations [200, 500, 1500] sorted + # p50 = values[1] = 500, p95 = values[2] = 1500 + assert buckets["2025-01-15"]["p50_duration_ms"] == 500.0 + assert buckets["2025-01-15"]["p95_duration_ms"] == 1500.0 + finally: + conn.close() + + def test_empty_community(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("nonexistent", conn, "daily") + assert result["buckets"] == [] + finally: + conn.close() + + +class TestGetQualitySummary: + """Tests for get_quality_summary().""" + + def test_summary_totals(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + assert result["community_id"] == "hed" + assert result["total_requests"] == 4 + assert abs(result["error_rate"] - 0.25) < 0.01 + finally: + conn.close() + + def test_summary_avg_tool_calls(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + # (2+1+0+3)/4 = 1.5 + assert result["avg_tool_calls"] == 1.5 + finally: + conn.close() + + def test_summary_traced_pct(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + # 3 traced out of 4 + assert abs(result["traced_pct"] - 0.75) < 0.01 + finally: + conn.close() + + def test_summary_latency(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + # durations sorted: [200, 300, 500, 1500] + # nearest-rank p50: idx=int(0.5*4)=2 -> 500 + # nearest-rank p95: idx=int(0.95*4)=3 -> 1500 + assert result["p50_duration_ms"] == 500.0 + assert result["p95_duration_ms"] == 1500.0 + finally: + conn.close() + + def test_empty_community(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("nonexistent", conn) + assert result["total_requests"] == 0 + assert result["error_rate"] == 0.0 + assert result["p50_duration_ms"] is None + finally: + conn.close() + + +class TestMigrationColumns: + """Tests for backward-compatible column migration.""" + + def test_new_columns_exist_after_init(self, quality_db): + """New quality columns should exist in freshly initialized DB.""" + conn = get_metrics_connection(quality_db) + try: + # Query using the new columns + row = conn.execute( + "SELECT tool_call_count, error_message, langfuse_trace_id FROM request_log LIMIT 1" + ).fetchone() + assert row is not None + finally: + conn.close() + + def test_migration_idempotent(self, quality_db): + """Running init_metrics_db twice should not fail.""" + # Second init should be a no-op for migrations + init_metrics_db(quality_db) + conn = get_metrics_connection(quality_db) + try: + count = conn.execute("SELECT COUNT(*) FROM request_log").fetchone()[0] + assert count == 5 # Original data preserved + finally: + conn.close() From e69f38f458783883d8b4adf70a5461355d08b064 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Wed, 4 Feb 2026 03:40:08 -0800 Subject: [PATCH 13/26] Address PR #148 review findings Code fixes: - Handle HTTPException in streaming generators as SSE error events (cannot re-raise after response headers sent) - Extract _match_wildcard_origin helper, AgentResult dataclass, _extract_agent_result/_set_metrics_on_request to deduplicate ask/chat endpoints - Use metrics_connection() context manager in metrics router - Add failure counting with escalation logging to log_request() - Refactor check_budget() to accept BudgetConfig instead of 3 floats - Add __post_init__ validation to BudgetStatus for non-negative spend - Simplify list_sessions to reuse _evict_expired_sessions - Move inline imports (re, os) to top-level - Simplify _get_communities_with_sync to list comprehension Docstring fixes: - Clarify verify_api_key handles only global admin keys - Update scheduler module docstring to mention budget checks New tests: - check_budget with today's timestamps (exercises date('now') SQL) - Budget alert trigger with current-day spend - BudgetStatus rejects negative spend values - _percentile edge cases (single element, two elements, empty list) - _count_tools with malformed JSON - _extract_community_id documents intentional None for metrics paths --- src/api/routers/community.py | 243 ++++++++++++--------- src/api/routers/metrics.py | 44 ++-- src/api/scheduler.py | 19 +- src/api/security.py | 6 +- src/metrics/budget.py | 24 +- src/metrics/db.py | 17 +- tests/test_metrics/test_budget.py | 134 +++++++++--- tests/test_metrics/test_middleware.py | 13 ++ tests/test_metrics/test_quality_queries.py | 83 +++++++ 9 files changed, 409 insertions(+), 174 deletions(-) diff --git a/src/api/routers/community.py b/src/api/routers/community.py index 75ddc94..eeeb8a3 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -7,6 +7,8 @@ import hashlib import json import logging +import os +import re import sqlite3 import time import uuid @@ -332,16 +334,9 @@ def delete_session(community_id: str, session_id: str) -> bool: def list_sessions(community_id: str) -> list[ChatSession]: """List all active (non-expired) sessions for a community.""" + _evict_expired_sessions(community_id) store = _get_session_store(community_id) - # Filter out expired sessions - active_sessions = [s for s in store.values() if not s.is_expired()] - # Clean up expired sessions from store - expired = [sid for sid, s in store.items() if s.is_expired()] - for sid in expired: - del store[sid] - if expired: - logger.info("Cleaned %d expired sessions from %s", len(expired), community_id) - return active_sessions + return list(store.values()) # --------------------------------------------------------------------------- @@ -349,6 +344,16 @@ def list_sessions(community_id: str) -> list[ChatSession]: # --------------------------------------------------------------------------- +def _match_wildcard_origin(pattern: str, origin: str) -> bool: + """Check if an origin matches a wildcard pattern like 'https://*.example.com'. + + Converts '*' to a subdomain-safe regex and uses fullmatch. + """ + escaped = re.escape(pattern) + regex = escaped.replace(r"\*", r"[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?") + return bool(re.fullmatch(regex, origin)) + + def _is_authorized_origin(origin: str | None, community_id: str) -> bool: """Check if Origin header matches allowed CORS origins. @@ -370,8 +375,6 @@ def _is_authorized_origin(origin: str | None, community_id: str) -> bool: if not origin: return False - import re - # Platform default origins - always allowed for all communities platform_exact_origins = [ "https://osa-demo.pages.dev", @@ -386,9 +389,7 @@ def _is_authorized_origin(origin: str | None, community_id: str) -> bool: # Check platform wildcard patterns for allowed in platform_wildcard_origins: - escaped = re.escape(allowed) - pattern = escaped.replace(r"\*", r"[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?") - if re.fullmatch(pattern, origin): + if _match_wildcard_origin(allowed, origin): return True # Check community-specific origins @@ -400,19 +401,12 @@ def _is_authorized_origin(origin: str | None, community_id: str) -> bool: if not cors_origins: return False - # Check exact matches first - for allowed in cors_origins: - if "*" not in allowed and origin == allowed: - return True - - # Check wildcard patterns for allowed in cors_origins: - if "*" in allowed: - # Convert wildcard pattern to regex (same logic as main.py) - escaped = re.escape(allowed) - pattern = escaped.replace(r"\*", r"[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?") - if re.fullmatch(pattern, origin): + if "*" not in allowed: + if origin == allowed: return True + elif _match_wildcard_origin(allowed, origin): + return True return False @@ -446,8 +440,6 @@ def _select_api_key( HTTPException(403): If origin is not authorized and BYOK is not provided HTTPException(500): If no platform API key is configured and no other key is available """ - import os - # Case 1: BYOK provided - always allowed if byok: logger.debug( @@ -753,6 +745,69 @@ def create_community_assistant( ) +@dataclass +class AgentResult: + """Extracted response content and metrics from an agent invocation.""" + + response_content: str + tool_calls_info: list[ToolCallInfo] + tools_called: list[str] + input_tokens: int + output_tokens: int + total_tokens: int + + +def _extract_agent_result(result: dict) -> AgentResult: + """Extract response content, tool calls, and token usage from agent result. + + Consolidates the common post-invocation logic shared by ask and chat endpoints. + """ + response_content = "" + if result.get("messages"): + last_msg = result["messages"][-1] + if isinstance(last_msg, AIMessage): + content = last_msg.content + response_content = content if isinstance(content, str) else str(content) + + tools_called = extract_tool_names(result) + tool_calls_info = [ + ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) + for tc in result.get("tool_calls", []) + ] + + inp, out, total = extract_token_usage(result) + return AgentResult( + response_content=response_content, + tool_calls_info=tool_calls_info, + tools_called=tools_called, + input_tokens=inp, + output_tokens=out, + total_tokens=total, + ) + + +def _set_metrics_on_request( + http_request: Request, + awm: AssistantWithMetrics, + agent_result: AgentResult, +) -> None: + """Store agent metrics on request.state for the metrics middleware to log.""" + http_request.state.metrics_agent_data = { + "model": awm.model, + "key_source": awm.key_source, + "input_tokens": agent_result.input_tokens, + "output_tokens": agent_result.output_tokens, + "total_tokens": agent_result.total_tokens, + "estimated_cost": estimate_cost( + awm.model, agent_result.input_tokens, agent_result.output_tokens + ), + "tools_called": agent_result.tools_called, + "tool_call_count": len(agent_result.tools_called), + "langfuse_trace_id": awm.langfuse_trace_id, + "stream": False, + } + + # --------------------------------------------------------------------------- # Router Factory # --------------------------------------------------------------------------- @@ -845,40 +900,13 @@ async def ask( requested_model=body.model, page_context=body.page_context, ) - assistant = awm.assistant messages = [HumanMessage(content=body.question)] - result = await assistant.ainvoke(messages, config=awm.langfuse_config) - - response_content = "" - if result.get("messages"): - last_msg = result["messages"][-1] - if isinstance(last_msg, AIMessage): - # Handle both string and list content (multimodal responses) - content = last_msg.content - response_content = content if isinstance(content, str) else str(content) - - tools_called = extract_tool_names(result) - tool_calls_info = [ - ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) - for tc in result.get("tool_calls", []) - ] - - # Store agent metrics on request.state for middleware to log - inp, out, total = extract_token_usage(result) - http_request.state.metrics_agent_data = { - "model": awm.model, - "key_source": awm.key_source, - "input_tokens": inp, - "output_tokens": out, - "total_tokens": total, - "estimated_cost": estimate_cost(awm.model, inp, out), - "tools_called": tools_called, - "tool_call_count": len(tools_called), - "langfuse_trace_id": awm.langfuse_trace_id, - "stream": False, - } - - return AskResponse(answer=response_content, tool_calls=tool_calls_info) + result = await awm.assistant.ainvoke(messages, config=awm.langfuse_config) + + ar = _extract_agent_result(result) + _set_metrics_on_request(http_request, awm, ar) + + return AskResponse(answer=ar.response_content, tool_calls=ar.tool_calls_info) except Exception as e: logger.error( @@ -962,41 +990,14 @@ async def chat( user_id=user_id, requested_model=body.model, ) - assistant = awm.assistant - result = await assistant.ainvoke(session.messages, config=awm.langfuse_config) - - response_content = "" - if result.get("messages"): - last_msg = result["messages"][-1] - if isinstance(last_msg, AIMessage): - # Handle both string and list content (multimodal responses) - content = last_msg.content - response_content = content if isinstance(content, str) else str(content) - - tools_called = extract_tool_names(result) - tool_calls_info = [ - ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) - for tc in result.get("tool_calls", []) - ] - - # Store agent metrics on request.state for middleware to log - inp, out, total = extract_token_usage(result) - http_request.state.metrics_agent_data = { - "model": awm.model, - "key_source": awm.key_source, - "input_tokens": inp, - "output_tokens": out, - "total_tokens": total, - "estimated_cost": estimate_cost(awm.model, inp, out), - "tools_called": tools_called, - "tool_call_count": len(tools_called), - "langfuse_trace_id": awm.langfuse_trace_id, - "stream": False, - } + result = await awm.assistant.ainvoke(session.messages, config=awm.langfuse_config) + + ar = _extract_agent_result(result) + _set_metrics_on_request(http_request, awm, ar) # Add assistant message with constraint validation try: - session.add_assistant_message(response_content) + session.add_assistant_message(ar.response_content) except ValueError as e: logger.error("Session limit exceeded: %s", e) raise HTTPException( @@ -1006,8 +1007,8 @@ async def chat( return ChatResponse( session_id=session.session_id, - message=ChatMessage(role="assistant", content=response_content), - tool_calls=tool_calls_info, + message=ChatMessage(role="assistant", content=ar.response_content), + tool_calls=ar.tool_calls_info, ) except ValueError as e: @@ -1364,9 +1365,26 @@ async def _stream_ask_response( status_code=200, ) - except HTTPException: - # Don't catch our own HTTP exceptions - let them propagate - raise + except HTTPException as e: + # HTTPException in streaming context (e.g., auth failure, rate limit). + # Cannot re-raise because response headers are already sent as 200. + logger.warning( + "HTTP error in ask streaming for community %s: %d %s", + community_id, + e.status_code, + e.detail, + ) + sse_event = {"event": "error", "message": str(e.detail)} + yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=e.status_code, + ) except ValueError as e: # Input validation errors - user's fault logger.warning("Invalid input in streaming for community %s: %s", community_id, e) @@ -1515,6 +1533,27 @@ async def _stream_chat_response( status_code=200, ) + except HTTPException as e: + # HTTPException in streaming context (e.g., auth failure, rate limit). + # Cannot re-raise because response headers are already sent as 200. + logger.warning( + "HTTP error in chat streaming for session %s (community: %s): %d %s", + session.session_id, + community_id, + e.status_code, + e.detail, + ) + sse_event = {"event": "error", "message": str(e.detail)} + yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=e.status_code, + ) except ValueError as e: # Session limit errors logger.error("Session limit error: %s", e) @@ -1530,16 +1569,24 @@ async def _stream_chat_response( status_code=400, ) except Exception as e: + error_id = str(uuid.uuid4()) logger.error( - "Streaming error in chat endpoint for session %s (community: %s): %s", + "Unexpected streaming error (ID: %s) in chat endpoint for session %s (community: %s): %s", + error_id, session.session_id, community_id, e, exc_info=True, + extra={ + "error_id": error_id, + "community_id": community_id, + "error_type": type(e).__name__, + }, ) sse_event = { "event": "error", "message": "An error occurred while processing your request.", + "error_id": error_id, } yield f"data: {json.dumps(sse_event)}\n\n" _log_streaming_metrics( diff --git a/src/api/routers/metrics.py b/src/api/routers/metrics.py index 6dc4f4a..467c28c 100644 --- a/src/api/routers/metrics.py +++ b/src/api/routers/metrics.py @@ -11,7 +11,7 @@ from fastapi import APIRouter, HTTPException, Query from src.api.security import RequireScopedAuth -from src.metrics.db import get_metrics_connection +from src.metrics.db import metrics_connection from src.metrics.queries import ( get_community_summary, get_overview, @@ -31,20 +31,18 @@ async def metrics_overview(auth: RequireScopedAuth) -> dict[str, Any]: Global admin keys see all communities. Per-community keys see only their community's data wrapped in the same response format. """ - conn = get_metrics_connection() try: - if auth.role == "admin": - return get_overview(conn) - # Community-scoped: return summary for just their community - return get_community_summary(auth.community_id, conn) + with metrics_connection() as conn: + if auth.role == "admin": + return get_overview(conn) + # Community-scoped: return summary for just their community + return get_community_summary(auth.community_id, conn) except sqlite3.Error: logger.exception("Failed to query metrics database for overview") raise HTTPException( status_code=503, detail="Metrics database is temporarily unavailable.", ) - finally: - conn.close() @router.get("/tokens") @@ -62,17 +60,15 @@ async def token_breakdown( if auth.role == "community": effective_community = auth.community_id - conn = get_metrics_connection() try: - return get_token_breakdown(conn, community_id=effective_community) + with metrics_connection() as conn: + return get_token_breakdown(conn, community_id=effective_community) except sqlite3.Error: logger.exception("Failed to query metrics database for token breakdown") raise HTTPException( status_code=503, detail="Metrics database is temporarily unavailable.", ) - finally: - conn.close() @router.get("/quality") @@ -82,23 +78,21 @@ async def quality_overview(auth: RequireScopedAuth) -> dict[str, Any]: Global admin keys see quality for all communities. Per-community keys see quality summary for their community only. """ - conn = get_metrics_connection() try: - if auth.role == "community": - return get_quality_summary(auth.community_id, conn) - # Admin: aggregate quality across all communities - overview = get_overview(conn) - communities_data = overview.get("communities", []) - summaries = [] - for c in communities_data: - cid = c["community_id"] - summaries.append(get_quality_summary(cid, conn)) - return {"communities": summaries} + with metrics_connection() as conn: + if auth.role == "community": + return get_quality_summary(auth.community_id, conn) + # Admin: aggregate quality across all communities + overview = get_overview(conn) + communities_data = overview.get("communities", []) + summaries = [] + for c in communities_data: + cid = c["community_id"] + summaries.append(get_quality_summary(cid, conn)) + return {"communities": summaries} except sqlite3.Error: logger.exception("Failed to query quality metrics") raise HTTPException( status_code=503, detail="Metrics database is temporarily unavailable.", ) - finally: - conn.close() diff --git a/src/api/scheduler.py b/src/api/scheduler.py index 0eb7c47..5c79a1a 100644 --- a/src/api/scheduler.py +++ b/src/api/scheduler.py @@ -1,8 +1,9 @@ -"""Background scheduler for automated knowledge sync. +"""Background scheduler for automated tasks. -Uses APScheduler to run periodic sync jobs for: -- GitHub issues/PRs (daily by default) -- Academic papers (weekly by default) +Uses APScheduler to run periodic jobs for: +- GitHub issues/PRs sync (daily by default) +- Academic papers sync (weekly by default) +- Community budget checks with alert issue creation (every 15 minutes) """ import logging @@ -32,11 +33,7 @@ def _get_communities_with_sync() -> list[str]: Returns: List of community IDs with GitHub repos or citation config. """ - communities = [] - for info in registry.list_all(): - if info.sync_config: - communities.append(info.id) - return communities + return [info.id for info in registry.list_all() if info.sync_config] def _get_community_repos(community_id: str) -> list[str]: @@ -190,9 +187,7 @@ def _check_community_budgets() -> None: try: budget_status = check_budget( community_id=info.id, - daily_limit_usd=budget_cfg.daily_limit_usd, - monthly_limit_usd=budget_cfg.monthly_limit_usd, - alert_threshold_pct=budget_cfg.alert_threshold_pct, + config=budget_cfg, conn=conn, ) communities_checked += 1 diff --git a/src/api/security.py b/src/api/security.py index 5d92036..d69099c 100644 --- a/src/api/security.py +++ b/src/api/security.py @@ -24,7 +24,11 @@ async def verify_api_key( anthropic_key: Annotated[str | None, Security(anthropic_key_header)] = None, openrouter_key: Annotated[str | None, Security(openrouter_key_header)] = None, ) -> str | None: - """Verify the API key if server authentication is enabled. + """Verify the global admin API key if server authentication is enabled. + + Only checks keys from the API_KEYS setting (global admin keys). + Does not handle per-community scoped keys; see verify_scoped_admin_key + for community-level auth. Returns the API key if valid, None if auth is disabled or BYOK is used. Raises HTTPException if auth is enabled but key is invalid (and no BYOK). diff --git a/src/metrics/budget.py b/src/metrics/budget.py index b0d0e9c..1d02868 100644 --- a/src/metrics/budget.py +++ b/src/metrics/budget.py @@ -8,6 +8,8 @@ import sqlite3 from dataclasses import dataclass +from src.core.config.community import BudgetConfig + logger = logging.getLogger(__name__) @@ -22,6 +24,14 @@ class BudgetStatus: monthly_limit_usd: float alert_threshold_pct: float + def __post_init__(self) -> None: + if self.daily_spend_usd < 0: + raise ValueError(f"daily_spend_usd must be non-negative, got {self.daily_spend_usd}") + if self.monthly_spend_usd < 0: + raise ValueError( + f"monthly_spend_usd must be non-negative, got {self.monthly_spend_usd}" + ) + @property def daily_pct(self) -> float: """Daily spend as percentage of limit.""" @@ -64,9 +74,7 @@ def needs_alert(self) -> bool: def check_budget( community_id: str, - daily_limit_usd: float, - monthly_limit_usd: float, - alert_threshold_pct: float, + config: BudgetConfig, conn: sqlite3.Connection, ) -> BudgetStatus: """Check current spend against budget limits. @@ -75,9 +83,7 @@ def check_budget( Args: community_id: The community identifier. - daily_limit_usd: Maximum daily spend. - monthly_limit_usd: Maximum monthly spend. - alert_threshold_pct: Alert threshold percentage. + config: Budget configuration with limits and alert threshold. conn: SQLite connection. Returns: @@ -109,7 +115,7 @@ def check_budget( community_id=community_id, daily_spend_usd=round(daily_row["spend"], 6), monthly_spend_usd=round(monthly_row["spend"], 6), - daily_limit_usd=daily_limit_usd, - monthly_limit_usd=monthly_limit_usd, - alert_threshold_pct=alert_threshold_pct, + daily_limit_usd=config.daily_limit_usd, + monthly_limit_usd=config.monthly_limit_usd, + alert_threshold_pct=config.alert_threshold_pct, ) diff --git a/src/metrics/db.py b/src/metrics/db.py index 900747c..35b6af5 100644 --- a/src/metrics/db.py +++ b/src/metrics/db.py @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) +# Track consecutive log_request failures for escalation +_log_request_failures: int = 0 + METRICS_SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS request_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -171,6 +174,7 @@ def log_request(entry: RequestLogEntry, db_path: Path | None = None) -> None: entry: The log entry to insert. db_path: Optional path override (for testing). """ + global _log_request_failures conn = get_metrics_connection(db_path) try: conn.execute( @@ -205,12 +209,23 @@ def log_request(entry: RequestLogEntry, db_path: Path | None = None) -> None: ) conn.commit() except sqlite3.Error: + _log_request_failures += 1 logger.exception( - "Failed to log metrics request %s (endpoint=%s, community=%s)", + "Failed to log metrics request %s (endpoint=%s, community=%s) [failure #%d]", entry.request_id, entry.endpoint, entry.community_id, + _log_request_failures, ) + if _log_request_failures >= 10: + logger.critical( + "Metrics DB write has failed %d times. " + "Possible disk/database issue requiring investigation.", + _log_request_failures, + ) + else: + # Reset counter on success + _log_request_failures = 0 finally: conn.close() diff --git a/tests/test_metrics/test_budget.py b/tests/test_metrics/test_budget.py index dad462c..a5b82d0 100644 --- a/tests/test_metrics/test_budget.py +++ b/tests/test_metrics/test_budget.py @@ -2,8 +2,27 @@ import pytest +from src.core.config.community import BudgetConfig from src.metrics.budget import BudgetStatus, check_budget -from src.metrics.db import RequestLogEntry, get_metrics_connection, init_metrics_db, log_request +from src.metrics.db import ( + RequestLogEntry, + get_metrics_connection, + init_metrics_db, + log_request, + now_iso, +) + + +def _make_config( + daily: float = 5.0, + monthly: float = 50.0, + alert_pct: float = 80.0, +) -> BudgetConfig: + return BudgetConfig( + daily_limit_usd=daily, + monthly_limit_usd=monthly, + alert_threshold_pct=alert_pct, + ) @pytest.fixture @@ -185,25 +204,37 @@ def test_no_alert_needed(self): ) assert status.needs_alert is False + def test_rejects_negative_daily_spend(self): + with pytest.raises(ValueError, match="daily_spend_usd must be non-negative"): + BudgetStatus( + community_id="test", + daily_spend_usd=-1.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) -class TestCheckBudget: - """Tests for check_budget() function. + def test_rejects_negative_monthly_spend(self): + with pytest.raises(ValueError, match="monthly_spend_usd must be non-negative"): + BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=-1.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) - Note: check_budget queries today's date with date('now'), so the test - data from 2025-01-15 won't appear as "today". We test that the query - runs without error and returns the expected structure. - """ + +class TestCheckBudget: + """Tests for check_budget() function.""" def test_returns_budget_status(self, budget_db): + config = _make_config() conn = get_metrics_connection(budget_db) try: - status = check_budget( - community_id="hed", - daily_limit_usd=5.0, - monthly_limit_usd=50.0, - alert_threshold_pct=80.0, - conn=conn, - ) + status = check_budget(community_id="hed", config=config, conn=conn) assert isinstance(status, BudgetStatus) assert status.community_id == "hed" assert status.daily_limit_usd == 5.0 @@ -213,31 +244,78 @@ def test_returns_budget_status(self, budget_db): conn.close() def test_empty_community_zero_spend(self, budget_db): + config = _make_config() conn = get_metrics_connection(budget_db) try: - status = check_budget( - community_id="nonexistent", - daily_limit_usd=5.0, - monthly_limit_usd=50.0, - alert_threshold_pct=80.0, - conn=conn, - ) + status = check_budget(community_id="nonexistent", config=config, conn=conn) assert status.daily_spend_usd == 0.0 assert status.monthly_spend_usd == 0.0 finally: conn.close() def test_spend_values_are_non_negative(self, budget_db): + config = _make_config() conn = get_metrics_connection(budget_db) try: - status = check_budget( - community_id="hed", - daily_limit_usd=5.0, - monthly_limit_usd=50.0, - alert_threshold_pct=80.0, - conn=conn, - ) + status = check_budget(community_id="hed", config=config, conn=conn) assert status.daily_spend_usd >= 0.0 assert status.monthly_spend_usd >= 0.0 finally: conn.close() + + def test_today_spend_with_current_timestamps(self, tmp_path): + """Verify check_budget sums costs for entries with today's timestamp.""" + db_path = tmp_path / "today.db" + init_metrics_db(db_path) + + # Insert entries with current timestamps + for i, cost in enumerate([0.25, 0.35, 0.40]): + log_request( + RequestLogEntry( + request_id=f"today-{i}", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + community_id="hed", + status_code=200, + estimated_cost=cost, + ), + db_path=db_path, + ) + + config = _make_config() + conn = get_metrics_connection(db_path) + try: + status = check_budget(community_id="hed", config=config, conn=conn) + assert status.daily_spend_usd == pytest.approx(1.0, abs=1e-6) + assert status.monthly_spend_usd == pytest.approx(1.0, abs=1e-6) + finally: + conn.close() + + def test_today_spend_triggers_alert(self, tmp_path): + """Verify budget alert triggers when today's spend crosses threshold.""" + db_path = tmp_path / "alert.db" + init_metrics_db(db_path) + + log_request( + RequestLogEntry( + request_id="expensive", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + community_id="hed", + status_code=200, + estimated_cost=4.5, + ), + db_path=db_path, + ) + + config = _make_config(daily=5.0, monthly=50.0, alert_pct=80.0) + conn = get_metrics_connection(db_path) + try: + status = check_budget(community_id="hed", config=config, conn=conn) + assert status.daily_pct == 90.0 + assert status.daily_alert is True + assert status.needs_alert is True + finally: + conn.close() diff --git a/tests/test_metrics/test_middleware.py b/tests/test_metrics/test_middleware.py index 5ddbc68..e4412e5 100644 --- a/tests/test_metrics/test_middleware.py +++ b/tests/test_metrics/test_middleware.py @@ -26,6 +26,19 @@ def test_returns_none_for_non_community(self): def test_returns_none_for_root(self): assert _extract_community_id("/") is None + def test_returns_none_for_community_metrics_paths(self): + """Community metrics/session paths intentionally return None. + + These endpoints handle community_id directly in route handlers + rather than through middleware extraction. + """ + assert _extract_community_id("/hed/metrics") is None + assert _extract_community_id("/hed/metrics/public") is None + assert _extract_community_id("/hed/metrics/usage") is None + assert _extract_community_id("/hed/metrics/quality") is None + assert _extract_community_id("/hed/sessions") is None + assert _extract_community_id("/hed/config") is None + class TestMetricsMiddleware: """Tests for MetricsMiddleware integration.""" diff --git a/tests/test_metrics/test_quality_queries.py b/tests/test_metrics/test_quality_queries.py index 29a0c2e..172cd99 100644 --- a/tests/test_metrics/test_quality_queries.py +++ b/tests/test_metrics/test_quality_queries.py @@ -235,6 +235,89 @@ def test_empty_community(self, quality_db): conn.close() +class TestPercentileEdgeCases: + """Tests for _percentile() helper edge cases.""" + + def test_single_element(self): + from src.metrics.queries import _percentile + + assert _percentile([100.0], 0.5) == 100.0 + assert _percentile([100.0], 0.95) == 100.0 + + def test_two_elements(self): + from src.metrics.queries import _percentile + + values = [100.0, 200.0] + assert _percentile(values, 0.5) == 200.0 # idx=int(0.5*2)=1 + assert _percentile(values, 0.95) == 200.0 # idx=int(0.95*2)=1 + + def test_empty_returns_none(self): + from src.metrics.queries import _percentile + + assert _percentile([], 0.5) is None + + +class TestCountToolsMalformedJSON: + """Tests for _count_tools() with malformed data.""" + + def test_malformed_json_skipped(self, tmp_path): + """Rows with invalid tools_called JSON should be skipped gracefully.""" + + from src.metrics.queries import get_quality_metrics + + db_path = tmp_path / "malformed.db" + init_metrics_db(db_path) + + # Insert a valid entry first + log_request( + RequestLogEntry( + request_id="valid", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + tools_called=["search_docs"], + ), + db_path=db_path, + ) + + # Insert a row with invalid JSON directly via SQL + conn = get_metrics_connection(db_path) + try: + conn.execute( + """INSERT INTO request_log + (request_id, timestamp, endpoint, method, community_id, + duration_ms, status_code, tools_called) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "bad", + "2025-01-15T11:00:00+00:00", + "/hed/ask", + "POST", + "hed", + 300.0, + 200, + "not-valid-json{{{", + ), + ) + conn.commit() + finally: + conn.close() + + # Query should still work, returning valid tool counts + conn = get_metrics_connection(db_path) + try: + result = get_quality_metrics("hed", conn, "daily") + assert len(result["buckets"]) == 1 + # Should have at least the valid entry's tool + bucket = result["buckets"][0] + assert bucket["requests"] == 2 + finally: + conn.close() + + class TestMigrationColumns: """Tests for backward-compatible column migration.""" From 4bcc07ae9c2c818a0957d33239830f6a6e6b2978 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Wed, 4 Feb 2026 12:06:48 -0800 Subject: [PATCH 14/26] docs: add community development pointers to CLAUDE.md and generalize local testing guide (#150) Add community development section to CLAUDE.md with links to the documentation site registry guides (adding a community, local testing, schema reference, extensions). Generalize .context/local-testing-guide.md from EEGLAB-specific to community-agnostic with placeholder COMMUNITY_ID. --- .context/local-testing-guide.md | 272 ++++++-------------------------- CLAUDE.md | 13 +- 2 files changed, 58 insertions(+), 227 deletions(-) diff --git a/.context/local-testing-guide.md b/.context/local-testing-guide.md index 6c57954..1cfbaf4 100644 --- a/.context/local-testing-guide.md +++ b/.context/local-testing-guide.md @@ -1,257 +1,81 @@ -# Local Testing Guide for EEGLAB Assistant +# Local Testing Guide for Community Assistants -## Quick Test (Verify Configuration) +Quick reference for testing any community assistant locally. For the full guide, see https://docs.osc.earth/osa/registry/local-testing/ -```bash -cd /Users/yahya/Documents/git/osa-phase1 +## Quick Validation -# Run quick verification -uv run python test_eeglab_interactive.py +```bash +# Validate config loads +uv run pytest tests/test_core/ -k "community" -v + +# Or programmatically +uv run python -c " +from pathlib import Path +from src.core.config.community import CommunityConfig +config = CommunityConfig.from_yaml(Path('src/assistants/COMMUNITY_ID/config.yaml')) +print(f'Loaded: {config.name} with {len(config.documentation)} docs') +" ``` -## Full Backend Testing - -### 1. Set Environment Variables +## Environment Variables ```bash -# Required: OpenRouter API key for LLM export OPENROUTER_API_KEY="your-key-here" - -# Optional: API keys for admin functions (sync) +# Optional: for sync operations export API_KEYS="test-key-123" - -# Optional: Specific EEGLAB key (if community has BYOK) -# export OPENROUTER_API_KEY_EEGLAB="eeglab-specific-key" +# Optional: community-specific key +# export OPENROUTER_API_KEY_COMMUNITY="key" ``` -### 2. Start Backend Server +## Server Testing ```bash -cd /Users/yahya/Documents/git/osa-phase1 - -# Start development server +# Start dev server uv run uvicorn src.api.main:app --reload --port 38528 -``` - -Server will be available at: `http://localhost:38528` - -### 3. Test Endpoints -#### A. List All Communities - -```bash +# List communities (verify yours appears) curl http://localhost:38528/communities | jq -``` - -**Expected response:** -```json -{ - "communities": [ - { - "id": "eeglab", - "name": "EEGLAB", - "description": "EEG signal processing and analysis toolbox", - "status": "available" - }, - { - "id": "hed", - "name": "HED (Hierarchical Event Descriptors)", - ... - } - ] -} -``` - -#### B. Get EEGLAB Community Info -```bash -curl http://localhost:38528/communities/eeglab | jq -``` - -**Expected response:** -```json -{ - "id": "eeglab", - "name": "EEGLAB", - "description": "EEG signal processing and analysis toolbox", - "status": "available", - "documentation_count": 26, - "github_repos": 6, - "has_sync_config": true -} -``` - -#### C. Ask a Question (Simple) - -```bash -curl -X POST http://localhost:38528/eeglab/ask \ - -H "Content-Type: application/json" \ - -d '{ - "question": "What is EEGLAB?", - "api_key": "your-openrouter-key" - }' | jq -``` - -**Expected response:** -```json -{ - "answer": "EEGLAB is an interactive MATLAB toolbox...", - "sources": [ - { - "title": "EEGLAB quickstart", - "url": "https://sccn.github.io/..." - } - ] -} -``` - -#### D. Ask About ICA - -```bash -curl -X POST http://localhost:38528/eeglab/ask \ +# Ask a question +curl -X POST http://localhost:38528/COMMUNITY_ID/ask \ -H "Content-Type: application/json" \ - -d '{ - "question": "How do I run ICA in EEGLAB?", - "api_key": "your-openrouter-key" - }' | jq + -d '{"question": "What is this tool?", "api_key": "your-key"}' | jq ``` -**Should mention:** ICA decomposition, ICLabel, artifact removal - -#### E. Test Chat Endpoint +## CLI Testing (No Server Needed) ```bash -curl -X POST http://localhost:38528/eeglab/chat \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [ - {"role": "user", "content": "What preprocessing steps should I do?"} - ], - "api_key": "your-openrouter-key" - }' | jq -``` - -**Should mention:** Filtering, re-referencing, ICA, artifact removal +# Interactive chat +uv run osa chat --community COMMUNITY_ID --standalone -### 4. Test Documentation Retrieval - -The assistant should automatically retrieve docs. Test by asking specific questions: - -```bash -# Should trigger retrieve_eeglab_docs tool -curl -X POST http://localhost:38528/eeglab/ask \ - -H "Content-Type: application/json" \ - -d '{ - "question": "How do I filter my EEG data in EEGLAB?", - "api_key": "your-openrouter-key", - "stream": false - }' | jq '.tool_calls' +# Single question +uv run osa ask --community COMMUNITY_ID "What is this tool?" --standalone ``` -**Expected:** Should call `retrieve_eeglab_docs` with filter-related docs - -### 5. Test via CLI (Easier!) +## Knowledge Sync ```bash -cd /Users/yahya/Documents/git/osa-phase1 - -# Set API key -export OPENROUTER_API_KEY="your-key-here" - -# Start interactive chat -uv run osa chat --community eeglab --standalone - -# Or ask single question -uv run osa ask --community eeglab "What is EEGLAB?" --standalone +uv run osa sync init --community COMMUNITY_ID +uv run osa sync github --community COMMUNITY_ID --full +uv run osa sync papers --community COMMUNITY_ID --citations ``` -**CLI is easier for testing because:** -- Handles API key automatically -- Shows formatted output -- Interactive mode for multi-turn conversations +## Test Checklist -## Test Questions for EEGLAB - -Good test questions to verify configuration: - -1. **Basic Info:** - - "What is EEGLAB?" - - "What can EEGLAB do?" - -2. **Preprocessing:** - - "What preprocessing steps should I do?" - - "How do I filter EEG data?" - - "How do I re-reference my data?" - -3. **ICA:** - - "How do I run ICA in EEGLAB?" - - "What is ICLabel?" - - "How do I remove artifacts with ICA?" - -4. **Plugins:** - - "What is clean_rawdata?" - - "How do I use ASR?" - - "What is the PREP pipeline?" - -5. **Integration:** - - "How do I use EEGLAB with BIDS?" - - "Can I use EEGLAB with Python?" - -6. **Knowledge Base (requires sync):** - - "What are the latest issues in the eeglab repo?" - - "Show me recent PRs in ICLabel" - - "Papers about EEGLAB ICA" +- [ ] Config validates without errors +- [ ] Community appears in `/communities` +- [ ] `/ask` endpoint returns relevant answers +- [ ] `/chat` endpoint works for multi-turn +- [ ] Preloaded docs are in context +- [ ] On-demand docs retrieved when relevant +- [ ] Documentation URLs in responses are valid +- [ ] CLI standalone mode works +- [ ] Knowledge sync completes (if configured) +- [ ] Assistant does not hallucinate PR/issue numbers ## Troubleshooting -### Server won't start - -```bash -# Check if port is already in use -lsof -i :38528 - -# Use different port -uv run uvicorn src.api.main:app --reload --port 38529 -``` - -### "Assistant not found" error - -```bash -# Verify EEGLAB is registered -uv run python -c "from src.assistants import discover_assistants, registry; discover_assistants(); print('eeglab' in registry)" -``` - -### Documentation not retrieved - -- Check that `retrieve_eeglab_docs` tool is available -- Check network access to sccn.github.io -- Check tool calls in response - -### Knowledge base empty - -- Knowledge base requires `API_KEYS` env var for sync -- Run sync locally: - ```bash - export API_KEYS="test-key" - uv run osa sync init --community eeglab - uv run osa sync github --community eeglab --full - ``` - -## Expected Behavior - -**What works without knowledge sync:** -- ✓ Assistant creation -- ✓ System prompt -- ✓ Documentation retrieval (fetches from URLs) -- ✓ Answering questions about EEGLAB -- ✓ Providing guidance on workflows - -**What needs knowledge sync:** -- ✗ Searching GitHub issues/PRs -- ✗ Listing recent activity -- ✗ Searching papers -- ✗ Citation counts - -## Next: Epic Branch Workflow - -See `epic-branch-workflow.md` for multi-phase development process. +- **Server won't start**: Check port with `lsof -i :38528` +- **Assistant not found**: Check discovery with `uv run python -c "from src.assistants import discover_assistants, registry; discover_assistants(); print([a.id for a in registry.list_available()])"` +- **Docs not retrieved**: Test source_url with `curl -I ` +- **Knowledge empty**: Run `uv run osa sync init --community COMMUNITY_ID` first diff --git a/CLAUDE.md b/CLAUDE.md index 8d679b4..0740af4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -146,9 +146,16 @@ src/ - **.context/api_key_authorization_design.md** - API key auth and CORS - **.context/security-architecture.md** - Security patterns -### Community Configuration -- **.context/yaml_registry.md** - YAML-based community config -- **.context/community_onboarding_review.md** - Onboarding guidelines +### Community Development (Adding/Modifying Communities) +- **Full docs site**: https://docs.osc.earth/osa/registry/ (canonical reference) + - [Adding a Community](https://docs.osc.earth/osa/registry/quick-start/) - Step-by-step guide + - [Local Testing](https://docs.osc.earth/osa/registry/local-testing/) - Testing a new community end-to-end + - [Schema Reference](https://docs.osc.earth/osa/registry/schema-reference/) - Full YAML config schema + - [Extensions](https://docs.osc.earth/osa/registry/extensions/) - Python plugins and MCP servers +- **.context/yaml_registry.md** - YAML-based community config (internal notes) +- **.context/community_onboarding_review.md** - Onboarding gap analysis +- **.context/local-testing-guide.md** - Quick local testing reference +- **Existing configs to reference**: `src/assistants/hed/config.yaml`, `src/assistants/eeglab/config.yaml` ### Tool System - **.context/tool-system-guide.md** - How tools work and are registered From cf66a4d7ae1ff987b223fca8530d1f7cb5e0f6c2 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Wed, 4 Feb 2026 13:33:25 -0800 Subject: [PATCH 15/26] feat: add BIDS community assistant (Phase 1) (#151) * feat: add BIDS community assistant (Phase 1) Add BIDS (Brain Imaging Data Structure) as a new community with: - 45 documentation sources (2 preloaded, 43 on-demand) covering the specification, all 12 modalities, derivatives, website getting-started guides, FAQs, tools, schema docs, and BEP process - 4 GitHub repos for issue/PR sync (specification, validator, website, examples) - 14 citation DOIs (canonical paper + 12 modality-specific extension papers + BIDS Apps) - System prompt with modality awareness, schema awareness, validator guidance, converter recommendations, and explicit anti-hallucination instructions for GitHub references - NeuroStars discourse integration (bids tag) - Budget and maintainer configuration All documentation URLs verified (raw GitHub + readthedocs stable). Config validates correctly and community discovery works (336 tests pass, no failures). Closes #149 * fix: address PR review findings for BIDS config - Fix phenotypic data description (phenotype/ directory, not participants.tsv) - Fix code description (data preparation scripts, not analysis code) - Add participants.tsv to data summary files description - Fix behavioral wording (no neural recordings, not no neuroimaging) - Fix genetics wording in system prompt (brain imaging data) - Add missing physiological recordings documentation entry (13 modality docs) - Add Physiological to system prompt modality list --- src/assistants/bids/__init__.py | 1 + src/assistants/bids/config.yaml | 540 ++++++++++++++++++++++++++++++++ src/assistants/bids/tools.py | 10 + 3 files changed, 551 insertions(+) create mode 100644 src/assistants/bids/__init__.py create mode 100644 src/assistants/bids/config.yaml create mode 100644 src/assistants/bids/tools.py diff --git a/src/assistants/bids/__init__.py b/src/assistants/bids/__init__.py new file mode 100644 index 0000000..85222e0 --- /dev/null +++ b/src/assistants/bids/__init__.py @@ -0,0 +1 @@ +# BIDS (Brain Imaging Data Structure) community assistant diff --git a/src/assistants/bids/config.yaml b/src/assistants/bids/config.yaml new file mode 100644 index 0000000..afdfe63 --- /dev/null +++ b/src/assistants/bids/config.yaml @@ -0,0 +1,540 @@ +# BIDS (Brain Imaging Data Structure) Assistant Configuration +# Single source of truth for the BIDS community assistant + +id: bids +name: BIDS (Brain Imaging Data Structure) +description: Standard for organizing and describing neuroimaging and behavioral data +status: available + +# Enable page context tool for widget embedding +enable_page_context: true + +# Allowed CORS origins for widget embedding on community websites +cors_origins: + - https://bids.neuroimaging.io + - https://bids-specification.readthedocs.io + +# Per-community OpenRouter API key (optional) +openrouter_api_key_env_var: "OPENROUTER_API_KEY_BIDS" + +# Community maintainers (GitHub usernames) +maintainers: + - effigies + +# Budget limits for cost management +budget: + daily_limit_usd: 5.00 + monthly_limit_usd: 50.00 + alert_threshold_pct: 80.0 + +# Default model for this community +default_model: "qwen/qwen3-235b-a22b-2507" +default_model_provider: "DeepInfra/FP8" + +# System prompt template +system_prompt: | + You are a technical assistant specialized in the Brain Imaging Data Structure (BIDS), a standard for + organizing and describing outputs of neuroimaging experiments. You provide explanations, troubleshooting, + and step-by-step guidance for organizing data according to the BIDS specification, using the BIDS validator, + and working with BIDS-compatible tools. + + You must stick strictly to the topic of BIDS and avoid digressions. + All responses should be accurate and based on the official BIDS specification and documentation. + + When a user's question is ambiguous, assume the most likely meaning and provide a useful starting point, + but also ask clarifying questions when necessary. + Communicate in a formal and technical style, prioritizing precision and accuracy while remaining clear. + Answers should be structured and easy to follow, with examples where appropriate. + + The BIDS homepage is https://bids.neuroimaging.io/ + The BIDS specification is at https://bids-specification.readthedocs.io/ + The BIDS GitHub organization is at https://github.com/bids-standard + The BIDS validator is at https://github.com/bids-standard/bids-validator + BIDS example datasets are at https://github.com/bids-standard/bids-examples + + You will respond with markdown formatted text. Be concise and include only the most relevant information unless told otherwise. + + ## Modality Awareness + + BIDS covers many data modalities, each with specific rules and file formats. When users ask about a + specific modality, retrieve the corresponding documentation. Key modalities include: + + - **MRI** (anatomical, functional, diffusion, fieldmaps, ASL, quantitative MRI) + - **EEG** (electroencephalography) + - **MEG** (magnetoencephalography) + - **iEEG** (intracranial electroencephalography) + - **PET** (positron emission tomography) + - **NIRS** (near-infrared spectroscopy) + - **Motion** (motion capture data) + - **Microscopy** (microscopy imaging data) + - **MRS** (magnetic resonance spectroscopy) + - **EMG** (electromyography) + - **Behavioral** (behavioral experiments with no neural recordings) + - **Genetics** (genetic descriptors associated with brain imaging data) + - **Physiological** (continuous physiological recordings: cardiac, respiratory) + + Each modality may have a dedicated extension paper. When discussing a specific modality, cite + the relevant paper if available. + + ## Schema Awareness + + The BIDS specification is driven by a machine-readable schema (YAML rules in the specification repository). + The BIDS validator uses this schema to check datasets. Key concepts: + + - **Entities**: Key-value pairs in filenames (sub-, ses-, task-, run-, etc.) + - **Suffixes**: File type identifiers (bold, T1w, eeg, channels, events, etc.) + - **Metadata**: JSON sidecar fields required or recommended for each file type + - **Rules**: Validation rules defining allowed combinations of entities, suffixes, and metadata + + When users ask about filename conventions or required metadata, consult the specification docs. + + ## Using Tools Liberally + + You have access to tools for documentation retrieval and knowledge discovery. **Use them proactively.** + + - Tool calls are inexpensive; do not hesitate to retrieve documentation + - Retrieve relevant documentation to ensure your answers are accurate and current + - When users ask about recent activity, issues, or PRs, use the knowledge discovery tools + - When users ask about research papers, use the paper search tool + + ## Using the retrieve_bids_docs Tool + + Before responding, use the retrieve_bids_docs tool to get any documentation you need. + Include links to relevant documents in your response. + + **Important guidelines:** + - Do NOT retrieve docs that have already been preloaded (listed below) + - Retrieve multiple relevant documents at once + - Get background information documents in addition to specific documents for the question + - If you have already loaded a document in this conversation, do not load it again + + {preloaded_docs_section} + + {available_docs_section} + + ## BIDS Validator Guidance + + Users frequently ask about validation errors. When helping with validation: + - Explain what the specific error or warning means + - Show the correct file naming or metadata format + - Point to the relevant section of the specification + - Mention that the BIDS validator can be run online at https://bids-standard.github.io/bids-validator/ + or via CLI with `bids-validator` (Deno-based) or `pip install bids-validator` (Python wrapper) + + ## Converter Tool Awareness + + Users often need help converting data to BIDS format. Common converters include: + - **dcm2bids** / **HeuDiConv** - DICOM to BIDS (MRI) + - **MNE-BIDS** - MEG/EEG/iEEG to BIDS (Python) + - **EEG-BIDS** - EEG to BIDS (MATLAB/EEGLAB plugin) + - **bidscoin** - General-purpose BIDS converter + - **pet2bids** - PET data to BIDS + + When users ask about data conversion, recommend the appropriate converter for their data type. + + ## BIDS Extension Proposals (BEPs) + + BEPs are the mechanism for adding new data types to BIDS. When users ask about extending BIDS + or about data types not yet in the specification, explain the BEP process and point them to + https://bids.neuroimaging.io/extensions/beps for the current list. + + ## Citation Guidance + + When users ask about citing BIDS, always recommend citing the canonical paper: + - Gorgolewski et al. (2016), "The brain imaging data structure", doi:10.1038/sdata.2016.44 + + Additionally recommend citing the modality-specific extension paper when applicable. + The full citation list is in the specification's introduction section. + + ## Knowledge Discovery Tools - YOU MUST USE THESE + + You have access to a synced knowledge database with GitHub issues, PRs, and academic papers. + **You MUST use these tools when users ask about recent activity, issues, or PRs.** + + **Available BIDS repositories in the database:** + {repo_list} + + **CRITICAL: When users mention these repos (even by short name), USE THE TOOLS:** + - "specification" or "spec" -> repo="bids-standard/bids-specification" + - "validator" -> repo="bids-standard/bids-validator" + - "website" -> repo="bids-standard/bids-website" + - "examples" -> repo="bids-standard/bids-examples" + + **MANDATORY: Use tools for these question patterns:** + - "What are the latest PRs?" -> CALL `list_bids_recent(item_type="pr")` + - "Latest PRs in the validator?" -> CALL `list_bids_recent(item_type="pr", repo="bids-standard/bids-validator")` + - "Open issues in the spec?" -> CALL `list_bids_recent(item_type="issue", status="open", repo="bids-standard/bids-specification")` + - "Recent activity?" -> CALL `list_bids_recent(limit=10)` + - "Any discussions about derivatives?" -> CALL `search_bids_discussions(query="derivatives")` + + **Core BIDS papers tracked for citations (DOIs in database):** + {paper_dois} + + **MANDATORY: Use tools for citation/paper questions:** + - "Has anyone cited the BIDS paper?" -> CALL `search_bids_papers(query="BIDS")` + - "Papers about BIDS EEG?" -> CALL `search_bids_papers(query="BIDS EEG")` + + ## CRITICAL: Do Not Hallucinate GitHub References + + **NEVER fabricate issue numbers, PR numbers, discussion links, or commit hashes.** + If you search the knowledge database and find no results, say so explicitly: + "I could not find any matching issues/PRs in the database." + + Do NOT invent plausible-sounding issue numbers or PR titles. Only reference items + that were returned by the knowledge discovery tools with real URLs. + + **DO NOT:** + - Tell users to "visit GitHub" or "check Google Scholar" when you have the data + - Make up PR numbers, issue numbers, paper titles, authors, or citation counts + - Say "I don't have access" -- you DO have access via the tools above + - Create fictional GitHub URLs like "github.com/bids-standard/bids-specification/issues/1234" + unless that exact URL was returned by a tool + + **Present results as discovery:** + - "Here are the recent PRs in the specification: [actual list with real URLs]" + - "There's a related discussion: [real link]" + - "Here are papers related to BIDS: [actual list from database]" + + The knowledge database may not be populated. If you get a message about initializing the database, + explain that the knowledge base is not set up yet. + + {page_context_section} + + {additional_instructions} + +# Documentation sources +# - preload: true = embedded in system prompt (for core docs, ~2-3 max) +# - preload: false/omitted = fetched on demand via retrieve_docs tool +documentation: + # === PRELOADED: Core concepts (2 docs) === + - title: BIDS common principles + url: https://bids-specification.readthedocs.io/en/stable/common-principles.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/common-principles.md + preload: true + category: core + description: Core principles of BIDS including file naming, directory structure, and metadata conventions. + + - title: Getting started with BIDS + url: https://bids.neuroimaging.io/getting_started/ + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/index.md + preload: true + category: core + description: Overview of getting started with BIDS, including links to tutorials and resources. + + # === ON-DEMAND: Specification core (4 docs) === + - title: Introduction and citing BIDS + url: https://bids-specification.readthedocs.io/en/stable/introduction.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/introduction.md + category: specification + description: Introduction to BIDS, motivation, extensions, and how to cite the specification and modality papers. + + - title: BIDS glossary + url: https://bids-specification.readthedocs.io/en/stable/glossary.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/glossary.md + category: specification + description: Definitions of BIDS terminology including entities, suffixes, datatypes, and modalities. + + - title: BIDS extensions + url: https://bids-specification.readthedocs.io/en/stable/extensions.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/extensions.md + category: specification + description: How BIDS is extended through BIDS Extension Proposals (BEPs). + + - title: Longitudinal and multi-site studies + url: https://bids-specification.readthedocs.io/en/stable/longitudinal-and-multi-site-studies.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/longitudinal-and-multi-site-studies.md + category: specification + description: Guidelines for organizing longitudinal studies and multi-site data in BIDS. + + # === ON-DEMAND: Modality-agnostic files (5 docs) === + - title: Dataset description + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/dataset-description.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/dataset-description.md + category: modality_agnostic + description: Required dataset_description.json file and its fields (Name, BIDSVersion, License, etc.). + + - title: Events file + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/events.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/events.md + category: modality_agnostic + description: Events TSV files for recording experimental events (onset, duration, trial_type). + + - title: Phenotypic and assessment data + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/phenotypic-and-assessment-data.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/phenotypic-and-assessment-data.md + category: modality_agnostic + description: Phenotypic and assessment data in the phenotype/ directory (questionnaires, behavioral measures). + + - title: Code description + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/code.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/code.md + category: modality_agnostic + description: Guidelines for including data preparation scripts (deidentification, format conversion) in BIDS datasets. + + - title: Data summary files + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/data-summary-files.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/data-summary-files.md + category: modality_agnostic + description: Dataset-level summary files (participants.tsv, samples.tsv, scans.tsv, sessions.tsv). + + # === ON-DEMAND: Modality-specific files (13 docs) === + - title: MRI data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/magnetic-resonance-imaging-data.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/magnetic-resonance-imaging-data.md + category: modality_specific + description: MRI data organization including anatomical, functional, diffusion, fieldmaps, ASL, and quantitative MRI. + + - title: EEG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/electroencephalography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/electroencephalography.md + category: modality_specific + description: EEG data organization including channel info, electrode positions, and events. + + - title: MEG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/magnetoencephalography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/magnetoencephalography.md + category: modality_specific + description: MEG data organization including channel info, sensor positions, and cross-talk matrices. + + - title: iEEG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/intracranial-electroencephalography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/intracranial-electroencephalography.md + category: modality_specific + description: Intracranial EEG data organization including electrode coordinates and photos. + + - title: PET data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/positron-emission-tomography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/positron-emission-tomography.md + category: modality_specific + description: PET data organization including blood data, reconstruction, and radioligand metadata. + + - title: Microscopy data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/microscopy.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/microscopy.md + category: modality_specific + description: Microscopy imaging data organization for tissue samples and staining protocols. + + - title: NIRS data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/near-infrared-spectroscopy.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/near-infrared-spectroscopy.md + category: modality_specific + description: Near-infrared spectroscopy data organization including optode positions and channels. + + - title: Motion data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/motion.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/motion.md + category: modality_specific + description: Motion capture data organization for body tracking and kinematics. + + - title: MRS data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/magnetic-resonance-spectroscopy.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/magnetic-resonance-spectroscopy.md + category: modality_specific + description: Magnetic resonance spectroscopy data organization. + + - title: EMG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/electromyography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/electromyography.md + category: modality_specific + description: Electromyography data organization for muscle activity recordings. + + - title: Behavioral experiments + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/behavioral-experiments.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/behavioral-experiments.md + category: modality_specific + description: Behavioral data organization for experiments with no neural recordings. + + - title: Genetic descriptor + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/genetic-descriptor.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/genetic-descriptor.md + category: modality_specific + description: Genetic descriptors associated with imaging participants. + + - title: Physiological recordings + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/physiological-recordings.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/physiological-recordings.md + category: modality_specific + description: Continuous physiological recordings (cardiac, respiratory) acquired alongside neural data. + + # === ON-DEMAND: Derivatives (4 docs) === + - title: Derivatives introduction + url: https://bids-specification.readthedocs.io/en/stable/derivatives/introduction.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/introduction.md + category: derivatives + description: Overview of BIDS derivatives, naming conventions, and metadata for processed data. + + - title: Imaging derivatives + url: https://bids-specification.readthedocs.io/en/stable/derivatives/imaging.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/imaging.md + category: derivatives + description: Imaging-specific derivative conventions (masks, segmentations, atlases). + + - title: Common derivative data types + url: https://bids-specification.readthedocs.io/en/stable/derivatives/common-data-types.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/common-data-types.md + category: derivatives + description: Common derivative formats shared across modalities. + + - title: Atlas derivatives + url: https://bids-specification.readthedocs.io/en/stable/derivatives/atlas.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/atlas.md + category: derivatives + description: Atlases as BIDS derivatives, including probability maps and label descriptions. + + # === ON-DEMAND: Website getting started (6 docs) === + - title: BIDS folder structure + url: https://bids.neuroimaging.io/getting_started/folders_and_files/folders + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/folders.md + category: getting_started + description: Overview of the BIDS directory hierarchy (subject, session, datatype folders). + + - title: BIDS file naming + url: https://bids.neuroimaging.io/getting_started/folders_and_files/files + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/files.md + category: getting_started + description: File naming conventions in BIDS (entities, suffixes, extensions). + + - title: BIDS derivatives overview + url: https://bids.neuroimaging.io/getting_started/folders_and_files/derivatives + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/derivatives.md + category: getting_started + description: Accessible overview of derivatives in BIDS for processed data. + + - title: JSON metadata + url: https://bids.neuroimaging.io/getting_started/folders_and_files/metadata/json + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/metadata/json.md + category: getting_started + description: How JSON sidecar files work in BIDS for metadata. + + - title: TSV files + url: https://bids.neuroimaging.io/getting_started/folders_and_files/metadata/tsv + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/metadata/tsv.md + category: getting_started + description: How TSV tabular files work in BIDS (participants, events, channels, etc.). + + - title: Event annotation tutorial + url: https://bids.neuroimaging.io/getting_started/tutorials/annotation + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/tutorials/annotation.md + category: getting_started + description: Tutorial on annotating events in BIDS datasets. + + # === ON-DEMAND: Website FAQ (5 docs) === + - title: General FAQ + url: https://bids.neuroimaging.io/faq/general + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/general.md + category: faq + description: Frequently asked questions about BIDS in general. + + - title: MRI FAQ + url: https://bids.neuroimaging.io/faq/mri + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/mri.md + category: faq + description: Frequently asked questions about MRI data in BIDS. + + - title: EEG FAQ + url: https://bids.neuroimaging.io/faq/eeg + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/eeg.md + category: faq + description: Frequently asked questions about EEG data in BIDS. + + - title: BIDS Apps FAQ + url: https://bids.neuroimaging.io/faq/bids-apps + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/bids-apps.md + category: faq + description: Frequently asked questions about BIDS Apps. + + - title: BIDS Extensions FAQ + url: https://bids.neuroimaging.io/faq/bids-extensions + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/bids-extensions.md + category: faq + description: Frequently asked questions about extending BIDS and the BEP process. + + # === ON-DEMAND: Website tools (3 docs) === + - title: BIDS Validator + url: https://bids.neuroimaging.io/tools/validator + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/tools/validator.md + category: tools + description: Overview of the BIDS validator, installation, and usage. + + - title: BIDS converters + url: https://bids.neuroimaging.io/tools/converters + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/tools/converters.md + category: tools + description: List of tools for converting data to BIDS format. + + - title: BIDS Apps + url: https://bids.neuroimaging.io/tools/bids-apps + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/tools/bids-apps.md + category: tools + description: Overview of BIDS Apps, standardized analysis pipelines. + + # === ON-DEMAND: Extensions and governance (2 docs) === + - title: BIDS Extension Proposals + url: https://bids.neuroimaging.io/extensions/beps + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/extensions/beps.md + category: extensions + description: List and status of all BIDS Extension Proposals (BEPs). + + - title: BEP guidelines + url: https://bids.neuroimaging.io/extensions/guidelines + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/extensions/guidelines.md + category: extensions + description: Guidelines for creating and submitting a new BIDS Extension Proposal. + + # === ON-DEMAND: Schema documentation (2 docs) === + - title: How the BIDS schema works + url: https://bids.neuroimaging.io/standards/schema/how-the-schema-works + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/standards/schema/how-the-schema-works.md + category: schema + description: Explanation of the machine-readable BIDS schema and how it drives validation. + + - title: Schema rules + url: https://bids.neuroimaging.io/standards/schema/schema-rules + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/standards/schema/schema-rules.md + category: schema + description: Documentation of schema rules that define BIDS validation logic. + +# GitHub repositories for issue/PR sync +github: + repos: + - bids-standard/bids-specification + - bids-standard/bids-validator + - bids-standard/bids-website + - bids-standard/bids-examples + +# Paper/citation search configuration +# Canonical paper + modality-specific extension papers +citations: + queries: + - Brain Imaging Data Structure + - BIDS neuroimaging + - BIDS specification + - BIDS validator + - BIDS apps + dois: + # Canonical BIDS paper + - "10.1038/sdata.2016.44" # Gorgolewski et al. (2016) - The brain imaging data structure + # Modality-specific extension papers + - "10.1038/s41597-019-0104-8" # EEG-BIDS (Pernet et al., 2019) + - "10.1038/s41597-019-0105-7" # iEEG-BIDS (Holdgraf et al., 2019) + - "10.1038/sdata.2018.110" # MEG-BIDS (Niso et al., 2018) + - "10.1038/s41597-022-01164-1" # PET-BIDS (Norgaard et al., 2021) + - "10.1177/0271678X20905433" # PET guidelines (Knudsen et al., 2020) + - "10.1093/gigascience/giaa104" # Genetics-BIDS (Moreau et al., 2020) + - "10.3389/fnins.2022.871228" # Microscopy-BIDS (Bourget et al., 2022) + - "10.1038/s41597-022-01571-4" # qMRI-BIDS (Karakuzu et al., 2022) + - "10.1038/s41597-022-01615-9" # ASL-BIDS (Clement et al., 2022) + - "10.1038/s41597-024-04136-9" # NIRS-BIDS (Luke et al., 2025) + - "10.1038/s41597-024-03559-8" # Motion-BIDS (Jeung et al., 2024) + - "10.1038/s41597-025-05543-2" # MRS-BIDS (Bouchard et al., 2025) + # Related ecosystem + - "10.1371/journal.pcbi.1005209" # BIDS Apps (Gorgolewski et al., 2017) + +# Discourse forums +discourse: + - url: https://neurostars.org + tags: + - bids + +# No custom extensions for Phase 1 +# Future: schema query tool, BEP status tool, converter recommendation tool diff --git a/src/assistants/bids/tools.py b/src/assistants/bids/tools.py new file mode 100644 index 0000000..d0276c9 --- /dev/null +++ b/src/assistants/bids/tools.py @@ -0,0 +1,10 @@ +"""BIDS community-specific tools. + +Phase 1: No custom tools needed. Documentation retrieval and knowledge +discovery tools are auto-generated from config.yaml. + +Future phases may add: +- Schema query tool (query BIDS schema for valid entities/metadata) +- BEP status tool (query active/completed BEPs) +- Converter recommendation tool (suggest converters for data types) +""" From 00bc18f305c5fd679548d5f1ecfaa65359887614 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 4 Feb 2026 21:33:39 +0000 Subject: [PATCH 16/26] chore: sync worker CORS from community configs [skip ci] --- workers/osa-worker/index.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index dbe3119..0937998 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -140,6 +140,8 @@ function isAllowedOrigin(origin) { // Allowed origins for OSA const allowedPatterns = [ 'https://osc.earth', + 'https://bids-specification.readthedocs.io', + 'https://bids.neuroimaging.io', 'https://eeglab.org', 'https://hedtags.org', 'https://sccn.github.io', @@ -154,6 +156,8 @@ function isAllowedOrigin(origin) { if (origin.endsWith('.eeglab.org')) return true; if (origin.endsWith('.github.io')) return true; if (origin.endsWith('.hedtags.org')) return true; + if (origin.endsWith('.neuroimaging.io')) return true; + if (origin.endsWith('.readthedocs.io')) return true; // Allow osa-demo.pages.dev and all subdomains (previews, branches) if (origin === 'https://osa-demo.pages.dev') return true; From 0f501285fddc7e749decd1ec1556b5a23adf09ce Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Wed, 4 Feb 2026 13:45:45 -0800 Subject: [PATCH 17/26] ci: add workflow_dispatch trigger to CORS sync workflow Allows manual re-triggering when automated runs fail due to transient issues (e.g. expired tokens, push race conditions). --- .github/workflows/sync-worker-cors.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/sync-worker-cors.yml b/.github/workflows/sync-worker-cors.yml index 08f56cf..4e65c0d 100644 --- a/.github/workflows/sync-worker-cors.yml +++ b/.github/workflows/sync-worker-cors.yml @@ -13,6 +13,7 @@ on: paths: - 'src/assistants/*/config.yaml' - 'workers/osa-worker/**' + workflow_dispatch: permissions: contents: write From a0b7e1272b317f9d3eafdee949ef51fb0513d331 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Wed, 4 Feb 2026 17:42:15 -0800 Subject: [PATCH 18/26] fix: update BIDS suggested question to Common Principles Replace "What are the required metadata fields?" with "What are the BIDS Common Principles?" as a more foundational starting question for new users. --- frontend/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/index.html b/frontend/index.html index e93573a..3db1630 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -188,7 +188,7 @@ suggestedQuestions: [ 'What is BIDS and why should I use it?', 'How do I organize my EEG data in BIDS?', - 'What are the required metadata fields?', + 'What are the BIDS Common Principles?', 'How do I validate my BIDS dataset?' ] } From 0d25a2b5f191b9676ec3987b07ebab170e7a8aaf Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Wed, 4 Feb 2026 18:19:33 -0800 Subject: [PATCH 19/26] fix: add PR/issue number lookup patterns to BIDS system prompt The model was not calling knowledge tools when asked about specific PR/issue numbers. Add explicit patterns showing how to search by number. --- src/assistants/bids/config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/assistants/bids/config.yaml b/src/assistants/bids/config.yaml index afdfe63..6214a93 100644 --- a/src/assistants/bids/config.yaml +++ b/src/assistants/bids/config.yaml @@ -166,6 +166,8 @@ system_prompt: | - "Open issues in the spec?" -> CALL `list_bids_recent(item_type="issue", status="open", repo="bids-standard/bids-specification")` - "Recent activity?" -> CALL `list_bids_recent(limit=10)` - "Any discussions about derivatives?" -> CALL `search_bids_discussions(query="derivatives")` + - "What is PR 2022 about?" -> CALL `search_bids_discussions(query="2022", include_issues=false, include_prs=true)` + - "What is issue #500?" -> CALL `search_bids_discussions(query="500", include_issues=true, include_prs=false)` **Core BIDS papers tracked for citations (DOIs in database):** {paper_dois} From b305ab244608f7aa0bbf6057ff81d8da4ea3de10 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Wed, 4 Feb 2026 18:37:45 -0800 Subject: [PATCH 20/26] feat: number-based lookup and switch to Claude Haiku 4.5 (#154) * feat: add number-based lookup to GitHub item search When query contains a PR/issue number (e.g. "2022", "#500", "PR 2022"), search now does a direct number lookup first, then falls back to full-text search for remaining slots. Deduplicates results. Fixes #153 * feat: switch all communities from Qwen to Claude Haiku 4.5 Qwen was not reliably calling knowledge tools when asked about specific PRs/issues. Claude Haiku 4.5 via Anthropic provider is more stable for tool calling and has caching enabled. * Address PR review: add number index, skip FTS for pure number queries, expand tests - Add idx_github_items_number index for direct number lookups - Add _is_pure_number_query() to skip FTS when query is just a number pattern - Add debug logging when number lookup finds no results - Add tests for status filter, nonexistent numbers, limits, bug/feature prefixes - Strengthen assertions on existing number lookup tests * Fix ruff SIM103: return condition directly in _is_pure_number_query * Fix ruff formatting --- src/assistants/bids/config.yaml | 4 +- src/assistants/eeglab/config.yaml | 4 +- src/assistants/hed/config.yaml | 4 +- src/knowledge/db.py | 3 + src/knowledge/search.py | 169 ++++++++++++++++++++-------- tests/test_knowledge/test_search.py | 130 +++++++++++++++++++++ 6 files changed, 261 insertions(+), 53 deletions(-) diff --git a/src/assistants/bids/config.yaml b/src/assistants/bids/config.yaml index 6214a93..b1f5790 100644 --- a/src/assistants/bids/config.yaml +++ b/src/assistants/bids/config.yaml @@ -28,8 +28,8 @@ budget: alert_threshold_pct: 80.0 # Default model for this community -default_model: "qwen/qwen3-235b-a22b-2507" -default_model_provider: "DeepInfra/FP8" +default_model: "anthropic/claude-haiku-4.5" +default_model_provider: "anthropic" # System prompt template system_prompt: | diff --git a/src/assistants/eeglab/config.yaml b/src/assistants/eeglab/config.yaml index 3a1baa6..c08a63e 100644 --- a/src/assistants/eeglab/config.yaml +++ b/src/assistants/eeglab/config.yaml @@ -5,8 +5,8 @@ id: eeglab name: EEGLAB description: EEG signal processing and analysis toolbox status: available -default_model: qwen/qwen3-235b-a22b-2507 -default_model_provider: DeepInfra/FP8 +default_model: anthropic/claude-haiku-4.5 +default_model_provider: anthropic cors_origins: - https://eeglab.org diff --git a/src/assistants/hed/config.yaml b/src/assistants/hed/config.yaml index 7c7c5c9..647a770 100644 --- a/src/assistants/hed/config.yaml +++ b/src/assistants/hed/config.yaml @@ -34,8 +34,8 @@ budget: # Default model for this community (optional) # If specified, overrides the platform-level default model # Format: creator/model-name (OpenRouter format) -default_model: "qwen/qwen3-235b-a22b-2507" -default_model_provider: "DeepInfra/FP8" # Optimized FP8 quantized inference +default_model: "anthropic/claude-haiku-4.5" +default_model_provider: "anthropic" # System prompt template # Available placeholders: {name}, {description}, {repo_list}, {paper_dois}, diff --git a/src/knowledge/db.py b/src/knowledge/db.py index 42b6dd3..479d829 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -38,6 +38,9 @@ UNIQUE(repo, item_type, number) ); +-- Index for direct number lookups (e.g. "What is PR #2022?") +CREATE INDEX IF NOT EXISTS idx_github_items_number ON github_items(number); + -- FTS5 virtual table for full-text search on GitHub items CREATE VIRTUAL TABLE IF NOT EXISTS github_items_fts USING fts5( title, diff --git a/src/knowledge/search.py b/src/knowledge/search.py index 100f066..d16aec5 100644 --- a/src/knowledge/search.py +++ b/src/knowledge/search.py @@ -106,6 +106,59 @@ class SearchResult: created_at: str +def _is_pure_number_query(query: str) -> bool: + """Check if the query is purely a number lookup with no useful text for FTS. + + Returns True for queries like "2022", "#500", "PR 2022", "issue #500" + where FTS would produce poor results (searching for literal "#500" or "PR 10"). + """ + stripped = query.strip() + # Pure number or hash-number + if stripped.lstrip("#").isdigit(): + return True + # Keyword + number with nothing else + return bool(re.fullmatch(r"(?:pr|pull|issue|bug|feature)\s*#?\s*\d+", stripped, re.IGNORECASE)) + + +def _extract_number(query: str) -> int | None: + """Extract an issue/PR number from a query string. + + Handles patterns like "2022", "#2022", "PR 2022", "issue #500". + Pure numeric queries (e.g. "2022") are treated as number lookups first; + this may match a PR/issue number rather than items mentioning the year. + + Returns: + The extracted number, or None if no number pattern found. + """ + # Strip and try common patterns + stripped = query.strip().lstrip("#") + # Direct number + if stripped.isdigit(): + return int(stripped) + # "PR 2022", "issue #500", "pull #2022", etc. + m = re.match(r"(?:pr|pull|issue|bug|feature)\s*#?\s*(\d+)", query.strip(), re.IGNORECASE) + if m: + return int(m.group(1)) + return None + + +def _row_to_result(row: sqlite3.Row) -> SearchResult: + """Convert a database row to a SearchResult.""" + first_message = row["first_message"] or "" + snippet = first_message[:200].strip() + if len(first_message) > 200: + snippet += "..." + return SearchResult( + title=row["title"], + url=row["url"], + snippet=snippet, + source="github", + item_type=row["item_type"], + status=row["status"], + created_at=row["created_at"] or "", + ) + + def search_github_items( query: str, project: str = "hed", @@ -114,10 +167,14 @@ def search_github_items( status: str | None = None, repo: str | None = None, ) -> list[SearchResult]: - """Search GitHub issues and PRs using phrase matching. + """Search GitHub issues and PRs by number, title, or body text. + + When the query contains a number (e.g. "2022", "#500", "PR 2022"), + results matching that number are returned first, followed by + full-text search results. Args: - query: Search phrase (treated as exact phrase, not FTS5 operators) + query: Search phrase, PR/issue number, or keyword project: Assistant/project name for database isolation. Defaults to 'hed'. limit: Maximum number of results item_type: Filter by 'issue' or 'pr' @@ -125,56 +182,74 @@ def search_github_items( repo: Filter by repository name Returns: - List of matching results, ordered by relevance - """ - # Build SQL query with optional filters - sql = """ - SELECT g.title, g.url, g.first_message, g.item_type, g.status, - g.created_at, g.repo - FROM github_items_fts f - JOIN github_items g ON f.rowid = g.id - WHERE github_items_fts MATCH ? + List of matching results, with number matches first """ - params: list[str | int] = [query] - - if item_type: - sql += " AND g.item_type = ?" - params.append(item_type) - if status: - sql += " AND g.status = ?" - params.append(status) - if repo: - sql += " AND g.repo = ?" - params.append(repo) - - sql += " ORDER BY rank LIMIT ?" - params.append(limit) - results = [] + seen_urls: set[str] = set() + try: with get_connection(project) as conn: - # Sanitize user query to prevent FTS5 injection - safe_query = _sanitize_fts5_query(query) - params[0] = safe_query - - for row in conn.execute(sql, params): - # Create snippet from first_message (first 200 chars) - first_message = row["first_message"] or "" - snippet = first_message[:200].strip() - if len(first_message) > 200: - snippet += "..." + # Phase 1: Try direct number lookup + number = _extract_number(query) + is_pure_number = _is_pure_number_query(query) + if number is not None: + num_sql = """ + SELECT title, url, first_message, item_type, status, + created_at, repo + FROM github_items WHERE number = ? + """ + num_params: list[str | int] = [number] + if item_type: + num_sql += " AND item_type = ?" + num_params.append(item_type) + if status: + num_sql += " AND status = ?" + num_params.append(status) + if repo: + num_sql += " AND repo = ?" + num_params.append(repo) + + for row in conn.execute(num_sql, num_params): + result = _row_to_result(row) + results.append(result) + seen_urls.add(result.url) + + if not results: + logger.debug("Number lookup for %d found no items", number) + + # Phase 2: Full-text search for remaining slots + # Skip FTS for pure number queries (e.g. "#500", "PR 2022") since + # the sanitized query would search for literal "#500" which won't + # match anything useful in title/body text. + remaining = limit - len(results) + if remaining > 0 and not is_pure_number: + fts_sql = """ + SELECT g.title, g.url, g.first_message, g.item_type, g.status, + g.created_at, g.repo + FROM github_items_fts f + JOIN github_items g ON f.rowid = g.id + WHERE github_items_fts MATCH ? + """ + fts_params: list[str | int] = [_sanitize_fts5_query(query)] + + if item_type: + fts_sql += " AND g.item_type = ?" + fts_params.append(item_type) + if status: + fts_sql += " AND g.status = ?" + fts_params.append(status) + if repo: + fts_sql += " AND g.repo = ?" + fts_params.append(repo) + + fts_sql += " ORDER BY rank LIMIT ?" + fts_params.append(remaining) + + for row in conn.execute(fts_sql, fts_params): + if row["url"] not in seen_urls: + results.append(_row_to_result(row)) + seen_urls.add(row["url"]) - results.append( - SearchResult( - title=row["title"], - url=row["url"], - snippet=snippet, - source="github", - item_type=row["item_type"], - status=row["status"], - created_at=row["created_at"] or "", - ) - ) except sqlite3.OperationalError as e: # Infrastructure failure (corruption, disk full, permissions) - must propagate logger.error( diff --git a/tests/test_knowledge/test_search.py b/tests/test_knowledge/test_search.py index 8f382ae..2460bbb 100644 --- a/tests/test_knowledge/test_search.py +++ b/tests/test_knowledge/test_search.py @@ -11,6 +11,8 @@ from src.knowledge.db import get_connection, init_db, upsert_github_item, upsert_paper from src.knowledge.search import ( SearchResult, + _extract_number, + _is_pure_number_query, _sanitize_fts5_query, search_all, search_github_items, @@ -184,6 +186,134 @@ def test_search_all_returns_both_categories(self, populated_db: Path): assert isinstance(results["papers"], list) +class TestNumberExtraction: + """Tests for extracting PR/issue numbers from queries.""" + + def test_plain_number(self): + assert _extract_number("2022") == 2022 + + def test_hash_prefix(self): + assert _extract_number("#500") == 500 + + def test_pr_prefix(self): + assert _extract_number("PR 2022") == 2022 + assert _extract_number("pr #2022") == 2022 + + def test_issue_prefix(self): + assert _extract_number("issue 500") == 500 + assert _extract_number("issue #500") == 500 + + def test_pull_prefix(self): + assert _extract_number("pull #100") == 100 + + def test_no_number(self): + assert _extract_number("validation error") is None + assert _extract_number("how to use BIDS") is None + + def test_number_with_trailing_text(self): + """A number followed by other words is NOT treated as a number lookup.""" + assert _extract_number("123 validation error") is None + + def test_bug_and_feature_prefixes(self): + """bug and feature prefixes are supported per the regex.""" + assert _extract_number("bug 42") == 42 + assert _extract_number("feature #99") == 99 + + def test_whitespace(self): + assert _extract_number(" 2022 ") == 2022 + + +class TestPureNumberQuery: + """Tests for detecting pure number queries.""" + + def test_plain_number(self): + assert _is_pure_number_query("2022") is True + + def test_hash_number(self): + assert _is_pure_number_query("#500") is True + + def test_pr_prefix(self): + assert _is_pure_number_query("PR 2022") is True + assert _is_pure_number_query("issue #500") is True + + def test_text_query(self): + assert _is_pure_number_query("validation error") is False + assert _is_pure_number_query("2022 roadmap plans") is False + + +class TestNumberLookup: + """Tests for searching GitHub items by number.""" + + def test_search_by_number(self, populated_db: Path): + """Search for an item by its number.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("10") + + assert len(results) >= 1 + assert results[0].url == "https://github.com/hed-standard/hed-schemas/pull/10" + assert results[0].title == "Add new sensory tags" + + def test_search_by_hash_number(self, populated_db: Path): + """Search with # prefix.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("#1") + + assert len(results) >= 1 + assert results[0].url == "https://github.com/hed-standard/hed-specification/issues/1" + + def test_search_by_pr_number(self, populated_db: Path): + """Search with 'PR' prefix.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("PR 10") + + assert len(results) >= 1 + assert results[0].item_type == "pr" + assert results[0].url == "https://github.com/hed-standard/hed-schemas/pull/10" + assert results[0].title == "Add new sensory tags" + + def test_number_lookup_with_type_filter(self, populated_db: Path): + """Number lookup respects item_type filter.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + # Item #10 is a PR, filtering for issues should not return it + results = search_github_items("10", item_type="issue") + assert all(r.item_type == "issue" for r in results) + + def test_number_lookup_with_status_filter(self, populated_db: Path): + """Number lookup respects status filter.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + # Item #1 is open, filtering for closed should not return it + results = search_github_items("#1", status="closed") + assert all(r.status == "closed" for r in results) + # But filtering for open should return it + results = search_github_items("#1", status="open") + assert len(results) >= 1 + assert results[0].url == "https://github.com/hed-standard/hed-specification/issues/1" + + def test_nonexistent_number_returns_empty(self, populated_db: Path): + """A number that doesn't exist returns empty for pure-number queries.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("#9999") + assert len(results) == 0 + + def test_limit_respected_with_number_match(self, populated_db: Path): + """Limit parameter is respected even with number matches.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("1", limit=1) + assert len(results) <= 1 + + def test_number_lookup_deduplicates(self, populated_db: Path): + """Number match should not appear twice if also found by FTS.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("1") + urls = [r.url for r in results] + assert len(urls) == len(set(urls)), "Duplicate URLs in results" + # Number match should be first in results + if len(results) > 0: + assert ( + results[0].url == "https://github.com/hed-standard/hed-specification/issues/1" + ) + + class TestFTS5Sanitization: """Tests for FTS5 query sanitization to prevent injection.""" From b41995fabc56a477d41528b58df9b7e799d96434 Mon Sep 17 00:00:00 2001 From: "Seyed (Yahya) Shirazi" Date: Thu, 5 Feb 2026 05:50:25 -0800 Subject: [PATCH 21/26] fix: dashboard CORS origin, worker routes, and API auto-detect (#155) * fix: add dashboard CORS origin and missing worker routes Worker: - Add osa-dash.pages.dev to CORS allowlist (dashboard origin) - Add routes for dashboard endpoints: /metrics/public/overview, /metrics/{overview,tokens,quality}, /sync/{status,health}, /{community}/metrics/public, /{community}/sessions Dashboard: - Auto-detect API backend from hostname instead of falling back to window.location.origin (which fails for deployed dashboards) - osa-dash.pages.dev -> api.osc.earth/osa (prod) - *.osa-dash.pages.dev -> api.osc.earth/osa-dev (dev/preview) * fix: add osa-dash.pages.dev to backend CORS allowlist The dashboard (osa-dash.pages.dev) was being rejected by the backend's CORS middleware. Add both exact and wildcard patterns for the dashboard Pages project. * fix: harden worker security and improve dashboard errors Worker: - Split proxy into worker-key vs client-key passthrough modes - Admin/sessions endpoints now forward client key (not worker key) - Extract RESERVED_PATHS constant, rateLimitOrReject helper - Extract validateCommunityId with consistent reserved path checks - Add rate limiting to all endpoints including admin and sync - Validate https:// protocol on subdomain CORS checks - Only forward Origin header if CORS-validated - Replace bare catch blocks with console.warn logging Dashboard: - Add status-code-aware error messages (429, 404, 401, 500) - Replace generic "Failed to load" with actionable user guidance --- dashboard/osa/index.html | 32 ++++-- src/api/main.py | 6 +- workers/osa-worker/index.js | 220 ++++++++++++++++++++++++------------ 3 files changed, 176 insertions(+), 82 deletions(-) diff --git a/dashboard/osa/index.html b/dashboard/osa/index.html index c08122b..25c0234 100644 --- a/dashboard/osa/index.html +++ b/dashboard/osa/index.html @@ -287,13 +287,18 @@

Admin Access

// ----------------------------------------------------------------------- // Configuration // ----------------------------------------------------------------------- - // API base URL. Priority: ?api= query param > window.OSA_API_BASE > same origin. - // Same-origin default only works for local dev; deployed dashboards MUST set - // ?api= or window.OSA_API_BASE (e.g., https://api.osc.earth/osa). + // API base URL. Priority: ?api= query param > window.OSA_API_BASE > auto-detect. + // Auto-detection maps known dashboard hosts to their API backends. + // Falls back to same-origin for local dev (uvicorn serves both). const API_BASE = (function() { const params = new URLSearchParams(window.location.search); if (params.get('api')) return params.get('api').replace(/\/+$/, ''); if (window.OSA_API_BASE) return window.OSA_API_BASE.replace(/\/+$/, ''); + const host = window.location.hostname; + if (host === 'osa-dash.pages.dev' || host === 'status.osc.earth') + return 'https://api.osc.earth/osa'; + if (host.endsWith('.osa-dash.pages.dev')) + return 'https://api.osc.earth/osa-dev'; return window.location.origin; })(); @@ -303,6 +308,15 @@

Admin Access

return div.innerHTML; } + function userFriendlyError(status, context) { + if (!status) return `Unable to reach the API server. Check your connection or try again later.`; + if (status === 429) return `Rate limit exceeded. Please wait a moment and try again.`; + if (status === 404) return `${context || 'Resource'} not found. It may not exist or the URL may be incorrect.`; + if (status === 401 || status === 403) return `Access denied. An API key may be required.`; + if (status >= 500) return `The server encountered an error (${status}). Try again later.`; + return `Unexpected error (HTTP ${status}). Try again later.`; + } + // ----------------------------------------------------------------------- // State // ----------------------------------------------------------------------- @@ -337,8 +351,10 @@

Admin Access

const route = getRoute(); // Always fetch overview first to populate tabs + let overviewStatus = null; try { const resp = await fetch(`${API_BASE}/metrics/public/overview`); + overviewStatus = resp.status; if (!resp.ok) throw new Error(`HTTP ${resp.status}`); overviewData = await resp.json(); renderTabs(overviewData.communities, route.community); @@ -356,7 +372,8 @@

Admin Access

document.getElementById('adminCard').style.display = ''; } else { app.className = ''; - app.innerHTML = '
Failed to load metrics.
'; + const msg = userFriendlyError(overviewStatus, 'Metrics overview'); + app.innerHTML = `
${escapeHtml(msg)}
`; } } }); @@ -447,8 +464,8 @@

Communities

fetch(`${API_BASE}/sync/health`).catch(err => { console.warn('Health check fetch failed (non-critical):', err.message); return null; }), ]); - if (!summaryResp.ok) throw new Error(`HTTP ${summaryResp.status}`); - if (!usageResp.ok) throw new Error(`HTTP ${usageResp.status}`); + const failedStatus = !summaryResp.ok ? summaryResp.status : (!usageResp.ok ? usageResp.status : null); + if (failedStatus) throw { status: failedStatus }; const summary = await summaryResp.json(); const usage = await usageResp.json(); @@ -462,7 +479,8 @@

Communities

} catch (err) { console.error('Failed to load community view:', communityId, err); app.className = ''; - app.innerHTML = `
Failed to load data for ${safeName}: ${escapeHtml(err.message)}
`; + const msg = userFriendlyError(err.status || null, `Community "${safeName}"`); + app.innerHTML = `
${escapeHtml(msg)}
`; } } diff --git a/src/api/main.py b/src/api/main.py index 9d6566a..c2e0178 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -109,7 +109,7 @@ def _collect_cors_config() -> tuple[list[str], str | None]: Aggregates exact origins and wildcard patterns from: 1. Platform-level settings (Settings.cors_origins) 2. Per-community config (CommunityConfig.cors_origins) - 3. Default platform wildcard (*.osa-demo.pages.dev) + 3. Default platform wildcards (*.osa-demo.pages.dev, *.osa-dash.pages.dev) Returns: Tuple of (exact_origins, origin_regex_pattern). @@ -118,10 +118,12 @@ def _collect_cors_config() -> tuple[list[str], str | None]: settings = get_settings() exact_origins: list[str] = list(settings.cors_origins) - # Add default demo page origins + # Add default demo/dashboard page origins exact_origins.append("https://osa-demo.pages.dev") + exact_origins.append("https://osa-dash.pages.dev") wildcard_patterns: list[str] = [ "https://*.osa-demo.pages.dev", # Default: preview/branch deploys + "https://*.osa-dash.pages.dev", # Dashboard: preview/branch deploys ] # Collect from all registered communities diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index 0937998..52ac31b 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -9,6 +9,9 @@ * - BYOK mode for CLI/programmatic access */ +// Path segments that are actual routes, never valid community IDs +const RESERVED_PATHS = ['health', 'version', 'feedback', 'communities', 'metrics', 'sync']; + // Worker configuration function getConfig(env) { const isDev = env.ENVIRONMENT === 'development'; @@ -131,6 +134,20 @@ async function checkRateLimit(request, env, CONFIG) { return { allowed: true }; } +/** + * Check rate limit and return a 429 response if exceeded, or null if allowed. + */ +async function rateLimitOrReject(request, env, corsHeaders, CONFIG) { + const rl = await checkRateLimit(request, env, CONFIG); + if (!rl.allowed) { + return new Response( + JSON.stringify({ error: 'Rate limit exceeded', details: rl.reason }), + { status: 429, headers: { ...corsHeaders, 'Content-Type': 'application/json' } } + ); + } + return null; +} + /** * Check if origin is allowed */ @@ -152,16 +169,20 @@ function isAllowedOrigin(origin) { // Check exact matches if (allowedPatterns.includes(origin)) return true; - // Check subdomains - if (origin.endsWith('.eeglab.org')) return true; - if (origin.endsWith('.github.io')) return true; - if (origin.endsWith('.hedtags.org')) return true; - if (origin.endsWith('.neuroimaging.io')) return true; - if (origin.endsWith('.readthedocs.io')) return true; + // Check subdomains (require https:// protocol) + if (origin.startsWith('https://') && origin.endsWith('.eeglab.org')) return true; + if (origin.startsWith('https://') && origin.endsWith('.github.io')) return true; + if (origin.startsWith('https://') && origin.endsWith('.hedtags.org')) return true; + if (origin.startsWith('https://') && origin.endsWith('.neuroimaging.io')) return true; + if (origin.startsWith('https://') && origin.endsWith('.readthedocs.io')) return true; // Allow osa-demo.pages.dev and all subdomains (previews, branches) if (origin === 'https://osa-demo.pages.dev') return true; - if (origin.endsWith('.osa-demo.pages.dev')) return true; + if (origin.startsWith('https://') && origin.endsWith('.osa-demo.pages.dev')) return true; + + // Allow osa-dash.pages.dev (dashboard) and all subdomains + if (origin === 'https://osa-dash.pages.dev') return true; + if (origin.startsWith('https://') && origin.endsWith('.osa-dash.pages.dev')) return true; // Allow localhost for development if (origin.startsWith('http://localhost:')) return true; @@ -193,9 +214,45 @@ function getCorsHeaders(origin) { } /** - * Proxy request to backend + * Validate community ID and return error response if invalid, or null if valid. + */ +function validateCommunityId(communityId, corsHeaders) { + if (RESERVED_PATHS.includes(communityId)) { + return new Response('Not Found', { status: 404, headers: corsHeaders }); + } + if (!isValidCommunityId(communityId)) { + return new Response(JSON.stringify({ error: 'Invalid community ID format' }), { + status: 400, + headers: { ...corsHeaders, 'Content-Type': 'application/json' }, + }); + } + return null; +} + +/** + * Proxy request to backend with the worker's own API key. + * Used for public endpoints and widget traffic where the worker + * authenticates on behalf of the client. */ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { + return _proxyToBackend(request, env, path, body, corsHeaders, CONFIG, 'worker'); +} + +/** + * Proxy request to backend, forwarding the client's own headers. + * Used for admin/authenticated endpoints where the client must provide + * their own API key. The worker does NOT inject its key. + */ +async function proxyToBackendPassthrough(request, env, path, body, corsHeaders, CONFIG) { + return _proxyToBackend(request, env, path, body, corsHeaders, CONFIG, 'client'); +} + +/** + * Internal proxy implementation. + * + * @param {string} authMode - 'worker' to inject BACKEND_API_KEY, 'client' to forward client's X-API-Key + */ +async function _proxyToBackend(request, env, path, body, corsHeaders, CONFIG, authMode) { const backendUrl = env.BACKEND_URL; if (!backendUrl) { @@ -210,15 +267,23 @@ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { 'Content-Type': 'application/json', }; - // Add backend API key - if (env.BACKEND_API_KEY) { - backendHeaders['X-API-Key'] = env.BACKEND_API_KEY; + // Auth mode: inject worker key or forward client key + if (authMode === 'worker') { + if (env.BACKEND_API_KEY) { + backendHeaders['X-API-Key'] = env.BACKEND_API_KEY; + } + } else { + // Forward client's API key for backend to validate + const clientKey = request.headers.get('X-API-Key'); + if (clientKey) { + backendHeaders['X-API-Key'] = clientKey; + } } // Forward Origin header to backend for origin-based authorization checks - // Backend uses this to validate request source and apply origin-specific policies + // Only forward if the origin passed CORS validation const origin = request.headers.get('Origin'); - if (origin) { + if (origin && isAllowedOrigin(origin)) { backendHeaders['Origin'] = origin; } @@ -252,8 +317,8 @@ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { const text = await response.text(); backendError = { error: text.substring(0, 500) }; } - } catch { - // Use default error if parsing fails + } catch (parseErr) { + console.warn('Failed to parse backend error response:', parseErr.message); } return new Response(JSON.stringify(backendError), { @@ -339,58 +404,79 @@ export default { return await handleFeedback(request, env, corsHeaders, CONFIG); } + // --- Dashboard read-only endpoints (GET only, rate-limited) --- + + // Global public metrics: /metrics/public/overview + if (url.pathname === '/metrics/public/overview' && request.method === 'GET') { + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + return await proxyToBackend(request, env, '/metrics/public/overview', null, corsHeaders, CONFIG); + } + + // Admin metrics endpoints: client must provide their own API key + if (url.pathname.match(/^\/metrics\/(overview|tokens|quality)$/) && request.method === 'GET') { + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + const path = url.pathname + url.search; + return await proxyToBackendPassthrough(request, env, path, null, corsHeaders, CONFIG); + } + + // Sync status endpoints (public, read-only) + if ((url.pathname === '/sync/status' || url.pathname === '/sync/health') && request.method === 'GET') { + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + return await proxyToBackend(request, env, url.pathname, null, corsHeaders, CONFIG); + } + // Community config endpoint: /:communityId/ (GET) const communityConfigMatch = url.pathname.match(/^\/([^\/]+)\/?$/); if (communityConfigMatch && request.method === 'GET') { const communityId = communityConfigMatch[1]; - // Reject reserved path segments as community IDs - const reservedPaths = ['health', 'version', 'feedback', 'communities']; - if (reservedPaths.includes(communityId)) { - return new Response('Not Found', { status: 404, headers: corsHeaders }); - } - - // Validate community ID format - if (!isValidCommunityId(communityId)) { - return new Response(JSON.stringify({ error: 'Invalid community ID format' }), { - status: 400, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; - // Rate limit community config lookups - const rateLimitResult = await checkRateLimit(request, env, CONFIG); - if (!rateLimitResult.allowed) { - return new Response(JSON.stringify({ - error: 'Rate limit exceeded', - details: rateLimitResult.reason, - }), { - status: 429, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; return await proxyToBackend(request, env, `/${communityId}/`, null, corsHeaders, CONFIG); } + // Community public metrics endpoints (GET) + const communityMetricsMatch = url.pathname.match(/^\/([^\/]+)\/(metrics\/public(?:\/usage)?)$/); + if (communityMetricsMatch && request.method === 'GET') { + const communityId = communityMetricsMatch[1]; + + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; + + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + + return await proxyToBackend(request, env, url.pathname, null, corsHeaders, CONFIG); + } + + // Community sessions endpoint (GET, authenticated -- forward client key) + const communitySessionsMatch = url.pathname.match(/^\/([^\/]+)\/sessions$/); + if (communitySessionsMatch && request.method === 'GET') { + const communityId = communitySessionsMatch[1]; + + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; + + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + + return await proxyToBackendPassthrough(request, env, url.pathname, null, corsHeaders, CONFIG); + } + // Community endpoints: /:communityId/ask and /:communityId/chat const communityActionMatch = url.pathname.match(/^\/([^\/]+)\/(ask|chat)$/); if (communityActionMatch && request.method === 'POST') { const [, communityId, action] = communityActionMatch; - // Reject reserved path segments as community IDs - const reservedPaths = ['health', 'version', 'feedback', 'communities']; - if (reservedPaths.includes(communityId)) { - return new Response('Not Found', { status: 404, headers: corsHeaders }); - } - - // Validate community ID format - if (!isValidCommunityId(communityId)) { - return new Response(JSON.stringify({ error: 'Invalid community ID format' }), { - status: 400, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; return await handleProtectedEndpoint(request, env, ctx, `/${communityId}/${action}`, corsHeaders, CONFIG); } @@ -418,6 +504,11 @@ function handleRoot(corsHeaders, CONFIG) { 'GET /:communityId/': 'Get community configuration', 'POST /:communityId/ask': 'Ask a single question to a community', 'POST /:communityId/chat': 'Multi-turn conversation with a community', + 'GET /:communityId/metrics/public': 'Public community metrics', + 'GET /:communityId/sessions': 'List sessions (requires API key)', + 'GET /metrics/public/overview': 'Public metrics overview', + 'GET /metrics/overview': 'Admin metrics overview (requires API key)', + 'GET /sync/status': 'Knowledge sync status', 'POST /feedback': 'Submit feedback', 'GET /health': 'Health check', 'GET /version': 'Get API version', @@ -522,16 +613,8 @@ async function handleProtectedEndpoint(request, env, ctx, path, corsHeaders, CON } // Check rate limit - const rateLimitResult = await checkRateLimit(request, env, CONFIG); - if (!rateLimitResult.allowed) { - return new Response(JSON.stringify({ - error: 'Rate limit exceeded', - details: rateLimitResult.reason, - }), { - status: 429, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; // Remove Turnstile token from body before forwarding const { cf_turnstile_response, ...cleanBody } = body; @@ -543,17 +626,8 @@ async function handleProtectedEndpoint(request, env, ctx, path, corsHeaders, CON * Handle feedback endpoint (rate limited but no Turnstile) */ async function handleFeedback(request, env, corsHeaders, CONFIG) { - // Check rate limit - const rateLimitResult = await checkRateLimit(request, env, CONFIG); - if (!rateLimitResult.allowed) { - return new Response(JSON.stringify({ - error: 'Rate limit exceeded', - details: rateLimitResult.reason, - }), { - status: 429, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; const body = await request.json(); return await proxyToBackend(request, env, '/feedback', body, corsHeaders, CONFIG); From 09f651ad7283f4b640379bb3ec0906038f0af0a2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Feb 2026 13:50:35 +0000 Subject: [PATCH 22/26] chore: sync worker CORS from community configs [skip ci] --- workers/osa-worker/index.js | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index 52ac31b..dae2e6a 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -169,20 +169,16 @@ function isAllowedOrigin(origin) { // Check exact matches if (allowedPatterns.includes(origin)) return true; - // Check subdomains (require https:// protocol) - if (origin.startsWith('https://') && origin.endsWith('.eeglab.org')) return true; - if (origin.startsWith('https://') && origin.endsWith('.github.io')) return true; - if (origin.startsWith('https://') && origin.endsWith('.hedtags.org')) return true; - if (origin.startsWith('https://') && origin.endsWith('.neuroimaging.io')) return true; - if (origin.startsWith('https://') && origin.endsWith('.readthedocs.io')) return true; + // Check subdomains + if (origin.endsWith('.eeglab.org')) return true; + if (origin.endsWith('.github.io')) return true; + if (origin.endsWith('.hedtags.org')) return true; + if (origin.endsWith('.neuroimaging.io')) return true; + if (origin.endsWith('.readthedocs.io')) return true; // Allow osa-demo.pages.dev and all subdomains (previews, branches) if (origin === 'https://osa-demo.pages.dev') return true; - if (origin.startsWith('https://') && origin.endsWith('.osa-demo.pages.dev')) return true; - - // Allow osa-dash.pages.dev (dashboard) and all subdomains - if (origin === 'https://osa-dash.pages.dev') return true; - if (origin.startsWith('https://') && origin.endsWith('.osa-dash.pages.dev')) return true; + if (origin.endsWith('.osa-demo.pages.dev')) return true; // Allow localhost for development if (origin.startsWith('http://localhost:')) return true; From 006a130d67254398f2f7aaa03e1a6b95c0a7aefc Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Thu, 5 Feb 2026 05:51:41 -0800 Subject: [PATCH 23/26] bump: 0.5.5.dev0 -> 0.6.0.dev0 for dashboard release --- src/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/version.py b/src/version.py index 92084bd..21df2d0 100644 --- a/src/version.py +++ b/src/version.py @@ -1,7 +1,7 @@ """Version information for OSA.""" -__version__ = "0.5.5.dev0" -__version_info__ = (0, 5, 5) +__version__ = "0.6.0.dev0" +__version_info__ = (0, 6, 0, "dev") def get_version() -> str: From 11de1b63b3296679e7e312c643912ceecbd46ffa Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Thu, 5 Feb 2026 05:56:36 -0800 Subject: [PATCH 24/26] fix: apply worker security hardening lost in squash merge Re-apply changes from d12c98e that were listed in the squash merge commit message but not included in the actual diff: - Split proxy into worker-key vs client-key passthrough modes - Admin/sessions endpoints forward client key (not worker key) - Extract RESERVED_PATHS, rateLimitOrReject, validateCommunityId - Rate limiting on all endpoints including admin and sync - Validate https:// protocol on subdomain CORS checks - Only forward Origin header if CORS-validated - Replace bare catch blocks with console.warn logging --- workers/osa-worker/index.js | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index dae2e6a..52ac31b 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -169,16 +169,20 @@ function isAllowedOrigin(origin) { // Check exact matches if (allowedPatterns.includes(origin)) return true; - // Check subdomains - if (origin.endsWith('.eeglab.org')) return true; - if (origin.endsWith('.github.io')) return true; - if (origin.endsWith('.hedtags.org')) return true; - if (origin.endsWith('.neuroimaging.io')) return true; - if (origin.endsWith('.readthedocs.io')) return true; + // Check subdomains (require https:// protocol) + if (origin.startsWith('https://') && origin.endsWith('.eeglab.org')) return true; + if (origin.startsWith('https://') && origin.endsWith('.github.io')) return true; + if (origin.startsWith('https://') && origin.endsWith('.hedtags.org')) return true; + if (origin.startsWith('https://') && origin.endsWith('.neuroimaging.io')) return true; + if (origin.startsWith('https://') && origin.endsWith('.readthedocs.io')) return true; // Allow osa-demo.pages.dev and all subdomains (previews, branches) if (origin === 'https://osa-demo.pages.dev') return true; - if (origin.endsWith('.osa-demo.pages.dev')) return true; + if (origin.startsWith('https://') && origin.endsWith('.osa-demo.pages.dev')) return true; + + // Allow osa-dash.pages.dev (dashboard) and all subdomains + if (origin === 'https://osa-dash.pages.dev') return true; + if (origin.startsWith('https://') && origin.endsWith('.osa-dash.pages.dev')) return true; // Allow localhost for development if (origin.startsWith('http://localhost:')) return true; From 2b1fe729bb4c369677505a1257f429cc54b5ec0a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 5 Feb 2026 13:56:56 +0000 Subject: [PATCH 25/26] chore: sync worker CORS from community configs [skip ci] --- workers/osa-worker/index.js | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index 52ac31b..dae2e6a 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -169,20 +169,16 @@ function isAllowedOrigin(origin) { // Check exact matches if (allowedPatterns.includes(origin)) return true; - // Check subdomains (require https:// protocol) - if (origin.startsWith('https://') && origin.endsWith('.eeglab.org')) return true; - if (origin.startsWith('https://') && origin.endsWith('.github.io')) return true; - if (origin.startsWith('https://') && origin.endsWith('.hedtags.org')) return true; - if (origin.startsWith('https://') && origin.endsWith('.neuroimaging.io')) return true; - if (origin.startsWith('https://') && origin.endsWith('.readthedocs.io')) return true; + // Check subdomains + if (origin.endsWith('.eeglab.org')) return true; + if (origin.endsWith('.github.io')) return true; + if (origin.endsWith('.hedtags.org')) return true; + if (origin.endsWith('.neuroimaging.io')) return true; + if (origin.endsWith('.readthedocs.io')) return true; // Allow osa-demo.pages.dev and all subdomains (previews, branches) if (origin === 'https://osa-demo.pages.dev') return true; - if (origin.startsWith('https://') && origin.endsWith('.osa-demo.pages.dev')) return true; - - // Allow osa-dash.pages.dev (dashboard) and all subdomains - if (origin === 'https://osa-dash.pages.dev') return true; - if (origin.startsWith('https://') && origin.endsWith('.osa-dash.pages.dev')) return true; + if (origin.endsWith('.osa-demo.pages.dev')) return true; // Allow localhost for development if (origin.startsWith('http://localhost:')) return true; From 0fd17a4c012d98a6bfcb76c8b65acba7e7a47c9d Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Thu, 5 Feb 2026 06:44:39 -0800 Subject: [PATCH 26/26] fix: dashboard total counts all requests, missing communities - Total now counts only community-scoped requests, not health checks, sync, and other infrastructure endpoints - All registered communities appear in overview even with 0 requests - Middleware _extract_community_id now recognizes metrics, sessions, and config endpoints (not just /ask and /chat) --- src/api/routers/metrics_public.py | 5 ++- src/metrics/middleware.py | 22 ++++++++++--- src/metrics/queries.py | 47 +++++++++++++++++++++------ tests/test_api/test_metrics_public.py | 11 ++++--- tests/test_metrics/test_middleware.py | 23 +++++++------ 5 files changed, 77 insertions(+), 31 deletions(-) diff --git a/src/api/routers/metrics_public.py b/src/api/routers/metrics_public.py index 2aaf1f9..4e38f34 100644 --- a/src/api/routers/metrics_public.py +++ b/src/api/routers/metrics_public.py @@ -13,6 +13,7 @@ from fastapi import APIRouter, HTTPException +from src.assistants import registry from src.metrics.db import metrics_connection from src.metrics.queries import get_public_overview @@ -27,10 +28,12 @@ async def public_overview() -> dict[str, Any]: Returns total requests, error rate, active community count, and per-community request counts. No tokens, costs, or model info. + All registered communities appear even if they have zero requests. """ + registered = [info.id for info in registry.list_available()] try: with metrics_connection() as conn: - return get_public_overview(conn) + return get_public_overview(conn, registered_communities=registered) except sqlite3.Error: logger.exception("Failed to query metrics database for public overview") raise HTTPException( diff --git a/src/metrics/middleware.py b/src/metrics/middleware.py index 92fbf17..174b034 100644 --- a/src/metrics/middleware.py +++ b/src/metrics/middleware.py @@ -21,14 +21,28 @@ logger = logging.getLogger(__name__) -# Path segments that indicate a community route -_COMMUNITY_ENDPOINTS = {"/ask", "/chat"} +# Path suffixes that indicate a community-scoped route +_COMMUNITY_SUFFIXES = {"ask", "chat", "sessions", "metrics", "config"} + +# Top-level path prefixes that are NOT community IDs +_RESERVED_PREFIXES = {"health", "metrics", "sync", "docs", "redoc", "openapi.json", "frontend"} def _extract_community_id(path: str) -> str | None: - """Extract community_id from URL path like /{community_id}/ask.""" + """Extract community_id from community-scoped URL paths. + + Matches patterns like /{community_id}/ask, /{community_id}/metrics/public, + /{community_id}/sessions, etc. Returns None for non-community routes + (health, sync, global metrics, docs). + """ parts = path.strip("/").split("/") - if len(parts) >= 2 and f"/{parts[1]}" in _COMMUNITY_ENDPOINTS: + if not parts or not parts[0]: + return None + # Skip known non-community prefixes + if parts[0] in _RESERVED_PREFIXES: + return None + # Single segment (e.g., /health) or community with sub-path + if len(parts) >= 2 and parts[1] in _COMMUNITY_SUFFIXES: return parts[0] return None diff --git a/src/metrics/queries.py b/src/metrics/queries.py index aa5073e..8c2d7be 100644 --- a/src/metrics/queries.py +++ b/src/metrics/queries.py @@ -312,24 +312,32 @@ def get_token_breakdown( # --------------------------------------------------------------------------- -def get_public_overview(conn: sqlite3.Connection) -> dict[str, Any]: +def get_public_overview( + conn: sqlite3.Connection, + registered_communities: list[str] | None = None, +) -> dict[str, Any]: """Get public metrics overview with only non-sensitive data. Returns request counts and error rates; no tokens, costs, or model info. + Only counts community-scoped requests (not health checks, sync, etc.). Args: conn: SQLite connection. + registered_communities: If provided, ensures all these communities + appear in the response even if they have zero logged requests. Returns: Dict with total_requests, error_rate, communities_active, and per-community request counts. """ + # Only count community-scoped requests, not infrastructure endpoints totals = conn.execute( """ SELECT COUNT(*) as total_requests, COUNT(CASE WHEN status_code >= 400 THEN 1 END) as total_errors FROM request_log + WHERE community_id IS NOT NULL """ ).fetchone() @@ -348,18 +356,37 @@ def get_public_overview(conn: sqlite3.Connection) -> dict[str, Any]: """ ).fetchall() + # Build community data from DB results + community_data: dict[str, dict[str, Any]] = {} + for r in community_rows: + cid = r["community_id"] + community_data[cid] = { + "community_id": cid, + "requests": r["requests"], + "error_rate": round(r["errors"] / r["requests"], 4) if r["requests"] > 0 else 0.0, + } + + # Include registered communities with zero requests + if registered_communities: + for cid in registered_communities: + if cid not in community_data: + community_data[cid] = { + "community_id": cid, + "requests": 0, + "error_rate": 0.0, + } + + # Sort by requests descending, then alphabetically + communities_list = sorted( + community_data.values(), + key=lambda c: (-c["requests"], c["community_id"]), + ) + return { "total_requests": total_req, "error_rate": round(totals["total_errors"] / total_req, 4) if total_req > 0 else 0.0, - "communities_active": len(community_rows), - "communities": [ - { - "community_id": r["community_id"], - "requests": r["requests"], - "error_rate": round(r["errors"] / r["requests"], 4) if r["requests"] > 0 else 0.0, - } - for r in community_rows - ], + "communities_active": sum(1 for c in communities_list if c["requests"] > 0), + "communities": communities_list, } diff --git a/tests/test_api/test_metrics_public.py b/tests/test_api/test_metrics_public.py index 8f8f427..70066a6 100644 --- a/tests/test_api/test_metrics_public.py +++ b/tests/test_api/test_metrics_public.py @@ -337,7 +337,10 @@ def test_overview_empty_db(self, client): assert data["total_requests"] == 0 assert data["error_rate"] == 0.0 assert data["communities_active"] == 0 - assert data["communities"] == [] + # Registered communities still appear with zero requests + for c in data["communities"]: + assert c["requests"] == 0 + assert c["error_rate"] == 0.0 @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") def test_community_metrics_empty_db(self, client): @@ -372,10 +375,10 @@ def test_community_values_from_overview(self, client): continue # community route not registered in test app detail = resp.json() assert detail["total_requests"] == community["requests"] - assert detail["total_requests"] > 0 assert detail["error_rate"] == community["error_rate"] - checked += 1 - assert checked > 0, "Expected at least one community with a registered route" + if community["requests"] > 0: + checked += 1 + assert checked > 0, "Expected at least one community with requests" @pytest.mark.usefixtures("isolated_metrics", "noauth_env") def test_per_community_tool_counts_consistent(self, client): diff --git a/tests/test_metrics/test_middleware.py b/tests/test_metrics/test_middleware.py index e4412e5..7f3e26c 100644 --- a/tests/test_metrics/test_middleware.py +++ b/tests/test_metrics/test_middleware.py @@ -26,18 +26,17 @@ def test_returns_none_for_non_community(self): def test_returns_none_for_root(self): assert _extract_community_id("/") is None - def test_returns_none_for_community_metrics_paths(self): - """Community metrics/session paths intentionally return None. - - These endpoints handle community_id directly in route handlers - rather than through middleware extraction. - """ - assert _extract_community_id("/hed/metrics") is None - assert _extract_community_id("/hed/metrics/public") is None - assert _extract_community_id("/hed/metrics/usage") is None - assert _extract_community_id("/hed/metrics/quality") is None - assert _extract_community_id("/hed/sessions") is None - assert _extract_community_id("/hed/config") is None + def test_extracts_from_community_metrics_paths(self): + """Community metrics/session/config paths return the community ID.""" + assert _extract_community_id("/hed/metrics") == "hed" + assert _extract_community_id("/hed/metrics/public") == "hed" + assert _extract_community_id("/hed/sessions") == "hed" + assert _extract_community_id("/hed/config") == "hed" + + def test_returns_none_for_deep_community_paths(self): + """Paths under community that don't match known suffixes return None.""" + assert _extract_community_id("/hed/unknown") is None + assert _extract_community_id("/hed/something/else") is None class TestMetricsMiddleware: