diff --git a/.context/local-testing-guide.md b/.context/local-testing-guide.md index 6c57954..1cfbaf4 100644 --- a/.context/local-testing-guide.md +++ b/.context/local-testing-guide.md @@ -1,257 +1,81 @@ -# Local Testing Guide for EEGLAB Assistant +# Local Testing Guide for Community Assistants -## Quick Test (Verify Configuration) +Quick reference for testing any community assistant locally. For the full guide, see https://docs.osc.earth/osa/registry/local-testing/ -```bash -cd /Users/yahya/Documents/git/osa-phase1 +## Quick Validation -# Run quick verification -uv run python test_eeglab_interactive.py +```bash +# Validate config loads +uv run pytest tests/test_core/ -k "community" -v + +# Or programmatically +uv run python -c " +from pathlib import Path +from src.core.config.community import CommunityConfig +config = CommunityConfig.from_yaml(Path('src/assistants/COMMUNITY_ID/config.yaml')) +print(f'Loaded: {config.name} with {len(config.documentation)} docs') +" ``` -## Full Backend Testing - -### 1. Set Environment Variables +## Environment Variables ```bash -# Required: OpenRouter API key for LLM export OPENROUTER_API_KEY="your-key-here" - -# Optional: API keys for admin functions (sync) +# Optional: for sync operations export API_KEYS="test-key-123" - -# Optional: Specific EEGLAB key (if community has BYOK) -# export OPENROUTER_API_KEY_EEGLAB="eeglab-specific-key" +# Optional: community-specific key +# export OPENROUTER_API_KEY_COMMUNITY="key" ``` -### 2. Start Backend Server +## Server Testing ```bash -cd /Users/yahya/Documents/git/osa-phase1 - -# Start development server +# Start dev server uv run uvicorn src.api.main:app --reload --port 38528 -``` - -Server will be available at: `http://localhost:38528` - -### 3. Test Endpoints -#### A. List All Communities - -```bash +# List communities (verify yours appears) curl http://localhost:38528/communities | jq -``` - -**Expected response:** -```json -{ - "communities": [ - { - "id": "eeglab", - "name": "EEGLAB", - "description": "EEG signal processing and analysis toolbox", - "status": "available" - }, - { - "id": "hed", - "name": "HED (Hierarchical Event Descriptors)", - ... - } - ] -} -``` - -#### B. Get EEGLAB Community Info -```bash -curl http://localhost:38528/communities/eeglab | jq -``` - -**Expected response:** -```json -{ - "id": "eeglab", - "name": "EEGLAB", - "description": "EEG signal processing and analysis toolbox", - "status": "available", - "documentation_count": 26, - "github_repos": 6, - "has_sync_config": true -} -``` - -#### C. Ask a Question (Simple) - -```bash -curl -X POST http://localhost:38528/eeglab/ask \ - -H "Content-Type: application/json" \ - -d '{ - "question": "What is EEGLAB?", - "api_key": "your-openrouter-key" - }' | jq -``` - -**Expected response:** -```json -{ - "answer": "EEGLAB is an interactive MATLAB toolbox...", - "sources": [ - { - "title": "EEGLAB quickstart", - "url": "https://sccn.github.io/..." - } - ] -} -``` - -#### D. Ask About ICA - -```bash -curl -X POST http://localhost:38528/eeglab/ask \ +# Ask a question +curl -X POST http://localhost:38528/COMMUNITY_ID/ask \ -H "Content-Type: application/json" \ - -d '{ - "question": "How do I run ICA in EEGLAB?", - "api_key": "your-openrouter-key" - }' | jq + -d '{"question": "What is this tool?", "api_key": "your-key"}' | jq ``` -**Should mention:** ICA decomposition, ICLabel, artifact removal - -#### E. Test Chat Endpoint +## CLI Testing (No Server Needed) ```bash -curl -X POST http://localhost:38528/eeglab/chat \ - -H "Content-Type: application/json" \ - -d '{ - "messages": [ - {"role": "user", "content": "What preprocessing steps should I do?"} - ], - "api_key": "your-openrouter-key" - }' | jq -``` - -**Should mention:** Filtering, re-referencing, ICA, artifact removal +# Interactive chat +uv run osa chat --community COMMUNITY_ID --standalone -### 4. Test Documentation Retrieval - -The assistant should automatically retrieve docs. Test by asking specific questions: - -```bash -# Should trigger retrieve_eeglab_docs tool -curl -X POST http://localhost:38528/eeglab/ask \ - -H "Content-Type: application/json" \ - -d '{ - "question": "How do I filter my EEG data in EEGLAB?", - "api_key": "your-openrouter-key", - "stream": false - }' | jq '.tool_calls' +# Single question +uv run osa ask --community COMMUNITY_ID "What is this tool?" --standalone ``` -**Expected:** Should call `retrieve_eeglab_docs` with filter-related docs - -### 5. Test via CLI (Easier!) +## Knowledge Sync ```bash -cd /Users/yahya/Documents/git/osa-phase1 - -# Set API key -export OPENROUTER_API_KEY="your-key-here" - -# Start interactive chat -uv run osa chat --community eeglab --standalone - -# Or ask single question -uv run osa ask --community eeglab "What is EEGLAB?" --standalone +uv run osa sync init --community COMMUNITY_ID +uv run osa sync github --community COMMUNITY_ID --full +uv run osa sync papers --community COMMUNITY_ID --citations ``` -**CLI is easier for testing because:** -- Handles API key automatically -- Shows formatted output -- Interactive mode for multi-turn conversations +## Test Checklist -## Test Questions for EEGLAB - -Good test questions to verify configuration: - -1. **Basic Info:** - - "What is EEGLAB?" - - "What can EEGLAB do?" - -2. **Preprocessing:** - - "What preprocessing steps should I do?" - - "How do I filter EEG data?" - - "How do I re-reference my data?" - -3. **ICA:** - - "How do I run ICA in EEGLAB?" - - "What is ICLabel?" - - "How do I remove artifacts with ICA?" - -4. **Plugins:** - - "What is clean_rawdata?" - - "How do I use ASR?" - - "What is the PREP pipeline?" - -5. **Integration:** - - "How do I use EEGLAB with BIDS?" - - "Can I use EEGLAB with Python?" - -6. **Knowledge Base (requires sync):** - - "What are the latest issues in the eeglab repo?" - - "Show me recent PRs in ICLabel" - - "Papers about EEGLAB ICA" +- [ ] Config validates without errors +- [ ] Community appears in `/communities` +- [ ] `/ask` endpoint returns relevant answers +- [ ] `/chat` endpoint works for multi-turn +- [ ] Preloaded docs are in context +- [ ] On-demand docs retrieved when relevant +- [ ] Documentation URLs in responses are valid +- [ ] CLI standalone mode works +- [ ] Knowledge sync completes (if configured) +- [ ] Assistant does not hallucinate PR/issue numbers ## Troubleshooting -### Server won't start - -```bash -# Check if port is already in use -lsof -i :38528 - -# Use different port -uv run uvicorn src.api.main:app --reload --port 38529 -``` - -### "Assistant not found" error - -```bash -# Verify EEGLAB is registered -uv run python -c "from src.assistants import discover_assistants, registry; discover_assistants(); print('eeglab' in registry)" -``` - -### Documentation not retrieved - -- Check that `retrieve_eeglab_docs` tool is available -- Check network access to sccn.github.io -- Check tool calls in response - -### Knowledge base empty - -- Knowledge base requires `API_KEYS` env var for sync -- Run sync locally: - ```bash - export API_KEYS="test-key" - uv run osa sync init --community eeglab - uv run osa sync github --community eeglab --full - ``` - -## Expected Behavior - -**What works without knowledge sync:** -- ✓ Assistant creation -- ✓ System prompt -- ✓ Documentation retrieval (fetches from URLs) -- ✓ Answering questions about EEGLAB -- ✓ Providing guidance on workflows - -**What needs knowledge sync:** -- ✗ Searching GitHub issues/PRs -- ✗ Listing recent activity -- ✗ Searching papers -- ✗ Citation counts - -## Next: Epic Branch Workflow - -See `epic-branch-workflow.md` for multi-phase development process. +- **Server won't start**: Check port with `lsof -i :38528` +- **Assistant not found**: Check discovery with `uv run python -c "from src.assistants import discover_assistants, registry; discover_assistants(); print([a.id for a in registry.list_available()])"` +- **Docs not retrieved**: Test source_url with `curl -I ` +- **Knowledge empty**: Run `uv run osa sync init --community COMMUNITY_ID` first diff --git a/.github/workflows/deploy-dashboard.yml b/.github/workflows/deploy-dashboard.yml new file mode 100644 index 0000000..0f62b97 --- /dev/null +++ b/.github/workflows/deploy-dashboard.yml @@ -0,0 +1,105 @@ +name: Deploy Dashboard to Cloudflare Pages + +on: + push: + branches: + - main + - develop + paths: + - 'dashboard/**' + - '.github/workflows/deploy-dashboard.yml' + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'dashboard/**' + - '.github/workflows/deploy-dashboard.yml' + workflow_dispatch: + +jobs: + deploy: + runs-on: ubuntu-latest + permissions: + contents: read + deployments: write + pull-requests: write + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Get branch name + id: branch + run: | + if [ "${{ github.event_name }}" = "pull_request" ]; then + BRANCH_NAME="${{ github.head_ref }}" + else + BRANCH_NAME="${{ github.ref_name }}" + fi + # Sanitize branch name for URL (replace / with -, lowercase) + SANITIZED=$(echo "$BRANCH_NAME" | tr '/' '-' | tr '[:upper:]' '[:lower:]') + # Cloudflare Pages truncates branch names to ~28 chars for preview URLs + TRUNCATED=$(echo "$SANITIZED" | cut -c1-28) + echo "name=$TRUNCATED" >> $GITHUB_OUTPUT + echo "original=$BRANCH_NAME" >> $GITHUB_OUTPUT + + - name: Deploy to Cloudflare Pages + id: deploy + uses: cloudflare/wrangler-action@v3 + with: + apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} + accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }} + # Deploy with branch name so Cloudflare Pages routes correctly: + # - main branch -> osa-dash.pages.dev (production), custom domain: status.osc.earth + # - develop branch -> develop.osa-dash.pages.dev (preview) + # - PR branches -> {branch}.osa-dash.pages.dev (preview) + # Dashboard is served at /osa/ path (e.g., status.osc.earth/osa/) + command: pages deploy dashboard --project-name=osa-dash --branch=${{ steps.branch.outputs.name }} + + - name: Comment on PR with preview URL + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + const branch = '${{ steps.branch.outputs.name }}'; + const previewUrl = `https://${branch}.osa-dash.pages.dev`; + + // Find existing comment from this workflow + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const botComment = comments.find(comment => + comment.user.type === 'Bot' && + comment.body.includes('') + ); + + const sha = '${{ github.sha }}'.substring(0, 7); + const body = [ + '', + '## Dashboard Preview', + '', + '| Name | Link |', + '|------|------|', + `| **Preview URL** | ${previewUrl} |`, + '| **Branch** | \`${{ steps.branch.outputs.original }}\` |', + `| **Commit** | \`${sha}\` |`, + '', + 'This preview will be updated automatically when you push new commits.' + ].join('\n'); + + if (botComment) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: body + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: body + }); + } diff --git a/.github/workflows/sync-worker-cors.yml b/.github/workflows/sync-worker-cors.yml index 3d364aa..4e65c0d 100644 --- a/.github/workflows/sync-worker-cors.yml +++ b/.github/workflows/sync-worker-cors.yml @@ -7,10 +7,13 @@ on: - 'src/assistants/*/config.yaml' - 'scripts/sync_worker_cors.py' - '.github/workflows/sync-worker-cors.yml' + - 'workers/osa-worker/**' pull_request: branches: [main, develop] paths: - 'src/assistants/*/config.yaml' + - 'workers/osa-worker/**' + workflow_dispatch: permissions: contents: write @@ -23,6 +26,7 @@ jobs: - uses: actions/checkout@v4 with: token: ${{ secrets.CI_ADMIN_TOKEN }} + fetch-depth: 0 # Fetch full history so git diff works - name: Set up Python uses: actions/setup-python@v5 @@ -40,16 +44,50 @@ jobs: - name: Check for changes id: check_changes run: | + # Check if CORS sync made changes if git diff --quiet workers/osa-worker/index.js; then - echo "changed=false" >> $GITHUB_OUTPUT echo "No CORS changes detected" + CORS_CHANGED=false else - echo "changed=true" >> $GITHUB_OUTPUT echo "CORS changes detected" + CORS_CHANGED=true + fi + + # Check if worker files were modified in the push that triggered this workflow + # Use the range from the push event (before -> after) + if [ "${{ github.event.before }}" != "0000000000000000000000000000000000000000" ]; then + # Normal push (not first commit) + if git diff --name-only ${{ github.event.before }} ${{ github.sha }} | grep -q "^workers/osa-worker/"; then + echo "Worker files changed in push" + WORKER_CHANGED=true + else + echo "No worker files in push" + WORKER_CHANGED=false + fi + else + # First commit to branch, check current files + if git ls-files workers/osa-worker/ | grep -q .; then + echo "Worker files exist (first commit)" + WORKER_CHANGED=true + else + WORKER_CHANGED=false + fi + fi + + # Set outputs + echo "cors_changed=$CORS_CHANGED" >> $GITHUB_OUTPUT + + # Deploy if either CORS sync changed files OR worker files were pushed + if [ "$CORS_CHANGED" = true ] || [ "$WORKER_CHANGED" = true ]; then + echo "changed=true" >> $GITHUB_OUTPUT + echo "✅ Deployment needed" + else + echo "changed=false" >> $GITHUB_OUTPUT + echo "No deployment needed" fi - - name: Commit changes (main/develop only) - if: steps.check_changes.outputs.changed == 'true' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop') && github.event_name == 'push' + - name: Commit CORS changes (main/develop only) + if: steps.check_changes.outputs.cors_changed == 'true' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop') && github.event_name == 'push' run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" @@ -64,7 +102,7 @@ jobs: run: | npm install -g wrangler cd workers/osa-worker - wrangler deploy --env="" + wrangler deploy echo "✅ Deployed to production worker" - name: Deploy to Cloudflare Workers (dev) @@ -82,9 +120,16 @@ jobs: uses: actions/github-script@v7 with: script: | + let message = '⚠️ **Worker Deployment Required**\n\n'; + if ('${{ steps.check_changes.outputs.cors_changed }}' === 'true') { + message += 'This PR modifies community CORS origins. '; + } + message += 'Worker changes detected. After merging to `main` or `develop`, the workflow will automatically deploy the worker.\n\n'; + message += '**Manual deployment (if needed):**\n```bash\ncd workers/osa-worker\nwrangler deploy --env dev # for develop branch\nwrangler deploy # for main branch\n```'; + github.rest.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, - body: '⚠️ **Worker CORS Update Required**\n\nThis PR modifies community CORS origins. After merging, run:\n\n```bash\npython scripts/sync_worker_cors.py\ngit add workers/osa-worker/index.js\ngit commit -m "chore: sync worker CORS"\ncd workers/osa-worker && wrangler deploy\n```\n\nOr merge to main and the sync workflow will auto-commit the changes.' + body: message }) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 371565f..7e8e946 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,7 +4,6 @@ on: push: branches: [main, develop] pull_request: - branches: [main, develop] jobs: lint: diff --git a/CLAUDE.md b/CLAUDE.md index 5e043a5..0740af4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -146,9 +146,16 @@ src/ - **.context/api_key_authorization_design.md** - API key auth and CORS - **.context/security-architecture.md** - Security patterns -### Community Configuration -- **.context/yaml_registry.md** - YAML-based community config -- **.context/community_onboarding_review.md** - Onboarding guidelines +### Community Development (Adding/Modifying Communities) +- **Full docs site**: https://docs.osc.earth/osa/registry/ (canonical reference) + - [Adding a Community](https://docs.osc.earth/osa/registry/quick-start/) - Step-by-step guide + - [Local Testing](https://docs.osc.earth/osa/registry/local-testing/) - Testing a new community end-to-end + - [Schema Reference](https://docs.osc.earth/osa/registry/schema-reference/) - Full YAML config schema + - [Extensions](https://docs.osc.earth/osa/registry/extensions/) - Python plugins and MCP servers +- **.context/yaml_registry.md** - YAML-based community config (internal notes) +- **.context/community_onboarding_review.md** - Onboarding gap analysis +- **.context/local-testing-guide.md** - Quick local testing reference +- **Existing configs to reference**: `src/assistants/hed/config.yaml`, `src/assistants/eeglab/config.yaml` ### Tool System - **.context/tool-system-guide.md** - How tools work and are registered @@ -239,6 +246,38 @@ docker exec osa-dev python -m src.cli.main sync github --full - Dev: https://develop.osa-demo.pages.dev - Prod: https://osa-demo.pages.dev +### Inspecting Knowledge Databases + +Knowledge databases (SQLite) live **inside the Docker containers**, not locally. +Do NOT look for `.db` files in the local repo; they won't be there. + +```bash +# List databases in a container +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa find /app/data/knowledge -name '*.db'" + +# Containers: osa (prod), osa-dev (dev) +# Database paths: /app/data/knowledge/{community_id}.db +# e.g., /app/data/knowledge/eeglab.db, /app/data/knowledge/hed.db + +# List tables (no sqlite3 binary; use python) +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa python3 -c 'import sqlite3; conn = sqlite3.connect(\"/app/data/knowledge/eeglab.db\"); print([r[0] for r in conn.execute(\"SELECT name FROM sqlite_master WHERE type=\\\"table\\\"\")]); conn.close()'" + +# Query example: count docstrings +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa python3 -c 'import sqlite3; conn = sqlite3.connect(\"/app/data/knowledge/eeglab.db\"); print(conn.execute(\"SELECT COUNT(*) FROM docstrings\").fetchone()[0]); conn.close()'" + +# Query example: search for a symbol +ssh -o "RequestTTY=no" -J hallu hedtools \ + "docker exec osa python3 -c 'import sqlite3; conn = sqlite3.connect(\"/app/data/knowledge/eeglab.db\"); [print(r) for r in conn.execute(\"SELECT symbol_name, file_path FROM docstrings WHERE symbol_name LIKE \\\"%erpimage%\\\"\").fetchall()]; conn.close()'" +``` + +**Important notes:** +- `sqlite3` CLI is not installed in containers; use `python3 -c` with the `sqlite3` module +- Use `ssh -o "RequestTTY=no"` to avoid interactive shell banners +- Dev and prod databases may differ; always check the right container + ## References - **API structure**: `.context/api-structure.md` (read first for API work) diff --git a/dashboard/_redirects b/dashboard/_redirects new file mode 100644 index 0000000..57ede61 --- /dev/null +++ b/dashboard/_redirects @@ -0,0 +1,3 @@ +/ /osa/ 301 +/osa /osa/index.html 200 +/osa/* /osa/index.html 200 diff --git a/dashboard/osa/index.html b/dashboard/osa/index.html new file mode 100644 index 0000000..25c0234 --- /dev/null +++ b/dashboard/osa/index.html @@ -0,0 +1,784 @@ + + + + + + Open Science Assistant - Dashboard + + + + +
+ +
+
Loading...
+ + + +
+ + + + diff --git a/frontend/index.html b/frontend/index.html index e93573a..3db1630 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -188,7 +188,7 @@ suggestedQuestions: [ 'What is BIDS and why should I use it?', 'How do I organize my EEG data in BIDS?', - 'What are the required metadata fields?', + 'What are the BIDS Common Principles?', 'How do I validate my BIDS dataset?' ] } diff --git a/src/api/config.py b/src/api/config.py index 77df6be..6de9bf3 100644 --- a/src/api/config.py +++ b/src/api/config.py @@ -1,5 +1,6 @@ """Configuration management for the OSA API.""" +import logging from functools import lru_cache from pydantic import Field @@ -7,6 +8,8 @@ from src.version import __version__ +logger = logging.getLogger(__name__) + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -56,6 +59,13 @@ class Settings(BaseSettings): ) require_api_auth: bool = Field(default=True, description="Require API key authentication") + # Per-community admin keys for scoped dashboard access + # Format: "community_id:key1,community_id:key2" (e.g., "hed:abc123,eeglab:xyz789") + community_admin_keys: str | None = Field( + default=None, + description="Per-community admin API keys (format: community_id:key,...)", + ) + # LLM Provider Settings (server defaults, can be overridden by BYOK) openrouter_api_key: str | None = Field(default=None, description="OpenRouter API key") openai_api_key: str | None = Field(default=None, description="OpenAI API key") @@ -127,6 +137,42 @@ class Settings(BaseSettings): description="Cron schedule for papers sync (default: weekly Sunday at 3am UTC)", ) + def parse_admin_keys(self) -> set[str]: + """Parse API_KEYS into a set of valid admin keys. + + Returns: + Set of valid API key strings. + """ + if not self.api_keys: + return set() + return {k.strip() for k in self.api_keys.split(",") if k.strip()} + + def parse_community_admin_keys(self) -> dict[str, set[str]]: + """Parse COMMUNITY_ADMIN_KEYS into {community_id: {keys}} mapping. + + Format: "community_id:key1,community_id:key2" + Multiple keys per community are supported. + + Returns: + Dict mapping community_id to set of valid API keys. + """ + if not self.community_admin_keys: + return {} + result: dict[str, set[str]] = {} + for entry in self.community_admin_keys.split(","): + entry = entry.strip() + if not entry: + continue + if ":" not in entry: + logger.error("Skipping malformed community_admin_keys entry (no ':'): %r", entry) + continue + community_id, key = entry.split(":", 1) + community_id = community_id.strip() + key = key.strip() + if community_id and key: + result.setdefault(community_id, set()).add(key) + return result + @lru_cache def get_settings() -> Settings: diff --git a/src/api/main.py b/src/api/main.py index 04c7dc4..c2e0178 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -14,11 +14,18 @@ from pydantic import BaseModel from src.api.config import get_settings -from src.api.routers import create_community_router, sync_router +from src.api.routers import ( + create_community_router, + metrics_public_router, + metrics_router, + sync_router, +) from src.api.routers.health import router as health_router from src.api.routers.widget_test import router as widget_test_router from src.api.scheduler import start_scheduler, stop_scheduler from src.assistants import discover_assistants, registry +from src.metrics.db import init_metrics_db +from src.metrics.middleware import MetricsMiddleware logger = logging.getLogger(__name__) @@ -53,6 +60,16 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.state.settings = settings app.state.start_time = datetime.now(UTC) + # Initialize metrics database (non-critical; degrade gracefully if unavailable) + try: + init_metrics_db() + except Exception: + logger.error( + "Failed to initialize metrics database. Metrics collection will be unavailable. " + "Check DATA_DIR permissions and disk space.", + exc_info=True, + ) + # Start background scheduler for knowledge sync scheduler = start_scheduler() app.state.scheduler = scheduler @@ -92,7 +109,7 @@ def _collect_cors_config() -> tuple[list[str], str | None]: Aggregates exact origins and wildcard patterns from: 1. Platform-level settings (Settings.cors_origins) 2. Per-community config (CommunityConfig.cors_origins) - 3. Default platform wildcard (*.osa-demo.pages.dev) + 3. Default platform wildcards (*.osa-demo.pages.dev, *.osa-dash.pages.dev) Returns: Tuple of (exact_origins, origin_regex_pattern). @@ -101,10 +118,12 @@ def _collect_cors_config() -> tuple[list[str], str | None]: settings = get_settings() exact_origins: list[str] = list(settings.cors_origins) - # Add default demo page origins + # Add default demo/dashboard page origins exact_origins.append("https://osa-demo.pages.dev") + exact_origins.append("https://osa-dash.pages.dev") wildcard_patterns: list[str] = [ "https://*.osa-demo.pages.dev", # Default: preview/branch deploys + "https://*.osa-dash.pages.dev", # Dashboard: preview/branch deploys ] # Collect from all registered communities @@ -153,6 +172,9 @@ def create_app() -> FastAPI: cors_kwargs["allow_origin_regex"] = origin_regex app.add_middleware(CORSMiddleware, **cors_kwargs) + # Metrics middleware - captures request timing and logs to metrics DB + app.add_middleware(MetricsMiddleware) + # Register routes register_routes(app) @@ -179,6 +201,10 @@ def register_routes(app: FastAPI) -> None: # Sync router (not community-specific) app.include_router(sync_router) + # Metrics routers (admin + public) + app.include_router(metrics_router) + app.include_router(metrics_public_router) + # Health check router app.include_router(health_router) @@ -221,11 +247,16 @@ async def root() -> dict[str, Any]: endpoints[f"GET /{community_id}/sessions"] = f"List active {name} sessions" endpoints[f"GET /{community_id}/sessions/{{session_id}}"] = "Get session info" endpoints[f"DELETE /{community_id}/sessions/{{session_id}}"] = "Delete a session" + endpoints[f"GET /{community_id}/metrics/public"] = f"Public {name} metrics" + endpoints[f"GET /{community_id}/metrics/public/usage"] = f"Public {name} usage stats" # Add non-community endpoints endpoints["GET /sync/status"] = "Knowledge sync status" endpoints["GET /sync/health"] = "Sync health check" endpoints["POST /sync/trigger"] = "Trigger sync (requires API key)" + endpoints["GET /metrics/overview"] = "Metrics overview (requires admin key)" + endpoints["GET /metrics/tokens"] = "Token breakdown (requires admin key)" + endpoints["GET /metrics/public/overview"] = "Public metrics overview" endpoints["GET /health"] = "Health check" return { diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py index 76dffb5..39f85ab 100644 --- a/src/api/routers/__init__.py +++ b/src/api/routers/__init__.py @@ -1,6 +1,13 @@ """API routers for Open Science Assistant.""" from src.api.routers.community import create_community_router +from src.api.routers.metrics import router as metrics_router +from src.api.routers.metrics_public import router as metrics_public_router from src.api.routers.sync import router as sync_router -__all__ = ["create_community_router", "sync_router"] +__all__ = [ + "create_community_router", + "metrics_public_router", + "metrics_router", + "sync_router", +] diff --git a/src/api/routers/community.py b/src/api/routers/community.py index 479d655..eeeb8a3 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -7,23 +7,45 @@ import hashlib import json import logging +import os +import re +import sqlite3 +import time import uuid from collections.abc import AsyncGenerator +from dataclasses import dataclass from datetime import UTC, datetime -from typing import Annotated, Literal +from typing import Annotated, Any, Literal -from fastapi import APIRouter, Header, HTTPException, Request +from fastapi import APIRouter, Header, HTTPException, Query, Request from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessage, HumanMessage from pydantic import BaseModel, Field, field_validator from src.api.config import get_settings -from src.api.security import RequireAuth +from src.api.security import AuthScope, RequireAuth, RequireScopedAuth from src.assistants import registry from src.assistants.community import CommunityAssistant from src.assistants.community import PageContext as AgentPageContext from src.assistants.registry import AssistantInfo from src.core.services.litellm_llm import create_openrouter_llm +from src.metrics.cost import estimate_cost +from src.metrics.db import ( + RequestLogEntry, + extract_token_usage, + extract_tool_names, + log_request, + metrics_connection, + now_iso, +) +from src.metrics.queries import ( + get_community_summary, + get_public_community_summary, + get_public_usage_stats, + get_quality_metrics, + get_quality_summary, + get_usage_stats, +) logger = logging.getLogger(__name__) @@ -312,16 +334,9 @@ def delete_session(community_id: str, session_id: str) -> bool: def list_sessions(community_id: str) -> list[ChatSession]: """List all active (non-expired) sessions for a community.""" + _evict_expired_sessions(community_id) store = _get_session_store(community_id) - # Filter out expired sessions - active_sessions = [s for s in store.values() if not s.is_expired()] - # Clean up expired sessions from store - expired = [sid for sid, s in store.items() if s.is_expired()] - for sid in expired: - del store[sid] - if expired: - logger.info("Cleaned %d expired sessions from %s", len(expired), community_id) - return active_sessions + return list(store.values()) # --------------------------------------------------------------------------- @@ -329,6 +344,16 @@ def list_sessions(community_id: str) -> list[ChatSession]: # --------------------------------------------------------------------------- +def _match_wildcard_origin(pattern: str, origin: str) -> bool: + """Check if an origin matches a wildcard pattern like 'https://*.example.com'. + + Converts '*' to a subdomain-safe regex and uses fullmatch. + """ + escaped = re.escape(pattern) + regex = escaped.replace(r"\*", r"[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?") + return bool(re.fullmatch(regex, origin)) + + def _is_authorized_origin(origin: str | None, community_id: str) -> bool: """Check if Origin header matches allowed CORS origins. @@ -350,8 +375,6 @@ def _is_authorized_origin(origin: str | None, community_id: str) -> bool: if not origin: return False - import re - # Platform default origins - always allowed for all communities platform_exact_origins = [ "https://osa-demo.pages.dev", @@ -366,9 +389,7 @@ def _is_authorized_origin(origin: str | None, community_id: str) -> bool: # Check platform wildcard patterns for allowed in platform_wildcard_origins: - escaped = re.escape(allowed) - pattern = escaped.replace(r"\*", r"[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?") - if re.fullmatch(pattern, origin): + if _match_wildcard_origin(allowed, origin): return True # Check community-specific origins @@ -380,19 +401,12 @@ def _is_authorized_origin(origin: str | None, community_id: str) -> bool: if not cors_origins: return False - # Check exact matches first for allowed in cors_origins: - if "*" not in allowed and origin == allowed: - return True - - # Check wildcard patterns - for allowed in cors_origins: - if "*" in allowed: - # Convert wildcard pattern to regex (same logic as main.py) - escaped = re.escape(allowed) - pattern = escaped.replace(r"\*", r"[a-zA-Z0-9]([a-zA-Z0-9\-]*[a-zA-Z0-9])?") - if re.fullmatch(pattern, origin): + if "*" not in allowed: + if origin == allowed: return True + elif _match_wildcard_origin(allowed, origin): + return True return False @@ -426,8 +440,6 @@ def _select_api_key( HTTPException(403): If origin is not authorized and BYOK is not provided HTTPException(500): If no platform API key is configured and no other key is available """ - import os - # Case 1: BYOK provided - always allowed if byok: logger.debug( @@ -604,6 +616,17 @@ def _get_cache_user_id(community_id: str, api_key: str | None, user_id: str | No return f"{community_id}_widget" +@dataclass +class AssistantWithMetrics: + """Community assistant bundled with metadata for metrics logging.""" + + assistant: CommunityAssistant + model: str + key_source: str + langfuse_config: dict | None = None + langfuse_trace_id: str | None = None + + def create_community_assistant( community_id: str, byok: str | None = None, @@ -612,13 +635,13 @@ def create_community_assistant( requested_model: str | None = None, preload_docs: bool = True, page_context: PageContext | None = None, -) -> CommunityAssistant: +) -> AssistantWithMetrics: """Create a community assistant instance with authorization checks. **Authorization:** - - If BYOK provided → always allowed - - If origin matches community CORS → can use community/platform keys - - Otherwise → rejects with 403 (CLI/unauthorized must provide BYOK) + - If BYOK provided -> always allowed + - If origin matches community CORS -> can use community/platform keys + - Otherwise -> rejects with 403 (CLI/unauthorized must provide BYOK) **Model Selection:** - Custom model requests require BYOK @@ -634,7 +657,8 @@ def create_community_assistant( page_context: Optional context about the page where the widget is embedded Returns: - Configured CommunityAssistant instance + AssistantWithMetrics containing the assistant, resolved model, and key source. + Access the assistant via .assistant attribute. Raises: ValueError: If community_id is not registered @@ -683,13 +707,106 @@ def create_community_assistant( title=page_context.title, ) - return registry.create_assistant( + assistant = registry.create_assistant( community_id, model=model, preload_docs=preload_docs, page_context=agent_page_context, ) + # Wire LangFuse tracing if configured + langfuse_config = None + langfuse_trace_id = None + try: + from src.core.services.llm import get_llm_service + except ImportError: + logger.debug("LangFuse tracing not available (module not installed)") + else: + try: + llm_service = get_llm_service(settings) + trace_id = f"{community_id}-{uuid.uuid4().hex[:12]}" + config = llm_service.get_config_with_tracing(trace_id=trace_id) + if config.get("callbacks"): + langfuse_config = config + langfuse_trace_id = trace_id + except Exception: + logger.warning( + "LangFuse tracing setup failed for %s, continuing without it", + community_id, + exc_info=True, + ) + + return AssistantWithMetrics( + assistant=assistant, + model=selected_model, + key_source=key_source, + langfuse_config=langfuse_config, + langfuse_trace_id=langfuse_trace_id, + ) + + +@dataclass +class AgentResult: + """Extracted response content and metrics from an agent invocation.""" + + response_content: str + tool_calls_info: list[ToolCallInfo] + tools_called: list[str] + input_tokens: int + output_tokens: int + total_tokens: int + + +def _extract_agent_result(result: dict) -> AgentResult: + """Extract response content, tool calls, and token usage from agent result. + + Consolidates the common post-invocation logic shared by ask and chat endpoints. + """ + response_content = "" + if result.get("messages"): + last_msg = result["messages"][-1] + if isinstance(last_msg, AIMessage): + content = last_msg.content + response_content = content if isinstance(content, str) else str(content) + + tools_called = extract_tool_names(result) + tool_calls_info = [ + ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) + for tc in result.get("tool_calls", []) + ] + + inp, out, total = extract_token_usage(result) + return AgentResult( + response_content=response_content, + tool_calls_info=tool_calls_info, + tools_called=tools_called, + input_tokens=inp, + output_tokens=out, + total_tokens=total, + ) + + +def _set_metrics_on_request( + http_request: Request, + awm: AssistantWithMetrics, + agent_result: AgentResult, +) -> None: + """Store agent metrics on request.state for the metrics middleware to log.""" + http_request.state.metrics_agent_data = { + "model": awm.model, + "key_source": awm.key_source, + "input_tokens": agent_result.input_tokens, + "output_tokens": agent_result.output_tokens, + "total_tokens": agent_result.total_tokens, + "estimated_cost": estimate_cost( + awm.model, agent_result.input_tokens, agent_result.output_tokens + ), + "tools_called": agent_result.tools_called, + "tool_call_count": len(agent_result.tools_called), + "langfuse_trace_id": awm.langfuse_trace_id, + "stream": False, + } + # --------------------------------------------------------------------------- # Router Factory @@ -765,6 +882,7 @@ async def ask( x_user_id, body.page_context, body.model, + http_request=http_request, ), media_type="text/event-stream", headers={ @@ -774,7 +892,7 @@ async def ask( ) try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=x_openrouter_key, origin=origin, @@ -783,22 +901,12 @@ async def ask( page_context=body.page_context, ) messages = [HumanMessage(content=body.question)] - result = await assistant.ainvoke(messages) - - response_content = "" - if result.get("messages"): - last_msg = result["messages"][-1] - if isinstance(last_msg, AIMessage): - # Handle both string and list content (multimodal responses) - content = last_msg.content - response_content = content if isinstance(content, str) else str(content) + result = await awm.assistant.ainvoke(messages, config=awm.langfuse_config) - tool_calls_info = [ - ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) - for tc in result.get("tool_calls", []) - ] + ar = _extract_agent_result(result) + _set_metrics_on_request(http_request, awm, ar) - return AskResponse(answer=response_content, tool_calls=tool_calls_info) + return AskResponse(answer=ar.response_content, tool_calls=ar.tool_calls_info) except Exception as e: logger.error( @@ -858,7 +966,13 @@ async def chat( if body.stream: return StreamingResponse( _stream_chat_response( - community_id, session, x_openrouter_key, origin, user_id, body.model + community_id, + session, + x_openrouter_key, + origin, + user_id, + body.model, + http_request=http_request, ), media_type="text/event-stream", headers={ @@ -869,31 +983,21 @@ async def chat( ) try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=x_openrouter_key, origin=origin, user_id=user_id, requested_model=body.model, ) - result = await assistant.ainvoke(session.messages) + result = await awm.assistant.ainvoke(session.messages, config=awm.langfuse_config) - response_content = "" - if result.get("messages"): - last_msg = result["messages"][-1] - if isinstance(last_msg, AIMessage): - # Handle both string and list content (multimodal responses) - content = last_msg.content - response_content = content if isinstance(content, str) else str(content) - - tool_calls_info = [ - ToolCallInfo(name=tc.get("name", ""), args=tc.get("args", {})) - for tc in result.get("tool_calls", []) - ] + ar = _extract_agent_result(result) + _set_metrics_on_request(http_request, awm, ar) # Add assistant message with constraint validation try: - session.add_assistant_message(response_content) + session.add_assistant_message(ar.response_content) except ValueError as e: logger.error("Session limit exceeded: %s", e) raise HTTPException( @@ -903,8 +1007,8 @@ async def chat( return ChatResponse( session_id=session.session_id, - message=ChatMessage(role="assistant", content=response_content), - tool_calls=tool_calls_info, + message=ChatMessage(role="assistant", content=ar.response_content), + tool_calls=ar.tool_calls_info, ) except ValueError as e: @@ -987,9 +1091,189 @@ async def get_community_config() -> CommunityConfigResponse: default_model_provider=default_provider, ) + # ----------------------------------------------------------------------- + # Per-community Metrics Endpoints + # ----------------------------------------------------------------------- + + def _require_community_access(auth: AuthScope) -> None: + """Raise 403 if the scoped key cannot access this community.""" + if not auth.can_access_community(community_id): + raise HTTPException( + status_code=403, + detail=f"Your API key does not have access to {community_id} metrics", + ) + + @router.get("/metrics") + async def community_metrics(auth: RequireScopedAuth) -> dict[str, Any]: + """Get metrics summary for this community. Requires admin or community key.""" + _require_community_access(auth) + try: + with metrics_connection() as conn: + return get_community_summary(community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query metrics for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + @router.get("/metrics/usage") + async def community_usage( + auth: RequireScopedAuth, + period: str = Query( + default="daily", + description="Time bucket period", + pattern="^(daily|weekly|monthly)$", + ), + ) -> dict[str, Any]: + """Get time-bucketed usage stats for this community. Requires admin or community key.""" + _require_community_access(auth) + try: + with metrics_connection() as conn: + return get_usage_stats(community_id, period, conn) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except sqlite3.Error: + logger.exception("Failed to query usage stats for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + @router.get("/metrics/quality") + async def community_quality( + auth: RequireScopedAuth, + period: str = Query( + default="daily", + description="Time bucket period", + pattern="^(daily|weekly|monthly)$", + ), + ) -> dict[str, Any]: + """Get quality metrics for this community. Requires admin or community key.""" + _require_community_access(auth) + try: + with metrics_connection() as conn: + return get_quality_metrics(community_id, conn, period) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except sqlite3.Error: + logger.exception("Failed to query quality metrics for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + @router.get("/metrics/quality/summary") + async def community_quality_summary(auth: RequireScopedAuth) -> dict[str, Any]: + """Get overall quality summary for this community. Requires admin or community key.""" + _require_community_access(auth) + try: + with metrics_connection() as conn: + return get_quality_summary(community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query quality summary for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + # ----------------------------------------------------------------------- + # Per-community Public Metrics Endpoints (no auth required) + # ----------------------------------------------------------------------- + + @router.get("/metrics/public") + async def community_metrics_public() -> dict[str, Any]: + """Get public metrics summary for this community. + + Returns request counts, error rate, and top tools. + No tokens, costs, or model information exposed. + """ + try: + with metrics_connection() as conn: + return get_public_community_summary(community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query public metrics for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + @router.get("/metrics/public/usage") + async def community_usage_public( + period: str = Query( + default="daily", + description="Time bucket period", + pattern="^(daily|weekly|monthly)$", + ), + ) -> dict[str, Any]: + """Get public time-bucketed usage stats for this community. + + Returns request counts and errors per time bucket. + No tokens or costs exposed. + """ + try: + with metrics_connection() as conn: + return get_public_usage_stats(community_id, period, conn) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + except sqlite3.Error: + logger.exception("Failed to query public usage stats for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + return router +# --------------------------------------------------------------------------- +# Metrics Helpers +# --------------------------------------------------------------------------- + + +def _log_streaming_metrics( + http_request: Request | None, + community_id: str, + endpoint: str, + awm: AssistantWithMetrics | None, + tools_called: list[str], + start_time: float, + status_code: int, +) -> None: + """Log metrics at the end of a streaming response. + + Called directly from streaming generators since middleware fires + before streaming completes. Wrapped in try/except to never disrupt + the SSE stream on failure. + """ + try: + duration_ms = (time.monotonic() - start_time) * 1000 + request_id = str(uuid.uuid4()) + if http_request: + request_id = getattr(http_request.state, "request_id", request_id) + # Mark as logged so middleware doesn't double-log + http_request.state.metrics_logged = True + + entry = RequestLogEntry( + request_id=request_id, + timestamp=now_iso(), + endpoint=endpoint, + method="POST", + community_id=community_id, + duration_ms=round(duration_ms, 1), + status_code=status_code, + model=awm.model if awm else None, + key_source=awm.key_source if awm else None, + tools_called=tools_called, + stream=True, + tool_call_count=len(tools_called), + langfuse_trace_id=awm.langfuse_trace_id if awm else None, + ) + log_request(entry) + except Exception: + logger.exception("Failed to log streaming metrics for %s", endpoint) + + # --------------------------------------------------------------------------- # Streaming Helpers # --------------------------------------------------------------------------- @@ -1003,6 +1287,7 @@ async def _stream_ask_response( user_id: str | None, page_context: PageContext | None = None, requested_model: str | None = None, + http_request: Request | None = None, ) -> AsyncGenerator[str, None]: """Stream response for ask endpoint with JSON-encoded SSE events. @@ -1013,8 +1298,12 @@ async def _stream_ask_response( data: {"event": "done"} data: {"event": "error", "message": "error text"} """ + start_time = time.monotonic() + tools_called: list[str] = [] + awm: AssistantWithMetrics | None = None + try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=byok, origin=origin, @@ -1023,7 +1312,7 @@ async def _stream_ask_response( preload_docs=True, page_context=page_context, ) - graph = assistant.build_graph() + graph = awm.assistant.build_graph() state = { "messages": [HumanMessage(content=question)], @@ -1031,7 +1320,8 @@ async def _stream_ask_response( "tool_calls": [], } - async for event in graph.astream_events(state, version="v2"): + stream_config = awm.langfuse_config or {} + async for event in graph.astream_events(state, version="v2", config=stream_config): kind = event.get("event") if kind == "on_chat_model_stream": @@ -1042,9 +1332,12 @@ async def _stream_ask_response( elif kind == "on_tool_start": tool_input = event.get("data", {}).get("input", {}) + tool_name = event.get("name", "") + if tool_name: + tools_called.append(tool_name) sse_event = { "event": "tool_start", - "name": event.get("name", ""), + "name": tool_name, "input": tool_input if isinstance(tool_input, dict) else {}, } yield f"data: {json.dumps(sse_event)}\n\n" @@ -1061,9 +1354,37 @@ async def _stream_ask_response( sse_event = {"event": "done"} yield f"data: {json.dumps(sse_event)}\n\n" - except HTTPException: - # Don't catch our own HTTP exceptions - let them propagate - raise + # Log metrics at end of streaming + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=200, + ) + + except HTTPException as e: + # HTTPException in streaming context (e.g., auth failure, rate limit). + # Cannot re-raise because response headers are already sent as 200. + logger.warning( + "HTTP error in ask streaming for community %s: %d %s", + community_id, + e.status_code, + e.detail, + ) + sse_event = {"event": "error", "message": str(e.detail)} + yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=e.status_code, + ) except ValueError as e: # Input validation errors - user's fault logger.warning("Invalid input in streaming for community %s: %s", community_id, e) @@ -1073,10 +1394,17 @@ async def _stream_ask_response( "retryable": False, } yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=400, + ) except Exception as e: # Unexpected errors - log with full context - import uuid - error_id = str(uuid.uuid4()) logger.error( "Unexpected streaming error (ID: %s) in ask endpoint for community %s: %s", @@ -1096,6 +1424,15 @@ async def _stream_ask_response( "error_id": error_id, } yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/ask", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=500, + ) async def _stream_chat_response( @@ -1105,6 +1442,7 @@ async def _stream_chat_response( origin: str | None, user_id: str | None, requested_model: str | None = None, + http_request: Request | None = None, ) -> AsyncGenerator[str, None]: """Stream assistant response as JSON-encoded Server-Sent Events. @@ -1115,8 +1453,12 @@ async def _stream_chat_response( data: {"event": "done", "session_id": "..."} data: {"event": "error", "message": "error text"} """ + start_time = time.monotonic() + tools_called: list[str] = [] + awm: AssistantWithMetrics | None = None + try: - assistant = create_community_assistant( + awm = create_community_assistant( community_id, byok=byok, origin=origin, @@ -1124,7 +1466,7 @@ async def _stream_chat_response( requested_model=requested_model, preload_docs=True, ) - graph = assistant.build_graph() + graph = awm.assistant.build_graph() state = { "messages": session.messages.copy(), @@ -1132,9 +1474,10 @@ async def _stream_chat_response( "tool_calls": [], } + stream_config = awm.langfuse_config or {} full_response = "" - async for event in graph.astream_events(state, version="v2"): + async for event in graph.astream_events(state, version="v2", config=stream_config): kind = event.get("event") if kind == "on_chat_model_stream": @@ -1147,9 +1490,12 @@ async def _stream_chat_response( elif kind == "on_tool_start": tool_input = event.get("data", {}).get("input", {}) + tool_name = event.get("name", "") + if tool_name: + tools_called.append(tool_name) sse_event = { "event": "tool_start", - "name": event.get("name", ""), + "name": tool_name, "input": tool_input if isinstance(tool_input, dict) else {}, } yield f"data: {json.dumps(sse_event)}\n\n" @@ -1176,21 +1522,79 @@ async def _stream_chat_response( sse_event = {"event": "done", "session_id": session.session_id} yield f"data: {json.dumps(sse_event)}\n\n" + # Log metrics at end of streaming + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=200, + ) + + except HTTPException as e: + # HTTPException in streaming context (e.g., auth failure, rate limit). + # Cannot re-raise because response headers are already sent as 200. + logger.warning( + "HTTP error in chat streaming for session %s (community: %s): %d %s", + session.session_id, + community_id, + e.status_code, + e.detail, + ) + sse_event = {"event": "error", "message": str(e.detail)} + yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=e.status_code, + ) except ValueError as e: # Session limit errors logger.error("Session limit error: %s", e) sse_event = {"event": "error", "message": str(e)} yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=400, + ) except Exception as e: + error_id = str(uuid.uuid4()) logger.error( - "Streaming error in chat endpoint for session %s (community: %s): %s", + "Unexpected streaming error (ID: %s) in chat endpoint for session %s (community: %s): %s", + error_id, session.session_id, community_id, e, exc_info=True, + extra={ + "error_id": error_id, + "community_id": community_id, + "error_type": type(e).__name__, + }, ) sse_event = { "event": "error", "message": "An error occurred while processing your request.", + "error_id": error_id, } yield f"data: {json.dumps(sse_event)}\n\n" + _log_streaming_metrics( + http_request=http_request, + community_id=community_id, + endpoint=f"/{community_id}/chat", + awm=awm, + tools_called=tools_called, + start_time=start_time, + status_code=500, + ) diff --git a/src/api/routers/metrics.py b/src/api/routers/metrics.py new file mode 100644 index 0000000..467c28c --- /dev/null +++ b/src/api/routers/metrics.py @@ -0,0 +1,98 @@ +"""Global metrics API endpoints. + +Provides cross-community metrics overview and token breakdowns. +Supports both global admin keys (see all) and per-community keys (filtered view). +""" + +import logging +import sqlite3 +from typing import Any + +from fastapi import APIRouter, HTTPException, Query + +from src.api.security import RequireScopedAuth +from src.metrics.db import metrics_connection +from src.metrics.queries import ( + get_community_summary, + get_overview, + get_quality_summary, + get_token_breakdown, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/metrics", tags=["Metrics"]) + + +@router.get("/overview") +async def metrics_overview(auth: RequireScopedAuth) -> dict[str, Any]: + """Get cross-community metrics overview. + + Global admin keys see all communities. Per-community keys see only + their community's data wrapped in the same response format. + """ + try: + with metrics_connection() as conn: + if auth.role == "admin": + return get_overview(conn) + # Community-scoped: return summary for just their community + return get_community_summary(auth.community_id, conn) + except sqlite3.Error: + logger.exception("Failed to query metrics database for overview") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + +@router.get("/tokens") +async def token_breakdown( + auth: RequireScopedAuth, + community_id: str | None = Query(default=None, description="Filter by community"), +) -> dict[str, Any]: + """Get token usage breakdown by model and key source. + + Global admin keys can filter by any community. Per-community keys + are automatically scoped to their community (community_id parameter ignored). + """ + # Community-scoped keys always filter to their own community + effective_community = community_id + if auth.role == "community": + effective_community = auth.community_id + + try: + with metrics_connection() as conn: + return get_token_breakdown(conn, community_id=effective_community) + except sqlite3.Error: + logger.exception("Failed to query metrics database for token breakdown") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) + + +@router.get("/quality") +async def quality_overview(auth: RequireScopedAuth) -> dict[str, Any]: + """Get quality metrics overview. + + Global admin keys see quality for all communities. + Per-community keys see quality summary for their community only. + """ + try: + with metrics_connection() as conn: + if auth.role == "community": + return get_quality_summary(auth.community_id, conn) + # Admin: aggregate quality across all communities + overview = get_overview(conn) + communities_data = overview.get("communities", []) + summaries = [] + for c in communities_data: + cid = c["community_id"] + summaries.append(get_quality_summary(cid, conn)) + return {"communities": summaries} + except sqlite3.Error: + logger.exception("Failed to query quality metrics") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) diff --git a/src/api/routers/metrics_public.py b/src/api/routers/metrics_public.py new file mode 100644 index 0000000..4e38f34 --- /dev/null +++ b/src/api/routers/metrics_public.py @@ -0,0 +1,42 @@ +"""Public metrics API endpoints. + +Exposes non-sensitive aggregate metrics (request counts, error rates) +without authentication. No tokens, costs, or model information. + +Per-community public metrics are served from the community router +at /{community_id}/metrics/public. +""" + +import logging +import sqlite3 +from typing import Any + +from fastapi import APIRouter, HTTPException + +from src.assistants import registry +from src.metrics.db import metrics_connection +from src.metrics.queries import get_public_overview + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/metrics/public", tags=["Public Metrics"]) + + +@router.get("/overview") +async def public_overview() -> dict[str, Any]: + """Get public metrics overview across all communities. + + Returns total requests, error rate, active community count, + and per-community request counts. No tokens, costs, or model info. + All registered communities appear even if they have zero requests. + """ + registered = [info.id for info in registry.list_available()] + try: + with metrics_connection() as conn: + return get_public_overview(conn, registered_communities=registered) + except sqlite3.Error: + logger.exception("Failed to query metrics database for public overview") + raise HTTPException( + status_code=503, + detail="Metrics database is temporarily unavailable.", + ) diff --git a/src/api/scheduler.py b/src/api/scheduler.py index 9629b95..5c79a1a 100644 --- a/src/api/scheduler.py +++ b/src/api/scheduler.py @@ -1,8 +1,9 @@ -"""Background scheduler for automated knowledge sync. +"""Background scheduler for automated tasks. -Uses APScheduler to run periodic sync jobs for: -- GitHub issues/PRs (daily by default) -- Academic papers (weekly by default) +Uses APScheduler to run periodic jobs for: +- GitHub issues/PRs sync (daily by default) +- Academic papers sync (weekly by default) +- Community budget checks with alert issue creation (every 15 minutes) """ import logging @@ -16,6 +17,9 @@ from src.knowledge.db import init_db from src.knowledge.github_sync import sync_repos from src.knowledge.papers_sync import sync_all_papers, sync_citing_papers +from src.metrics.alerts import create_budget_alert_issue +from src.metrics.budget import check_budget +from src.metrics.db import metrics_connection logger = logging.getLogger(__name__) @@ -29,11 +33,7 @@ def _get_communities_with_sync() -> list[str]: Returns: List of community IDs with GitHub repos or citation config. """ - communities = [] - for info in registry.list_all(): - if info.sync_config: - communities.append(info.id) - return communities + return [info.id for info in registry.list_all() if info.sync_config] def _get_community_repos(community_id: str) -> list[str]: @@ -63,6 +63,7 @@ def _get_community_paper_dois(community_id: str) -> list[str]: # Failure tracking for alerting _github_sync_failures = 0 _papers_sync_failures = 0 +_budget_check_failures = 0 MAX_CONSECUTIVE_FAILURES = 3 @@ -166,6 +167,80 @@ def _run_papers_sync() -> None: ) +def _check_community_budgets() -> None: + """Check budget limits for all communities and create alert issues if exceeded.""" + global _budget_check_failures + logger.info("Starting scheduled budget check for all communities") + try: + communities_checked = 0 + communities_failed = 0 + alerts_created = 0 + + with metrics_connection() as conn: + for info in registry.list_all(): + if not info.community_config or not info.community_config.budget: + continue + + budget_cfg = info.community_config.budget + maintainers = info.community_config.maintainers + + try: + budget_status = check_budget( + community_id=info.id, + config=budget_cfg, + conn=conn, + ) + communities_checked += 1 + + if budget_status.needs_alert: + issue_url = create_budget_alert_issue( + budget_status=budget_status, + maintainers=maintainers, + ) + if issue_url: + alerts_created += 1 + logger.warning( + "Budget alert created for %s: %s", + info.id, + issue_url, + ) + except Exception: + communities_failed += 1 + logger.exception("Failed to check budget for community %s", info.id) + + log_level = logging.WARNING if communities_failed else logging.INFO + logger.log( + log_level, + "Budget check complete: %d checked, %d failed, %d alerts created", + communities_checked, + communities_failed, + alerts_created, + ) + if communities_failed: + _budget_check_failures += 1 + if _budget_check_failures >= MAX_CONSECUTIVE_FAILURES: + logger.critical( + "Budget check has had community failures for %d consecutive runs. " + "Manual intervention required.", + _budget_check_failures, + ) + else: + _budget_check_failures = 0 + except Exception: + _budget_check_failures += 1 + logger.error( + "Budget check job failed (attempt %d/%d)", + _budget_check_failures, + MAX_CONSECUTIVE_FAILURES, + exc_info=True, + ) + if _budget_check_failures >= MAX_CONSECUTIVE_FAILURES: + logger.critical( + "Budget check has failed %d times consecutively. Manual intervention required.", + _budget_check_failures, + ) + + def start_scheduler() -> BackgroundScheduler | None: """Start the background scheduler with configured sync jobs. @@ -224,6 +299,20 @@ def start_scheduler() -> BackgroundScheduler | None: except ValueError as e: logger.error("Invalid papers sync cron expression: %s", e) + # Add budget check job (every 15 minutes) + try: + budget_trigger = CronTrigger(minute="*/15") + _scheduler.add_job( + _check_community_budgets, + trigger=budget_trigger, + id="budget_check", + name="Community Budget Check", + replace_existing=True, + ) + logger.info("Budget check scheduled: every 15 minutes") + except ValueError as e: + logger.error("Failed to schedule budget check: %s", e) + # Start the scheduler _scheduler.start() logger.info("Background scheduler started") diff --git a/src/api/security.py b/src/api/security.py index bed7366..d69099c 100644 --- a/src/api/security.py +++ b/src/api/security.py @@ -1,6 +1,7 @@ """Security and authentication for the OSA API.""" -from typing import Annotated +from dataclasses import dataclass +from typing import Annotated, Literal from fastapi import Depends, HTTPException, Security, status from fastapi.security import APIKeyHeader @@ -23,7 +24,11 @@ async def verify_api_key( anthropic_key: Annotated[str | None, Security(anthropic_key_header)] = None, openrouter_key: Annotated[str | None, Security(openrouter_key_header)] = None, ) -> str | None: - """Verify the API key if server authentication is enabled. + """Verify the global admin API key if server authentication is enabled. + + Only checks keys from the API_KEYS setting (global admin keys). + Does not handle per-community scoped keys; see verify_scoped_admin_key + for community-level auth. Returns the API key if valid, None if auth is disabled or BYOK is used. Raises HTTPException if auth is enabled but key is invalid (and no BYOK). @@ -45,8 +50,7 @@ async def verify_api_key( if openai_key or anthropic_key or openrouter_key: return None - # Parse comma-separated API keys - valid_keys = {k.strip() for k in settings.api_keys.split(",") if k.strip()} + valid_keys = settings.parse_admin_keys() # If auth is enabled, require valid API key if not api_key: @@ -86,8 +90,7 @@ async def verify_admin_api_key( if not settings.api_keys: return None - # Parse comma-separated API keys - valid_keys = {k.strip() for k in settings.api_keys.split(",") if k.strip()} + valid_keys = settings.parse_admin_keys() # Admin endpoints require valid server API key (no BYOK bypass) if not api_key: @@ -110,6 +113,83 @@ async def verify_admin_api_key( RequireAdminAuth = Annotated[str | None, Depends(verify_admin_api_key)] +@dataclass(frozen=True) +class AuthScope: + """Scoped authentication result for community-level access control. + + Attributes: + role: "admin" for global admin keys, "community" for per-community keys. + community_id: The community this key is scoped to, or None for global admin. + + Use ``can_access_community(id)`` to check whether this scope permits + access to a given community's data. + """ + + role: Literal["admin", "community"] + community_id: str | None = None + + def __post_init__(self) -> None: + if self.role == "community" and not self.community_id: + raise ValueError("community role requires a community_id") + if self.role == "admin" and self.community_id is not None: + raise ValueError("admin role must not have a community_id") + + def can_access_community(self, community_id: str) -> bool: + """Check if this auth scope permits access to a given community.""" + if self.role == "admin": + return True + return self.community_id == community_id + + +async def verify_scoped_admin_key( + api_key: Annotated[str | None, Security(api_key_header)], + settings: Annotated[Settings, Depends(get_settings)], +) -> AuthScope: + """Verify API key and return scoped auth context. + + Checks in order: + 1. Global admin keys (api_keys setting) -> AuthScope(role="admin") + 2. Per-community keys (community_admin_keys) -> AuthScope(role="community", community_id=X) + 3. No match -> 401/403 + + When auth is disabled (require_api_auth=False or no keys configured), + returns admin scope for backward compatibility. + """ + # If auth is not required, grant full admin access + if not settings.require_api_auth: + return AuthScope(role="admin") + + # If no API keys configured at all, auth is effectively disabled + if not settings.api_keys and not settings.community_admin_keys: + return AuthScope(role="admin") + + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key required for metrics access", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Check global admin keys first + if settings.api_keys and api_key in settings.parse_admin_keys(): + return AuthScope(role="admin") + + # Check per-community keys + community_keys = settings.parse_community_admin_keys() + for community_id, keys in community_keys.items(): + if api_key in keys: + return AuthScope(role="community", community_id=community_id) + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + +# Dependency for scoped admin routes (supports per-community keys) +RequireScopedAuth = Annotated[AuthScope, Depends(verify_scoped_admin_key)] + + class BYOKHeaders: """BYOK (Bring Your Own Key) headers for LLM providers. diff --git a/src/assistants/bids/__init__.py b/src/assistants/bids/__init__.py new file mode 100644 index 0000000..85222e0 --- /dev/null +++ b/src/assistants/bids/__init__.py @@ -0,0 +1 @@ +# BIDS (Brain Imaging Data Structure) community assistant diff --git a/src/assistants/bids/config.yaml b/src/assistants/bids/config.yaml new file mode 100644 index 0000000..b1f5790 --- /dev/null +++ b/src/assistants/bids/config.yaml @@ -0,0 +1,542 @@ +# BIDS (Brain Imaging Data Structure) Assistant Configuration +# Single source of truth for the BIDS community assistant + +id: bids +name: BIDS (Brain Imaging Data Structure) +description: Standard for organizing and describing neuroimaging and behavioral data +status: available + +# Enable page context tool for widget embedding +enable_page_context: true + +# Allowed CORS origins for widget embedding on community websites +cors_origins: + - https://bids.neuroimaging.io + - https://bids-specification.readthedocs.io + +# Per-community OpenRouter API key (optional) +openrouter_api_key_env_var: "OPENROUTER_API_KEY_BIDS" + +# Community maintainers (GitHub usernames) +maintainers: + - effigies + +# Budget limits for cost management +budget: + daily_limit_usd: 5.00 + monthly_limit_usd: 50.00 + alert_threshold_pct: 80.0 + +# Default model for this community +default_model: "anthropic/claude-haiku-4.5" +default_model_provider: "anthropic" + +# System prompt template +system_prompt: | + You are a technical assistant specialized in the Brain Imaging Data Structure (BIDS), a standard for + organizing and describing outputs of neuroimaging experiments. You provide explanations, troubleshooting, + and step-by-step guidance for organizing data according to the BIDS specification, using the BIDS validator, + and working with BIDS-compatible tools. + + You must stick strictly to the topic of BIDS and avoid digressions. + All responses should be accurate and based on the official BIDS specification and documentation. + + When a user's question is ambiguous, assume the most likely meaning and provide a useful starting point, + but also ask clarifying questions when necessary. + Communicate in a formal and technical style, prioritizing precision and accuracy while remaining clear. + Answers should be structured and easy to follow, with examples where appropriate. + + The BIDS homepage is https://bids.neuroimaging.io/ + The BIDS specification is at https://bids-specification.readthedocs.io/ + The BIDS GitHub organization is at https://github.com/bids-standard + The BIDS validator is at https://github.com/bids-standard/bids-validator + BIDS example datasets are at https://github.com/bids-standard/bids-examples + + You will respond with markdown formatted text. Be concise and include only the most relevant information unless told otherwise. + + ## Modality Awareness + + BIDS covers many data modalities, each with specific rules and file formats. When users ask about a + specific modality, retrieve the corresponding documentation. Key modalities include: + + - **MRI** (anatomical, functional, diffusion, fieldmaps, ASL, quantitative MRI) + - **EEG** (electroencephalography) + - **MEG** (magnetoencephalography) + - **iEEG** (intracranial electroencephalography) + - **PET** (positron emission tomography) + - **NIRS** (near-infrared spectroscopy) + - **Motion** (motion capture data) + - **Microscopy** (microscopy imaging data) + - **MRS** (magnetic resonance spectroscopy) + - **EMG** (electromyography) + - **Behavioral** (behavioral experiments with no neural recordings) + - **Genetics** (genetic descriptors associated with brain imaging data) + - **Physiological** (continuous physiological recordings: cardiac, respiratory) + + Each modality may have a dedicated extension paper. When discussing a specific modality, cite + the relevant paper if available. + + ## Schema Awareness + + The BIDS specification is driven by a machine-readable schema (YAML rules in the specification repository). + The BIDS validator uses this schema to check datasets. Key concepts: + + - **Entities**: Key-value pairs in filenames (sub-, ses-, task-, run-, etc.) + - **Suffixes**: File type identifiers (bold, T1w, eeg, channels, events, etc.) + - **Metadata**: JSON sidecar fields required or recommended for each file type + - **Rules**: Validation rules defining allowed combinations of entities, suffixes, and metadata + + When users ask about filename conventions or required metadata, consult the specification docs. + + ## Using Tools Liberally + + You have access to tools for documentation retrieval and knowledge discovery. **Use them proactively.** + + - Tool calls are inexpensive; do not hesitate to retrieve documentation + - Retrieve relevant documentation to ensure your answers are accurate and current + - When users ask about recent activity, issues, or PRs, use the knowledge discovery tools + - When users ask about research papers, use the paper search tool + + ## Using the retrieve_bids_docs Tool + + Before responding, use the retrieve_bids_docs tool to get any documentation you need. + Include links to relevant documents in your response. + + **Important guidelines:** + - Do NOT retrieve docs that have already been preloaded (listed below) + - Retrieve multiple relevant documents at once + - Get background information documents in addition to specific documents for the question + - If you have already loaded a document in this conversation, do not load it again + + {preloaded_docs_section} + + {available_docs_section} + + ## BIDS Validator Guidance + + Users frequently ask about validation errors. When helping with validation: + - Explain what the specific error or warning means + - Show the correct file naming or metadata format + - Point to the relevant section of the specification + - Mention that the BIDS validator can be run online at https://bids-standard.github.io/bids-validator/ + or via CLI with `bids-validator` (Deno-based) or `pip install bids-validator` (Python wrapper) + + ## Converter Tool Awareness + + Users often need help converting data to BIDS format. Common converters include: + - **dcm2bids** / **HeuDiConv** - DICOM to BIDS (MRI) + - **MNE-BIDS** - MEG/EEG/iEEG to BIDS (Python) + - **EEG-BIDS** - EEG to BIDS (MATLAB/EEGLAB plugin) + - **bidscoin** - General-purpose BIDS converter + - **pet2bids** - PET data to BIDS + + When users ask about data conversion, recommend the appropriate converter for their data type. + + ## BIDS Extension Proposals (BEPs) + + BEPs are the mechanism for adding new data types to BIDS. When users ask about extending BIDS + or about data types not yet in the specification, explain the BEP process and point them to + https://bids.neuroimaging.io/extensions/beps for the current list. + + ## Citation Guidance + + When users ask about citing BIDS, always recommend citing the canonical paper: + - Gorgolewski et al. (2016), "The brain imaging data structure", doi:10.1038/sdata.2016.44 + + Additionally recommend citing the modality-specific extension paper when applicable. + The full citation list is in the specification's introduction section. + + ## Knowledge Discovery Tools - YOU MUST USE THESE + + You have access to a synced knowledge database with GitHub issues, PRs, and academic papers. + **You MUST use these tools when users ask about recent activity, issues, or PRs.** + + **Available BIDS repositories in the database:** + {repo_list} + + **CRITICAL: When users mention these repos (even by short name), USE THE TOOLS:** + - "specification" or "spec" -> repo="bids-standard/bids-specification" + - "validator" -> repo="bids-standard/bids-validator" + - "website" -> repo="bids-standard/bids-website" + - "examples" -> repo="bids-standard/bids-examples" + + **MANDATORY: Use tools for these question patterns:** + - "What are the latest PRs?" -> CALL `list_bids_recent(item_type="pr")` + - "Latest PRs in the validator?" -> CALL `list_bids_recent(item_type="pr", repo="bids-standard/bids-validator")` + - "Open issues in the spec?" -> CALL `list_bids_recent(item_type="issue", status="open", repo="bids-standard/bids-specification")` + - "Recent activity?" -> CALL `list_bids_recent(limit=10)` + - "Any discussions about derivatives?" -> CALL `search_bids_discussions(query="derivatives")` + - "What is PR 2022 about?" -> CALL `search_bids_discussions(query="2022", include_issues=false, include_prs=true)` + - "What is issue #500?" -> CALL `search_bids_discussions(query="500", include_issues=true, include_prs=false)` + + **Core BIDS papers tracked for citations (DOIs in database):** + {paper_dois} + + **MANDATORY: Use tools for citation/paper questions:** + - "Has anyone cited the BIDS paper?" -> CALL `search_bids_papers(query="BIDS")` + - "Papers about BIDS EEG?" -> CALL `search_bids_papers(query="BIDS EEG")` + + ## CRITICAL: Do Not Hallucinate GitHub References + + **NEVER fabricate issue numbers, PR numbers, discussion links, or commit hashes.** + If you search the knowledge database and find no results, say so explicitly: + "I could not find any matching issues/PRs in the database." + + Do NOT invent plausible-sounding issue numbers or PR titles. Only reference items + that were returned by the knowledge discovery tools with real URLs. + + **DO NOT:** + - Tell users to "visit GitHub" or "check Google Scholar" when you have the data + - Make up PR numbers, issue numbers, paper titles, authors, or citation counts + - Say "I don't have access" -- you DO have access via the tools above + - Create fictional GitHub URLs like "github.com/bids-standard/bids-specification/issues/1234" + unless that exact URL was returned by a tool + + **Present results as discovery:** + - "Here are the recent PRs in the specification: [actual list with real URLs]" + - "There's a related discussion: [real link]" + - "Here are papers related to BIDS: [actual list from database]" + + The knowledge database may not be populated. If you get a message about initializing the database, + explain that the knowledge base is not set up yet. + + {page_context_section} + + {additional_instructions} + +# Documentation sources +# - preload: true = embedded in system prompt (for core docs, ~2-3 max) +# - preload: false/omitted = fetched on demand via retrieve_docs tool +documentation: + # === PRELOADED: Core concepts (2 docs) === + - title: BIDS common principles + url: https://bids-specification.readthedocs.io/en/stable/common-principles.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/common-principles.md + preload: true + category: core + description: Core principles of BIDS including file naming, directory structure, and metadata conventions. + + - title: Getting started with BIDS + url: https://bids.neuroimaging.io/getting_started/ + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/index.md + preload: true + category: core + description: Overview of getting started with BIDS, including links to tutorials and resources. + + # === ON-DEMAND: Specification core (4 docs) === + - title: Introduction and citing BIDS + url: https://bids-specification.readthedocs.io/en/stable/introduction.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/introduction.md + category: specification + description: Introduction to BIDS, motivation, extensions, and how to cite the specification and modality papers. + + - title: BIDS glossary + url: https://bids-specification.readthedocs.io/en/stable/glossary.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/glossary.md + category: specification + description: Definitions of BIDS terminology including entities, suffixes, datatypes, and modalities. + + - title: BIDS extensions + url: https://bids-specification.readthedocs.io/en/stable/extensions.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/extensions.md + category: specification + description: How BIDS is extended through BIDS Extension Proposals (BEPs). + + - title: Longitudinal and multi-site studies + url: https://bids-specification.readthedocs.io/en/stable/longitudinal-and-multi-site-studies.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/longitudinal-and-multi-site-studies.md + category: specification + description: Guidelines for organizing longitudinal studies and multi-site data in BIDS. + + # === ON-DEMAND: Modality-agnostic files (5 docs) === + - title: Dataset description + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/dataset-description.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/dataset-description.md + category: modality_agnostic + description: Required dataset_description.json file and its fields (Name, BIDSVersion, License, etc.). + + - title: Events file + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/events.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/events.md + category: modality_agnostic + description: Events TSV files for recording experimental events (onset, duration, trial_type). + + - title: Phenotypic and assessment data + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/phenotypic-and-assessment-data.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/phenotypic-and-assessment-data.md + category: modality_agnostic + description: Phenotypic and assessment data in the phenotype/ directory (questionnaires, behavioral measures). + + - title: Code description + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/code.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/code.md + category: modality_agnostic + description: Guidelines for including data preparation scripts (deidentification, format conversion) in BIDS datasets. + + - title: Data summary files + url: https://bids-specification.readthedocs.io/en/stable/modality-agnostic-files/data-summary-files.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-agnostic-files/data-summary-files.md + category: modality_agnostic + description: Dataset-level summary files (participants.tsv, samples.tsv, scans.tsv, sessions.tsv). + + # === ON-DEMAND: Modality-specific files (13 docs) === + - title: MRI data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/magnetic-resonance-imaging-data.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/magnetic-resonance-imaging-data.md + category: modality_specific + description: MRI data organization including anatomical, functional, diffusion, fieldmaps, ASL, and quantitative MRI. + + - title: EEG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/electroencephalography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/electroencephalography.md + category: modality_specific + description: EEG data organization including channel info, electrode positions, and events. + + - title: MEG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/magnetoencephalography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/magnetoencephalography.md + category: modality_specific + description: MEG data organization including channel info, sensor positions, and cross-talk matrices. + + - title: iEEG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/intracranial-electroencephalography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/intracranial-electroencephalography.md + category: modality_specific + description: Intracranial EEG data organization including electrode coordinates and photos. + + - title: PET data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/positron-emission-tomography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/positron-emission-tomography.md + category: modality_specific + description: PET data organization including blood data, reconstruction, and radioligand metadata. + + - title: Microscopy data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/microscopy.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/microscopy.md + category: modality_specific + description: Microscopy imaging data organization for tissue samples and staining protocols. + + - title: NIRS data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/near-infrared-spectroscopy.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/near-infrared-spectroscopy.md + category: modality_specific + description: Near-infrared spectroscopy data organization including optode positions and channels. + + - title: Motion data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/motion.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/motion.md + category: modality_specific + description: Motion capture data organization for body tracking and kinematics. + + - title: MRS data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/magnetic-resonance-spectroscopy.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/magnetic-resonance-spectroscopy.md + category: modality_specific + description: Magnetic resonance spectroscopy data organization. + + - title: EMG data + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/electromyography.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/electromyography.md + category: modality_specific + description: Electromyography data organization for muscle activity recordings. + + - title: Behavioral experiments + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/behavioral-experiments.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/behavioral-experiments.md + category: modality_specific + description: Behavioral data organization for experiments with no neural recordings. + + - title: Genetic descriptor + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/genetic-descriptor.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/genetic-descriptor.md + category: modality_specific + description: Genetic descriptors associated with imaging participants. + + - title: Physiological recordings + url: https://bids-specification.readthedocs.io/en/stable/modality-specific-files/physiological-recordings.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/modality-specific-files/physiological-recordings.md + category: modality_specific + description: Continuous physiological recordings (cardiac, respiratory) acquired alongside neural data. + + # === ON-DEMAND: Derivatives (4 docs) === + - title: Derivatives introduction + url: https://bids-specification.readthedocs.io/en/stable/derivatives/introduction.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/introduction.md + category: derivatives + description: Overview of BIDS derivatives, naming conventions, and metadata for processed data. + + - title: Imaging derivatives + url: https://bids-specification.readthedocs.io/en/stable/derivatives/imaging.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/imaging.md + category: derivatives + description: Imaging-specific derivative conventions (masks, segmentations, atlases). + + - title: Common derivative data types + url: https://bids-specification.readthedocs.io/en/stable/derivatives/common-data-types.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/common-data-types.md + category: derivatives + description: Common derivative formats shared across modalities. + + - title: Atlas derivatives + url: https://bids-specification.readthedocs.io/en/stable/derivatives/atlas.html + source_url: https://raw.githubusercontent.com/bids-standard/bids-specification/master/src/derivatives/atlas.md + category: derivatives + description: Atlases as BIDS derivatives, including probability maps and label descriptions. + + # === ON-DEMAND: Website getting started (6 docs) === + - title: BIDS folder structure + url: https://bids.neuroimaging.io/getting_started/folders_and_files/folders + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/folders.md + category: getting_started + description: Overview of the BIDS directory hierarchy (subject, session, datatype folders). + + - title: BIDS file naming + url: https://bids.neuroimaging.io/getting_started/folders_and_files/files + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/files.md + category: getting_started + description: File naming conventions in BIDS (entities, suffixes, extensions). + + - title: BIDS derivatives overview + url: https://bids.neuroimaging.io/getting_started/folders_and_files/derivatives + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/derivatives.md + category: getting_started + description: Accessible overview of derivatives in BIDS for processed data. + + - title: JSON metadata + url: https://bids.neuroimaging.io/getting_started/folders_and_files/metadata/json + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/metadata/json.md + category: getting_started + description: How JSON sidecar files work in BIDS for metadata. + + - title: TSV files + url: https://bids.neuroimaging.io/getting_started/folders_and_files/metadata/tsv + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/folders_and_files/metadata/tsv.md + category: getting_started + description: How TSV tabular files work in BIDS (participants, events, channels, etc.). + + - title: Event annotation tutorial + url: https://bids.neuroimaging.io/getting_started/tutorials/annotation + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/getting_started/tutorials/annotation.md + category: getting_started + description: Tutorial on annotating events in BIDS datasets. + + # === ON-DEMAND: Website FAQ (5 docs) === + - title: General FAQ + url: https://bids.neuroimaging.io/faq/general + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/general.md + category: faq + description: Frequently asked questions about BIDS in general. + + - title: MRI FAQ + url: https://bids.neuroimaging.io/faq/mri + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/mri.md + category: faq + description: Frequently asked questions about MRI data in BIDS. + + - title: EEG FAQ + url: https://bids.neuroimaging.io/faq/eeg + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/eeg.md + category: faq + description: Frequently asked questions about EEG data in BIDS. + + - title: BIDS Apps FAQ + url: https://bids.neuroimaging.io/faq/bids-apps + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/bids-apps.md + category: faq + description: Frequently asked questions about BIDS Apps. + + - title: BIDS Extensions FAQ + url: https://bids.neuroimaging.io/faq/bids-extensions + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/faq/bids-extensions.md + category: faq + description: Frequently asked questions about extending BIDS and the BEP process. + + # === ON-DEMAND: Website tools (3 docs) === + - title: BIDS Validator + url: https://bids.neuroimaging.io/tools/validator + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/tools/validator.md + category: tools + description: Overview of the BIDS validator, installation, and usage. + + - title: BIDS converters + url: https://bids.neuroimaging.io/tools/converters + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/tools/converters.md + category: tools + description: List of tools for converting data to BIDS format. + + - title: BIDS Apps + url: https://bids.neuroimaging.io/tools/bids-apps + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/tools/bids-apps.md + category: tools + description: Overview of BIDS Apps, standardized analysis pipelines. + + # === ON-DEMAND: Extensions and governance (2 docs) === + - title: BIDS Extension Proposals + url: https://bids.neuroimaging.io/extensions/beps + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/extensions/beps.md + category: extensions + description: List and status of all BIDS Extension Proposals (BEPs). + + - title: BEP guidelines + url: https://bids.neuroimaging.io/extensions/guidelines + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/extensions/guidelines.md + category: extensions + description: Guidelines for creating and submitting a new BIDS Extension Proposal. + + # === ON-DEMAND: Schema documentation (2 docs) === + - title: How the BIDS schema works + url: https://bids.neuroimaging.io/standards/schema/how-the-schema-works + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/standards/schema/how-the-schema-works.md + category: schema + description: Explanation of the machine-readable BIDS schema and how it drives validation. + + - title: Schema rules + url: https://bids.neuroimaging.io/standards/schema/schema-rules + source_url: https://raw.githubusercontent.com/bids-standard/bids-website/main/docs/standards/schema/schema-rules.md + category: schema + description: Documentation of schema rules that define BIDS validation logic. + +# GitHub repositories for issue/PR sync +github: + repos: + - bids-standard/bids-specification + - bids-standard/bids-validator + - bids-standard/bids-website + - bids-standard/bids-examples + +# Paper/citation search configuration +# Canonical paper + modality-specific extension papers +citations: + queries: + - Brain Imaging Data Structure + - BIDS neuroimaging + - BIDS specification + - BIDS validator + - BIDS apps + dois: + # Canonical BIDS paper + - "10.1038/sdata.2016.44" # Gorgolewski et al. (2016) - The brain imaging data structure + # Modality-specific extension papers + - "10.1038/s41597-019-0104-8" # EEG-BIDS (Pernet et al., 2019) + - "10.1038/s41597-019-0105-7" # iEEG-BIDS (Holdgraf et al., 2019) + - "10.1038/sdata.2018.110" # MEG-BIDS (Niso et al., 2018) + - "10.1038/s41597-022-01164-1" # PET-BIDS (Norgaard et al., 2021) + - "10.1177/0271678X20905433" # PET guidelines (Knudsen et al., 2020) + - "10.1093/gigascience/giaa104" # Genetics-BIDS (Moreau et al., 2020) + - "10.3389/fnins.2022.871228" # Microscopy-BIDS (Bourget et al., 2022) + - "10.1038/s41597-022-01571-4" # qMRI-BIDS (Karakuzu et al., 2022) + - "10.1038/s41597-022-01615-9" # ASL-BIDS (Clement et al., 2022) + - "10.1038/s41597-024-04136-9" # NIRS-BIDS (Luke et al., 2025) + - "10.1038/s41597-024-03559-8" # Motion-BIDS (Jeung et al., 2024) + - "10.1038/s41597-025-05543-2" # MRS-BIDS (Bouchard et al., 2025) + # Related ecosystem + - "10.1371/journal.pcbi.1005209" # BIDS Apps (Gorgolewski et al., 2017) + +# Discourse forums +discourse: + - url: https://neurostars.org + tags: + - bids + +# No custom extensions for Phase 1 +# Future: schema query tool, BEP status tool, converter recommendation tool diff --git a/src/assistants/bids/tools.py b/src/assistants/bids/tools.py new file mode 100644 index 0000000..d0276c9 --- /dev/null +++ b/src/assistants/bids/tools.py @@ -0,0 +1,10 @@ +"""BIDS community-specific tools. + +Phase 1: No custom tools needed. Documentation retrieval and knowledge +discovery tools are auto-generated from config.yaml. + +Future phases may add: +- Schema query tool (query BIDS schema for valid entities/metadata) +- BEP status tool (query active/completed BEPs) +- Converter recommendation tool (suggest converters for data types) +""" diff --git a/src/assistants/eeglab/config.yaml b/src/assistants/eeglab/config.yaml index a9f9a9b..c08a63e 100644 --- a/src/assistants/eeglab/config.yaml +++ b/src/assistants/eeglab/config.yaml @@ -5,14 +5,25 @@ id: eeglab name: EEGLAB description: EEG signal processing and analysis toolbox status: available -default_model: qwen/qwen3-235b-a22b-2507 -default_model_provider: DeepInfra/FP8 +default_model: anthropic/claude-haiku-4.5 +default_model_provider: anthropic cors_origins: - https://eeglab.org - https://www.eeglab.org - https://sccn.github.io +# Community maintainers (GitHub usernames) +# Used for scoped admin access and budget alert @mentions +maintainers: + - arnodelorme + +# Budget limits for cost management +budget: + daily_limit_usd: 5.00 + monthly_limit_usd: 50.00 + alert_threshold_pct: 80.0 + # System prompt template with runtime-substituted placeholders # Placeholders (in curly braces) are replaced by CommunityAssistant at runtime: # {preloaded_docs_section} - Embedded documentation content diff --git a/src/assistants/hed/config.yaml b/src/assistants/hed/config.yaml index f81da10..647a770 100644 --- a/src/assistants/hed/config.yaml +++ b/src/assistants/hed/config.yaml @@ -19,11 +19,23 @@ cors_origins: # The backend must have this environment variable set openrouter_api_key_env_var: "OPENROUTER_API_KEY_HED" +# Community maintainers (GitHub usernames) +# Used for scoped admin access and budget alert @mentions +maintainers: + - VisLab + - yarikoptic + +# Budget limits for cost management +budget: + daily_limit_usd: 5.00 + monthly_limit_usd: 50.00 + alert_threshold_pct: 80.0 + # Default model for this community (optional) # If specified, overrides the platform-level default model # Format: creator/model-name (OpenRouter format) -default_model: "qwen/qwen3-235b-a22b-2507" -default_model_provider: "DeepInfra/FP8" # Optimized FP8 quantized inference +default_model: "anthropic/claude-haiku-4.5" +default_model_provider: "anthropic" # System prompt template # Available placeholders: {name}, {description}, {repo_list}, {paper_dois}, diff --git a/src/core/config/community.py b/src/core/config/community.py index dbcff55..daf9604 100644 --- a/src/core/config/community.py +++ b/src/core/config/community.py @@ -596,6 +596,35 @@ def validate_agent_roles(self) -> "FAQGenerationConfig": return self +class BudgetConfig(BaseModel): + """Budget limits and alert thresholds for a community. + + When configured, the scheduler periodically checks spend against + these limits and creates GitHub issues when thresholds are exceeded. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + daily_limit_usd: float = Field(..., gt=0, description="Maximum daily spend in USD") + monthly_limit_usd: float = Field(..., gt=0, description="Maximum monthly spend in USD") + alert_threshold_pct: float = Field( + default=80.0, + ge=0, + le=100, + description="Percentage of limit at which to trigger alert (default: 80%)", + ) + + @model_validator(mode="after") + def validate_limits(self) -> "BudgetConfig": + """Ensure daily limit does not exceed monthly limit.""" + if self.daily_limit_usd > self.monthly_limit_usd: + raise ValueError( + f"daily_limit_usd ({self.daily_limit_usd}) cannot exceed " + f"monthly_limit_usd ({self.monthly_limit_usd})" + ) + return self + + class CommunityConfig(BaseModel): """Configuration for a single research community assistant. @@ -730,6 +759,32 @@ def validate_id(cls, v: str) -> str: If not specified, uses default routing for the model. """ + maintainers: list[str] = Field(default_factory=list) + """GitHub usernames of community maintainers. + + Used for: + - @mentioning in automated alert issues (budget alerts, etc.) + - Documenting who is responsible for the community + + Example: + maintainers: + - octocat + - janedoe + """ + + budget: BudgetConfig | None = None + """Budget limits and alert thresholds for cost management. + + When configured, the scheduler checks spend against these limits + and creates GitHub issues when thresholds are exceeded. + + Example: + budget: + daily_limit_usd: 5.0 + monthly_limit_usd: 50.0 + alert_threshold_pct: 80 + """ + @field_validator("cors_origins") @classmethod def validate_cors_origins(cls, v: list[str]) -> list[str]: @@ -756,6 +811,30 @@ def validate_cors_origins(cls, v: list[str]) -> list[str]: validated.append(origin) return validated + @field_validator("maintainers") + @classmethod + def validate_maintainers(cls, v: list[str]) -> list[str]: + """Validate GitHub usernames in maintainers list. + + GitHub usernames: 1-39 chars, alphanumeric or hyphens, + cannot start or end with hyphen. + """ + gh_username_pattern = re.compile(r"^[a-zA-Z0-9]([a-zA-Z0-9-]{0,37}[a-zA-Z0-9])?$") + validated = [] + for username in v: + username = username.strip() + if not username: + continue + if not gh_username_pattern.match(username): + raise ValueError( + f"Invalid GitHub username: '{username}'. " + "Must be 1-39 alphanumeric characters or hyphens, " + "cannot start/end with hyphen." + ) + if username not in validated: + validated.append(username) + return validated + @field_validator("openrouter_api_key_env_var") @classmethod def validate_openrouter_api_key_env_var(cls, v: str | None) -> str | None: diff --git a/src/knowledge/db.py b/src/knowledge/db.py index 42b6dd3..479d829 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -38,6 +38,9 @@ UNIQUE(repo, item_type, number) ); +-- Index for direct number lookups (e.g. "What is PR #2022?") +CREATE INDEX IF NOT EXISTS idx_github_items_number ON github_items(number); + -- FTS5 virtual table for full-text search on GitHub items CREATE VIRTUAL TABLE IF NOT EXISTS github_items_fts USING fts5( title, diff --git a/src/knowledge/search.py b/src/knowledge/search.py index 100f066..d16aec5 100644 --- a/src/knowledge/search.py +++ b/src/knowledge/search.py @@ -106,6 +106,59 @@ class SearchResult: created_at: str +def _is_pure_number_query(query: str) -> bool: + """Check if the query is purely a number lookup with no useful text for FTS. + + Returns True for queries like "2022", "#500", "PR 2022", "issue #500" + where FTS would produce poor results (searching for literal "#500" or "PR 10"). + """ + stripped = query.strip() + # Pure number or hash-number + if stripped.lstrip("#").isdigit(): + return True + # Keyword + number with nothing else + return bool(re.fullmatch(r"(?:pr|pull|issue|bug|feature)\s*#?\s*\d+", stripped, re.IGNORECASE)) + + +def _extract_number(query: str) -> int | None: + """Extract an issue/PR number from a query string. + + Handles patterns like "2022", "#2022", "PR 2022", "issue #500". + Pure numeric queries (e.g. "2022") are treated as number lookups first; + this may match a PR/issue number rather than items mentioning the year. + + Returns: + The extracted number, or None if no number pattern found. + """ + # Strip and try common patterns + stripped = query.strip().lstrip("#") + # Direct number + if stripped.isdigit(): + return int(stripped) + # "PR 2022", "issue #500", "pull #2022", etc. + m = re.match(r"(?:pr|pull|issue|bug|feature)\s*#?\s*(\d+)", query.strip(), re.IGNORECASE) + if m: + return int(m.group(1)) + return None + + +def _row_to_result(row: sqlite3.Row) -> SearchResult: + """Convert a database row to a SearchResult.""" + first_message = row["first_message"] or "" + snippet = first_message[:200].strip() + if len(first_message) > 200: + snippet += "..." + return SearchResult( + title=row["title"], + url=row["url"], + snippet=snippet, + source="github", + item_type=row["item_type"], + status=row["status"], + created_at=row["created_at"] or "", + ) + + def search_github_items( query: str, project: str = "hed", @@ -114,10 +167,14 @@ def search_github_items( status: str | None = None, repo: str | None = None, ) -> list[SearchResult]: - """Search GitHub issues and PRs using phrase matching. + """Search GitHub issues and PRs by number, title, or body text. + + When the query contains a number (e.g. "2022", "#500", "PR 2022"), + results matching that number are returned first, followed by + full-text search results. Args: - query: Search phrase (treated as exact phrase, not FTS5 operators) + query: Search phrase, PR/issue number, or keyword project: Assistant/project name for database isolation. Defaults to 'hed'. limit: Maximum number of results item_type: Filter by 'issue' or 'pr' @@ -125,56 +182,74 @@ def search_github_items( repo: Filter by repository name Returns: - List of matching results, ordered by relevance - """ - # Build SQL query with optional filters - sql = """ - SELECT g.title, g.url, g.first_message, g.item_type, g.status, - g.created_at, g.repo - FROM github_items_fts f - JOIN github_items g ON f.rowid = g.id - WHERE github_items_fts MATCH ? + List of matching results, with number matches first """ - params: list[str | int] = [query] - - if item_type: - sql += " AND g.item_type = ?" - params.append(item_type) - if status: - sql += " AND g.status = ?" - params.append(status) - if repo: - sql += " AND g.repo = ?" - params.append(repo) - - sql += " ORDER BY rank LIMIT ?" - params.append(limit) - results = [] + seen_urls: set[str] = set() + try: with get_connection(project) as conn: - # Sanitize user query to prevent FTS5 injection - safe_query = _sanitize_fts5_query(query) - params[0] = safe_query - - for row in conn.execute(sql, params): - # Create snippet from first_message (first 200 chars) - first_message = row["first_message"] or "" - snippet = first_message[:200].strip() - if len(first_message) > 200: - snippet += "..." + # Phase 1: Try direct number lookup + number = _extract_number(query) + is_pure_number = _is_pure_number_query(query) + if number is not None: + num_sql = """ + SELECT title, url, first_message, item_type, status, + created_at, repo + FROM github_items WHERE number = ? + """ + num_params: list[str | int] = [number] + if item_type: + num_sql += " AND item_type = ?" + num_params.append(item_type) + if status: + num_sql += " AND status = ?" + num_params.append(status) + if repo: + num_sql += " AND repo = ?" + num_params.append(repo) + + for row in conn.execute(num_sql, num_params): + result = _row_to_result(row) + results.append(result) + seen_urls.add(result.url) + + if not results: + logger.debug("Number lookup for %d found no items", number) + + # Phase 2: Full-text search for remaining slots + # Skip FTS for pure number queries (e.g. "#500", "PR 2022") since + # the sanitized query would search for literal "#500" which won't + # match anything useful in title/body text. + remaining = limit - len(results) + if remaining > 0 and not is_pure_number: + fts_sql = """ + SELECT g.title, g.url, g.first_message, g.item_type, g.status, + g.created_at, g.repo + FROM github_items_fts f + JOIN github_items g ON f.rowid = g.id + WHERE github_items_fts MATCH ? + """ + fts_params: list[str | int] = [_sanitize_fts5_query(query)] + + if item_type: + fts_sql += " AND g.item_type = ?" + fts_params.append(item_type) + if status: + fts_sql += " AND g.status = ?" + fts_params.append(status) + if repo: + fts_sql += " AND g.repo = ?" + fts_params.append(repo) + + fts_sql += " ORDER BY rank LIMIT ?" + fts_params.append(remaining) + + for row in conn.execute(fts_sql, fts_params): + if row["url"] not in seen_urls: + results.append(_row_to_result(row)) + seen_urls.add(row["url"]) - results.append( - SearchResult( - title=row["title"], - url=row["url"], - snippet=snippet, - source="github", - item_type=row["item_type"], - status=row["status"], - created_at=row["created_at"] or "", - ) - ) except sqlite3.OperationalError as e: # Infrastructure failure (corruption, disk full, permissions) - must propagate logger.error( diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 0000000..d05dfb5 --- /dev/null +++ b/src/metrics/__init__.py @@ -0,0 +1 @@ +"""Metrics collection and request logging for OSA.""" diff --git a/src/metrics/alerts.py b/src/metrics/alerts.py new file mode 100644 index 0000000..ee29693 --- /dev/null +++ b/src/metrics/alerts.py @@ -0,0 +1,163 @@ +"""GitHub issue alerting for budget and operational alerts. + +Creates GitHub issues when budget thresholds are exceeded, +with deduplication to avoid spamming. +""" + +import json +import logging +import subprocess + +from src.metrics.budget import BudgetStatus + +logger = logging.getLogger(__name__) + +# GitHub repo for alert issues (org/repo format) +ALERT_REPO = "OpenScience-Collective/osa" + + +def _issue_exists(title: str, repo: str = ALERT_REPO) -> bool | None: + """Check if an open issue with this title already exists. + + Uses gh CLI to search for existing issues to prevent duplicates. + + Returns: + True if a matching issue exists, False if not, None if the check failed. + """ + try: + result = subprocess.run( + [ + "gh", + "issue", + "list", + "--repo", + repo, + "--state", + "open", + "--search", + title, + "--json", + "title", + "--limit", + "5", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + logger.error("gh issue list failed: %s", result.stderr) + return None + + issues = json.loads(result.stdout) + return any(issue.get("title") == title for issue in issues) + except Exception: + logger.exception("Failed to check existing issues") + return None + + +def create_budget_alert_issue( + budget_status: BudgetStatus, + maintainers: list[str], + repo: str = ALERT_REPO, +) -> str | None: + """Create a GitHub issue for a budget alert. + + Includes deduplication: checks for existing open issues with the same + title before creating a new one. + + Args: + budget_status: The budget check result with spend/limit data. + maintainers: GitHub usernames to @mention in the issue body. + repo: GitHub repository in org/repo format. + + Returns: + The issue URL if created, None if skipped (duplicate or error). + """ + # Determine alert type + alert_parts = [] + if budget_status.daily_exceeded: + alert_parts.append("daily limit exceeded") + elif budget_status.daily_alert: + alert_parts.append(f"daily spend at {budget_status.daily_pct:.0f}%") + if budget_status.monthly_exceeded: + alert_parts.append("monthly limit exceeded") + elif budget_status.monthly_alert: + alert_parts.append(f"monthly spend at {budget_status.monthly_pct:.0f}%") + + if not alert_parts: + return None + + alert_type = ", ".join(alert_parts) + title = f"[Budget Alert] {budget_status.community_id}: {alert_type}" + + # Check for existing open issue (three-state: True/False/None) + exists = _issue_exists(title, repo) + if exists is None: + logger.warning( + "Could not verify deduplication for %s; suppressing alert to prevent spam. " + "Check gh CLI and GitHub token configuration.", + budget_status.community_id, + ) + return None + if exists: + logger.info( + "Budget alert issue already exists for %s, skipping", budget_status.community_id + ) + return None + + # Build issue body + mentions = ( + " ".join(f"@{m}" for m in maintainers) if maintainers else "No maintainers configured" + ) + + body = f"""## Budget Alert for `{budget_status.community_id}` + +**Alert:** {alert_type} + +### Current Spend + +| Metric | Spend | Limit | Usage | +|--------|-------|-------|-------| +| Daily | ${budget_status.daily_spend_usd:.4f} | ${budget_status.daily_limit_usd:.2f} | {budget_status.daily_pct:.1f}% | +| Monthly | ${budget_status.monthly_spend_usd:.4f} | ${budget_status.monthly_limit_usd:.2f} | {budget_status.monthly_pct:.1f}% | + +### Alert Threshold +Configured at {budget_status.alert_threshold_pct:.0f}% of limits. + +### Maintainers +{mentions} + +--- +*This issue was created automatically by the OSA budget monitoring system.* +""" + + try: + result = subprocess.run( + [ + "gh", + "issue", + "create", + "--repo", + repo, + "--title", + title, + "--body", + body, + "--label", + "cost-management,operations", + ], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + logger.error("Failed to create budget alert issue: %s", result.stderr) + return None + + issue_url = result.stdout.strip() + logger.info("Created budget alert issue: %s", issue_url) + return issue_url + except Exception: + logger.exception("Failed to create budget alert issue for %s", budget_status.community_id) + return None diff --git a/src/metrics/budget.py b/src/metrics/budget.py new file mode 100644 index 0000000..1d02868 --- /dev/null +++ b/src/metrics/budget.py @@ -0,0 +1,121 @@ +"""Budget checking for community cost management. + +Queries the metrics database for current spend and compares against +configured budget limits. +""" + +import logging +import sqlite3 +from dataclasses import dataclass + +from src.core.config.community import BudgetConfig + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BudgetStatus: + """Result of a budget check for a community.""" + + community_id: str + daily_spend_usd: float + monthly_spend_usd: float + daily_limit_usd: float + monthly_limit_usd: float + alert_threshold_pct: float + + def __post_init__(self) -> None: + if self.daily_spend_usd < 0: + raise ValueError(f"daily_spend_usd must be non-negative, got {self.daily_spend_usd}") + if self.monthly_spend_usd < 0: + raise ValueError( + f"monthly_spend_usd must be non-negative, got {self.monthly_spend_usd}" + ) + + @property + def daily_pct(self) -> float: + """Daily spend as percentage of limit.""" + if self.daily_limit_usd <= 0: + return 0.0 + return (self.daily_spend_usd / self.daily_limit_usd) * 100 + + @property + def monthly_pct(self) -> float: + """Monthly spend as percentage of limit.""" + if self.monthly_limit_usd <= 0: + return 0.0 + return (self.monthly_spend_usd / self.monthly_limit_usd) * 100 + + @property + def daily_exceeded(self) -> bool: + """Whether daily spend has reached or exceeded the daily limit.""" + return self.daily_spend_usd >= self.daily_limit_usd + + @property + def monthly_exceeded(self) -> bool: + """Whether monthly spend has reached or exceeded the monthly limit.""" + return self.monthly_spend_usd >= self.monthly_limit_usd + + @property + def daily_alert(self) -> bool: + """Whether daily spend crossed the alert threshold.""" + return self.daily_pct >= self.alert_threshold_pct + + @property + def monthly_alert(self) -> bool: + """Whether monthly spend crossed the alert threshold.""" + return self.monthly_pct >= self.alert_threshold_pct + + @property + def needs_alert(self) -> bool: + """Whether any alert threshold has been crossed.""" + return self.daily_alert or self.monthly_alert + + +def check_budget( + community_id: str, + config: BudgetConfig, + conn: sqlite3.Connection, +) -> BudgetStatus: + """Check current spend against budget limits. + + Queries estimated_cost from request_log for today and current month. + + Args: + community_id: The community identifier. + config: Budget configuration with limits and alert threshold. + conn: SQLite connection. + + Returns: + BudgetStatus with current spend and limit info. + """ + # Daily spend (today UTC) + daily_row = conn.execute( + """ + SELECT COALESCE(SUM(estimated_cost), 0) as spend + FROM request_log + WHERE community_id = ? + AND date(timestamp) = date('now') + """, + (community_id,), + ).fetchone() + + # Monthly spend (current month UTC) + monthly_row = conn.execute( + """ + SELECT COALESCE(SUM(estimated_cost), 0) as spend + FROM request_log + WHERE community_id = ? + AND strftime('%Y-%m', timestamp) = strftime('%Y-%m', 'now') + """, + (community_id,), + ).fetchone() + + return BudgetStatus( + community_id=community_id, + daily_spend_usd=round(daily_row["spend"], 6), + monthly_spend_usd=round(monthly_row["spend"], 6), + daily_limit_usd=config.daily_limit_usd, + monthly_limit_usd=config.monthly_limit_usd, + alert_threshold_pct=config.alert_threshold_pct, + ) diff --git a/src/metrics/cost.py b/src/metrics/cost.py new file mode 100644 index 0000000..1fd5036 --- /dev/null +++ b/src/metrics/cost.py @@ -0,0 +1,67 @@ +"""Cost estimation for LLM requests. + +Model pricing table with per-token costs (USD per million tokens). +Pricing is from OpenRouter as of 2025-07; update regularly. +""" + +import logging + +logger = logging.getLogger(__name__) + +# Pricing: USD per 1M tokens (input, output) +# Source: https://openrouter.ai/models +# Last updated: 2025-07 +MODEL_PRICING: dict[str, tuple[float, float]] = { + # Qwen models + "qwen/qwen3-235b-a22b-2507": (0.14, 0.34), + "qwen/qwen3-30b-a3b-2507": (0.07, 0.15), + # OpenAI models + "openai/gpt-4o": (2.50, 10.00), + "openai/gpt-4o-mini": (0.15, 0.60), + "openai/gpt-oss-120b": (0.00, 0.00), # Free tier + "openai/o1": (15.00, 60.00), + "openai/o3-mini": (1.10, 4.40), + # Anthropic models + "anthropic/claude-opus-4": (15.00, 75.00), + "anthropic/claude-sonnet-4": (3.00, 15.00), + "anthropic/claude-haiku-4.5": (0.80, 4.00), + "anthropic/claude-3.5-sonnet": (3.00, 15.00), + # Google models + "google/gemini-2.5-pro-preview": (1.25, 10.00), + "google/gemini-2.5-flash-preview": (0.15, 0.60), + # DeepSeek models + "deepseek/deepseek-chat-v3": (0.14, 0.28), + "deepseek/deepseek-r1": (0.55, 2.19), + # Meta models + "meta-llama/llama-4-maverick": (0.16, 0.40), +} + +# Fallback rate for models not in the pricing table +_FALLBACK_INPUT_RATE = 1.00 # USD per 1M tokens +_FALLBACK_OUTPUT_RATE = 3.00 # USD per 1M tokens + + +def estimate_cost( + model: str | None, + input_tokens: int, + output_tokens: int, +) -> float: + """Estimate the USD cost for a request. + + Args: + model: Model name in OpenRouter format (e.g., "qwen/qwen3-235b-a22b-2507"). + input_tokens: Number of input tokens. + output_tokens: Number of output tokens. + + Returns: + Estimated cost in USD, rounded to 6 decimal places. + """ + if model and model in MODEL_PRICING: + input_rate, output_rate = MODEL_PRICING[model] + else: + if model: + logger.warning("No pricing data for model %s, using fallback rates", model) + input_rate, output_rate = _FALLBACK_INPUT_RATE, _FALLBACK_OUTPUT_RATE + + cost = (input_tokens * input_rate + output_tokens * output_rate) / 1_000_000 + return round(cost, 6) diff --git a/src/metrics/db.py b/src/metrics/db.py new file mode 100644 index 0000000..35b6af5 --- /dev/null +++ b/src/metrics/db.py @@ -0,0 +1,280 @@ +"""Metrics storage layer using SQLite with WAL mode. + +Single SQLite database at {data_dir}/metrics.db stores all request logs. +WAL mode enables concurrent reads during writes. +""" + +import json +import logging +import sqlite3 +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path + +from langchain_core.messages import AIMessage, BaseMessage + +logger = logging.getLogger(__name__) + +# Track consecutive log_request failures for escalation +_log_request_failures: int = 0 + +METRICS_SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS request_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + request_id TEXT NOT NULL, + timestamp TEXT NOT NULL, + community_id TEXT, + endpoint TEXT NOT NULL, + method TEXT NOT NULL, + duration_ms REAL, + status_code INTEGER, + model TEXT, + input_tokens INTEGER, + output_tokens INTEGER, + total_tokens INTEGER, + estimated_cost REAL, + tools_called TEXT, + key_source TEXT, + stream INTEGER DEFAULT 0, + tool_call_count INTEGER DEFAULT 0, + error_message TEXT, + langfuse_trace_id TEXT +); + +CREATE INDEX IF NOT EXISTS idx_request_log_community + ON request_log(community_id); +CREATE INDEX IF NOT EXISTS idx_request_log_timestamp + ON request_log(timestamp); +CREATE INDEX IF NOT EXISTS idx_request_log_community_timestamp + ON request_log(community_id, timestamp); +""" + +# Columns added after initial schema; ALTER TABLE for existing databases +_MIGRATION_COLUMNS = [ + ("tool_call_count", "INTEGER DEFAULT 0"), + ("error_message", "TEXT"), + ("langfuse_trace_id", "TEXT"), +] + + +@dataclass +class RequestLogEntry: + """A single request log entry for the metrics database.""" + + request_id: str + timestamp: str + endpoint: str + method: str + community_id: str | None = None + duration_ms: float | None = None + status_code: int | None = None + model: str | None = None + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + estimated_cost: float | None = None + tools_called: list[str] = field(default_factory=list) + key_source: str | None = None + stream: bool = False + tool_call_count: int = 0 + error_message: str | None = None + langfuse_trace_id: str | None = None + + +def get_metrics_db_path() -> Path: + """Return path to the metrics SQLite database. + + Uses DATA_DIR environment variable if set (Docker deployments), + otherwise falls back to platform-specific user data directory. + """ + import os + + from platformdirs import user_data_dir + + data_dir_env = os.environ.get("DATA_DIR") + base = Path(data_dir_env) if data_dir_env else Path(user_data_dir("osa", "osc")) + base.mkdir(parents=True, exist_ok=True) + return base / "metrics.db" + + +def get_metrics_connection(db_path: Path | None = None) -> sqlite3.Connection: + """Get a connection to the metrics database. + + Args: + db_path: Optional path override (for testing). + """ + path = db_path or get_metrics_db_path() + conn = sqlite3.connect(str(path)) + conn.row_factory = sqlite3.Row + return conn + + +@contextmanager +def metrics_connection(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]: + """Context manager for a metrics database connection. + + Ensures the connection is closed after use, even if an exception occurs. + + Args: + db_path: Optional path override (for testing). + """ + conn = get_metrics_connection(db_path) + try: + yield conn + finally: + conn.close() + + +def _migrate_columns(conn: sqlite3.Connection) -> None: + """Add new columns to existing databases (backward-compatible migration). + + Attempts ALTER TABLE ADD COLUMN for each new column. If the column + already exists, SQLite raises OperationalError with 'duplicate column', + which we catch and ignore, making this function idempotent. + """ + for col_name, col_def in _MIGRATION_COLUMNS: + try: + conn.execute(f"ALTER TABLE request_log ADD COLUMN {col_name} {col_def}") + logger.info("Added column %s to request_log", col_name) + except sqlite3.OperationalError as e: + if "duplicate column" in str(e).lower(): + pass # Expected on subsequent runs + else: + logger.error("Failed to add column %s to request_log: %s", col_name, e) + raise + + +def init_metrics_db(db_path: Path | None = None) -> None: + """Initialize the metrics database schema. Idempotent. + + Creates the request_log table and indexes if they don't exist. + Enables WAL mode for concurrent read/write access. + Runs migrations to add new columns to existing databases. + + Args: + db_path: Optional path override (for testing). + """ + conn = get_metrics_connection(db_path) + try: + conn.execute("PRAGMA journal_mode=WAL") + conn.executescript(METRICS_SCHEMA_SQL) + _migrate_columns(conn) + conn.commit() + logger.info("Metrics database initialized at %s", db_path or get_metrics_db_path()) + finally: + conn.close() + + +def log_request(entry: RequestLogEntry, db_path: Path | None = None) -> None: + """Insert a request log entry into the database. + + Args: + entry: The log entry to insert. + db_path: Optional path override (for testing). + """ + global _log_request_failures + conn = get_metrics_connection(db_path) + try: + conn.execute( + """ + INSERT INTO request_log ( + request_id, timestamp, community_id, endpoint, method, + duration_ms, status_code, model, input_tokens, output_tokens, + total_tokens, estimated_cost, tools_called, key_source, stream, + tool_call_count, error_message, langfuse_trace_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.request_id, + entry.timestamp, + entry.community_id, + entry.endpoint, + entry.method, + entry.duration_ms, + entry.status_code, + entry.model, + entry.input_tokens, + entry.output_tokens, + entry.total_tokens, + entry.estimated_cost, + json.dumps(entry.tools_called) if entry.tools_called else None, + entry.key_source, + 1 if entry.stream else 0, + entry.tool_call_count, + entry.error_message, + entry.langfuse_trace_id, + ), + ) + conn.commit() + except sqlite3.Error: + _log_request_failures += 1 + logger.exception( + "Failed to log metrics request %s (endpoint=%s, community=%s) [failure #%d]", + entry.request_id, + entry.endpoint, + entry.community_id, + _log_request_failures, + ) + if _log_request_failures >= 10: + logger.critical( + "Metrics DB write has failed %d times. " + "Possible disk/database issue requiring investigation.", + _log_request_failures, + ) + else: + # Reset counter on success + _log_request_failures = 0 + finally: + conn.close() + + +def extract_token_usage(result: dict) -> tuple[int, int, int]: + """Extract token usage from agent result messages. + + Sums usage_metadata from all AIMessages in result["messages"]. + + Args: + result: Agent result dict containing "messages" list. + + Returns: + Tuple of (input_tokens, output_tokens, total_tokens). + Returns (0, 0, 0) if no usage data is available. + """ + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + + messages: list[BaseMessage] = result.get("messages", []) + for msg in messages: + if not isinstance(msg, AIMessage): + continue + usage = getattr(msg, "usage_metadata", None) + if usage is None: + continue + # usage_metadata is a dict with input_tokens, output_tokens, total_tokens + if isinstance(usage, dict): + input_tokens += usage.get("input_tokens", 0) + output_tokens += usage.get("output_tokens", 0) + total_tokens += usage.get("total_tokens", 0) + + return input_tokens, output_tokens, total_tokens + + +def extract_tool_names(result: dict) -> list[str]: + """Extract tool names from agent result. + + Args: + result: Agent result dict containing "tool_calls" list. + + Returns: + List of tool names called during the request. + """ + tool_calls = result.get("tool_calls", []) + return [tc.get("name", "") for tc in tool_calls if tc.get("name")] + + +def now_iso() -> str: + """Return current UTC time as ISO string.""" + return datetime.now(UTC).isoformat() diff --git a/src/metrics/middleware.py b/src/metrics/middleware.py new file mode 100644 index 0000000..174b034 --- /dev/null +++ b/src/metrics/middleware.py @@ -0,0 +1,109 @@ +"""Request timing and metrics middleware. + +Captures request-scoped data (endpoint, duration, status_code, timestamp). +For agent requests, the handler sets metrics on request.state which the +middleware reads after the response completes. + +Streaming caveat: For streaming responses, the middleware fires before +streaming completes. Streaming handlers log metrics directly at the end +of the generator instead. +""" + +import logging +import time +import uuid + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +from src.metrics.db import RequestLogEntry, log_request, now_iso + +logger = logging.getLogger(__name__) + +# Path suffixes that indicate a community-scoped route +_COMMUNITY_SUFFIXES = {"ask", "chat", "sessions", "metrics", "config"} + +# Top-level path prefixes that are NOT community IDs +_RESERVED_PREFIXES = {"health", "metrics", "sync", "docs", "redoc", "openapi.json", "frontend"} + + +def _extract_community_id(path: str) -> str | None: + """Extract community_id from community-scoped URL paths. + + Matches patterns like /{community_id}/ask, /{community_id}/metrics/public, + /{community_id}/sessions, etc. Returns None for non-community routes + (health, sync, global metrics, docs). + """ + parts = path.strip("/").split("/") + if not parts or not parts[0]: + return None + # Skip known non-community prefixes + if parts[0] in _RESERVED_PREFIXES: + return None + # Single segment (e.g., /health) or community with sub-path + if len(parts) >= 2 and parts[1] in _COMMUNITY_SUFFIXES: + return parts[0] + return None + + +class MetricsMiddleware(BaseHTTPMiddleware): + """Middleware that logs request timing and metrics. + + Sets request.state.request_id and request.state.start_time for + downstream handlers to use. After the response, logs a basic + request entry unless the handler has set request.state.metrics_logged + (indicating the handler logged its own detailed entry, e.g. streaming). + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + request_id = str(uuid.uuid4()) + request.state.request_id = request_id + request.state.start_time = time.monotonic() + request.state.metrics_logged = False + + response = await call_next(request) + + try: + # If handler already logged metrics (streaming), skip + if getattr(request.state, "metrics_logged", False): + return response + + duration_ms = (time.monotonic() - request.state.start_time) * 1000 + community_id = _extract_community_id(request.url.path) + + # Check if handler set agent metrics on request.state + agent_data = getattr(request.state, "metrics_agent_data", None) + + agent_kwargs = {} + if agent_data and isinstance(agent_data, dict): + agent_kwargs = { + "model": agent_data.get("model"), + "input_tokens": agent_data.get("input_tokens"), + "output_tokens": agent_data.get("output_tokens"), + "total_tokens": agent_data.get("total_tokens"), + "estimated_cost": agent_data.get("estimated_cost"), + "tools_called": agent_data.get("tools_called", []), + "key_source": agent_data.get("key_source"), + "stream": agent_data.get("stream", False), + "tool_call_count": agent_data.get("tool_call_count", 0), + "error_message": agent_data.get("error_message"), + "langfuse_trace_id": agent_data.get("langfuse_trace_id"), + } + + entry = RequestLogEntry( + request_id=request_id, + timestamp=now_iso(), + endpoint=request.url.path, + method=request.method, + community_id=community_id, + duration_ms=round(duration_ms, 1), + status_code=response.status_code, + **agent_kwargs, + ) + + log_request(entry) + except Exception: + logger.exception("Metrics middleware failed for request %s", request_id) + + return response diff --git a/src/metrics/queries.py b/src/metrics/queries.py new file mode 100644 index 0000000..8c2d7be --- /dev/null +++ b/src/metrics/queries.py @@ -0,0 +1,642 @@ +"""Aggregation queries for the metrics database. + +Provides summary statistics, usage breakdowns, and overview queries +for both per-community and cross-community metrics. +""" + +import json +import logging +import sqlite3 +from typing import Any + +logger = logging.getLogger(__name__) + +# SQLite strftime patterns for time-bucketed queries +_PERIOD_FORMAT_MAP = { + "daily": "%Y-%m-%d", + "weekly": "%Y-W%W", + "monthly": "%Y-%m", +} + + +def _validate_period(period: str) -> str: + """Validate and return the strftime format for a period. + + Raises ValueError if period is not one of: daily, weekly, monthly. + """ + if period not in _PERIOD_FORMAT_MAP: + raise ValueError(f"Invalid period: {period}. Must be one of: daily, weekly, monthly") + return _PERIOD_FORMAT_MAP[period] + + +def _count_tools( + community_id: str, conn: sqlite3.Connection, limit: int = 5 +) -> list[dict[str, Any]]: + """Count tool usage from JSON arrays in the tools_called column. + + Returns top tools sorted by count descending. + """ + tool_rows = conn.execute( + """ + SELECT tools_called + FROM request_log + WHERE community_id = ? AND tools_called IS NOT NULL + """, + (community_id,), + ).fetchall() + tool_counts: dict[str, int] = {} + for tr in tool_rows: + try: + tools = json.loads(tr["tools_called"]) + for tool in tools: + tool_counts[tool] = tool_counts.get(tool, 0) + 1 + except (json.JSONDecodeError, TypeError): + logger.warning( + "Malformed tools_called data in request_log for community %s: %r", + community_id, + tr["tools_called"], + ) + top = sorted(tool_counts.items(), key=lambda x: x[1], reverse=True)[:limit] + return [{"tool": name, "count": count} for name, count in top] + + +def get_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: + """Get summary statistics for a single community. + + Args: + community_id: The community identifier. + conn: SQLite connection (with row_factory=sqlite3.Row). + + Returns: + Dict with total_requests, total_input_tokens, total_output_tokens, + total_tokens, avg_duration_ms, total_estimated_cost, error_rate, + top_models, top_tools. + """ + row = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as total_estimated_cost, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_count + FROM request_log + WHERE community_id = ? + """, + (community_id,), + ).fetchone() + + total = row["total_requests"] + error_rate = row["error_count"] / total if total > 0 else 0.0 + + # Top models + model_rows = conn.execute( + """ + SELECT model, COUNT(*) as count + FROM request_log + WHERE community_id = ? AND model IS NOT NULL + GROUP BY model + ORDER BY count DESC + LIMIT 5 + """, + (community_id,), + ).fetchall() + + return { + "community_id": community_id, + "total_requests": total, + "total_input_tokens": row["total_input_tokens"], + "total_output_tokens": row["total_output_tokens"], + "total_tokens": row["total_tokens"], + "avg_duration_ms": round(row["avg_duration_ms"], 1), + "total_estimated_cost": round(row["total_estimated_cost"], 4), + "error_rate": round(error_rate, 4), + "top_models": [{"model": r["model"], "count": r["count"]} for r in model_rows], + "top_tools": _count_tools(community_id, conn), + } + + +def get_usage_stats( + community_id: str, + period: str, + conn: sqlite3.Connection, +) -> dict[str, Any]: + """Get time-bucketed usage statistics for a community. + + Args: + community_id: The community identifier. + period: One of "daily", "weekly", "monthly". + conn: SQLite connection. + + Returns: + Dict with period, community_id, and buckets list. + """ + fmt = _validate_period(period) + + # Safe to use f-string: fmt is from _PERIOD_FORMAT_MAP whitelist, not user input + rows = conn.execute( + f""" + SELECT + strftime('{fmt}', timestamp) as bucket, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as estimated_cost, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors + FROM request_log + WHERE community_id = ? + GROUP BY bucket + ORDER BY bucket + """, + (community_id,), + ).fetchall() + + return { + "community_id": community_id, + "period": period, + "buckets": [ + { + "bucket": r["bucket"], + "requests": r["requests"], + "tokens": r["tokens"], + "avg_duration_ms": round(r["avg_duration_ms"], 1), + "estimated_cost": round(r["estimated_cost"], 4), + "errors": r["errors"], + } + for r in rows + ], + } + + +def get_overview(conn: sqlite3.Connection) -> dict[str, Any]: + """Get cross-community metrics overview. + + Args: + conn: SQLite connection. + + Returns: + Dict with total stats and per-community breakdown. + """ + # Global totals + totals = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as total_estimated_cost, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as total_errors + FROM request_log + """ + ).fetchone() + + total_req = totals["total_requests"] + + # Per-community breakdown + community_rows = conn.execute( + """ + SELECT + community_id, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as tokens, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms, + COALESCE(SUM(estimated_cost), 0) as estimated_cost + FROM request_log + WHERE community_id IS NOT NULL + GROUP BY community_id + ORDER BY requests DESC + """ + ).fetchall() + + return { + "total_requests": total_req, + "total_tokens": totals["total_tokens"], + "avg_duration_ms": round(totals["avg_duration_ms"], 1), + "total_estimated_cost": round(totals["total_estimated_cost"], 4), + "error_rate": round(totals["total_errors"] / total_req, 4) if total_req > 0 else 0.0, + "communities": [ + { + "community_id": r["community_id"], + "requests": r["requests"], + "tokens": r["tokens"], + "avg_duration_ms": round(r["avg_duration_ms"], 1), + "estimated_cost": round(r["estimated_cost"], 4), + } + for r in community_rows + ], + } + + +def get_token_breakdown( + conn: sqlite3.Connection, + community_id: str | None = None, +) -> dict[str, Any]: + """Get token usage breakdown by model and key_source. + + Args: + conn: SQLite connection. + community_id: Optional filter by community. + + Returns: + Dict with by_model and by_key_source breakdowns. + """ + where = "" + params: tuple = () + if community_id: + where = "WHERE community_id = ?" + params = (community_id,) + + by_model = conn.execute( + f""" + SELECT + model, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(estimated_cost), 0) as estimated_cost + FROM request_log + {where} + {"AND" if where else "WHERE"} model IS NOT NULL + GROUP BY model + ORDER BY total_tokens DESC + """, + params, + ).fetchall() + + by_key_source = conn.execute( + f""" + SELECT + key_source, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(estimated_cost), 0) as estimated_cost + FROM request_log + {where} + {"AND" if where else "WHERE"} key_source IS NOT NULL + GROUP BY key_source + ORDER BY requests DESC + """, + params, + ).fetchall() + + return { + "community_id": community_id, + "by_model": [ + { + "model": r["model"], + "requests": r["requests"], + "input_tokens": r["input_tokens"], + "output_tokens": r["output_tokens"], + "total_tokens": r["total_tokens"], + "estimated_cost": round(r["estimated_cost"], 4), + } + for r in by_model + ], + "by_key_source": [ + { + "key_source": r["key_source"], + "requests": r["requests"], + "total_tokens": r["total_tokens"], + "estimated_cost": round(r["estimated_cost"], 4), + } + for r in by_key_source + ], + } + + +# --------------------------------------------------------------------------- +# Public query functions (no tokens, costs, or model info) +# --------------------------------------------------------------------------- + + +def get_public_overview( + conn: sqlite3.Connection, + registered_communities: list[str] | None = None, +) -> dict[str, Any]: + """Get public metrics overview with only non-sensitive data. + + Returns request counts and error rates; no tokens, costs, or model info. + Only counts community-scoped requests (not health checks, sync, etc.). + + Args: + conn: SQLite connection. + registered_communities: If provided, ensures all these communities + appear in the response even if they have zero logged requests. + + Returns: + Dict with total_requests, error_rate, communities_active, + and per-community request counts. + """ + # Only count community-scoped requests, not infrastructure endpoints + totals = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as total_errors + FROM request_log + WHERE community_id IS NOT NULL + """ + ).fetchone() + + total_req = totals["total_requests"] + + community_rows = conn.execute( + """ + SELECT + community_id, + COUNT(*) as requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors + FROM request_log + WHERE community_id IS NOT NULL + GROUP BY community_id + ORDER BY requests DESC + """ + ).fetchall() + + # Build community data from DB results + community_data: dict[str, dict[str, Any]] = {} + for r in community_rows: + cid = r["community_id"] + community_data[cid] = { + "community_id": cid, + "requests": r["requests"], + "error_rate": round(r["errors"] / r["requests"], 4) if r["requests"] > 0 else 0.0, + } + + # Include registered communities with zero requests + if registered_communities: + for cid in registered_communities: + if cid not in community_data: + community_data[cid] = { + "community_id": cid, + "requests": 0, + "error_rate": 0.0, + } + + # Sort by requests descending, then alphabetically + communities_list = sorted( + community_data.values(), + key=lambda c: (-c["requests"], c["community_id"]), + ) + + return { + "total_requests": total_req, + "error_rate": round(totals["total_errors"] / total_req, 4) if total_req > 0 else 0.0, + "communities_active": sum(1 for c in communities_list if c["requests"] > 0), + "communities": communities_list, + } + + +def get_public_community_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: + """Get public summary for a single community. + + Returns request counts and top tools; no tokens, costs, or model info. + + Args: + community_id: The community identifier. + conn: SQLite connection. + + Returns: + Dict with community_id, total_requests, error_rate, top_tools. + """ + row = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_count + FROM request_log + WHERE community_id = ? + """, + (community_id,), + ).fetchone() + + total = row["total_requests"] + error_rate = row["error_count"] / total if total > 0 else 0.0 + + return { + "community_id": community_id, + "total_requests": total, + "error_rate": round(error_rate, 4), + "top_tools": _count_tools(community_id, conn), + } + + +def get_public_usage_stats( + community_id: str, + period: str, + conn: sqlite3.Connection, +) -> dict[str, Any]: + """Get public time-bucketed usage statistics. + + Returns request counts and errors per bucket; no tokens or costs. + + Args: + community_id: The community identifier. + period: One of "daily", "weekly", "monthly". + conn: SQLite connection. + + Returns: + Dict with period, community_id, and buckets list. + """ + fmt = _validate_period(period) + + # Safe to use f-string: fmt is from _PERIOD_FORMAT_MAP whitelist, not user input + rows = conn.execute( + f""" + SELECT + strftime('{fmt}', timestamp) as bucket, + COUNT(*) as requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors + FROM request_log + WHERE community_id = ? + GROUP BY bucket + ORDER BY bucket + """, + (community_id,), + ).fetchall() + + return { + "community_id": community_id, + "period": period, + "buckets": [ + { + "bucket": r["bucket"], + "requests": r["requests"], + "errors": r["errors"], + } + for r in rows + ], + } + + +# --------------------------------------------------------------------------- +# Quality metrics queries +# --------------------------------------------------------------------------- + + +def get_quality_metrics( + community_id: str, + conn: sqlite3.Connection, + period: str = "daily", +) -> dict[str, Any]: + """Get quality metrics for a community, bucketed by time period. + + Returns error rates, average tool call counts, and latency percentiles + (p50, p95) per time bucket. + + Args: + community_id: The community identifier. + conn: SQLite connection. + period: One of "daily", "weekly", "monthly". + + Returns: + Dict with period, community_id, and buckets with quality data. + """ + fmt = _validate_period(period) + + rows = conn.execute( + f""" + SELECT + strftime('{fmt}', timestamp) as bucket, + COUNT(*) as requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as errors, + COALESCE(AVG(tool_call_count), 0) as avg_tool_calls, + COUNT(CASE WHEN error_message IS NOT NULL THEN 1 END) as agent_errors, + COUNT(CASE WHEN langfuse_trace_id IS NOT NULL THEN 1 END) as traced_requests + FROM request_log + WHERE community_id = ? + GROUP BY bucket + ORDER BY bucket + """, + (community_id,), + ).fetchall() + + # Compute latency percentiles per bucket + buckets = [] + for r in rows: + bucket_name = r["bucket"] + total = r["requests"] + error_rate = r["errors"] / total if total > 0 else 0.0 + + # Latency percentiles for this bucket + p50, p95 = _fetch_latency_percentiles( + conn, + community_id, + extra_where=f"AND strftime('{fmt}', timestamp) = ?", + extra_params=(bucket_name,), + ) + + buckets.append( + { + "bucket": bucket_name, + "requests": total, + "error_rate": round(error_rate, 4), + "avg_tool_calls": round(r["avg_tool_calls"], 2), + "agent_errors": r["agent_errors"], + "traced_requests": r["traced_requests"], + "p50_duration_ms": round(p50, 1) if p50 is not None else None, + "p95_duration_ms": round(p95, 1) if p95 is not None else None, + } + ) + + return { + "community_id": community_id, + "period": period, + "buckets": buckets, + } + + +def get_quality_summary(community_id: str, conn: sqlite3.Connection) -> dict[str, Any]: + """Get overall quality overview for a community. + + Args: + community_id: The community identifier. + conn: SQLite connection. + + Returns: + Dict with overall quality stats: error_rate, avg_tool_calls, + agent_errors, p50/p95 latency, traced percentage. + """ + row = conn.execute( + """ + SELECT + COUNT(*) as total_requests, + COUNT(CASE WHEN status_code >= 400 THEN 1 END) as error_count, + COALESCE(AVG(tool_call_count), 0) as avg_tool_calls, + COUNT(CASE WHEN error_message IS NOT NULL THEN 1 END) as agent_errors, + COUNT(CASE WHEN langfuse_trace_id IS NOT NULL THEN 1 END) as traced + FROM request_log + WHERE community_id = ? + """, + (community_id,), + ).fetchone() + + total = row["total_requests"] + error_rate = row["error_count"] / total if total > 0 else 0.0 + traced_pct = row["traced"] / total if total > 0 else 0.0 + + # Latency percentiles + p50, p95 = _fetch_latency_percentiles(conn, community_id) + + return { + "community_id": community_id, + "total_requests": total, + "error_rate": round(error_rate, 4), + "avg_tool_calls": round(row["avg_tool_calls"], 2), + "agent_errors": row["agent_errors"], + "traced_pct": round(traced_pct, 4), + "p50_duration_ms": round(p50, 1) if p50 is not None else None, + "p95_duration_ms": round(p95, 1) if p95 is not None else None, + } + + +def _percentile(sorted_values: list[float], pct: float) -> float | None: + """Compute percentile from a sorted list of values. + + Uses floor-rank method: index = floor(pct * len), clamped to valid range. + + Args: + sorted_values: Pre-sorted list of numeric values. + pct: Percentile as fraction (e.g., 0.5 for p50, 0.95 for p95). + + Returns: + The percentile value, or None if list is empty. + """ + if not sorted_values: + return None + idx = int(pct * len(sorted_values)) + idx = min(idx, len(sorted_values) - 1) + return sorted_values[idx] + + +def _fetch_latency_percentiles( + conn: sqlite3.Connection, + community_id: str, + extra_where: str = "", + extra_params: tuple[str, ...] = (), +) -> tuple[float | None, float | None]: + """Fetch p50 and p95 latency for a community with optional filters. + + Args: + conn: SQLite connection with row_factory set. + community_id: The community identifier. + extra_where: Additional SQL WHERE clause (e.g., "AND strftime(...) = ?"). + extra_params: Parameters for the extra WHERE clause. + + Returns: + Tuple of (p50, p95) duration in ms, or None if no data. + """ + durations = conn.execute( + f""" + SELECT duration_ms + FROM request_log + WHERE community_id = ? AND duration_ms IS NOT NULL {extra_where} + ORDER BY duration_ms + """, + (community_id, *extra_params), + ).fetchall() + values = [d["duration_ms"] for d in durations] + return _percentile(values, 0.5), _percentile(values, 0.95) diff --git a/src/version.py b/src/version.py index 71f5185..21df2d0 100644 --- a/src/version.py +++ b/src/version.py @@ -1,7 +1,7 @@ """Version information for OSA.""" -__version__ = "0.5.5" -__version_info__ = (0, 5, 5) +__version__ = "0.6.0.dev0" +__version_info__ = (0, 6, 0, "dev") def get_version() -> str: diff --git a/tests/test_api/test_dashboard.py b/tests/test_api/test_dashboard.py new file mode 100644 index 0000000..cc14100 --- /dev/null +++ b/tests/test_api/test_dashboard.py @@ -0,0 +1,92 @@ +"""Tests for the dashboard static HTML page. + +The dashboard is a standalone static site in dashboard/osa/index.html, +deployed separately to Cloudflare Pages. These tests verify the HTML +contains the expected structure and API references. +""" + +from pathlib import Path + +DASHBOARD_HTML_PATH = Path(__file__).parent.parent.parent / "dashboard" / "osa" / "index.html" + + +class TestDashboardHTML: + """Tests for dashboard/osa/index.html static file.""" + + def test_file_exists(self) -> None: + assert DASHBOARD_HTML_PATH.exists(), "dashboard/osa/index.html must exist" + + def test_is_valid_html(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "" in content + assert "" in content + + def test_contains_page_title(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "Open Science Assistant" in content + + def test_contains_chart_js_cdn(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "chart.js" in content + + def test_references_public_overview_api(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "/metrics/public/overview" in content + + def test_references_community_public_metrics_api(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + # Should use /{community}/metrics/public pattern + assert "/metrics/public" in content + assert "/metrics/public/usage" in content + + def test_has_client_side_routing(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "getRoute" in content + assert "window.location.pathname" in content + + def test_has_aggregate_view(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "renderAggregateView" in content + assert "Questions Answered" in content + + def test_has_community_view(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "loadCommunityView" in content + + def test_has_tab_bar(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "tabBar" in content + assert "tab-link" in content + assert "renderTabs" in content + + def test_has_admin_key_input(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "adminKeyInput" in content + assert "Admin Access" in content + + def test_admin_section_hidden_by_default(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "admin-section" in content + assert "display: none" in content or "display:none" in content + + def test_has_period_toggle(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "changePeriod" in content + assert "daily" in content + assert "weekly" in content + assert "monthly" in content + + def test_api_base_configurable(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + # Should support ?api= query param or window.OSA_API_BASE override + assert "OSA_API_BASE" in content + + def test_has_base_path_constant(self) -> None: + content = DASHBOARD_HTML_PATH.read_text() + assert "BASE_PATH" in content + assert "const BASE_PATH = '/osa'" in content + + def test_cloudflare_redirects_file_exists(self) -> None: + redirects_path = DASHBOARD_HTML_PATH.parent.parent / "_redirects" + assert redirects_path.exists(), "_redirects needed for Cloudflare Pages SPA routing" diff --git a/tests/test_api/test_metrics_endpoints.py b/tests/test_api/test_metrics_endpoints.py new file mode 100644 index 0000000..9dffc60 --- /dev/null +++ b/tests/test_api/test_metrics_endpoints.py @@ -0,0 +1,420 @@ +"""Tests for metrics API endpoints.""" + +import os +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from src.api.main import app +from src.metrics.db import ( + RequestLogEntry, + init_metrics_db, + log_request, +) + +ADMIN_KEY = "test-metrics-admin-key" +COMMUNITY_KEY = "hed-community-key" + + +@pytest.fixture +def metrics_db(tmp_path): + """Create isolated metrics database with sample data.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="r1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + estimated_cost=0.001, + tools_called=["search_docs"], + key_source="platform", + tool_call_count=1, + langfuse_trace_id="trace-001", + ), + RequestLogEntry( + request_id="r2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/chat", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=200, + output_tokens=100, + total_tokens=300, + key_source="byok", + tool_call_count=0, + ), + RequestLogEntry( + request_id="r3", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=500.0, + status_code=500, + model="qwen/qwen3-235b", + error_message="LLM timeout", + tool_call_count=0, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +@pytest.fixture +def isolated_metrics(metrics_db): + """Patch metrics DB path for all metrics code.""" + with patch("src.metrics.db.get_metrics_db_path", return_value=metrics_db): + yield metrics_db + + +@pytest.fixture +def auth_env(): + """Set up environment for admin auth and clear settings cache.""" + from src.api.config import get_settings + + os.environ["API_KEYS"] = ADMIN_KEY + os.environ["REQUIRE_API_AUTH"] = "true" + get_settings.cache_clear() + + yield + + del os.environ["API_KEYS"] + del os.environ["REQUIRE_API_AUTH"] + get_settings.cache_clear() + + +@pytest.fixture +def scoped_auth_env(): + """Set up environment with both admin and community keys.""" + from src.api.config import get_settings + + os.environ["API_KEYS"] = ADMIN_KEY + os.environ["REQUIRE_API_AUTH"] = "true" + os.environ["COMMUNITY_ADMIN_KEYS"] = f"hed:{COMMUNITY_KEY}" + get_settings.cache_clear() + + yield + + del os.environ["API_KEYS"] + del os.environ["REQUIRE_API_AUTH"] + del os.environ["COMMUNITY_ADMIN_KEYS"] + get_settings.cache_clear() + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +class TestMetricsOverview: + """Tests for GET /metrics/overview.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200_with_admin_key(self, client): + response = client.get("/metrics/overview", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/metrics/overview", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert "total_requests" in data + assert "total_tokens" in data + assert "avg_duration_ms" in data + assert "error_rate" in data + assert "communities" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_401_without_key(self, client): + response = client.get("/metrics/overview") + assert response.status_code == 401 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_403_with_invalid_key(self, client): + response = client.get("/metrics/overview", headers={"X-API-Key": "wrong-key"}) + assert response.status_code == 403 + + +class TestTokenBreakdown: + """Tests for GET /metrics/tokens.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/metrics/tokens", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/metrics/tokens", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert "by_model" in data + assert "by_key_source" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_filter_by_community(self, client): + response = client.get( + "/metrics/tokens", + params={"community_id": "hed"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + data = response.json() + assert data["community_id"] == "hed" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_admin_auth(self, client): + response = client.get("/metrics/tokens") + assert response.status_code == 401 + + +class TestCommunityMetrics: + """Tests for GET /{community_id}/metrics.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/hed/metrics", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["community_id"] == "hed" + assert "total_requests" in data + assert "total_tokens" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_admin_auth(self, client): + response = client.get("/hed/metrics") + assert response.status_code == 401 + + +class TestCommunityUsage: + """Tests for GET /{community_id}/metrics/usage.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_daily_usage(self, client): + response = client.get( + "/hed/metrics/usage", + params={"period": "daily"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 200 + data = response.json() + assert data["period"] == "daily" + assert "buckets" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_monthly_usage(self, client): + response = client.get( + "/hed/metrics/usage", + params={"period": "monthly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_invalid_period_returns_422(self, client): + """Invalid period rejected by Query pattern validation.""" + response = client.get( + "/hed/metrics/usage", + params={"period": "hourly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 422 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_admin_auth(self, client): + response = client.get("/hed/metrics/usage", params={"period": "daily"}) + assert response.status_code == 401 + + +class TestCommunityQuality: + """Tests for GET /{community_id}/metrics/quality.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/hed/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["community_id"] == "hed" + assert data["period"] == "daily" + assert "buckets" in data + assert len(data["buckets"]) > 0 + bucket = data["buckets"][0] + assert "requests" in bucket + assert "error_rate" in bucket + assert "avg_tool_calls" in bucket + assert "agent_errors" in bucket + assert "traced_requests" in bucket + assert "p50_duration_ms" in bucket + assert "p95_duration_ms" in bucket + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_weekly_period(self, client): + response = client.get( + "/hed/metrics/quality", + params={"period": "weekly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 200 + assert response.json()["period"] == "weekly" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_invalid_period_returns_422(self, client): + response = client.get( + "/hed/metrics/quality", + params={"period": "hourly"}, + headers={"X-API-Key": ADMIN_KEY}, + ) + assert response.status_code == 422 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_auth(self, client): + response = client.get("/hed/metrics/quality") + assert response.status_code == 401 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_quality_data_reflects_errors(self, client): + """Verify error data from fixture is reflected in quality metrics.""" + response = client.get("/hed/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + # All 3 entries are on 2025-01-15, so one daily bucket + assert len(data["buckets"]) == 1 + bucket = data["buckets"][0] + assert bucket["requests"] == 3 + # 1 out of 3 requests has status_code >= 400 + assert bucket["error_rate"] > 0 + assert bucket["agent_errors"] == 1 + assert bucket["traced_requests"] == 1 + + +class TestCommunityQualitySummary: + """Tests for GET /{community_id}/metrics/quality/summary.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200(self, client): + response = client.get("/hed/metrics/quality/summary", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics/quality/summary", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["community_id"] == "hed" + assert "total_requests" in data + assert "error_rate" in data + assert "avg_tool_calls" in data + assert "agent_errors" in data + assert "traced_pct" in data + assert "p50_duration_ms" in data + assert "p95_duration_ms" in data + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_summary_values(self, client): + """Verify summary aggregates match fixture data.""" + response = client.get("/hed/metrics/quality/summary", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert data["total_requests"] == 3 + assert data["agent_errors"] == 1 + # 1 traced out of 3 + assert 0 < data["traced_pct"] < 1 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_auth(self, client): + response = client.get("/hed/metrics/quality/summary") + assert response.status_code == 401 + + +class TestGlobalQuality: + """Tests for GET /metrics/quality.""" + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_200_for_admin(self, client): + response = client.get("/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_admin_sees_communities_list(self, client): + response = client.get("/metrics/quality", headers={"X-API-Key": ADMIN_KEY}) + data = response.json() + assert "communities" in data + assert len(data["communities"]) > 0 + community = data["communities"][0] + assert community["community_id"] == "hed" + assert "error_rate" in community + assert "avg_tool_calls" in community + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_requires_auth(self, client): + response = client.get("/metrics/quality") + assert response.status_code == 401 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_returns_403_with_invalid_key(self, client): + response = client.get("/metrics/quality", headers={"X-API-Key": "wrong-key"}) + assert response.status_code == 403 + + +class TestScopedKeyOnGlobalEndpoints: + """Tests for community-scoped keys on global metrics endpoints.""" + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_community_key_on_overview_returns_scoped_data(self, client): + """Community key on /metrics/overview should return only their community summary.""" + response = client.get("/metrics/overview", headers={"X-API-Key": COMMUNITY_KEY}) + assert response.status_code == 200 + data = response.json() + # Scoped key gets community summary, not the full overview with communities list + assert data["community_id"] == "hed" + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_community_key_on_tokens_is_auto_scoped(self, client): + """Community key on /metrics/tokens should be auto-scoped to own community.""" + response = client.get( + "/metrics/tokens", + params={"community_id": "eeglab"}, # Try to view another community + headers={"X-API-Key": COMMUNITY_KEY}, + ) + assert response.status_code == 200 + data = response.json() + # Should be forced to own community, ignoring the query param + assert data["community_id"] == "hed" + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_community_key_on_quality_returns_own_summary(self, client): + """Community key on /metrics/quality should return only own community quality.""" + response = client.get("/metrics/quality", headers={"X-API-Key": COMMUNITY_KEY}) + assert response.status_code == 200 + data = response.json() + # Scoped key gets single community summary, not communities list + assert data["community_id"] == "hed" + assert "error_rate" in data + + @pytest.mark.usefixtures("isolated_metrics", "scoped_auth_env") + def test_admin_key_still_sees_all(self, client): + """Admin key should still see full data on all global endpoints.""" + response = client.get("/metrics/overview", headers={"X-API-Key": ADMIN_KEY}) + assert response.status_code == 200 + data = response.json() + assert "communities" in data diff --git a/tests/test_api/test_metrics_public.py b/tests/test_api/test_metrics_public.py new file mode 100644 index 0000000..70066a6 --- /dev/null +++ b/tests/test_api/test_metrics_public.py @@ -0,0 +1,398 @@ +"""Tests for public metrics API endpoints. + +Global overview: GET /metrics/public/overview (no auth) +Per-community: GET /{community_id}/metrics/public (no auth) +Per-community usage: GET /{community_id}/metrics/public/usage (no auth) +""" + +import os +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from src.api.main import app +from src.metrics.db import ( + RequestLogEntry, + init_metrics_db, + log_request, +) + + +@pytest.fixture +def metrics_db(tmp_path): + """Create isolated metrics database with sample data.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="r1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + estimated_cost=0.001, + tools_called=["search_docs", "validate_hed"], + key_source="platform", + ), + RequestLogEntry( + request_id="r2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/chat", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=200, + output_tokens=100, + total_tokens=300, + tools_called=["search_docs"], + key_source="byok", + ), + RequestLogEntry( + request_id="r3", + timestamp="2025-01-16T09:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=100.0, + status_code=500, + ), + RequestLogEntry( + request_id="r4", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/bids/ask", + method="POST", + community_id="bids", + duration_ms=250.0, + status_code=200, + model="anthropic/claude-sonnet", + input_tokens=150, + output_tokens=75, + total_tokens=225, + key_source="platform", + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +@pytest.fixture +def isolated_metrics(metrics_db): + """Patch metrics DB path for all metrics code.""" + with patch("src.metrics.db.get_metrics_db_path", return_value=metrics_db): + yield metrics_db + + +@pytest.fixture +def noauth_env(): + """Disable auth requirement.""" + from src.api.config import get_settings + + os.environ["REQUIRE_API_AUTH"] = "false" + get_settings.cache_clear() + yield + del os.environ["REQUIRE_API_AUTH"] + get_settings.cache_clear() + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app) + + +class TestPublicOverview: + """Tests for GET /metrics/public/overview.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_returns_200_without_auth(self, client): + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_response_structure(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + assert "total_requests" in data + assert "error_rate" in data + assert "communities_active" in data + assert "communities" in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_no_sensitive_fields(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + assert "total_tokens" not in data + assert "total_estimated_cost" not in data + assert "avg_duration_ms" not in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_request_counts_correct(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + assert data["total_requests"] == 4 + assert data["communities_active"] == 2 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_error_rate_includes_errors(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + # 1 error out of 4 requests = 0.25 + assert data["error_rate"] == 0.25 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_community_breakdown_no_sensitive_fields(self, client): + response = client.get("/metrics/public/overview") + data = response.json() + for community in data["communities"]: + assert "community_id" in community + assert "requests" in community + assert "error_rate" in community + assert "tokens" not in community + assert "estimated_cost" not in community + + +class TestCommunityPublicMetrics: + """Tests for GET /{community_id}/metrics/public.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_returns_200_without_auth(self, client): + response = client.get("/hed/metrics/public") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_response_structure(self, client): + response = client.get("/hed/metrics/public") + data = response.json() + assert data["community_id"] == "hed" + assert "total_requests" in data + assert "error_rate" in data + assert "top_tools" in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_no_sensitive_fields(self, client): + response = client.get("/hed/metrics/public") + data = response.json() + assert "total_tokens" not in data + assert "total_estimated_cost" not in data + assert "top_models" not in data + assert "avg_duration_ms" not in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_top_tools_populated(self, client): + response = client.get("/hed/metrics/public") + data = response.json() + tools = {t["tool"]: t["count"] for t in data["top_tools"]} + assert "search_docs" in tools + assert tools["search_docs"] == 2 + + +class TestCommunityPublicUsage: + """Tests for GET /{community_id}/metrics/public/usage.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_daily_usage_returns_200(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_daily_usage_structure(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + data = response.json() + assert data["community_id"] == "hed" + assert data["period"] == "daily" + assert "buckets" in data + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_buckets_no_sensitive_fields(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + data = response.json() + for bucket in data["buckets"]: + assert "bucket" in bucket + assert "requests" in bucket + assert "errors" in bucket + assert "tokens" not in bucket + assert "estimated_cost" not in bucket + assert "avg_duration_ms" not in bucket + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_monthly_usage(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "monthly"}) + assert response.status_code == 200 + data = response.json() + assert data["period"] == "monthly" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_weekly_usage(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "weekly"}) + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_invalid_period_returns_422(self, client): + response = client.get("/hed/metrics/public/usage", params={"period": "hourly"}) + assert response.status_code == 422 + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_default_period_is_daily(self, client): + response = client.get("/hed/metrics/public/usage") + assert response.status_code == 200 + data = response.json() + assert data["period"] == "daily" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_daily_buckets_count_and_errors(self, client): + """Verify bucket count and error values match fixture data.""" + response = client.get("/hed/metrics/public/usage", params={"period": "daily"}) + data = response.json() + buckets = data["buckets"] + # Fixture has HED requests on 2025-01-15 and 2025-01-16 + assert len(buckets) == 2 + bucket_map = {b["bucket"]: b for b in buckets} + # 2025-01-16 has one request with status_code=500 + assert bucket_map["2025-01-16"]["errors"] == 1 + + +class TestPublicAdminBoundary: + """Verify public endpoints work without auth while admin endpoints require it.""" + + @pytest.fixture + def auth_env(self): + """Enable auth with a test API key so admin endpoints reject anonymous requests.""" + from src.api.config import get_settings + + os.environ["REQUIRE_API_AUTH"] = "true" + os.environ["API_KEYS"] = "test-secret-key" + get_settings.cache_clear() + yield + del os.environ["REQUIRE_API_AUTH"] + del os.environ["API_KEYS"] + get_settings.cache_clear() + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_public_overview_no_auth_200(self, client): + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_admin_overview_no_auth_rejected(self, client): + """Admin endpoint must reject unauthenticated requests.""" + response = client.get("/metrics/overview") + assert response.status_code in (401, 403) + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_admin_tokens_no_auth_rejected(self, client): + """Admin token endpoint must reject unauthenticated requests.""" + response = client.get("/metrics/tokens") + assert response.status_code in (401, 403) + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_public_overview_accessible_with_auth_enabled(self, client): + """Public overview must return 200 even when auth is required.""" + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_community_public_metrics_accessible_with_auth_enabled(self, client): + """Per-community public metrics must return 200 even when auth is required.""" + response = client.get("/hed/metrics/public") + assert response.status_code == 200 + + @pytest.mark.usefixtures("isolated_metrics", "auth_env") + def test_community_public_usage_accessible_with_auth_enabled(self, client): + """Per-community public usage must return 200 even when auth is required.""" + response = client.get("/hed/metrics/public/usage") + assert response.status_code == 200 + + +class TestEmptyDatabase: + """Verify public endpoints handle empty databases gracefully.""" + + @pytest.fixture + def empty_metrics_db(self, tmp_path): + db_path = tmp_path / "empty_metrics.db" + init_metrics_db(db_path) + return db_path + + @pytest.fixture + def isolated_empty_metrics(self, empty_metrics_db): + with patch("src.metrics.db.get_metrics_db_path", return_value=empty_metrics_db): + yield + + @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") + def test_overview_empty_db(self, client): + response = client.get("/metrics/public/overview") + assert response.status_code == 200 + data = response.json() + assert data["total_requests"] == 0 + assert data["error_rate"] == 0.0 + assert data["communities_active"] == 0 + # Registered communities still appear with zero requests + for c in data["communities"]: + assert c["requests"] == 0 + assert c["error_rate"] == 0.0 + + @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") + def test_community_metrics_empty_db(self, client): + response = client.get("/hed/metrics/public") + assert response.status_code == 200 + data = response.json() + assert data["total_requests"] == 0 + assert data["error_rate"] == 0.0 + assert data["top_tools"] == [] + + @pytest.mark.usefixtures("isolated_empty_metrics", "noauth_env") + def test_community_usage_empty_db(self, client): + response = client.get("/hed/metrics/public/usage") + assert response.status_code == 200 + data = response.json() + assert data["buckets"] == [] + + +class TestCommunityMetricsValues: + """Verify computed values per community match fixture data dynamically.""" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_community_values_from_overview(self, client): + """Check each community's request count and error rate from overview.""" + response = client.get("/metrics/public/overview") + data = response.json() + checked = 0 + for community in data["communities"]: + cid = community["community_id"] + resp = client.get(f"/{cid}/metrics/public") + if resp.status_code != 200: + continue # community route not registered in test app + detail = resp.json() + assert detail["total_requests"] == community["requests"] + assert detail["error_rate"] == community["error_rate"] + if community["requests"] > 0: + checked += 1 + assert checked > 0, "Expected at least one community with requests" + + @pytest.mark.usefixtures("isolated_metrics", "noauth_env") + def test_per_community_tool_counts_consistent(self, client): + """Each tool count should be a positive integer.""" + response = client.get("/metrics/public/overview") + checked = 0 + for community in response.json()["communities"]: + cid = community["community_id"] + resp = client.get(f"/{cid}/metrics/public") + if resp.status_code != 200: + continue # community route not registered in test app + detail = resp.json() + for tool_entry in detail["top_tools"]: + assert isinstance(tool_entry["tool"], str) + assert tool_entry["count"] > 0 + checked += 1 + assert checked > 0, "Expected at least one community with a registered route" diff --git a/tests/test_api/test_scoped_auth.py b/tests/test_api/test_scoped_auth.py new file mode 100644 index 0000000..e1c7ce7 --- /dev/null +++ b/tests/test_api/test_scoped_auth.py @@ -0,0 +1,200 @@ +"""Tests for per-community scoped authentication. + +Tests AuthScope, verify_scoped_admin_key, and per-community metrics access control. +Uses real HTTP requests against FastAPI test apps. +""" + +import os + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api.config import Settings, get_settings +from src.api.security import AuthScope, RequireScopedAuth + + +@pytest.fixture +def app_scoped_auth() -> FastAPI: + """Create a test app with scoped auth endpoints.""" + os.environ["API_KEYS"] = "global-admin-key" + os.environ["REQUIRE_API_AUTH"] = "true" + os.environ["COMMUNITY_ADMIN_KEYS"] = "hed:hed-key-1,eeglab:eeglab-key-1,hed:hed-key-2" + + get_settings.cache_clear() + + app = FastAPI() + + @app.get("/scoped") + async def scoped_route(auth: RequireScopedAuth) -> dict: + return { + "role": auth.role, + "community_id": auth.community_id, + } + + @app.get("/metrics/{community_id}") + async def community_metrics(community_id: str, auth: RequireScopedAuth) -> dict: + if not auth.can_access_community(community_id): + from fastapi import HTTPException + + raise HTTPException(status_code=403, detail="Access denied") + return {"community_id": community_id, "role": auth.role} + + yield app + + for key in ["API_KEYS", "REQUIRE_API_AUTH", "COMMUNITY_ADMIN_KEYS"]: + os.environ.pop(key, None) + get_settings.cache_clear() + + +@pytest.fixture +def client(app_scoped_auth: FastAPI) -> TestClient: + return TestClient(app_scoped_auth) + + +class TestAuthScope: + """Tests for AuthScope dataclass.""" + + def test_admin_can_access_any_community(self): + scope = AuthScope(role="admin") + assert scope.can_access_community("hed") is True + assert scope.can_access_community("eeglab") is True + assert scope.can_access_community("anything") is True + + def test_community_scope_can_access_own(self): + scope = AuthScope(role="community", community_id="hed") + assert scope.can_access_community("hed") is True + + def test_community_scope_cannot_access_other(self): + scope = AuthScope(role="community", community_id="hed") + assert scope.can_access_community("eeglab") is False + + def test_community_role_requires_community_id(self): + with pytest.raises(ValueError, match="community role requires a community_id"): + AuthScope(role="community") + + def test_admin_role_rejects_community_id(self): + with pytest.raises(ValueError, match="admin role must not have a community_id"): + AuthScope(role="admin", community_id="hed") + + +class TestVerifyScopedAdminKey: + """Tests for verify_scoped_admin_key dependency.""" + + def test_global_admin_key_returns_admin_role(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "global-admin-key"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "admin" + assert data["community_id"] is None + + def test_community_key_returns_community_role(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "hed-key-1"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "community" + assert data["community_id"] == "hed" + + def test_second_community_key_works(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "hed-key-2"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "community" + assert data["community_id"] == "hed" + + def test_eeglab_community_key(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "eeglab-key-1"}) + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "community" + assert data["community_id"] == "eeglab" + + def test_no_key_returns_401(self, client: TestClient): + resp = client.get("/scoped") + assert resp.status_code == 401 + + def test_invalid_key_returns_403(self, client: TestClient): + resp = client.get("/scoped", headers={"X-API-Key": "wrong-key"}) + assert resp.status_code == 403 + + +class TestScopedCommunityAccess: + """Tests for per-community access control via scoped auth.""" + + def test_admin_can_access_any_community_metrics(self, client: TestClient): + resp = client.get("/metrics/hed", headers={"X-API-Key": "global-admin-key"}) + assert resp.status_code == 200 + assert resp.json()["community_id"] == "hed" + + resp = client.get("/metrics/eeglab", headers={"X-API-Key": "global-admin-key"}) + assert resp.status_code == 200 + assert resp.json()["community_id"] == "eeglab" + + def test_community_key_can_access_own_metrics(self, client: TestClient): + resp = client.get("/metrics/hed", headers={"X-API-Key": "hed-key-1"}) + assert resp.status_code == 200 + assert resp.json()["community_id"] == "hed" + + def test_community_key_cannot_access_other_metrics(self, client: TestClient): + resp = client.get("/metrics/eeglab", headers={"X-API-Key": "hed-key-1"}) + assert resp.status_code == 403 + + +class TestScopedAuthDisabled: + """Tests for scoped auth when authentication is disabled.""" + + def test_no_auth_required_returns_admin_scope(self): + os.environ["REQUIRE_API_AUTH"] = "false" + os.environ.pop("API_KEYS", None) + os.environ.pop("COMMUNITY_ADMIN_KEYS", None) + get_settings.cache_clear() + + app = FastAPI() + + @app.get("/scoped") + async def scoped_route(auth: RequireScopedAuth) -> dict: + return {"role": auth.role, "community_id": auth.community_id} + + client = TestClient(app) + resp = client.get("/scoped") + assert resp.status_code == 200 + data = resp.json() + assert data["role"] == "admin" + assert data["community_id"] is None + + os.environ.pop("REQUIRE_API_AUTH", None) + get_settings.cache_clear() + + +class TestParseAdminKeys: + """Tests for Settings.parse_community_admin_keys().""" + + def test_parse_single_key(self): + s = Settings(community_admin_keys="hed:abc123") + result = s.parse_community_admin_keys() + assert result == {"hed": {"abc123"}} + + def test_parse_multiple_communities(self): + s = Settings(community_admin_keys="hed:key1,eeglab:key2") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1"}, "eeglab": {"key2"}} + + def test_parse_multiple_keys_same_community(self): + s = Settings(community_admin_keys="hed:key1,hed:key2") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1", "key2"}} + + def test_parse_empty(self): + s = Settings(community_admin_keys=None) + result = s.parse_community_admin_keys() + assert result == {} + + def test_parse_with_spaces(self): + s = Settings(community_admin_keys=" hed : key1 , eeglab : key2 ") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1"}, "eeglab": {"key2"}} + + def test_parse_ignores_malformed_entries(self): + s = Settings(community_admin_keys="hed:key1,badentry,eeglab:key2") + result = s.parse_community_admin_keys() + assert result == {"hed": {"key1"}, "eeglab": {"key2"}} diff --git a/tests/test_assistants/test_community_yaml_generic.py b/tests/test_assistants/test_community_yaml_generic.py index 1bc0d54..c796774 100644 --- a/tests/test_assistants/test_community_yaml_generic.py +++ b/tests/test_assistants/test_community_yaml_generic.py @@ -96,6 +96,7 @@ def test_documentation_urls_valid_format(self, community_id): ) @pytest.mark.slow + @pytest.mark.skip(reason="Disabled: upstream HED URL broken (404). See #139") def test_documentation_urls_accessible(self, community_id): """All documentation source URLs should return HTTP 200. diff --git a/tests/test_core/test_config/test_community.py b/tests/test_core/test_config/test_community.py index 4659982..27dcbdb 100644 --- a/tests/test_core/test_config/test_community.py +++ b/tests/test_core/test_config/test_community.py @@ -13,6 +13,7 @@ from pydantic import ValidationError from src.core.config.community import ( + BudgetConfig, CitationConfig, CommunitiesConfig, CommunityConfig, @@ -320,6 +321,116 @@ def test_allows_unique_extensions(self) -> None: assert len(config.mcp_servers) == 2 +class TestBudgetConfig: + """Tests for BudgetConfig model.""" + + def test_valid_budget_config(self) -> None: + """Should create BudgetConfig with valid inputs.""" + config = BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert config.daily_limit_usd == 5.0 + assert config.monthly_limit_usd == 50.0 + assert config.alert_threshold_pct == 80.0 + + def test_default_alert_threshold(self) -> None: + """Should default alert_threshold_pct to 80.0.""" + config = BudgetConfig(daily_limit_usd=5.0, monthly_limit_usd=50.0) + assert config.alert_threshold_pct == 80.0 + + def test_rejects_zero_daily_limit(self) -> None: + """Should reject zero daily limit.""" + with pytest.raises(ValidationError): + BudgetConfig(daily_limit_usd=0.0, monthly_limit_usd=50.0) + + def test_rejects_negative_daily_limit(self) -> None: + """Should reject negative daily limit.""" + with pytest.raises(ValidationError): + BudgetConfig(daily_limit_usd=-1.0, monthly_limit_usd=50.0) + + def test_rejects_zero_monthly_limit(self) -> None: + """Should reject zero monthly limit.""" + with pytest.raises(ValidationError): + BudgetConfig(daily_limit_usd=5.0, monthly_limit_usd=0.0) + + def test_rejects_negative_threshold(self) -> None: + """Should reject negative alert threshold.""" + with pytest.raises(ValidationError): + BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=-1.0, + ) + + def test_rejects_threshold_over_100(self) -> None: + """Should reject alert threshold over 100.""" + with pytest.raises(ValidationError): + BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=101.0, + ) + + def test_rejects_extra_fields(self) -> None: + """Should reject extra fields (strict schema).""" + with pytest.raises(ValidationError): + BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + unknown_field="value", # type: ignore + ) + + def test_rejects_daily_exceeding_monthly(self) -> None: + """Should reject daily limit greater than monthly limit.""" + with pytest.raises(ValidationError, match="cannot exceed"): + BudgetConfig(daily_limit_usd=100.0, monthly_limit_usd=50.0) + + def test_accepts_equal_daily_and_monthly(self) -> None: + """Should accept daily limit equal to monthly limit.""" + config = BudgetConfig(daily_limit_usd=50.0, monthly_limit_usd=50.0) + assert config.daily_limit_usd == config.monthly_limit_usd + + +class TestCommunityConfigBudget: + """Tests for CommunityConfig.budget field.""" + + def test_budget_none_by_default(self) -> None: + """Should default budget to None.""" + config = CommunityConfig(id="test", name="Test", description="Test") + assert config.budget is None + + def test_budget_config_set(self) -> None: + """Should accept valid budget config.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + budget=BudgetConfig( + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + ), + ) + assert config.budget is not None + assert config.budget.daily_limit_usd == 5.0 + + def test_budget_from_yaml_dict(self) -> None: + """Should parse budget from YAML-like dict.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + budget={ + "daily_limit_usd": 10.0, + "monthly_limit_usd": 100.0, + "alert_threshold_pct": 90.0, + }, + ) + assert config.budget.daily_limit_usd == 10.0 + assert config.budget.alert_threshold_pct == 90.0 + + class TestCommunityConfig: """Tests for CommunityConfig model.""" @@ -584,6 +695,85 @@ def test_accepts_numeric_only_domain(self) -> None: assert len(config.cors_origins) == 1 +class TestCommunityConfigMaintainers: + """Tests for CommunityConfig.maintainers validation.""" + + def test_valid_maintainers(self) -> None: + """Should accept valid GitHub usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["octocat", "jane-doe", "user123"], + ) + assert config.maintainers == ["octocat", "jane-doe", "user123"] + + def test_defaults_to_empty(self) -> None: + """Should default to empty list.""" + config = CommunityConfig(id="test", name="Test", description="Test") + assert config.maintainers == [] + + def test_rejects_invalid_username_with_special_chars(self) -> None: + """Should reject usernames with special characters.""" + with pytest.raises(ValidationError, match="Invalid GitHub username"): + CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["bad@user"], + ) + + def test_rejects_username_starting_with_hyphen(self) -> None: + """Should reject usernames starting with hyphen.""" + with pytest.raises(ValidationError, match="Invalid GitHub username"): + CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["-badstart"], + ) + + def test_rejects_username_ending_with_hyphen(self) -> None: + """Should reject usernames ending with hyphen.""" + with pytest.raises(ValidationError, match="Invalid GitHub username"): + CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["badend-"], + ) + + def test_deduplicates_maintainers(self) -> None: + """Should remove duplicate usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["octocat", "octocat", "jane"], + ) + assert config.maintainers == ["octocat", "jane"] + + def test_strips_whitespace(self) -> None: + """Should strip whitespace from usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=[" octocat ", " jane "], + ) + assert config.maintainers == ["octocat", "jane"] + + def test_single_char_username(self) -> None: + """Should accept single-character usernames.""" + config = CommunityConfig( + id="test", + name="Test", + description="Test", + maintainers=["a"], + ) + assert config.maintainers == ["a"] + + class TestCommunitiesConfig: """Tests for CommunitiesConfig model.""" diff --git a/tests/test_knowledge/test_search.py b/tests/test_knowledge/test_search.py index 8f382ae..2460bbb 100644 --- a/tests/test_knowledge/test_search.py +++ b/tests/test_knowledge/test_search.py @@ -11,6 +11,8 @@ from src.knowledge.db import get_connection, init_db, upsert_github_item, upsert_paper from src.knowledge.search import ( SearchResult, + _extract_number, + _is_pure_number_query, _sanitize_fts5_query, search_all, search_github_items, @@ -184,6 +186,134 @@ def test_search_all_returns_both_categories(self, populated_db: Path): assert isinstance(results["papers"], list) +class TestNumberExtraction: + """Tests for extracting PR/issue numbers from queries.""" + + def test_plain_number(self): + assert _extract_number("2022") == 2022 + + def test_hash_prefix(self): + assert _extract_number("#500") == 500 + + def test_pr_prefix(self): + assert _extract_number("PR 2022") == 2022 + assert _extract_number("pr #2022") == 2022 + + def test_issue_prefix(self): + assert _extract_number("issue 500") == 500 + assert _extract_number("issue #500") == 500 + + def test_pull_prefix(self): + assert _extract_number("pull #100") == 100 + + def test_no_number(self): + assert _extract_number("validation error") is None + assert _extract_number("how to use BIDS") is None + + def test_number_with_trailing_text(self): + """A number followed by other words is NOT treated as a number lookup.""" + assert _extract_number("123 validation error") is None + + def test_bug_and_feature_prefixes(self): + """bug and feature prefixes are supported per the regex.""" + assert _extract_number("bug 42") == 42 + assert _extract_number("feature #99") == 99 + + def test_whitespace(self): + assert _extract_number(" 2022 ") == 2022 + + +class TestPureNumberQuery: + """Tests for detecting pure number queries.""" + + def test_plain_number(self): + assert _is_pure_number_query("2022") is True + + def test_hash_number(self): + assert _is_pure_number_query("#500") is True + + def test_pr_prefix(self): + assert _is_pure_number_query("PR 2022") is True + assert _is_pure_number_query("issue #500") is True + + def test_text_query(self): + assert _is_pure_number_query("validation error") is False + assert _is_pure_number_query("2022 roadmap plans") is False + + +class TestNumberLookup: + """Tests for searching GitHub items by number.""" + + def test_search_by_number(self, populated_db: Path): + """Search for an item by its number.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("10") + + assert len(results) >= 1 + assert results[0].url == "https://github.com/hed-standard/hed-schemas/pull/10" + assert results[0].title == "Add new sensory tags" + + def test_search_by_hash_number(self, populated_db: Path): + """Search with # prefix.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("#1") + + assert len(results) >= 1 + assert results[0].url == "https://github.com/hed-standard/hed-specification/issues/1" + + def test_search_by_pr_number(self, populated_db: Path): + """Search with 'PR' prefix.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("PR 10") + + assert len(results) >= 1 + assert results[0].item_type == "pr" + assert results[0].url == "https://github.com/hed-standard/hed-schemas/pull/10" + assert results[0].title == "Add new sensory tags" + + def test_number_lookup_with_type_filter(self, populated_db: Path): + """Number lookup respects item_type filter.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + # Item #10 is a PR, filtering for issues should not return it + results = search_github_items("10", item_type="issue") + assert all(r.item_type == "issue" for r in results) + + def test_number_lookup_with_status_filter(self, populated_db: Path): + """Number lookup respects status filter.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + # Item #1 is open, filtering for closed should not return it + results = search_github_items("#1", status="closed") + assert all(r.status == "closed" for r in results) + # But filtering for open should return it + results = search_github_items("#1", status="open") + assert len(results) >= 1 + assert results[0].url == "https://github.com/hed-standard/hed-specification/issues/1" + + def test_nonexistent_number_returns_empty(self, populated_db: Path): + """A number that doesn't exist returns empty for pure-number queries.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("#9999") + assert len(results) == 0 + + def test_limit_respected_with_number_match(self, populated_db: Path): + """Limit parameter is respected even with number matches.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("1", limit=1) + assert len(results) <= 1 + + def test_number_lookup_deduplicates(self, populated_db: Path): + """Number match should not appear twice if also found by FTS.""" + with patch("src.knowledge.db.get_db_path", return_value=populated_db): + results = search_github_items("1") + urls = [r.url for r in results] + assert len(urls) == len(set(urls)), "Duplicate URLs in results" + # Number match should be first in results + if len(results) > 0: + assert ( + results[0].url == "https://github.com/hed-standard/hed-specification/issues/1" + ) + + class TestFTS5Sanitization: """Tests for FTS5 query sanitization to prevent injection.""" diff --git a/tests/test_metrics/__init__.py b/tests/test_metrics/__init__.py new file mode 100644 index 0000000..97451ef --- /dev/null +++ b/tests/test_metrics/__init__.py @@ -0,0 +1 @@ +"""Tests for the metrics module.""" diff --git a/tests/test_metrics/test_alerts.py b/tests/test_metrics/test_alerts.py new file mode 100644 index 0000000..464b560 --- /dev/null +++ b/tests/test_metrics/test_alerts.py @@ -0,0 +1,178 @@ +"""Tests for budget alert issue creation.""" + +from unittest.mock import patch + +from src.metrics.alerts import create_budget_alert_issue +from src.metrics.budget import BudgetStatus + + +def _make_budget_status( + community_id: str = "hed", + daily_spend: float = 4.5, + monthly_spend: float = 45.0, + daily_limit: float = 5.0, + monthly_limit: float = 50.0, + alert_pct: float = 80.0, +) -> BudgetStatus: + return BudgetStatus( + community_id=community_id, + daily_spend_usd=daily_spend, + monthly_spend_usd=monthly_spend, + daily_limit_usd=daily_limit, + monthly_limit_usd=monthly_limit, + alert_threshold_pct=alert_pct, + ) + + +class TestCreateBudgetAlertIssue: + """Tests for create_budget_alert_issue().""" + + def test_no_alert_when_no_threshold_crossed(self): + """Returns None when no alert condition is met.""" + status = _make_budget_status(daily_spend=1.0, monthly_spend=10.0) + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result is None + + @patch("src.metrics.alerts._issue_exists", return_value=True) + @patch("src.metrics.alerts.subprocess") + def test_skips_duplicate_issue(self, mock_subprocess, _mock_exists): + """Returns None when issue already exists.""" + status = _make_budget_status() + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result is None + mock_subprocess.run.assert_not_called() + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_creates_issue_for_daily_alert(self, mock_subprocess, _mock_exists): + """Creates issue when daily alert threshold crossed.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/1\n", "stderr": ""}, + )() + + status = _make_budget_status(daily_spend=4.5, monthly_spend=10.0) + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result == "https://github.com/test/issues/1" + + # Verify subprocess.run was called with gh issue create + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + assert cmd[0] == "gh" + assert cmd[1] == "issue" + assert cmd[2] == "create" + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_creates_issue_for_monthly_exceeded(self, mock_subprocess, _mock_exists): + """Creates issue when monthly limit exceeded.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/2\n", "stderr": ""}, + )() + + status = _make_budget_status(daily_spend=1.0, monthly_spend=50.0) + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result == "https://github.com/test/issues/2" + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_issue_body_contains_maintainer_mentions(self, mock_subprocess, _mock_exists): + """Issue body includes @mentions for maintainers.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/3\n", "stderr": ""}, + )() + + status = _make_budget_status() + create_budget_alert_issue(status, maintainers=["VisLab", "yarikoptic"]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + # Find the --body argument + body_idx = cmd.index("--body") + 1 + body = cmd[body_idx] + assert "@VisLab" in body + assert "@yarikoptic" in body + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_issue_title_format(self, mock_subprocess, _mock_exists): + """Issue title follows expected format.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/4\n", "stderr": ""}, + )() + + status = _make_budget_status(daily_spend=5.0) + create_budget_alert_issue(status, maintainers=[]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + title_idx = cmd.index("--title") + 1 + title = cmd[title_idx] + assert title.startswith("[Budget Alert] hed:") + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_issue_has_labels(self, mock_subprocess, _mock_exists): + """Issue is created with cost-management and operations labels.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/5\n", "stderr": ""}, + )() + + status = _make_budget_status() + create_budget_alert_issue(status, maintainers=[]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + label_idx = cmd.index("--label") + 1 + labels = cmd[label_idx] + assert "cost-management" in labels + assert "operations" in labels + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_returns_none_on_gh_failure(self, mock_subprocess, _mock_exists): + """Returns None when gh CLI fails.""" + mock_subprocess.run.return_value = type( + "Result", (), {"returncode": 1, "stdout": "", "stderr": "auth required"} + )() + + status = _make_budget_status() + result = create_budget_alert_issue(status, maintainers=[]) + assert result is None + + @patch("src.metrics.alerts._issue_exists", return_value=False) + @patch("src.metrics.alerts.subprocess") + def test_no_maintainers_message(self, mock_subprocess, _mock_exists): + """Body shows 'No maintainers configured' when list is empty.""" + mock_subprocess.run.return_value = type( + "Result", + (), + {"returncode": 0, "stdout": "https://github.com/test/issues/6\n", "stderr": ""}, + )() + + status = _make_budget_status() + create_budget_alert_issue(status, maintainers=[]) + + call_args = mock_subprocess.run.call_args + cmd = call_args[0][0] + body_idx = cmd.index("--body") + 1 + body = cmd[body_idx] + assert "No maintainers configured" in body + + @patch("src.metrics.alerts._issue_exists", return_value=None) + @patch("src.metrics.alerts.subprocess") + def test_suppresses_alert_when_dedup_check_fails(self, mock_subprocess, _mock_exists): + """Returns None and does not create issue when dedup check fails.""" + status = _make_budget_status() + result = create_budget_alert_issue(status, maintainers=["user1"]) + assert result is None + mock_subprocess.run.assert_not_called() diff --git a/tests/test_metrics/test_budget.py b/tests/test_metrics/test_budget.py new file mode 100644 index 0000000..a5b82d0 --- /dev/null +++ b/tests/test_metrics/test_budget.py @@ -0,0 +1,321 @@ +"""Tests for budget checking.""" + +import pytest + +from src.core.config.community import BudgetConfig +from src.metrics.budget import BudgetStatus, check_budget +from src.metrics.db import ( + RequestLogEntry, + get_metrics_connection, + init_metrics_db, + log_request, + now_iso, +) + + +def _make_config( + daily: float = 5.0, + monthly: float = 50.0, + alert_pct: float = 80.0, +) -> BudgetConfig: + return BudgetConfig( + daily_limit_usd=daily, + monthly_limit_usd=monthly, + alert_threshold_pct=alert_pct, + ) + + +@pytest.fixture +def budget_db(tmp_path): + """Create a metrics DB with cost data for budget testing.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="b1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + estimated_cost=0.50, + ), + RequestLogEntry( + request_id="b2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + estimated_cost=0.30, + ), + RequestLogEntry( + request_id="b3", + timestamp="2025-01-14T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=250.0, + status_code=200, + estimated_cost=2.00, + ), + RequestLogEntry( + request_id="b4", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/eeglab/ask", + method="POST", + community_id="eeglab", + duration_ms=150.0, + status_code=200, + estimated_cost=0.10, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +class TestBudgetStatus: + """Tests for BudgetStatus dataclass.""" + + def test_daily_pct(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_pct == 80.0 + + def test_monthly_pct(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=40.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.monthly_pct == 80.0 + + def test_zero_limit_returns_zero_pct(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=1.0, + monthly_spend_usd=1.0, + daily_limit_usd=0.0, + monthly_limit_usd=0.0, + alert_threshold_pct=80.0, + ) + assert status.daily_pct == 0.0 + assert status.monthly_pct == 0.0 + + def test_daily_exceeded(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=5.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_exceeded is True + + def test_daily_not_exceeded(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.99, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_exceeded is False + + def test_monthly_exceeded(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=50.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.monthly_exceeded is True + + def test_daily_alert_at_threshold(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_alert is True + + def test_daily_alert_below_threshold(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=3.99, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.daily_alert is False + + def test_needs_alert_daily(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=4.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.needs_alert is True + + def test_needs_alert_monthly(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=40.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.needs_alert is True + + def test_no_alert_needed(self): + status = BudgetStatus( + community_id="test", + daily_spend_usd=1.0, + monthly_spend_usd=10.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + assert status.needs_alert is False + + def test_rejects_negative_daily_spend(self): + with pytest.raises(ValueError, match="daily_spend_usd must be non-negative"): + BudgetStatus( + community_id="test", + daily_spend_usd=-1.0, + monthly_spend_usd=0.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + + def test_rejects_negative_monthly_spend(self): + with pytest.raises(ValueError, match="monthly_spend_usd must be non-negative"): + BudgetStatus( + community_id="test", + daily_spend_usd=0.0, + monthly_spend_usd=-1.0, + daily_limit_usd=5.0, + monthly_limit_usd=50.0, + alert_threshold_pct=80.0, + ) + + +class TestCheckBudget: + """Tests for check_budget() function.""" + + def test_returns_budget_status(self, budget_db): + config = _make_config() + conn = get_metrics_connection(budget_db) + try: + status = check_budget(community_id="hed", config=config, conn=conn) + assert isinstance(status, BudgetStatus) + assert status.community_id == "hed" + assert status.daily_limit_usd == 5.0 + assert status.monthly_limit_usd == 50.0 + assert status.alert_threshold_pct == 80.0 + finally: + conn.close() + + def test_empty_community_zero_spend(self, budget_db): + config = _make_config() + conn = get_metrics_connection(budget_db) + try: + status = check_budget(community_id="nonexistent", config=config, conn=conn) + assert status.daily_spend_usd == 0.0 + assert status.monthly_spend_usd == 0.0 + finally: + conn.close() + + def test_spend_values_are_non_negative(self, budget_db): + config = _make_config() + conn = get_metrics_connection(budget_db) + try: + status = check_budget(community_id="hed", config=config, conn=conn) + assert status.daily_spend_usd >= 0.0 + assert status.monthly_spend_usd >= 0.0 + finally: + conn.close() + + def test_today_spend_with_current_timestamps(self, tmp_path): + """Verify check_budget sums costs for entries with today's timestamp.""" + db_path = tmp_path / "today.db" + init_metrics_db(db_path) + + # Insert entries with current timestamps + for i, cost in enumerate([0.25, 0.35, 0.40]): + log_request( + RequestLogEntry( + request_id=f"today-{i}", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + community_id="hed", + status_code=200, + estimated_cost=cost, + ), + db_path=db_path, + ) + + config = _make_config() + conn = get_metrics_connection(db_path) + try: + status = check_budget(community_id="hed", config=config, conn=conn) + assert status.daily_spend_usd == pytest.approx(1.0, abs=1e-6) + assert status.monthly_spend_usd == pytest.approx(1.0, abs=1e-6) + finally: + conn.close() + + def test_today_spend_triggers_alert(self, tmp_path): + """Verify budget alert triggers when today's spend crosses threshold.""" + db_path = tmp_path / "alert.db" + init_metrics_db(db_path) + + log_request( + RequestLogEntry( + request_id="expensive", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + community_id="hed", + status_code=200, + estimated_cost=4.5, + ), + db_path=db_path, + ) + + config = _make_config(daily=5.0, monthly=50.0, alert_pct=80.0) + conn = get_metrics_connection(db_path) + try: + status = check_budget(community_id="hed", config=config, conn=conn) + assert status.daily_pct == 90.0 + assert status.daily_alert is True + assert status.needs_alert is True + finally: + conn.close() diff --git a/tests/test_metrics/test_cost.py b/tests/test_metrics/test_cost.py new file mode 100644 index 0000000..6286fd0 --- /dev/null +++ b/tests/test_metrics/test_cost.py @@ -0,0 +1,64 @@ +"""Tests for cost estimation.""" + +from src.metrics.cost import MODEL_PRICING, estimate_cost + + +class TestEstimateCost: + """Tests for estimate_cost().""" + + def test_known_model(self): + """Cost for a known model uses its pricing.""" + cost = estimate_cost("openai/gpt-4o", input_tokens=1000, output_tokens=500) + # input: 1000 * 2.50 / 1M = 0.0025, output: 500 * 10.00 / 1M = 0.005 + assert cost == 0.0075 + + def test_unknown_model_uses_fallback(self): + """Unknown model uses fallback rates.""" + cost = estimate_cost("unknown/model", input_tokens=1_000_000, output_tokens=1_000_000) + # fallback: 1.00 input + 3.00 output = 4.00 + assert cost == 4.0 + + def test_none_model_uses_fallback(self): + """None model uses fallback rates.""" + cost = estimate_cost(None, input_tokens=1_000_000, output_tokens=0) + assert cost == 1.0 + + def test_zero_tokens(self): + """Zero tokens should return zero cost.""" + cost = estimate_cost("openai/gpt-4o", input_tokens=0, output_tokens=0) + assert cost == 0.0 + + def test_rounding(self): + """Cost should be rounded to 6 decimal places.""" + cost = estimate_cost("openai/gpt-4o-mini", input_tokens=1, output_tokens=1) + # input: 1 * 0.15 / 1M = 0.00000015, output: 1 * 0.60 / 1M = 0.0000006 + # total: 0.00000075 -> rounds to 0.000001 + assert cost == 0.000001 + + def test_all_models_have_pricing(self): + """All models in the pricing table should have valid (input, output) tuples.""" + for model, (input_rate, output_rate) in MODEL_PRICING.items(): + assert isinstance(input_rate, (int, float)), f"{model} has invalid input rate" + assert isinstance(output_rate, (int, float)), f"{model} has invalid output rate" + assert input_rate >= 0, f"{model} has negative input rate" + assert output_rate >= 0, f"{model} has negative output rate" + + def test_qwen_model_cost(self): + """Verify cost for a Qwen model.""" + cost = estimate_cost( + "qwen/qwen3-235b-a22b-2507", + input_tokens=1_000_000, + output_tokens=1_000_000, + ) + # input: 0.14, output: 0.34, total: 0.48 + assert cost == 0.48 + + def test_expensive_model(self): + """Verify cost for an expensive model (Claude Opus 4).""" + cost = estimate_cost( + "anthropic/claude-opus-4", + input_tokens=1_000_000, + output_tokens=1_000_000, + ) + # input: 15.00, output: 75.00, total: 90.00 + assert cost == 90.0 diff --git a/tests/test_metrics/test_db.py b/tests/test_metrics/test_db.py new file mode 100644 index 0000000..c25d702 --- /dev/null +++ b/tests/test_metrics/test_db.py @@ -0,0 +1,235 @@ +"""Tests for metrics database storage layer.""" + +import json + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from src.metrics.db import ( + RequestLogEntry, + extract_token_usage, + extract_tool_names, + get_metrics_connection, + init_metrics_db, + log_request, + now_iso, +) + + +@pytest.fixture +def metrics_db(tmp_path): + """Create a temporary metrics database.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + return db_path + + +class TestInitMetricsDb: + """Tests for init_metrics_db().""" + + def test_creates_table(self, tmp_path): + """init_metrics_db creates the request_log table.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='request_log'" + ).fetchall() + assert len(tables) == 1 + finally: + conn.close() + + def test_idempotent(self, tmp_path): + """Calling init_metrics_db twice does not error.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='request_log'" + ).fetchall() + assert len(tables) == 1 + finally: + conn.close() + + def test_wal_mode(self, tmp_path): + """init_metrics_db sets WAL journal mode.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + assert mode == "wal" + finally: + conn.close() + + def test_creates_indexes(self, tmp_path): + """init_metrics_db creates the expected indexes.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + conn = get_metrics_connection(db_path) + try: + indexes = conn.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='request_log'" + ).fetchall() + index_names = {r[0] for r in indexes} + assert "idx_request_log_community" in index_names + assert "idx_request_log_timestamp" in index_names + assert "idx_request_log_community_timestamp" in index_names + finally: + conn.close() + + +class TestLogRequest: + """Tests for log_request().""" + + def test_inserts_entry(self, metrics_db): + """log_request inserts a row that can be read back.""" + entry = RequestLogEntry( + request_id="test-123", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=150.5, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + tools_called=["search_docs", "validate_hed"], + key_source="platform", + stream=False, + ) + log_request(entry, db_path=metrics_db) + + conn = get_metrics_connection(metrics_db) + try: + row = conn.execute( + "SELECT * FROM request_log WHERE request_id = ?", ("test-123",) + ).fetchone() + assert row is not None + assert row["community_id"] == "hed" + assert row["endpoint"] == "/hed/ask" + assert row["method"] == "POST" + assert row["duration_ms"] == 150.5 + assert row["status_code"] == 200 + assert row["model"] == "qwen/qwen3-235b" + assert row["input_tokens"] == 100 + assert row["output_tokens"] == 50 + assert row["total_tokens"] == 150 + assert row["key_source"] == "platform" + assert row["stream"] == 0 + tools = json.loads(row["tools_called"]) + assert tools == ["search_docs", "validate_hed"] + finally: + conn.close() + + def test_null_agent_fields(self, metrics_db): + """Non-agent requests have NULL agent fields.""" + entry = RequestLogEntry( + request_id="basic-req", + timestamp=now_iso(), + endpoint="/health", + method="GET", + status_code=200, + duration_ms=5.0, + ) + log_request(entry, db_path=metrics_db) + + conn = get_metrics_connection(metrics_db) + try: + row = conn.execute( + "SELECT * FROM request_log WHERE request_id = ?", ("basic-req",) + ).fetchone() + assert row is not None + assert row["model"] is None + assert row["input_tokens"] is None + assert row["output_tokens"] is None + assert row["total_tokens"] is None + assert row["tools_called"] is None + assert row["key_source"] is None + assert row["community_id"] is None + finally: + conn.close() + + def test_stream_flag(self, metrics_db): + """stream=True is stored as 1.""" + entry = RequestLogEntry( + request_id="stream-req", + timestamp=now_iso(), + endpoint="/hed/ask", + method="POST", + stream=True, + ) + log_request(entry, db_path=metrics_db) + + conn = get_metrics_connection(metrics_db) + try: + row = conn.execute( + "SELECT stream FROM request_log WHERE request_id = ?", ("stream-req",) + ).fetchone() + assert row["stream"] == 1 + finally: + conn.close() + + +class TestExtractTokenUsage: + """Tests for extract_token_usage().""" + + def test_extracts_from_ai_messages(self): + """Extracts token usage from AIMessages with usage_metadata.""" + msg = AIMessage(content="hello") + msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + result = {"messages": [HumanMessage(content="hi"), msg]} + + inp, out, total = extract_token_usage(result) + assert inp == 10 + assert out == 5 + assert total == 15 + + def test_sums_multiple_messages(self): + """Sums usage across multiple AIMessages.""" + msg1 = AIMessage(content="first") + msg1.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + msg2 = AIMessage(content="second") + msg2.usage_metadata = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + result = {"messages": [msg1, msg2]} + + inp, out, total = extract_token_usage(result) + assert inp == 30 + assert out == 15 + assert total == 45 + + def test_returns_zeros_when_no_usage(self): + """Returns (0, 0, 0) when no usage_metadata.""" + result = {"messages": [AIMessage(content="hello")]} + inp, out, total = extract_token_usage(result) + assert (inp, out, total) == (0, 0, 0) + + def test_returns_zeros_for_empty_result(self): + """Returns (0, 0, 0) for empty result.""" + assert extract_token_usage({}) == (0, 0, 0) + assert extract_token_usage({"messages": []}) == (0, 0, 0) + + +class TestExtractToolNames: + """Tests for extract_tool_names().""" + + def test_extracts_names(self): + result = {"tool_calls": [{"name": "search"}, {"name": "validate"}]} + assert extract_tool_names(result) == ["search", "validate"] + + def test_empty_tool_calls(self): + assert extract_tool_names({"tool_calls": []}) == [] + assert extract_tool_names({}) == [] + + def test_skips_empty_names(self): + result = {"tool_calls": [{"name": ""}, {"name": "search"}]} + assert extract_tool_names(result) == ["search"] diff --git a/tests/test_metrics/test_middleware.py b/tests/test_metrics/test_middleware.py new file mode 100644 index 0000000..7f3e26c --- /dev/null +++ b/tests/test_metrics/test_middleware.py @@ -0,0 +1,114 @@ +"""Tests for metrics middleware.""" + +from unittest.mock import patch + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from src.metrics.middleware import MetricsMiddleware, _extract_community_id + + +class TestExtractCommunityId: + """Tests for _extract_community_id helper.""" + + def test_extracts_from_ask(self): + assert _extract_community_id("/hed/ask") == "hed" + + def test_extracts_from_chat(self): + assert _extract_community_id("/bids/chat") == "bids" + + def test_returns_none_for_non_community(self): + assert _extract_community_id("/health") is None + assert _extract_community_id("/sync/status") is None + assert _extract_community_id("/metrics/overview") is None + + def test_returns_none_for_root(self): + assert _extract_community_id("/") is None + + def test_extracts_from_community_metrics_paths(self): + """Community metrics/session/config paths return the community ID.""" + assert _extract_community_id("/hed/metrics") == "hed" + assert _extract_community_id("/hed/metrics/public") == "hed" + assert _extract_community_id("/hed/sessions") == "hed" + assert _extract_community_id("/hed/config") == "hed" + + def test_returns_none_for_deep_community_paths(self): + """Paths under community that don't match known suffixes return None.""" + assert _extract_community_id("/hed/unknown") is None + assert _extract_community_id("/hed/something/else") is None + + +class TestMetricsMiddleware: + """Tests for MetricsMiddleware integration.""" + + @pytest.fixture + def test_app(self): + """Create a minimal FastAPI app with MetricsMiddleware.""" + app = FastAPI() + app.add_middleware(MetricsMiddleware) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.post("/hed/ask") + async def ask(request: Request): + # Simulate agent metrics + request.state.metrics_agent_data = { + "model": "test-model", + "key_source": "platform", + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "tools_called": ["search"], + "stream": False, + } + return {"answer": "test"} + + @app.post("/hed/streaming") + async def streaming(request: Request): + # Simulate handler that logs its own metrics + request.state.metrics_logged = True + return {"answer": "streamed"} + + with patch("src.metrics.middleware.log_request") as mock_log: + yield app, mock_log + + def test_logs_basic_request(self, test_app): + """Middleware logs basic request data for non-agent endpoints.""" + app, mock_log = test_app + client = TestClient(app) + response = client.get("/health") + assert response.status_code == 200 + + assert mock_log.called + entry = mock_log.call_args[0][0] + assert entry.endpoint == "/health" + assert entry.method == "GET" + assert entry.status_code == 200 + assert entry.duration_ms > 0 + assert entry.model is None # non-agent + + def test_picks_up_agent_metrics(self, test_app): + """Middleware reads agent metrics from request.state.""" + app, mock_log = test_app + client = TestClient(app) + response = client.post("/hed/ask", json={}) + assert response.status_code == 200 + + assert mock_log.called + entry = mock_log.call_args[0][0] + assert entry.model == "test-model" + assert entry.key_source == "platform" + assert entry.input_tokens == 10 + assert entry.tools_called == ["search"] + assert entry.community_id == "hed" + + def test_skips_when_handler_logged(self, test_app): + """Middleware skips logging when handler set metrics_logged=True.""" + app, mock_log = test_app + client = TestClient(app) + response = client.post("/hed/streaming", json={}) + assert response.status_code == 200 + assert not mock_log.called diff --git a/tests/test_metrics/test_quality_queries.py b/tests/test_metrics/test_quality_queries.py new file mode 100644 index 0000000..172cd99 --- /dev/null +++ b/tests/test_metrics/test_quality_queries.py @@ -0,0 +1,345 @@ +"""Tests for quality metrics queries.""" + +import pytest + +from src.metrics.db import RequestLogEntry, get_metrics_connection, init_metrics_db, log_request + + +@pytest.fixture +def quality_db(tmp_path): + """Create a metrics DB with quality-related data.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="q1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + tool_call_count=2, + langfuse_trace_id="trace-abc", + ), + RequestLogEntry( + request_id="q2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=500.0, + status_code=200, + model="qwen/qwen3-235b", + tool_call_count=1, + langfuse_trace_id="trace-def", + ), + RequestLogEntry( + request_id="q3", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=1500.0, + status_code=500, + model="qwen/qwen3-235b", + tool_call_count=0, + error_message="LLM timeout", + ), + RequestLogEntry( + request_id="q4", + timestamp="2025-01-16T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + tool_call_count=3, + langfuse_trace_id="trace-ghi", + ), + RequestLogEntry( + request_id="q5", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/eeglab/ask", + method="POST", + community_id="eeglab", + duration_ms=250.0, + status_code=200, + tool_call_count=1, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +class TestGetQualityMetrics: + """Tests for get_quality_metrics().""" + + def test_daily_buckets(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + assert result["community_id"] == "hed" + assert result["period"] == "daily" + buckets = {b["bucket"]: b for b in result["buckets"]} + assert "2025-01-15" in buckets + assert "2025-01-16" in buckets + finally: + conn.close() + + def test_error_rate_per_bucket(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + # Jan 15: 1 error out of 3 requests + assert abs(buckets["2025-01-15"]["error_rate"] - 0.3333) < 0.01 + # Jan 16: 0 errors + assert buckets["2025-01-16"]["error_rate"] == 0.0 + finally: + conn.close() + + def test_avg_tool_calls(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + # Jan 15: (2+1+0)/3 = 1.0 + assert buckets["2025-01-15"]["avg_tool_calls"] == 1.0 + # Jan 16: 3/1 = 3.0 + assert buckets["2025-01-16"]["avg_tool_calls"] == 3.0 + finally: + conn.close() + + def test_agent_errors(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + assert buckets["2025-01-15"]["agent_errors"] == 1 # "LLM timeout" + assert buckets["2025-01-16"]["agent_errors"] == 0 + finally: + conn.close() + + def test_traced_requests(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + assert buckets["2025-01-15"]["traced_requests"] == 2 + assert buckets["2025-01-16"]["traced_requests"] == 1 + finally: + conn.close() + + def test_latency_percentiles(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("hed", conn, "daily") + buckets = {b["bucket"]: b for b in result["buckets"]} + # Jan 15: durations [200, 500, 1500] sorted + # p50 = values[1] = 500, p95 = values[2] = 1500 + assert buckets["2025-01-15"]["p50_duration_ms"] == 500.0 + assert buckets["2025-01-15"]["p95_duration_ms"] == 1500.0 + finally: + conn.close() + + def test_empty_community(self, quality_db): + from src.metrics.queries import get_quality_metrics + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_metrics("nonexistent", conn, "daily") + assert result["buckets"] == [] + finally: + conn.close() + + +class TestGetQualitySummary: + """Tests for get_quality_summary().""" + + def test_summary_totals(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + assert result["community_id"] == "hed" + assert result["total_requests"] == 4 + assert abs(result["error_rate"] - 0.25) < 0.01 + finally: + conn.close() + + def test_summary_avg_tool_calls(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + # (2+1+0+3)/4 = 1.5 + assert result["avg_tool_calls"] == 1.5 + finally: + conn.close() + + def test_summary_traced_pct(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + # 3 traced out of 4 + assert abs(result["traced_pct"] - 0.75) < 0.01 + finally: + conn.close() + + def test_summary_latency(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("hed", conn) + # durations sorted: [200, 300, 500, 1500] + # nearest-rank p50: idx=int(0.5*4)=2 -> 500 + # nearest-rank p95: idx=int(0.95*4)=3 -> 1500 + assert result["p50_duration_ms"] == 500.0 + assert result["p95_duration_ms"] == 1500.0 + finally: + conn.close() + + def test_empty_community(self, quality_db): + from src.metrics.queries import get_quality_summary + + conn = get_metrics_connection(quality_db) + try: + result = get_quality_summary("nonexistent", conn) + assert result["total_requests"] == 0 + assert result["error_rate"] == 0.0 + assert result["p50_duration_ms"] is None + finally: + conn.close() + + +class TestPercentileEdgeCases: + """Tests for _percentile() helper edge cases.""" + + def test_single_element(self): + from src.metrics.queries import _percentile + + assert _percentile([100.0], 0.5) == 100.0 + assert _percentile([100.0], 0.95) == 100.0 + + def test_two_elements(self): + from src.metrics.queries import _percentile + + values = [100.0, 200.0] + assert _percentile(values, 0.5) == 200.0 # idx=int(0.5*2)=1 + assert _percentile(values, 0.95) == 200.0 # idx=int(0.95*2)=1 + + def test_empty_returns_none(self): + from src.metrics.queries import _percentile + + assert _percentile([], 0.5) is None + + +class TestCountToolsMalformedJSON: + """Tests for _count_tools() with malformed data.""" + + def test_malformed_json_skipped(self, tmp_path): + """Rows with invalid tools_called JSON should be skipped gracefully.""" + + from src.metrics.queries import get_quality_metrics + + db_path = tmp_path / "malformed.db" + init_metrics_db(db_path) + + # Insert a valid entry first + log_request( + RequestLogEntry( + request_id="valid", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + tools_called=["search_docs"], + ), + db_path=db_path, + ) + + # Insert a row with invalid JSON directly via SQL + conn = get_metrics_connection(db_path) + try: + conn.execute( + """INSERT INTO request_log + (request_id, timestamp, endpoint, method, community_id, + duration_ms, status_code, tools_called) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + "bad", + "2025-01-15T11:00:00+00:00", + "/hed/ask", + "POST", + "hed", + 300.0, + 200, + "not-valid-json{{{", + ), + ) + conn.commit() + finally: + conn.close() + + # Query should still work, returning valid tool counts + conn = get_metrics_connection(db_path) + try: + result = get_quality_metrics("hed", conn, "daily") + assert len(result["buckets"]) == 1 + # Should have at least the valid entry's tool + bucket = result["buckets"][0] + assert bucket["requests"] == 2 + finally: + conn.close() + + +class TestMigrationColumns: + """Tests for backward-compatible column migration.""" + + def test_new_columns_exist_after_init(self, quality_db): + """New quality columns should exist in freshly initialized DB.""" + conn = get_metrics_connection(quality_db) + try: + # Query using the new columns + row = conn.execute( + "SELECT tool_call_count, error_message, langfuse_trace_id FROM request_log LIMIT 1" + ).fetchone() + assert row is not None + finally: + conn.close() + + def test_migration_idempotent(self, quality_db): + """Running init_metrics_db twice should not fail.""" + # Second init should be a no-op for migrations + init_metrics_db(quality_db) + conn = get_metrics_connection(quality_db) + try: + count = conn.execute("SELECT COUNT(*) FROM request_log").fetchone()[0] + assert count == 5 # Original data preserved + finally: + conn.close() diff --git a/tests/test_metrics/test_queries.py b/tests/test_metrics/test_queries.py new file mode 100644 index 0000000..ae71c40 --- /dev/null +++ b/tests/test_metrics/test_queries.py @@ -0,0 +1,303 @@ +"""Tests for metrics aggregation queries.""" + +import pytest + +from src.metrics.db import RequestLogEntry, get_metrics_connection, init_metrics_db, log_request + + +@pytest.fixture +def populated_db(tmp_path): + """Create a metrics DB with sample data for query testing.""" + db_path = tmp_path / "metrics.db" + init_metrics_db(db_path) + + entries = [ + RequestLogEntry( + request_id="r1", + timestamp="2025-01-15T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=200.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=100, + output_tokens=50, + total_tokens=150, + estimated_cost=0.001, + tools_called=["search_docs"], + key_source="platform", + ), + RequestLogEntry( + request_id="r2", + timestamp="2025-01-15T11:00:00+00:00", + endpoint="/hed/chat", + method="POST", + community_id="hed", + duration_ms=300.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=200, + output_tokens=100, + total_tokens=300, + estimated_cost=0.002, + tools_called=["search_docs", "validate_hed"], + key_source="byok", + ), + RequestLogEntry( + request_id="r3", + timestamp="2025-01-16T10:00:00+00:00", + endpoint="/hed/ask", + method="POST", + community_id="hed", + duration_ms=150.0, + status_code=500, + model="openai/gpt-4o", + input_tokens=50, + output_tokens=0, + total_tokens=50, + key_source="byok", + ), + RequestLogEntry( + request_id="r4", + timestamp="2025-01-15T12:00:00+00:00", + endpoint="/bids/ask", + method="POST", + community_id="bids", + duration_ms=250.0, + status_code=200, + model="qwen/qwen3-235b", + input_tokens=80, + output_tokens=40, + total_tokens=120, + estimated_cost=0.0008, + tools_called=["search_docs"], + key_source="platform", + ), + RequestLogEntry( + request_id="r5", + timestamp="2025-01-15T09:00:00+00:00", + endpoint="/health", + method="GET", + duration_ms=5.0, + status_code=200, + ), + ] + for e in entries: + log_request(e, db_path=db_path) + + return db_path + + +class TestGetCommunitySummary: + """Tests for get_community_summary().""" + + def test_returns_correct_totals(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + assert result["community_id"] == "hed" + assert result["total_requests"] == 3 + assert result["total_input_tokens"] == 350 + assert result["total_output_tokens"] == 150 + assert result["total_tokens"] == 500 + finally: + conn.close() + + def test_error_rate(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + # 1 error out of 3 requests + assert abs(result["error_rate"] - 0.3333) < 0.01 + finally: + conn.close() + + def test_top_models(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + models = {m["model"]: m["count"] for m in result["top_models"]} + assert models["qwen/qwen3-235b"] == 2 + assert models["openai/gpt-4o"] == 1 + finally: + conn.close() + + def test_top_tools(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("hed", conn) + tools = {t["tool"]: t["count"] for t in result["top_tools"]} + assert tools["search_docs"] == 2 + assert tools["validate_hed"] == 1 + finally: + conn.close() + + def test_empty_community(self, populated_db): + from src.metrics.queries import get_community_summary + + conn = get_metrics_connection(populated_db) + try: + result = get_community_summary("nonexistent", conn) + assert result["total_requests"] == 0 + assert result["total_tokens"] == 0 + assert result["error_rate"] == 0.0 + finally: + conn.close() + + +class TestGetUsageStats: + """Tests for get_usage_stats().""" + + def test_daily_bucketing(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + result = get_usage_stats("hed", "daily", conn) + assert result["period"] == "daily" + assert result["community_id"] == "hed" + buckets = {b["bucket"]: b for b in result["buckets"]} + assert "2025-01-15" in buckets + assert "2025-01-16" in buckets + assert buckets["2025-01-15"]["requests"] == 2 + assert buckets["2025-01-16"]["requests"] == 1 + finally: + conn.close() + + def test_monthly_bucketing(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + result = get_usage_stats("hed", "monthly", conn) + buckets = result["buckets"] + assert len(buckets) == 1 + assert buckets[0]["bucket"] == "2025-01" + assert buckets[0]["requests"] == 3 + finally: + conn.close() + + def test_invalid_period_raises(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + with pytest.raises(ValueError, match="Invalid period"): + get_usage_stats("hed", "hourly", conn) + finally: + conn.close() + + def test_empty_community_returns_empty_buckets(self, populated_db): + from src.metrics.queries import get_usage_stats + + conn = get_metrics_connection(populated_db) + try: + result = get_usage_stats("nonexistent", "daily", conn) + assert result["buckets"] == [] + finally: + conn.close() + + +class TestGetOverview: + """Tests for get_overview().""" + + def test_overview_totals(self, populated_db): + from src.metrics.queries import get_overview + + conn = get_metrics_connection(populated_db) + try: + result = get_overview(conn) + # 5 total entries + assert result["total_requests"] == 5 + assert result["total_tokens"] == 620 # 150+300+50+120+0 + finally: + conn.close() + + def test_overview_communities(self, populated_db): + from src.metrics.queries import get_overview + + conn = get_metrics_connection(populated_db) + try: + result = get_overview(conn) + communities = {c["community_id"]: c for c in result["communities"]} + assert "hed" in communities + assert "bids" in communities + assert communities["hed"]["requests"] == 3 + assert communities["bids"]["requests"] == 1 + finally: + conn.close() + + def test_error_rate_in_overview(self, populated_db): + from src.metrics.queries import get_overview + + conn = get_metrics_connection(populated_db) + try: + result = get_overview(conn) + # 1 error out of 5 requests + assert abs(result["error_rate"] - 0.2) < 0.01 + finally: + conn.close() + + +class TestGetTokenBreakdown: + """Tests for get_token_breakdown().""" + + def test_by_model(self, populated_db): + from src.metrics.queries import get_token_breakdown + + conn = get_metrics_connection(populated_db) + try: + result = get_token_breakdown(conn) + models = {m["model"]: m for m in result["by_model"]} + assert "qwen/qwen3-235b" in models + assert models["qwen/qwen3-235b"]["requests"] == 3 + finally: + conn.close() + + def test_by_key_source(self, populated_db): + from src.metrics.queries import get_token_breakdown + + conn = get_metrics_connection(populated_db) + try: + result = get_token_breakdown(conn) + sources = {s["key_source"]: s for s in result["by_key_source"]} + assert "platform" in sources + assert "byok" in sources + assert sources["platform"]["requests"] == 2 + assert sources["byok"]["requests"] == 2 + finally: + conn.close() + + def test_filter_by_community(self, populated_db): + from src.metrics.queries import get_token_breakdown + + conn = get_metrics_connection(populated_db) + try: + result = get_token_breakdown(conn, community_id="bids") + assert result["community_id"] == "bids" + assert len(result["by_model"]) == 1 + assert result["by_model"][0]["model"] == "qwen/qwen3-235b" + finally: + conn.close() + + def test_empty_db_returns_empty(self, tmp_path): + from src.metrics.queries import get_token_breakdown + + db_path = tmp_path / "empty.db" + init_metrics_db(db_path) + conn = get_metrics_connection(db_path) + try: + result = get_token_breakdown(conn) + assert result["by_model"] == [] + assert result["by_key_source"] == [] + finally: + conn.close() diff --git a/workers/osa-worker/README.md b/workers/osa-worker/README.md index a847f83..9b6b7f7 100644 --- a/workers/osa-worker/README.md +++ b/workers/osa-worker/README.md @@ -2,7 +2,9 @@ Security proxy for the Open Science Assistant backend. Provides: - **Turnstile verification** (visible widget) for bot protection -- **Rate limiting** (IP-based, per-minute and per-hour) +- **Hybrid rate limiting** (IP-based, per-minute and per-hour) + - Per-minute: Built-in API (fast bot protection, <1ms) + - Per-hour: KV (global human abuse prevention) - **CORS validation** for allowed origins - **API key injection** for backend authentication - **BYOK mode** for CLI/programmatic access @@ -39,23 +41,11 @@ npm install -g wrangler wrangler login ``` -### 2. Create KV namespaces for rate limiting +### 2. KV namespaces -```bash -# Production -wrangler kv:namespace create "RATE_LIMITER" -# Copy the ID and update wrangler.toml - -# Development -wrangler kv:namespace create "RATE_LIMITER" --env dev -# Copy the ID and update wrangler.toml [env.dev.kv_namespaces] -``` +KV namespaces are already configured in `wrangler.toml` for per-hour rate limiting. No additional setup needed. -### 3. Update wrangler.toml - -Replace `REPLACE_WITH_KV_ID` and `REPLACE_WITH_DEV_KV_ID` with the IDs from step 2. - -### 4. Set up Turnstile +### 3. Set up Turnstile 1. Go to Cloudflare Dashboard > Turnstile 2. Create a new widget with **Visible** mode @@ -67,7 +57,7 @@ Replace `REPLACE_WITH_KV_ID` and `REPLACE_WITH_DEV_KV_ID` with the IDs from step 4. Copy the Site Key (for frontend integration) 5. Copy the Secret Key (for this worker) -### 5. Set secrets +### 4. Set secrets ```bash # Backend API key (generate with: python -c "import secrets; print(secrets.token_urlsafe(32))") @@ -81,7 +71,7 @@ wrangler secret put BACKEND_API_KEY --env dev wrangler secret put TURNSTILE_SECRET_KEY --env dev ``` -### 6. Deploy +### 5. Deploy ```bash # Production @@ -104,10 +94,22 @@ wrangler deploy --env dev ## Rate Limits -| Environment | Per Minute | Per Hour | -|-------------|------------|----------| -| Production | 10 | 100 | -| Development | 60 | 600 | +Hybrid approach for optimal performance and protection: + +| Environment | Per Minute (Bot Protection) | Per Hour (Human Abuse) | +|-------------|----------------------------|----------------------| +| Production | 10 (built-in API) | 20 (KV, global) | +| Development | 60 (built-in API) | 100 (KV, global) | + +**Why hybrid?** +- **Per-minute**: Needs to be fast (<1ms), catches bots immediately → Built-in API +- **Per-hour**: Needs global consistency across edge locations → KV +- **Result**: 50% fewer KV writes (1 vs 2 per request), faster bot protection + +**Rate limit scope:** +- Limits are **per IP address**, not per session +- 20/hour in production = ~1 question every 3 minutes (reasonable for research) +- Prevents abuse while allowing legitimate use ## BYOK Mode diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index e067007..dae2e6a 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -9,12 +9,15 @@ * - BYOK mode for CLI/programmatic access */ +// Path segments that are actual routes, never valid community IDs +const RESERVED_PATHS = ['health', 'version', 'feedback', 'communities', 'metrics', 'sync']; + // Worker configuration function getConfig(env) { const isDev = env.ENVIRONMENT === 'development'; return { RATE_LIMIT_PER_MINUTE: isDev ? 60 : 10, - RATE_LIMIT_PER_HOUR: isDev ? 600 : 100, + RATE_LIMIT_PER_HOUR: isDev ? 100 : 20, REQUEST_TIMEOUT: 120000, // 2 minutes for LLM responses IS_DEV: isDev, }; @@ -66,37 +69,85 @@ async function verifyTurnstileToken(token, secretKey, ip) { } /** - * Check rate limit using KV storage + * Hybrid rate limiting approach: + * - Per-minute (bot protection): Built-in API (fast, <1ms, in-memory) + * - Per-hour (human abuse): KV (global consistency, 1 write per request) + * + * Benefits: + * - 50% reduction in KV writes (1 vs 2 per request) + * - Faster bot protection (<1ms vs ~10-50ms for critical first check) + * - Global hourly limits across all edge locations + * + * Known limitation: + * - KV read-then-write is not atomic; concurrent requests from same IP + * may slightly exceed hourly limit. Per-minute guard constrains this. */ async function checkRateLimit(request, env, CONFIG) { - if (!env.RATE_LIMITER) return { allowed: true }; - const ip = request.headers.get('CF-Connecting-IP') || 'unknown'; - const now = Math.floor(Date.now() / 1000); - const minuteKey = `rl:min:${ip}:${Math.floor(now / 60)}`; - const hourKey = `rl:hour:${ip}:${Math.floor(now / 3600)}`; - - // Check per-minute limit - const minuteCount = parseInt(await env.RATE_LIMITER.get(minuteKey) || '0'); - if (minuteCount >= CONFIG.RATE_LIMIT_PER_MINUTE) { - return { allowed: false, reason: 'Too many requests per minute' }; + + // Check hourly limit first (KV, read-only, no token consumed) + // This prevents wasting per-minute tokens on already-rejected requests + if (env.RATE_LIMITER_KV) { + try { + const now = Math.floor(Date.now() / 1000); + const hourKey = `rl:hour:${ip}:${Math.floor(now / 3600)}`; + + // Check current count + const hourCount = parseInt(await env.RATE_LIMITER_KV.get(hourKey) || '0', 10); + if (hourCount >= CONFIG.RATE_LIMIT_PER_HOUR) { + return { allowed: false, reason: 'Too many requests per hour' }; + } + } catch (error) { + console.error('Per-hour rate limit check error:', error); + // Fail open for KV errors + } } - // Check per-hour limit - const hourCount = parseInt(await env.RATE_LIMITER.get(hourKey) || '0'); - if (hourCount >= CONFIG.RATE_LIMIT_PER_HOUR) { - return { allowed: false, reason: 'Too many requests per hour' }; + // Check per-minute limit (built-in API, fast, consumes token) + // Only check this AFTER hourly passes to avoid wasting tokens + if (env.RATE_LIMITER_MINUTE) { + try { + const { success } = await env.RATE_LIMITER_MINUTE.limit({ key: ip }); + if (!success) { + return { allowed: false, reason: 'Too many requests per minute' }; + } + } catch (error) { + console.error('Per-minute rate limit check error:', error); + // Fail open for built-in API errors + } } - // Increment counters - await Promise.all([ - env.RATE_LIMITER.put(minuteKey, (minuteCount + 1).toString(), { expirationTtl: 120 }), - env.RATE_LIMITER.put(hourKey, (hourCount + 1).toString(), { expirationTtl: 7200 }), - ]); + // Increment hourly counter (1 write per request instead of 2) + // Done last, after both checks pass + if (env.RATE_LIMITER_KV) { + try { + const now = Math.floor(Date.now() / 1000); + const hourKey = `rl:hour:${ip}:${Math.floor(now / 3600)}`; + const hourCount = parseInt(await env.RATE_LIMITER_KV.get(hourKey) || '0', 10); + await env.RATE_LIMITER_KV.put(hourKey, (hourCount + 1).toString(), { expirationTtl: 7200 }); + } catch (error) { + console.error('Per-hour rate limit increment error:', error); + // Already allowed, so don't fail the request + } + } return { allowed: true }; } +/** + * Check rate limit and return a 429 response if exceeded, or null if allowed. + */ +async function rateLimitOrReject(request, env, corsHeaders, CONFIG) { + const rl = await checkRateLimit(request, env, CONFIG); + if (!rl.allowed) { + return new Response( + JSON.stringify({ error: 'Rate limit exceeded', details: rl.reason }), + { status: 429, headers: { ...corsHeaders, 'Content-Type': 'application/json' } } + ); + } + return null; +} + /** * Check if origin is allowed */ @@ -106,6 +157,8 @@ function isAllowedOrigin(origin) { // Allowed origins for OSA const allowedPatterns = [ 'https://osc.earth', + 'https://bids-specification.readthedocs.io', + 'https://bids.neuroimaging.io', 'https://eeglab.org', 'https://hedtags.org', 'https://sccn.github.io', @@ -120,6 +173,8 @@ function isAllowedOrigin(origin) { if (origin.endsWith('.eeglab.org')) return true; if (origin.endsWith('.github.io')) return true; if (origin.endsWith('.hedtags.org')) return true; + if (origin.endsWith('.neuroimaging.io')) return true; + if (origin.endsWith('.readthedocs.io')) return true; // Allow osa-demo.pages.dev and all subdomains (previews, branches) if (origin === 'https://osa-demo.pages.dev') return true; @@ -155,9 +210,45 @@ function getCorsHeaders(origin) { } /** - * Proxy request to backend + * Validate community ID and return error response if invalid, or null if valid. + */ +function validateCommunityId(communityId, corsHeaders) { + if (RESERVED_PATHS.includes(communityId)) { + return new Response('Not Found', { status: 404, headers: corsHeaders }); + } + if (!isValidCommunityId(communityId)) { + return new Response(JSON.stringify({ error: 'Invalid community ID format' }), { + status: 400, + headers: { ...corsHeaders, 'Content-Type': 'application/json' }, + }); + } + return null; +} + +/** + * Proxy request to backend with the worker's own API key. + * Used for public endpoints and widget traffic where the worker + * authenticates on behalf of the client. */ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { + return _proxyToBackend(request, env, path, body, corsHeaders, CONFIG, 'worker'); +} + +/** + * Proxy request to backend, forwarding the client's own headers. + * Used for admin/authenticated endpoints where the client must provide + * their own API key. The worker does NOT inject its key. + */ +async function proxyToBackendPassthrough(request, env, path, body, corsHeaders, CONFIG) { + return _proxyToBackend(request, env, path, body, corsHeaders, CONFIG, 'client'); +} + +/** + * Internal proxy implementation. + * + * @param {string} authMode - 'worker' to inject BACKEND_API_KEY, 'client' to forward client's X-API-Key + */ +async function _proxyToBackend(request, env, path, body, corsHeaders, CONFIG, authMode) { const backendUrl = env.BACKEND_URL; if (!backendUrl) { @@ -172,15 +263,23 @@ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { 'Content-Type': 'application/json', }; - // Add backend API key - if (env.BACKEND_API_KEY) { - backendHeaders['X-API-Key'] = env.BACKEND_API_KEY; + // Auth mode: inject worker key or forward client key + if (authMode === 'worker') { + if (env.BACKEND_API_KEY) { + backendHeaders['X-API-Key'] = env.BACKEND_API_KEY; + } + } else { + // Forward client's API key for backend to validate + const clientKey = request.headers.get('X-API-Key'); + if (clientKey) { + backendHeaders['X-API-Key'] = clientKey; + } } // Forward Origin header to backend for origin-based authorization checks - // Backend uses this to validate request source and apply origin-specific policies + // Only forward if the origin passed CORS validation const origin = request.headers.get('Origin'); - if (origin) { + if (origin && isAllowedOrigin(origin)) { backendHeaders['Origin'] = origin; } @@ -214,8 +313,8 @@ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { const text = await response.text(); backendError = { error: text.substring(0, 500) }; } - } catch { - // Use default error if parsing fails + } catch (parseErr) { + console.warn('Failed to parse backend error response:', parseErr.message); } return new Response(JSON.stringify(backendError), { @@ -301,58 +400,79 @@ export default { return await handleFeedback(request, env, corsHeaders, CONFIG); } + // --- Dashboard read-only endpoints (GET only, rate-limited) --- + + // Global public metrics: /metrics/public/overview + if (url.pathname === '/metrics/public/overview' && request.method === 'GET') { + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + return await proxyToBackend(request, env, '/metrics/public/overview', null, corsHeaders, CONFIG); + } + + // Admin metrics endpoints: client must provide their own API key + if (url.pathname.match(/^\/metrics\/(overview|tokens|quality)$/) && request.method === 'GET') { + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + const path = url.pathname + url.search; + return await proxyToBackendPassthrough(request, env, path, null, corsHeaders, CONFIG); + } + + // Sync status endpoints (public, read-only) + if ((url.pathname === '/sync/status' || url.pathname === '/sync/health') && request.method === 'GET') { + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + return await proxyToBackend(request, env, url.pathname, null, corsHeaders, CONFIG); + } + // Community config endpoint: /:communityId/ (GET) const communityConfigMatch = url.pathname.match(/^\/([^\/]+)\/?$/); if (communityConfigMatch && request.method === 'GET') { const communityId = communityConfigMatch[1]; - // Reject reserved path segments as community IDs - const reservedPaths = ['health', 'version', 'feedback', 'communities']; - if (reservedPaths.includes(communityId)) { - return new Response('Not Found', { status: 404, headers: corsHeaders }); - } + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; - // Validate community ID format - if (!isValidCommunityId(communityId)) { - return new Response(JSON.stringify({ error: 'Invalid community ID format' }), { - status: 400, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } - - // Rate limit community config lookups - const rateLimitResult = await checkRateLimit(request, env, CONFIG); - if (!rateLimitResult.allowed) { - return new Response(JSON.stringify({ - error: 'Rate limit exceeded', - details: rateLimitResult.reason, - }), { - status: 429, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; return await proxyToBackend(request, env, `/${communityId}/`, null, corsHeaders, CONFIG); } + // Community public metrics endpoints (GET) + const communityMetricsMatch = url.pathname.match(/^\/([^\/]+)\/(metrics\/public(?:\/usage)?)$/); + if (communityMetricsMatch && request.method === 'GET') { + const communityId = communityMetricsMatch[1]; + + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; + + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + + return await proxyToBackend(request, env, url.pathname, null, corsHeaders, CONFIG); + } + + // Community sessions endpoint (GET, authenticated -- forward client key) + const communitySessionsMatch = url.pathname.match(/^\/([^\/]+)\/sessions$/); + if (communitySessionsMatch && request.method === 'GET') { + const communityId = communitySessionsMatch[1]; + + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; + + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; + + return await proxyToBackendPassthrough(request, env, url.pathname, null, corsHeaders, CONFIG); + } + // Community endpoints: /:communityId/ask and /:communityId/chat const communityActionMatch = url.pathname.match(/^\/([^\/]+)\/(ask|chat)$/); if (communityActionMatch && request.method === 'POST') { const [, communityId, action] = communityActionMatch; - // Reject reserved path segments as community IDs - const reservedPaths = ['health', 'version', 'feedback', 'communities']; - if (reservedPaths.includes(communityId)) { - return new Response('Not Found', { status: 404, headers: corsHeaders }); - } - - // Validate community ID format - if (!isValidCommunityId(communityId)) { - return new Response(JSON.stringify({ error: 'Invalid community ID format' }), { - status: 400, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const invalid = validateCommunityId(communityId, corsHeaders); + if (invalid) return invalid; return await handleProtectedEndpoint(request, env, ctx, `/${communityId}/${action}`, corsHeaders, CONFIG); } @@ -380,6 +500,11 @@ function handleRoot(corsHeaders, CONFIG) { 'GET /:communityId/': 'Get community configuration', 'POST /:communityId/ask': 'Ask a single question to a community', 'POST /:communityId/chat': 'Multi-turn conversation with a community', + 'GET /:communityId/metrics/public': 'Public community metrics', + 'GET /:communityId/sessions': 'List sessions (requires API key)', + 'GET /metrics/public/overview': 'Public metrics overview', + 'GET /metrics/overview': 'Admin metrics overview (requires API key)', + 'GET /sync/status': 'Knowledge sync status', 'POST /feedback': 'Submit feedback', 'GET /health': 'Health check', 'GET /version': 'Get API version', @@ -484,16 +609,8 @@ async function handleProtectedEndpoint(request, env, ctx, path, corsHeaders, CON } // Check rate limit - const rateLimitResult = await checkRateLimit(request, env, CONFIG); - if (!rateLimitResult.allowed) { - return new Response(JSON.stringify({ - error: 'Rate limit exceeded', - details: rateLimitResult.reason, - }), { - status: 429, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; // Remove Turnstile token from body before forwarding const { cf_turnstile_response, ...cleanBody } = body; @@ -505,17 +622,8 @@ async function handleProtectedEndpoint(request, env, ctx, path, corsHeaders, CON * Handle feedback endpoint (rate limited but no Turnstile) */ async function handleFeedback(request, env, corsHeaders, CONFIG) { - // Check rate limit - const rateLimitResult = await checkRateLimit(request, env, CONFIG); - if (!rateLimitResult.allowed) { - return new Response(JSON.stringify({ - error: 'Rate limit exceeded', - details: rateLimitResult.reason, - }), { - status: 429, - headers: { ...corsHeaders, 'Content-Type': 'application/json' }, - }); - } + const rejected = await rateLimitOrReject(request, env, corsHeaders, CONFIG); + if (rejected) return rejected; const body = await request.json(); return await proxyToBackend(request, env, '/feedback', body, corsHeaders, CONFIG); diff --git a/workers/osa-worker/wrangler.toml b/workers/osa-worker/wrangler.toml index 32afe2f..5cdcfa1 100644 --- a/workers/osa-worker/wrangler.toml +++ b/workers/osa-worker/wrangler.toml @@ -10,9 +10,19 @@ compatibility_date = "2024-01-01" BACKEND_URL = "https://api.osc.earth/osa" ENVIRONMENT = "production" -# KV namespace for rate limiting +# Hybrid rate limiting approach: +# - Per-minute (bot protection): Built-in API (fast, in-memory) +# - Per-hour (human abuse): KV (global consistency) + +# Per-minute rate limiter (built-in API, free, <1ms latency) +[[ratelimits]] +name = "RATE_LIMITER_MINUTE" +namespace_id = "1001" +simple = { limit = 10, period = 60 } + +# Per-hour rate limiter (KV, global, 1 write per request) [[kv_namespaces]] -binding = "RATE_LIMITER" +binding = "RATE_LIMITER_KV" id = "8f8506b1a7fb400680a00014312c124d" # Development environment @@ -20,8 +30,15 @@ id = "8f8506b1a7fb400680a00014312c124d" name = "osa-worker-dev" vars = { BACKEND_URL = "https://api.osc.earth/osa-dev", ENVIRONMENT = "development" } +# Per-minute rate limiter (dev has higher limit: 60/min) +[[env.dev.ratelimits]] +name = "RATE_LIMITER_MINUTE" +namespace_id = "1002" +simple = { limit = 60, period = 60 } + +# Per-hour rate limiter (KV) [[env.dev.kv_namespaces]] -binding = "RATE_LIMITER" +binding = "RATE_LIMITER_KV" id = "6d46ef72877b4129b38ef2ca1e1cd5ea" # ============================================================================= @@ -34,23 +51,19 @@ id = "6d46ef72877b4129b38ef2ca1e1cd5ea" # 2. Login to Cloudflare: # wrangler login # -# 3. Create KV namespaces: -# wrangler kv:namespace create "RATE_LIMITER" -# wrangler kv:namespace create "RATE_LIMITER" --env dev -# -# 4. Update this file with the KV namespace IDs from step 3 +# 3. KV namespaces already exist (IDs in this file) # -# 5. Set secrets: +# 4. Set secrets: # wrangler secret put BACKEND_API_KEY # wrangler secret put TURNSTILE_SECRET_KEY # wrangler secret put BACKEND_API_KEY --env dev # wrangler secret put TURNSTILE_SECRET_KEY --env dev # -# 6. Deploy: +# 5. Deploy: # wrangler deploy # Production # wrangler deploy --env dev # Development # -# 7. Verify: +# 6. Verify: # curl https://osa-worker..workers.dev/health # # =============================================================================