diff --git a/src/api/recommendations.py b/src/api/recommendations.py index 49f39bc..c3e2062 100644 --- a/src/api/recommendations.py +++ b/src/api/recommendations.py @@ -14,6 +14,13 @@ from src.core.models import User from src.integrations.github.api import github_client +# +from src.rules.ai_rules_scan import ( + scan_repo_for_ai_rule_files, + translate_ai_rule_files_to_yaml, +) +import yaml + logger = structlog.get_logger() router = APIRouter(prefix="/rules", tags=["Recommendations"]) @@ -135,6 +142,62 @@ class MetricConfig(TypedDict): thresholds: dict[str, float] explanation: Callable[[float | int], str] +class ScanAIFilesRequest(BaseModel): + """ + Payload for scanning a repo for AI assistant rule files (Cursor, Claude, Copilot, etc.). + """ + + repo_url: HttpUrl = Field( + ..., description="Full URL of the GitHub repository (e.g., https://github.com/owner/repo)" + ) + github_token: str | None = Field( + None, description="Optional GitHub Personal Access Token (higher rate limits / private repos)" + ) + installation_id: int | None = Field( + None, description="GitHub App installation ID (optional; used to get installation token)" + ) + include_content: bool = Field( + False, description="If True, include file content in response (for translation pipeline)" + ) + + +class ScanAIFilesCandidate(BaseModel): + """A single candidate AI rule file.""" + + path: str = Field(..., description="Repository-relative file path") + has_keywords: bool = Field(..., description="True if content contains known AI-instruction keywords") + content: str | None = Field(None, description="File content; only set when include_content was True") + + +class ScanAIFilesResponse(BaseModel): + """Response from the scan-ai-files endpoint.""" + + repo_full_name: str = Field(..., description="Repository in owner/repo form") + ref: str = Field(..., description="Branch or ref that was scanned (e.g. main)") + candidate_files: list[ScanAIFilesCandidate] = Field( + default_factory=list, description="Candidate AI rule files matching path patterns" + ) + warnings: list[str] = Field(default_factory=list, description="Warnings (e.g. rate limit, partial results)") + +class TranslateAIFilesRequest(BaseModel): + """Request for translating AI rule files into .watchflow rules YAML.""" + + repo_url: HttpUrl = Field(..., description="Full URL of the GitHub repository") + github_token: str | None = Field(None, description="Optional GitHub PAT") + installation_id: int | None = Field(None, description="Optional GitHub App installation ID") + + +class TranslateAIFilesResponse(BaseModel): + """Response from translate-ai-files endpoint.""" + + repo_full_name: str = Field(..., description="Repository in owner/repo form") + ref: str = Field(..., description="Branch scanned (e.g. main)") + rules_yaml: str = Field(..., description="Merged rules YAML (rules: [...])") + rules_count: int = Field(..., description="Number of rules in rules_yaml") + ambiguous: list[dict[str, Any]] = Field(default_factory=list, description="Statements that could not be translated") + warnings: list[str] = Field(default_factory=list) + + def _get_severity_label(value: float, thresholds: dict[str, float]) -> tuple[str, str]: """ @@ -420,6 +483,75 @@ def parse_repo_from_url(url: str) -> str: return f"{p.owner}/{p.repo}" +def _ref_to_branch(ref: str | None) -> str | None: + """Convert a full ref (e.g. refs/heads/feature-x) to branch name for use with GitHub API.""" + if not ref or not ref.strip(): + return None + ref = ref.strip() + if ref.startswith("refs/heads/"): + return ref[len("refs/heads/") :].strip() or None + return ref + + +async def get_suggested_rules_from_repo( + repo_full_name: str, + installation_id: int | None, + github_token: str | None, + *, + ref: str | None = None, +) -> tuple[str, int, list[dict[str, Any]], list[str]]: + """ + Run agentic scan+translate for a repo (rules.md, etc. -> Watchflow YAML). + Safe to call from event processors; returns empty result on any failure. + Returns (rules_yaml, rules_count, ambiguous_list, rule_sources). + When ref is provided (e.g. from push or PR head), scans that branch; otherwise uses default branch. + """ + try: + repo_data, repo_error = await github_client.get_repository( + repo_full_name, installation_id=installation_id, user_token=github_token + ) + if repo_error or not repo_data: + return ("rules: []\n", 0, [], []) + default_branch = repo_data.get("default_branch") or "main" + scan_ref = _ref_to_branch(ref) if ref else default_branch + if not scan_ref: + scan_ref = default_branch + + tree_entries = await github_client.get_repository_tree( + repo_full_name, + ref=scan_ref, + installation_id=installation_id, + user_token=github_token, + recursive=True, + ) + if not tree_entries: + return ("rules: []\n", 0, [], []) + + async def get_content(path: str): + return await github_client.get_file_content( + repo_full_name, path, installation_id, github_token, ref=scan_ref + ) + + raw_candidates = await scan_repo_for_ai_rule_files( + tree_entries, fetch_content=True, get_file_content=get_content + ) + candidates_with_content = [c for c in raw_candidates if c.get("content")] + if not candidates_with_content: + return ("rules: []\n", 0, [], []) + + rules_yaml, ambiguous, rule_sources = await translate_ai_rule_files_to_yaml(candidates_with_content) + rules_count = 0 + try: + parsed = yaml.safe_load(rules_yaml) + rules_count = len(parsed.get("rules", [])) if isinstance(parsed, dict) else 0 + except Exception: + pass + return (rules_yaml, rules_count, ambiguous, rule_sources) + except Exception as e: + logger.warning("get_suggested_rules_from_repo_failed", repo=repo_full_name, error=str(e)) + return ("rules: []\n", 0, [], []) + + # --- Endpoints --- # Main API surface—keep stable for clients. @@ -680,17 +812,18 @@ async def proceed_with_pr( try: # Step 1: Get repository metadata to find default branch - repo_data = await github_client.get_repository( + repo_data, repo_error = await github_client.get_repository( repo_full_name=repo_full_name, installation_id=installation_id, user_token=user_token, ) - if not repo_data: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Repository '{repo_full_name}' not found or access denied.", - ) + if repo_error: + err_status = repo_error["status"] + status_code = status.HTTP_429_TOO_MANY_REQUESTS if err_status == 403 else err_status + if status_code not in (401, 403, 404, 429): + status_code = status.HTTP_502_BAD_GATEWAY + raise HTTPException(status_code=status_code, detail=repo_error["message"]) base_branch = payload.base_branch or repo_data.get("default_branch", "main") @@ -795,3 +928,203 @@ async def proceed_with_pr( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create pull request. Please try again.", ) from e + +@router.post( + "/scan-ai-files", + response_model=ScanAIFilesResponse, + status_code=status.HTTP_200_OK, + summary="Scan repository for AI rule files", + description=( + "Lists files matching *rules*.md, *guidelines*.md, *prompt*.md, .cursor/rules/*.mdc. " + "Optionally fetches content and flags files that contain AI-instruction keywords." + ), + dependencies=[Depends(rate_limiter)], +) +async def scan_ai_rule_files( + request: Request, + payload: ScanAIFilesRequest, + user: User | None = Depends(get_current_user_optional), + ) -> ScanAIFilesResponse: + """ + Scan a repository for AI assistant rule files (Cursor, Claude, Copilot, etc.). + """ + repo_url_str = str(payload.repo_url) + client_ip = request.client.host if request.client else "unknown" + logger.info("scan_ai_files_requested", repo_url=repo_url_str, ip=client_ip) + + try: + repo_full_name = parse_repo_from_url(repo_url_str) + except ValueError as e: + logger.warning("invalid_url_provided", url=repo_url_str, error=str(e)) + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e) + ) from e + + # Resolve token (same as recommend_rules) + github_token = None + if user and user.github_token: + try: + github_token = user.github_token.get_secret_value() + except (AttributeError, TypeError): + github_token = str(user.github_token) if user.github_token else None + elif payload.github_token: + github_token = payload.github_token + elif payload.installation_id: + installation_token = await github_client.get_installation_access_token(payload.installation_id) + if installation_token: + github_token = installation_token + + installation_id = payload.installation_id + + # Default branch + repo_data, repo_error = await github_client.get_repository( + repo_full_name, installation_id=installation_id, user_token=github_token + ) + if repo_error: + err_status = repo_error["status"] + status_code = status.HTTP_429_TOO_MANY_REQUESTS if err_status == 403 else err_status + if status_code not in (401, 403, 404, 429): + status_code = status.HTTP_502_BAD_GATEWAY + raise HTTPException(status_code=status_code, detail=repo_error["message"]) + default_branch = repo_data.get("default_branch") or "main" + ref = default_branch + + # Full tree + tree_entries = await github_client.get_repository_tree( + repo_full_name, + ref=ref, + installation_id=installation_id, + user_token=github_token, + recursive=True, + ) + if not tree_entries: + return ScanAIFilesResponse( + repo_full_name=repo_full_name, + ref=ref, + candidate_files=[], + warnings=["Could not load repository tree; check access and ref."], + ) + + # Optional content fetcher for keyword scan (and optionally include in response) + async def get_content(path: str): + return await github_client.get_file_content( + repo_full_name, path, installation_id, github_token + ) + + # Always fetch content so has_keywords is set; strip content in response unless include_content + raw_candidates = await scan_repo_for_ai_rule_files( + tree_entries, + fetch_content=True, + get_file_content=get_content, + ) + + candidates = [ + ScanAIFilesCandidate( + path=c["path"], + has_keywords=c["has_keywords"], + content=c["content"] if payload.include_content else None, + ) + for c in raw_candidates + ] + + return ScanAIFilesResponse( + repo_full_name=repo_full_name, + ref=ref, + candidate_files=candidates, + warnings=[], + ) + +@router.post( + "/translate-ai-files", + response_model=TranslateAIFilesResponse, + status_code=status.HTTP_200_OK, + summary="Translate AI rule files to Watchflow YAML", + description="Scans repo for AI rule files, extracts statements, maps or translates to .watchflow rules YAML.", + dependencies=[Depends(rate_limiter)], +) +async def translate_ai_rule_files( + request: Request, + payload: TranslateAIFilesRequest, + user: User | None = Depends(get_current_user_optional), +) -> TranslateAIFilesResponse: + repo_url_str = str(payload.repo_url) + logger.info("translate_ai_files_requested", repo_url=repo_url_str) + + try: + repo_full_name = parse_repo_from_url(repo_url_str) + except ValueError as e: + logger.warning("invalid_url_provided", url=repo_url_str, error=str(e)) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e)) from e + + github_token = None + if user and user.github_token: + try: + github_token = user.github_token.get_secret_value() + except (AttributeError, TypeError): + github_token = str(user.github_token) if user.github_token else None + elif payload.github_token: + github_token = payload.github_token + elif payload.installation_id: + installation_token = await github_client.get_installation_access_token(payload.installation_id) + if installation_token: + github_token = installation_token + installation_id = payload.installation_id + + repo_data, repo_error = await github_client.get_repository( + repo_full_name, installation_id=installation_id, user_token=github_token + ) + if repo_error: + err_status = repo_error["status"] + status_code = status.HTTP_429_TOO_MANY_REQUESTS if err_status == 403 else err_status + if status_code not in (401, 403, 404, 429): + status_code = status.HTTP_502_BAD_GATEWAY + raise HTTPException(status_code=status_code, detail=repo_error["message"]) + default_branch = repo_data.get("default_branch") or "main" + ref = default_branch + + tree_entries = await github_client.get_repository_tree( + repo_full_name, ref=ref, installation_id=installation_id, user_token=github_token, recursive=True + ) + if not tree_entries: + return TranslateAIFilesResponse( + repo_full_name=repo_full_name, + ref=ref, + rules_yaml="rules: []\n", + rules_count=0, + ambiguous=[], + warnings=["Could not load repository tree."], + ) + + async def get_content(path: str): + return await github_client.get_file_content(repo_full_name, path, installation_id, github_token) + + raw_candidates = await scan_repo_for_ai_rule_files( + tree_entries, fetch_content=True, get_file_content=get_content + ) + candidates_with_content = [c for c in raw_candidates if c.get("content")] + if not candidates_with_content: + return TranslateAIFilesResponse( + repo_full_name=repo_full_name, + ref=ref, + rules_yaml="rules: []\n", + rules_count=0, + ambiguous=[], + warnings=["No AI rule file content could be loaded."], + ) + + rules_yaml, ambiguous, rule_sources = await translate_ai_rule_files_to_yaml(candidates_with_content) + rules_count = rules_yaml.count("\n - ") + (1 if rules_yaml.strip() != "rules: []" and " - " in rules_yaml else 0) + try: + parsed = yaml.safe_load(rules_yaml) + rules_count = len(parsed.get("rules", [])) if isinstance(parsed, dict) else 0 + except Exception: + pass + + return TranslateAIFilesResponse( + repo_full_name=repo_full_name, + ref=ref, + rules_yaml=rules_yaml, + rules_count=rules_count, + ambiguous=ambiguous, + warnings=[], + ) \ No newline at end of file diff --git a/src/event_processors/pull_request/processor.py b/src/event_processors/pull_request/processor.py index ccbc86a..9a1a07f 100644 --- a/src/event_processors/pull_request/processor.py +++ b/src/event_processors/pull_request/processor.py @@ -3,6 +3,8 @@ from typing import Any from src.agents import get_agent +from src.api.recommendations import get_suggested_rules_from_repo +from src.rules.ai_rules_scan import is_relevant_pr from src.core.models import Violation from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.event_processors.pull_request.enricher import PullRequestEnricher @@ -60,6 +62,32 @@ async def process(self, task: Task) -> ProcessingResult: raise ValueError("Failed to get installation access token") github_token = github_token_optional + # Agentic: scan repo only when relevant (PR targets default branch) + # Use the PR head ref so we scan the branch being proposed, not main. + if is_relevant_pr(task.payload): + try: + pr_head_ref = pr_data.get("head", {}).get("ref") # branch name, e.g. feature-x + rules_yaml, rules_count, ambiguous, rule_sources = await get_suggested_rules_from_repo( + repo_full_name, installation_id, github_token, ref=pr_head_ref + ) + logger.info("=" * 80) + logger.info("📋 Suggested rules (agentic scan + translation)") + logger.info(f" Repo: {repo_full_name} | PR #{pr_number} | Ref: {pr_head_ref or 'default'} | Translated rules: {rules_count}") + if rule_sources: + from_mapping = sum(1 for s in rule_sources if s == "mapping") + from_agent = sum(1 for s in rule_sources if s == "agent") + logger.info(" From deterministic mapping: %s | From AI agent: %s", from_mapping, from_agent) + logger.info(" Per-rule source: %s", rule_sources) + if rules_count > 0: + logger.info(" YAML:\n%s", rules_yaml) + if ambiguous: + logger.info(" Ambiguous (not translated): %s", [a.get("statement", "") for a in ambiguous]) + logger.info("=" * 80) + except Exception as e: + logger.warning("Suggested rules scan failed: %s", e) + else: + logger.info("PR not relevant for agentic scan (skip): base ref=%s", task.payload.get("pull_request", {}).get("base", {}).get("ref")) + # 1. Enrich event data event_data = await self.enricher.enrich_event_data(task, github_token) api_calls += 1 diff --git a/src/event_processors/push.py b/src/event_processors/push.py index 2e77bf8..b6690bd 100644 --- a/src/event_processors/push.py +++ b/src/event_processors/push.py @@ -3,11 +3,14 @@ from typing import Any from src.agents import get_agent +from src.api.recommendations import get_suggested_rules_from_repo +from src.rules.ai_rules_scan import is_relevant_push from src.core.models import Severity, Violation from src.event_processors.base import BaseEventProcessor, ProcessingResult from src.integrations.github.check_runs import CheckRunManager from src.tasks.task_queue import Task + logger = logging.getLogger(__name__) @@ -62,6 +65,33 @@ async def process(self, task: Task) -> ProcessingResult: error="No installation ID found", ) + # Agentic: scan repo only when relevant (default branch or touched rule files) + # Use the branch that was pushed so we scan that branch's file content, not main. + if is_relevant_push(task.payload): + try: + github_token = await self.github_client.get_installation_access_token(task.installation_id) + push_ref = payload.get("ref") # e.g. refs/heads/feature-x + rules_yaml, rules_count, ambiguous, rule_sources = await get_suggested_rules_from_repo( + task.repo_full_name, task.installation_id, github_token, ref=push_ref + ) + logger.info("=" * 80) + logger.info("📋 Suggested rules (agentic scan + translation)") + logger.info(f" Repo: {task.repo_full_name} | Ref: {push_ref or 'default'} | Translated rules: {rules_count}") + if rule_sources: + from_mapping = sum(1 for s in rule_sources if s == "mapping") + from_agent = sum(1 for s in rule_sources if s == "agent") + logger.info(" From deterministic mapping: %s | From AI agent: %s", from_mapping, from_agent) + logger.info(" Per-rule source: %s", rule_sources) + if rules_count > 0: + logger.info(" YAML:\n%s", rules_yaml) + if ambiguous: + logger.info(" Ambiguous (not translated): %s", [a.get("statement", "") for a in ambiguous]) + logger.info("=" * 80) + except Exception as e: + logger.warning("Suggested rules scan failed: %s", e) + else: + logger.info("Push not relevant for agentic scan (skip): ref=%s", task.payload.get("ref")) + rules_optional = await self.rule_provider.get_rules(task.repo_full_name, task.installation_id) rules = rules_optional if rules_optional is not None else [] diff --git a/src/integrations/github/api.py b/src/integrations/github/api.py index 4d6ac85..c1f5f99 100644 --- a/src/integrations/github/api.py +++ b/src/integrations/github/api.py @@ -129,27 +129,51 @@ async def get_installation_access_token(self, installation_id: int) -> str | Non async def get_repository( self, repo_full_name: str, installation_id: int | None = None, user_token: str | None = None - ) -> dict[str, Any] | None: - """Fetch repository metadata (default branch, language, etc.). Supports public access.""" + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """ + Fetch repository metadata. Returns (repo_data, None) on success; + (None, {"status": int, "message": str}) on failure for meaningful API responses. + """ headers = await self._get_auth_headers( - installation_id=installation_id, user_token=user_token, allow_anonymous=True + installation_id=installation_id, user_token=user_token ) if not headers: - return None + return ( + None, + {"status": 401, "message": "Authentication required. Provide github_token or installation_id in the request."}, + ) url = f"{config.github.api_base_url}/repos/{repo_full_name}" session = await self._get_session() async with session.get(url, headers=headers) as response: if response.status == 200: data = await response.json() - return cast("dict[str, Any]", data) - return None + return cast("dict[str, Any]", data), None + try: + body = await response.json() + gh_message = body.get("message", "") if isinstance(body, dict) else "" + except Exception: + gh_message = "" + if response.status == 404: + msg = gh_message or "Repository not found or access denied. Check repo name and token permissions." + return None, {"status": 404, "message": msg} + if response.status == 403: + msg = "GitHub API rate limit exceeded. Try again later or provide github_token for higher limits." + if gh_message and "rate limit" in gh_message.lower(): + msg = gh_message + return None, {"status": 403, "message": msg} + if response.status == 401: + return ( + None, + {"status": 401, "message": gh_message or "Invalid or expired token. Check github_token or installation_id."}, + ) + return None, {"status": response.status, "message": gh_message or f"GitHub API returned {response.status}."} async def list_directory_any_auth( self, repo_full_name: str, path: str, installation_id: int | None = None, user_token: str | None = None ) -> list[dict[str, Any]]: - """List directory contents using either installation or user token.""" + """List directory contents using installation or user token (auth required).""" headers = await self._get_auth_headers( - installation_id=installation_id, user_token=user_token, allow_anonymous=True + installation_id=installation_id, user_token=user_token ) if not headers: return [] @@ -164,21 +188,80 @@ async def list_directory_any_auth( response.raise_for_status() return [] + + async def get_repository_tree( + self, + repo_full_name: str, + ref: str | None = None, + installation_id: int | None = None, + user_token: str | None = None, + recursive: bool = True, + ) -> list[dict[str, Any]]: + """Get the tree of a repository. Requires authentication (github_token or installation_id).""" + headers = await self._get_auth_headers( + installation_id=installation_id, + user_token=user_token, + ) + if not headers: + return [] + ref = ref or "main" + tree_sha = await self._resolve_tree_sha(repo_full_name, ref, headers) + if not tree_sha: + return [] + + url = ( f"{config.github.api_base_url}" + f"/repos/{repo_full_name}/git/trees/{tree_sha}" + f"?recursive={recursive}" ) + + session = await self._get_session() + async with session.get(url, headers=headers) as response: + if response.status != 200: + return [] + data = await response.json() + return cast("list[dict[str, Any]]", data.get("tree", [])) + + + async def _resolve_tree_sha(self, repo_full_name: str, ref: str, headers: dict[str, str]) -> str | None: + """Resolve the SHA of the tree for the given ref (commit SHA from ref -> tree SHA from commit).""" + session = await self._get_session() + ref_url = f"{config.github.api_base_url}/repos/{repo_full_name}/git/ref/heads/{ref}" + async with session.get(ref_url, headers=headers) as response: + if response.status != 200: + return None + data = await response.json() + commit_sha = data.get("object", {}).get("sha") if isinstance(data, dict) else None + if not commit_sha: + return None + commit_url = f"{config.github.api_base_url}/repos/{repo_full_name}/git/commits/{commit_sha}" + async with session.get(commit_url, headers=headers) as response: + if response.status != 200: + return None + commit_data = await response.json() + tree_sha = commit_data.get("tree", {}).get("sha") if isinstance(commit_data, dict) else None + return tree_sha + async def get_file_content( - self, repo_full_name: str, file_path: str, installation_id: int | None, user_token: str | None = None + self, + repo_full_name: str, + file_path: str, + installation_id: int | None, + user_token: str | None = None, + ref: str | None = None, ) -> str | None: """ - Fetches the content of a file from a repository. Supports anonymous access for public analysis. + Fetches the content of a file from a repository. Requires authentication (github_token or installation_id). + When ref is provided (branch name, tag, or commit SHA), returns content at that ref; otherwise uses default branch. """ headers = await self._get_auth_headers( installation_id=installation_id, user_token=user_token, accept="application/vnd.github.raw", - allow_anonymous=True, ) if not headers: return None url = f"{config.github.api_base_url}/repos/{repo_full_name}/contents/{file_path}" + if ref: + url = f"{url}?ref={ref}" session = await self._get_session() async with session.get(url, headers=headers) as response: @@ -1030,7 +1113,6 @@ async def fetch_recent_pull_requests( headers = await self._get_auth_headers( installation_id=installation_id, user_token=user_token, - allow_anonymous=True, # Support public repos ) if not headers: logger.error("pr_fetch_auth_failed", repo=repo_full_name, error_type="auth_error") @@ -1139,10 +1221,9 @@ async def execute_graphql( url = f"{config.github.api_base_url}/graphql" payload = {"query": query, "variables": variables} - # Get appropriate headers (can be anonymous for public data or authenticated) - # Priority: user_token > installation_id > anonymous (if allowed) + # Get appropriate headers (auth required: user_token or installation_id) headers = await self._get_auth_headers( - user_token=user_token, installation_id=installation_id, allow_anonymous=True + user_token=user_token, installation_id=installation_id ) if not headers: # Fallback or error? GraphQL usually demands auth. diff --git a/src/rules/ai_rules_scan.py b/src/rules/ai_rules_scan.py new file mode 100644 index 0000000..8da5d13 --- /dev/null +++ b/src/rules/ai_rules_scan.py @@ -0,0 +1,372 @@ +""" +Scan for AI assistant rule files in a repository (Cursor, Claude, Copilot, etc.). +Used by the repo-scanning flow to find *rules*.md, *guidelines*.md, *prompt*.md +and .cursor/rules/*.mdc, then optionally flag files that contain instruction keywords. +""" + +import logging +import re +from collections.abc import Awaitable, Callable +from typing import Any, cast +from src.core.utils.patterns import matches_any +import yaml + +logger = logging.getLogger(__name__) + +# --- Path patterns (globs) --- +AI_RULE_FILE_PATTERNS = [ + "*rules*.md", + "*guidelines*.md", + "*prompt*.md", + "**/*rules*.md", + "**/*guidelines*.md", + "**/*prompt*.md", + ".cursor/rules/*.mdc", + ".cursor/rules/**/*.mdc", +] + +# --- Keywords (content) --- +AI_RULE_KEYWORDS = [ + "Cursor rule:", + "Claude:", + "always use", + "never commit", + "Copilot", + "AI assistant", + "when writing code", + "when generating", + "pr title", + "pr description", + "pr size", + "pr approvals", + "pr reviews", + "pr comments", + "pr files", + "pr commits", + "pr branches", + "pr tags", +] + + +def path_matches_ai_rule_patterns(path: str) -> bool: + """Return True if path matches any of the AI rule file glob patterns.""" + if not path or not path.strip(): + return False + normalized = path.replace("\\", "/").strip() + return matches_any(normalized, AI_RULE_FILE_PATTERNS) + + +def content_has_ai_keywords(content: str | None) -> bool: + """Return True if content contains any of the AI rule keywords (case-insensitive).""" + if not content: + return False + lower = content.lower() + return any(kw.lower() in lower for kw in AI_RULE_KEYWORDS) + +def is_relevant_push(payload: dict[str, Any]) -> bool: + """ + Return True if we should run agentic scan for this push. + Relevant when: push is to default branch, or any changed file matches AI rule path patterns. + """ + ref = (payload.get("ref") or "").strip() + repo = payload.get("repository") or {} + default_branch = repo.get("default_branch") or "main" + if ref == f"refs/heads/{default_branch}": + return True + for commit in payload.get("commits") or []: + for path in (commit.get("added") or []) + (commit.get("modified") or []) + (commit.get("removed") or []): + if path and path_matches_ai_rule_patterns(path): + return True + return False + + +def is_relevant_pr(payload: dict[str, Any]) -> bool: + """ + Return True if we should run agentic scan for this PR. + Relevant when: PR targets the repo's default branch. + """ + pr = payload.get("pull_request") or {} + base = pr.get("base") or {} + default_branch = ( + (base.get("repo") or {}).get("default_branch") + or (payload.get("repository") or {}).get("default_branch") + or "main" + ) + return base.get("ref") == default_branch + +def filter_tree_entries_for_ai_rules( + tree_entries: list[dict[str, Any]], + *, + blob_only: bool = True, + ) -> list[dict[str, Any]]: + """ + From a GitHub tree response (list of { path, type, ... }), return entries + that match AI rule file patterns. By default only 'blob' (files) are included. + """ + result = [] + for entry in tree_entries: + if blob_only and entry.get("type") != "blob": + continue + path = entry.get("path") or "" + if path_matches_ai_rule_patterns(path): + result.append(entry) + return cast("list[dict[str, Any]]", result) + + +GetContentFn = Callable[[str], Awaitable[str | None]] + + +async def scan_repo_for_ai_rule_files( + tree_entries: list[dict[str, Any]], + *, + fetch_content: bool = False, + get_file_content: GetContentFn | None = None, + ) -> list[dict[str, Any]]: + """ + Filter tree entries to AI-rule candidates, optionally fetch content and set has_keywords. + + Returns list of { "path", "has_keywords", "content" }. content is only set when fetch_content + is True and get_file_content is provided. + """ + candidates = filter_tree_entries_for_ai_rules(tree_entries, blob_only=True) + results: list[dict[str, Any]] = [] + + for entry in candidates: + path = entry.get("path") or "" + has_keywords = False + content: str | None = None + + if fetch_content and get_file_content: + try: + content = await get_file_content(path) + has_keywords = content_has_ai_keywords(content) + except Exception as e: + logger.warning("ai_rules_scan_fetch_failed path=%s error=%s", path, str(e)) + + results.append({ + "path": path, + "has_keywords": has_keywords, + "content": content, + }) + + return cast("list[dict[str, Any]]", results) + + +# --- Deterministic extraction (parsing) --- + +# Line prefixes that indicate a rule statement (strip prefix, use rest of line or next line). +EXTRACTOR_LINE_PREFIXES = [ + "cursor rule:", + "claude:", + "copilot:", + "rule:", + "guideline:", + "instruction:", +] + +# Phrases that suggest a rule (include the whole line if it contains one of these). +EXTRACTOR_PHRASE_MARKERS = [ + "always use", + "never commit", + "must have", + "should have", + "required to", + "prs must", + "pull requests must", + "every pr", + "all prs", +] + +def extract_rule_statements_from_markdown(content: str) -> list[str]: + """ + Parse markdown content and return a list of rule-like statements (deterministic). + Uses line prefixes (Cursor rule:, Claude:, etc.) and phrase markers (always use, never commit, etc.). + """ + if not content or not content.strip(): + return [] + statements: list[str] = [] + seen: set[str] = set() + lines = content.splitlines() + + for i, line in enumerate(lines): + stripped = line.strip() + if not stripped or len(stripped) > 500: + continue + lower = stripped.lower() + + # 1) Line starts with a known prefix -> rest of line is the statement + for prefix in EXTRACTOR_LINE_PREFIXES: + if lower.startswith(prefix): + rest = stripped[len(prefix) :].strip() + if rest: + normalized = _normalize_statement(rest) + if normalized and normalized not in seen: + statements.append(rest) + seen.add(normalized) + break + else: + # 2) Line contains a phrase marker -> treat whole line as statement + for marker in EXTRACTOR_PHRASE_MARKERS: + if marker in lower: + normalized = _normalize_statement(stripped) + if normalized and normalized not in seen: + statements.append(stripped) + seen.add(normalized) + break + + return statements + + +def _normalize_statement(s: str) -> str: + """Normalize for deduplication: lowercase, collapse whitespace.""" + return " ".join(s.lower().split()) if s else "" + + +# --- Mapping layer (known phrase -> fixed YAML rule; no LLM) --- + +# Each entry: (list of regex patterns or substrings to match, rule dict for .watchflow/rules.yaml) +# Match is case-insensitive. First match wins. +STATEMENT_TO_YAML_MAPPINGS: list[tuple[list[str], dict[str, Any]]] = [ + # PRs must have a linked issue + ( + ["prs must have a linked issue", "pull requests must reference", "require linked issue", "must link an issue"], + { + "description": "PRs must reference an issue (e.g. Fixes #123)", + "enabled": True, + "severity": "medium", + "event_types": ["pull_request"], + "parameters": {"require_linked_issue": True}, + }, + ), + # PR title pattern (conventional commits) + ( + ["pr title must match", "use conventional commits", "title must follow convention"], + { + "description": "PR title must follow conventional commits (feat, fix, docs, etc.)", + "enabled": True, + "severity": "medium", + "event_types": ["pull_request"], + "parameters": {"title_pattern": "^feat|^fix|^docs|^style|^refactor|^test|^chore|^perf|^ci|^build|^revert"}, + }, + ), + # Min description length + ( + ["pr description must be", "description length", "min description", "meaningful pr description"], + { + "description": "PR description must be at least 50 characters", + "enabled": True, + "severity": "medium", + "event_types": ["pull_request"], + "parameters": {"min_description_length": 50}, + }, + ), + # Max PR size + ( + ["pr size", "max lines", "limit pr size", "keep prs small"], + { + "description": "PR must not exceed 500 lines changed", + "enabled": True, + "severity": "medium", + "event_types": ["pull_request"], + "parameters": {"max_lines": 500}, + }, + ), + # Min approvals + ( + ["min approvals", "at least one approval", "require approval", "prs need approval"], + { + "description": "PRs require at least one approval", + "enabled": True, + "severity": "high", + "event_types": ["pull_request"], + "parameters": {"min_approvals": 1}, + }, + ), +] + +def try_map_statement_to_yaml(statement: str) -> dict[str, Any] | None: + """ + If the statement matches a known phrase, return the corresponding rule dict (one entry for rules: []). + Otherwise return None (caller should use feasibility agent). + """ + if not statement or not statement.strip(): + return None + lower = statement.lower() + # for patterns, rule_dict in STATEMENT_TO_YAML_MAPPINGS: + # for p in patterns: + # if p in lower: + # return dict(rule_dict) + # return None + + for patterns, rule_dict in STATEMENT_TO_YAML_MAPPINGS: + for p in patterns: + if p in lower: + logger.warning( + "deterministic_mapping_matched statement=%r pattern=%r", + statement[:100], + p, + ) + return dict(rule_dict) + return None + +# --- Translate pipeline (extract -> map or feasibility -> merge YAML) --- + +async def translate_ai_rule_files_to_yaml( + candidates: list[dict[str, Any]], + *, + get_feasibility_agent: Callable[[], Any] | None = None, + ) -> tuple[str, list[dict[str, Any]], list[str]]: + """ + From candidate files (each with "path" and "content"), extract statements, translate to + Watchflow rules (mapping layer first, then feasibility agent), merge into one YAML string. + + Returns: + (rules_yaml_str, ambiguous_list, rule_sources) + - rules_yaml_str: full "rules:\n - ..." YAML. + - ambiguous_list: [{"statement", "path", "reason"}] for statements that could not be translated. + - rule_sources: one of "mapping" or "agent" per rule (same order as rules in rules_yaml). + """ + all_rules: list[dict[str, Any]] = [] + rule_sources: list[str] = [] + ambiguous: list[dict[str, Any]] = [] + + if get_feasibility_agent is None: + from src.agents import get_agent + def _default_agent(): + return get_agent("feasibility") + get_feasibility_agent = _default_agent + + for cand in candidates: + content = cand.get("content") if isinstance(cand.get("content"), str) else None + path = cand.get("path") or "" + if not content: + continue + statements = extract_rule_statements_from_markdown(content) + for st in statements: + # 1) Try deterministic mapping first + mapped = try_map_statement_to_yaml(st) + if mapped is not None: + all_rules.append(mapped) + rule_sources.append("mapping") + continue + # 2) Fall back to feasibility agent + try: + agent = get_feasibility_agent() + result = await agent.execute(rule_description=st) + if result.success and result.data.get("is_feasible") and result.data.get("yaml_content"): + yaml_content = result.data["yaml_content"].strip() + parsed = yaml.safe_load(yaml_content) + if isinstance(parsed, dict) and "rules" in parsed and isinstance(parsed["rules"], list): + for r in parsed["rules"]: + if isinstance(r, dict): + all_rules.append(r) + rule_sources.append("agent") + else: + ambiguous.append({"statement": st, "path": path, "reason": "Feasibility agent returned invalid YAML"}) + else: + ambiguous.append({"statement": st, "path": path, "reason": result.message or "Not feasible"}) + except Exception as e: + ambiguous.append({"statement": st, "path": path, "reason": str(e)}) + + rules_yaml = yaml.dump({"rules": all_rules}, indent=2, sort_keys=False) if all_rules else "rules: []\n" + return rules_yaml, ambiguous, rule_sources \ No newline at end of file diff --git a/tests/integration/test_scan_ai_files.py b/tests/integration/test_scan_ai_files.py new file mode 100644 index 0000000..df384e4 --- /dev/null +++ b/tests/integration/test_scan_ai_files.py @@ -0,0 +1,71 @@ +""" +Integration tests for POST /api/v1/rules/scan-ai-files. +""" + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.main import app + + +class TestScanAIFilesEndpoint: + """Integration tests for scan-ai-files endpoint.""" + + @pytest.fixture + def client(self) -> TestClient: + return TestClient(app) + + def test_scan_ai_files_returns_200_and_list_when_mocked( + self, client: TestClient + ) -> None: + """With GitHub mocked, endpoint returns 200 and candidate_files is a list.""" + mock_tree = [ + {"path": "README.md", "type": "blob"}, + {"path": "docs/cursor-guidelines.md", "type": "blob"}, + ] + mock_repo = {"default_branch": "main", "full_name": "owner/repo"} + + async def mock_get_repository(*args, **kwargs): + return (mock_repo, None) + + async def mock_get_tree(*args, **kwargs): + return mock_tree + + with ( + patch( + "src.api.recommendations.github_client.get_repository", + new_callable=AsyncMock, + side_effect=mock_get_repository, + ), + patch( + "src.api.recommendations.github_client.get_repository_tree", + new_callable=AsyncMock, + side_effect=mock_get_tree, + ), + ): + response = client.post( + "/api/v1/rules/scan-ai-files", + json={ + "repo_url": "https://github.com/owner/repo", + "include_content": False, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "repo_full_name" in data + assert data["repo_full_name"] == "owner/repo" + assert "ref" in data + assert data["ref"] == "main" + assert "candidate_files" in data + assert isinstance(data["candidate_files"], list) + assert "warnings" in data + # At least the matching path should appear + paths = [c["path"] for c in data["candidate_files"]] + assert "docs/cursor-guidelines.md" in paths + for c in data["candidate_files"]: + assert "path" in c + assert "has_keywords" in c + \ No newline at end of file diff --git a/tests/unit/api/test_proceed_with_pr.py b/tests/unit/api/test_proceed_with_pr.py index 7b2154e..37447a9 100644 --- a/tests/unit/api/test_proceed_with_pr.py +++ b/tests/unit/api/test_proceed_with_pr.py @@ -7,7 +7,7 @@ def test_proceed_with_pr_happy_path(monkeypatch): client = TestClient(app) async def _fake_get_repo(repo_full_name, installation_id=None, user_token=None): - return {"default_branch": "main"} + return ({"default_branch": "main"}, None) async def _fake_get_sha(repo_full_name, ref, installation_id=None, user_token=None): return "base-sha" diff --git a/tests/unit/integrations/github/test_api.py b/tests/unit/integrations/github/test_api.py index bfbeeca..1a469ef 100644 --- a/tests/unit/integrations/github/test_api.py +++ b/tests/unit/integrations/github/test_api.py @@ -126,9 +126,10 @@ async def test_get_repository_success(github_client, mock_aiohttp_session): mock_aiohttp_session.post.return_value = mock_token_response mock_aiohttp_session.get.return_value = mock_repo_response - repo = await github_client.get_repository("owner/repo", installation_id=123) + repo_data, repo_error = await github_client.get_repository("owner/repo", installation_id=123) - assert repo == {"full_name": "owner/repo"} + assert repo_data == {"full_name": "owner/repo"} + assert repo_error is None @pytest.mark.asyncio @@ -139,9 +140,11 @@ async def test_get_repository_failure(github_client, mock_aiohttp_session): mock_aiohttp_session.post.return_value = mock_token_response mock_aiohttp_session.get.return_value = mock_repo_response - repo = await github_client.get_repository("owner/repo", installation_id=123) + repo_data, repo_error = await github_client.get_repository("owner/repo", installation_id=123) - assert repo is None + assert repo_data is None + assert repo_error is not None + assert repo_error["status"] == 404 @pytest.mark.asyncio @@ -222,3 +225,52 @@ async def test_list_pull_requests_success(github_client, mock_aiohttp_session): prs = await github_client.list_pull_requests("owner/repo", installation_id=123) assert prs == [{"number": 1}] + + +@pytest.mark.asyncio +async def test_get_repository_tree_success(github_client, mock_aiohttp_session): + """get_repository_tree returns tree entries when ref is resolved and tree GET succeeds.""" + from unittest.mock import AsyncMock, patch + + tree_sha = "fake_tree_sha_123" + tree_response = mock_aiohttp_session.create_mock_response( + 200, + json_data={ + "sha": tree_sha, + "tree": [ + {"path": "README.md", "type": "blob", "sha": "a"}, + {"path": "docs/guidelines.md", "type": "blob", "sha": "b"}, + {"path": "src/main.py", "type": "blob", "sha": "c"}, + ], + "truncated": False, + }, + ) + + mock_headers = {"Authorization": "Bearer fake", "Accept": "application/vnd.github.v3+json"} + with ( + patch.object( + github_client, + "_get_auth_headers", + new_callable=AsyncMock, + return_value=mock_headers, + ), + patch.object( + github_client, + "_resolve_tree_sha", + new_callable=AsyncMock, + return_value=tree_sha, + ), + ): + mock_aiohttp_session.get.return_value = tree_response + + result = await github_client.get_repository_tree( + "owner/repo", + ref="main", + installation_id=123, + ) + + assert len(result) == 3 + paths = [e["path"] for e in result] + assert "README.md" in paths + assert "docs/guidelines.md" in paths + assert "src/main.py" in paths \ No newline at end of file diff --git a/tests/unit/rules/test_ai_rules_scan.py b/tests/unit/rules/test_ai_rules_scan.py new file mode 100644 index 0000000..8df791d --- /dev/null +++ b/tests/unit/rules/test_ai_rules_scan.py @@ -0,0 +1,190 @@ +""" +Unit tests for src/rules/ai_rules_scan.py. + +Covers: +- path_matches_ai_rule_patterns: which paths match AI rule file patterns +- content_has_ai_keywords: keyword detection in content +- filter_tree_entries_for_ai_rules: filtering GitHub tree entries +- scan_repo_for_ai_rule_files: full scan with optional content fetch and has_keywords +""" + +import pytest + +from src.rules.ai_rules_scan import ( + AI_RULE_FILE_PATTERNS, + AI_RULE_KEYWORDS, + content_has_ai_keywords, + filter_tree_entries_for_ai_rules, + path_matches_ai_rule_patterns, + scan_repo_for_ai_rule_files, +) + + +class TestPathMatchesAiRulePatterns: + """Tests for path_matches_ai_rule_patterns().""" + + @pytest.mark.parametrize( + "path", + [ + "cursor-rules.md", + "docs/guidelines.md", + "CONTRIBUTING-guidelines.md", + "copilot-prompts.md", + "prompt.md", + ".cursor/rules/foo.mdc", + ".cursor/rules/sub/bar.mdc", + "README-rules-and-conventions.md", + ], + ) + def test_matches_candidate_paths(self, path: str) -> None: + assert path_matches_ai_rule_patterns(path) is True + + @pytest.mark.parametrize( + "path", + [ + "README.md", + "docs/readme.md", + "src/main.py", + "config.yaml", + "rules.txt", + "guidelines.txt", + ], + ) + def test_rejects_non_candidate_paths(self, path: str) -> None: + assert path_matches_ai_rule_patterns(path) is False + + def test_empty_or_whitespace_returns_false(self) -> None: + assert path_matches_ai_rule_patterns("") is False + assert path_matches_ai_rule_patterns(" ") is False + + def test_normalizes_backslashes(self) -> None: + assert path_matches_ai_rule_patterns(".cursor\\rules\\x.mdc") is True + + +class TestContentHasAiKeywords: + """Tests for content_has_ai_keywords().""" + + @pytest.mark.parametrize( + "content,keyword", + [ + ("Cursor rule: Always use type hints", "Cursor rule:"), + ("Claude: Prefer immutable data", "Claude:"), + ("We should always use async/await", "always use"), + ("never commit secrets", "never commit"), + ("Use Copilot suggestions wisely", "Copilot"), + ("AI assistant instructions", "AI assistant"), + ("when writing code follow style guide", "when writing code"), + ("when generating docs use templates", "when generating"), + ], + ) + def test_detects_keywords(self, content: str, keyword: str) -> None: + assert content_has_ai_keywords(content) is True + + def test_case_insensitive(self) -> None: + assert content_has_ai_keywords("CURSOR RULE: do something") is True + assert content_has_ai_keywords("CLAUDE: optional") is True + + def test_no_keywords_returns_false(self) -> None: + assert content_has_ai_keywords("Just a normal readme.") is False + assert content_has_ai_keywords("") is False + assert content_has_ai_keywords(None) is False + + +class TestFilterTreeEntriesForAiRules: + """Tests for filter_tree_entries_for_ai_rules().""" + + def test_keeps_only_matching_blobs(self) -> None: + entries = [ + {"path": "src/main.py", "type": "blob"}, + {"path": "cursor-rules.md", "type": "blob"}, + {"path": "docs/guidelines.md", "type": "blob"}, + {"path": "README.md", "type": "blob"}, + {"path": "docs", "type": "tree"}, + ] + result = filter_tree_entries_for_ai_rules(entries, blob_only=True) + assert len(result) == 2 + paths = [e["path"] for e in result] + assert "cursor-rules.md" in paths + assert "docs/guidelines.md" in paths + + def test_excludes_trees_when_blob_only(self) -> None: + entries = [ + {"path": ".cursor/rules", "type": "tree"}, + {"path": ".cursor/rules/guidelines.mdc", "type": "blob"}, + ] + result = filter_tree_entries_for_ai_rules(entries, blob_only=True) + assert len(result) == 1 + assert result[0]["path"] == ".cursor/rules/guidelines.mdc" + + def test_empty_list_returns_empty(self) -> None: + assert filter_tree_entries_for_ai_rules([]) == [] + + def test_includes_trees_when_blob_only_false(self) -> None: + entries = [ + {"path": "docs/guidelines.md", "type": "blob"}, + ] + result = filter_tree_entries_for_ai_rules(entries, blob_only=False) + assert len(result) == 1 + + +class TestScanRepoForAiRuleFiles: + """Tests for scan_repo_for_ai_rule_files() (async).""" + + @pytest.mark.asyncio + async def test_filter_only_no_content(self) -> None: + tree_entries = [ + {"path": "cursor-rules.md", "type": "blob"}, + {"path": "src/main.py", "type": "blob"}, + ] + result = await scan_repo_for_ai_rule_files( + tree_entries, + fetch_content=False, + get_file_content=None, + ) + assert len(result) == 1 + assert result[0]["path"] == "cursor-rules.md" + assert result[0]["has_keywords"] is False + assert result[0]["content"] is None + + @pytest.mark.asyncio + async def test_fetch_content_sets_has_keywords(self) -> None: + tree_entries = [ + {"path": "cursor-rules.md", "type": "blob"}, + {"path": "docs/guidelines.md", "type": "blob"}, + ] + + async def mock_get_content(path: str) -> str | None: + if path == "cursor-rules.md": + return "Cursor rule: Always use type hints." + if path == "docs/guidelines.md": + return "No AI keywords here." + return None + + result = await scan_repo_for_ai_rule_files( + tree_entries, + fetch_content=True, + get_file_content=mock_get_content, + ) + assert len(result) == 2 + by_path = {r["path"]: r for r in result} + assert by_path["cursor-rules.md"]["has_keywords"] is True + assert by_path["cursor-rules.md"]["content"] == "Cursor rule: Always use type hints." + assert by_path["docs/guidelines.md"]["has_keywords"] is False + assert by_path["docs/guidelines.md"]["content"] == "No AI keywords here." + + @pytest.mark.asyncio + async def test_fetch_failure_keeps_has_keywords_false(self) -> None: + tree_entries = [{"path": "cursor-rules.md", "type": "blob"}] + + async def failing_get_content(path: str) -> str | None: + raise OSError("Network error") + + result = await scan_repo_for_ai_rule_files( + tree_entries, + fetch_content=True, + get_file_content=failing_get_content, + ) + assert len(result) == 1 + assert result[0]["path"] == "cursor-rules.md" + assert result[0]["has_keywords"] is False + assert result[0]["content"] is None \ No newline at end of file