From 937dc4de3e6b0bbe61a8f5d3c4aa6c057f5c373d Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Tue, 27 Jan 2026 14:29:03 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- pyproject.toml | 1 + src/ghstack/github.py | 41 +++++ src/ghstack/github_fake.py | 83 ++++++++++ src/ghstack/merge_rules.py | 311 +++++++++++++++++++++++++++++++++++++ uv.lock | 2 + 5 files changed, 438 insertions(+) create mode 100644 src/ghstack/merge_rules.py diff --git a/pyproject.toml b/pyproject.toml index 7715388..4b9cb00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "typing-extensions>=3", "click<9,>=8", "flake8<8.0.0,>=7.0.0", + "pyyaml<7,>=6", ] name = "ghstack" version = "0.14.0" diff --git a/src/ghstack/github.py b/src/ghstack/github.py index 31343ca..d813f89 100644 --- a/src/ghstack/github.py +++ b/src/ghstack/github.py @@ -95,3 +95,44 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: Returns: parsed JSON response """ pass + + # Merge rules related API methods + + def get_pr_reviews(self, owner: str, repo: str, number: int) -> Any: + """Get reviews for a pull request.""" + return self.get(f"repos/{owner}/{repo}/pulls/{number}/reviews") + + def get_pr_files(self, owner: str, repo: str, number: int) -> Any: + """Get files changed in a pull request.""" + return self.get(f"repos/{owner}/{repo}/pulls/{number}/files") + + def get_check_runs(self, owner: str, repo: str, ref: str) -> Any: + """Get check runs for a commit ref.""" + return self.get(f"repos/{owner}/{repo}/commits/{ref}/check-runs") + + def get_team_members(self, org: str, team_slug: str) -> Any: + """Get members of a team.""" + return self.get(f"orgs/{org}/teams/{team_slug}/members") + + def get_file_contents( + self, owner: str, repo: str, path: str, ref: str = "HEAD" + ) -> str: + """ + Get the contents of a file from the repository. + + Returns the decoded file contents as a string. + """ + import base64 + + result = self.get(f"repos/{owner}/{repo}/contents/{path}?ref={ref}") + content = result.get("content", "") + encoding = result.get("encoding", "") + if encoding == "base64": + return base64.b64decode(content).decode("utf-8") + return content + + def post_issue_comment( + self, owner: str, repo: str, number: int, body: str + ) -> Any: + """Post a comment on an issue or pull request.""" + return self.post(f"repos/{owner}/{repo}/issues/{number}/comments", body=body) diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index 622b6af..7b2ec4d 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -257,6 +257,19 @@ def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] +@dataclass +class PullRequestReview: + user: str + state: str # APPROVED, CHANGES_REQUESTED, COMMENTED, etc. + + +@dataclass +class CheckRun: + name: str + status: str # queued, in_progress, completed + conclusion: Optional[str] # success, failure, neutral, cancelled, skipped, etc. + + @dataclass class PullRequest(Node): baseRef: Optional[Ref] @@ -274,6 +287,10 @@ class PullRequest(Node): url: str reviewers: List[str] = dataclasses.field(default_factory=list) labels: List[str] = dataclasses.field(default_factory=list) + # Merge rules related fields + files: List[str] = dataclasses.field(default_factory=list) + reviews: List[PullRequestReview] = dataclasses.field(default_factory=list) + check_runs: List[CheckRun] = dataclasses.field(default_factory=list) def repository(self, info: GraphQLResolveInfo) -> Repository: return github_state(info).repositories[self._repository] @@ -464,6 +481,72 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: # For now, pretend all branches are not protected raise ghstack.github.NotFoundError() + # GET /repos/{owner}/{repo}/pulls/{number}/reviews + if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)/reviews$", path): + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + return [ + {"user": {"login": r.user}, "state": r.state} + for r in pr.reviews + ] + + # GET /repos/{owner}/{repo}/pulls/{number}/files + if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)/files$", path): + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + return [{"filename": f} for f in pr.files] + + # GET /repos/{owner}/{repo}/pulls/{number} + if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)$", path): + state = self.state + repo = state.repository(m.group(1), m.group(2)) + pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) + head_sha = "" + if pr.headRef: + head_sha = pr.headRef.target.oid + return { + "number": pr.number, + "title": pr.title, + "body": pr.body, + "head": {"sha": head_sha}, + "base": {"ref": pr.baseRefName}, + } + + # GET /repos/{owner}/{repo}/commits/{ref}/check-runs + if m := re.match(r"^repos/([^/]+)/([^/]+)/commits/([^/]+)/check-runs$", path): + # For the fake endpoint, we need to find the PR by head SHA + # and return its check runs + state = self.state + ref = m.group(3) + # Search for PR with matching head ref + for pr in state.pull_requests.values(): + if pr.headRef and pr.headRef.target.oid == ref: + return { + "total_count": len(pr.check_runs), + "check_runs": [ + { + "name": c.name, + "status": c.status, + "conclusion": c.conclusion, + } + for c in pr.check_runs + ], + } + # No matching PR found + return {"total_count": 0, "check_runs": []} + + # GET /orgs/{org}/teams/{team_slug}/members + if m := re.match(r"^orgs/([^/]+)/teams/([^/]+)/members$", path): + # Return empty list for fake endpoint + return [] + + # GET /repos/{owner}/{repo}/contents/{path} + if m := re.match(r"^repos/([^/]+)/([^/]+)/contents/(.+?)(?:\?ref=(.+))?$", path): + # Return a NotFoundError for the fake endpoint + raise ghstack.github.NotFoundError() + elif method == "post": if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path): return self._create_pull( diff --git a/src/ghstack/merge_rules.py b/src/ghstack/merge_rules.py new file mode 100644 index 0000000..edc5bba --- /dev/null +++ b/src/ghstack/merge_rules.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 + +""" +Merge Rules Engine for ghstack land. + +This module provides functionality to load, parse, and validate merge rules +that control when PRs can be landed. Rules specify required approvers and +CI checks based on file patterns. +""" + +import fnmatch +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Set + +import yaml + +import ghstack.github + + +class MergeValidationError(RuntimeError): + """Raised when merge validation fails.""" + + def __init__(self, result: "ValidationResult"): + self.result = result + super().__init__(self._format_message()) + + def _format_message(self) -> str: + lines = [f"Merge validation failed for PR #{self.result.pr_number}"] + if self.result.rule_name: + lines.append(f"Rule: {self.result.rule_name}") + if self.result.errors: + lines.append("Errors:") + for error in self.result.errors: + lines.append(f" - {error}") + return "\n".join(lines) + + +@dataclass +class MergeRule: + """Represents a single merge rule configuration.""" + + name: str + patterns: List[str] + approved_by: List[str] + mandatory_checks_name: List[str] + ignore_flaky_failures: bool = False + + +@dataclass +class ValidationResult: + """Result of validating a PR against merge rules.""" + + valid: bool + pr_number: int + rule_name: Optional[str] = None + errors: List[str] = field(default_factory=list) + matched_files: List[str] = field(default_factory=list) + + +class MergeRulesLoader: + """Loads merge rules from repository or local files.""" + + def __init__( + self, + github: "ghstack.github.GitHubEndpoint", + owner: str, + repo: str, + ): + self.github = github + self.owner = owner + self.repo = repo + + def load_from_repo(self, ref: str = "HEAD") -> List[MergeRule]: + """Load merge rules from the repository's .github/merge_rules.yaml file.""" + try: + content = self.github.get_file_contents( + self.owner, self.repo, ".github/merge_rules.yaml", ref + ) + return self._parse_yaml(content) + except ghstack.github.NotFoundError: + logging.debug("No merge_rules.yaml found in repository") + return [] + except Exception as e: + logging.warning(f"Failed to load merge rules: {e}") + return [] + + def load_from_file(self, path: str) -> List[MergeRule]: + """Load merge rules from a local file path.""" + with open(path, encoding="utf-8") as f: + content = f.read() + return self._parse_yaml(content) + + def _parse_yaml(self, content: str) -> List[MergeRule]: + """Parse YAML content into a list of MergeRule objects.""" + data = yaml.safe_load(content) + if not isinstance(data, list): + raise ValueError("merge_rules.yaml must be a list of rules") + + rules = [] + for item in data: + rule = MergeRule( + name=item.get("name", "unnamed"), + patterns=item.get("patterns", []), + approved_by=item.get("approved_by", []), + mandatory_checks_name=item.get("mandatory_checks_name", []), + ignore_flaky_failures=item.get("ignore_flaky_failures", False), + ) + rules.append(rule) + return rules + + +class MergeValidator: + """Validates PRs against merge rules.""" + + def __init__( + self, + github: "ghstack.github.GitHubEndpoint", + owner: str, + repo: str, + ): + self.github = github + self.owner = owner + self.repo = repo + self._team_cache: Dict[str, Set[str]] = {} + + def get_pr_files(self, pr_number: int) -> List[str]: + """Get list of files changed in a PR.""" + files = self.github.get_pr_files(self.owner, self.repo, pr_number) + return [f["filename"] for f in files] + + def get_pr_approvers(self, pr_number: int) -> Set[str]: + """Get set of users who have approved the PR.""" + reviews = self.github.get_pr_reviews(self.owner, self.repo, pr_number) + approvers: Set[str] = set() + + # Track the latest review state for each user + user_states: Dict[str, str] = {} + for review in reviews: + user = review.get("user", {}).get("login", "") + state = review.get("state", "") + if user and state: + user_states[user] = state + + # Only count users whose latest review is APPROVED + for user, state in user_states.items(): + if state == "APPROVED": + approvers.add(user) + + return approvers + + def get_pr_check_statuses(self, pr_number: int) -> Dict[str, str]: + """Get CI check statuses for a PR's head commit.""" + # First get the PR to find the head SHA + pr_info = self.github.get( + f"repos/{self.owner}/{self.repo}/pulls/{pr_number}" + ) + head_sha = pr_info.get("head", {}).get("sha", "") + if not head_sha: + return {} + + check_runs = self.github.get_check_runs(self.owner, self.repo, head_sha) + statuses: Dict[str, str] = {} + for check in check_runs.get("check_runs", []): + name = check.get("name", "") + conclusion = check.get("conclusion") + status = check.get("status", "") + # Use conclusion if available, otherwise use status + if conclusion: + statuses[name] = conclusion + else: + statuses[name] = status + return statuses + + def expand_team_members(self, team_ref: str) -> Set[str]: + """ + Expand a team reference (org/team-slug) to its members. + + Returns an empty set if the reference isn't a team or if the + API call fails. + """ + if "/" not in team_ref: + # Not a team reference, return as single user + return {team_ref} + + if team_ref in self._team_cache: + return self._team_cache[team_ref] + + try: + org, team_slug = team_ref.split("/", 1) + members = self.github.get_team_members(org, team_slug) + member_logins = {m.get("login", "") for m in members if m.get("login")} + self._team_cache[team_ref] = member_logins + return member_logins + except Exception as e: + logging.warning(f"Failed to expand team {team_ref}: {e}") + # Return as single user if team expansion fails + return {team_ref} + + def find_matching_rule( + self, files: List[str], rules: List[MergeRule] + ) -> Optional[MergeRule]: + """ + Find the first rule that matches any of the given files. + + Rules are matched in order - first matching rule wins. + """ + for rule in rules: + for file_path in files: + for pattern in rule.patterns: + if fnmatch.fnmatch(file_path, pattern): + return rule + return None + + def validate_pr( + self, pr_number: int, rules: List[MergeRule] + ) -> ValidationResult: + """ + Validate a PR against the provided merge rules. + + Returns a ValidationResult indicating whether the PR passes + all required checks and approvals. + """ + files = self.get_pr_files(pr_number) + + if not files: + return ValidationResult( + valid=True, + pr_number=pr_number, + rule_name=None, + errors=[], + matched_files=[], + ) + + rule = self.find_matching_rule(files, rules) + if rule is None: + # No matching rule, PR passes by default + return ValidationResult( + valid=True, + pr_number=pr_number, + rule_name=None, + errors=[], + matched_files=files, + ) + + errors: List[str] = [] + + # Validate approvers + if rule.approved_by: + if "any" not in rule.approved_by: + approvers = self.get_pr_approvers(pr_number) + required_approvers: Set[str] = set() + + for approver_ref in rule.approved_by: + required_approvers.update(self.expand_team_members(approver_ref)) + + if not approvers.intersection(required_approvers): + missing = ", ".join(sorted(rule.approved_by)) + errors.append(f"Missing required approval from: {missing}") + + # Validate CI checks + if rule.mandatory_checks_name: + check_statuses = self.get_pr_check_statuses(pr_number) + + for check_name in rule.mandatory_checks_name: + status = check_statuses.get(check_name, "missing") + if status == "missing": + errors.append(f'Check "{check_name}" has not run') + elif status not in ("success", "neutral", "skipped"): + if status == "in_progress" or status == "queued": + errors.append( + f'Check "{check_name}" has not completed (status: {status})' + ) + elif rule.ignore_flaky_failures: + logging.info( + f'Check "{check_name}" failed but ignore_flaky_failures is set' + ) + else: + errors.append( + f'Check "{check_name}" has not passed (status: {status})' + ) + + return ValidationResult( + valid=len(errors) == 0, + pr_number=pr_number, + rule_name=rule.name, + errors=errors, + matched_files=files, + ) + + +def format_validation_error_comment(result: ValidationResult) -> str: + """Format a validation result as a markdown comment for posting to GitHub.""" + lines = [f"## Merge validation failed for PR #{result.pr_number}"] + + if result.rule_name: + lines.append(f"\n**Rule:** {result.rule_name}") + + if result.errors: + lines.append("\n### Errors:") + for error in result.errors: + lines.append(f"- {error}") + + if result.matched_files: + lines.append("\n### Matched Files:") + for file_path in result.matched_files[:10]: # Limit to first 10 + lines.append(f"- `{file_path}`") + if len(result.matched_files) > 10: + lines.append(f"- ... and {len(result.matched_files) - 10} more files") + + return "\n".join(lines) diff --git a/uv.lock b/uv.lock index b25d84e..3d30c27 100644 --- a/uv.lock +++ b/uv.lock @@ -414,6 +414,7 @@ dependencies = [ { name = "aiohttp" }, { name = "click" }, { name = "flake8" }, + { name = "pyyaml" }, { name = "requests" }, { name = "typing-extensions" }, ] @@ -440,6 +441,7 @@ requires-dist = [ { name = "click", specifier = ">=8,<9" }, { name = "flake8", specifier = ">=7.0.0,<8.0.0" }, { name = "importlib-metadata", marker = "python_full_version < '3.8'", specifier = ">=1.4" }, + { name = "pyyaml", specifier = ">=6,<7" }, { name = "requests", specifier = ">=2,<3" }, { name = "typing-extensions", specifier = ">=3" }, ] From 4d5f7aac27df640400e23c8f95a3fb3de55ad709 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Tue, 27 Jan 2026 14:31:42 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- src/ghstack/github.py | 4 +--- src/ghstack/github_fake.py | 11 +++++++---- src/ghstack/merge_rules.py | 8 ++------ 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/ghstack/github.py b/src/ghstack/github.py index d813f89..b7b1f11 100644 --- a/src/ghstack/github.py +++ b/src/ghstack/github.py @@ -131,8 +131,6 @@ def get_file_contents( return base64.b64decode(content).decode("utf-8") return content - def post_issue_comment( - self, owner: str, repo: str, number: int, body: str - ) -> Any: + def post_issue_comment(self, owner: str, repo: str, number: int, body: str) -> Any: """Post a comment on an issue or pull request.""" return self.post(f"repos/{owner}/{repo}/issues/{number}/comments", body=body) diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index 7b2ec4d..7f6e6f7 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -487,8 +487,7 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: repo = state.repository(m.group(1), m.group(2)) pr = state.pull_request(repo, GitHubNumber(int(m.group(3)))) return [ - {"user": {"login": r.user}, "state": r.state} - for r in pr.reviews + {"user": {"login": r.user}, "state": r.state} for r in pr.reviews ] # GET /repos/{owner}/{repo}/pulls/{number}/files @@ -515,7 +514,9 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: } # GET /repos/{owner}/{repo}/commits/{ref}/check-runs - if m := re.match(r"^repos/([^/]+)/([^/]+)/commits/([^/]+)/check-runs$", path): + if m := re.match( + r"^repos/([^/]+)/([^/]+)/commits/([^/]+)/check-runs$", path + ): # For the fake endpoint, we need to find the PR by head SHA # and return its check runs state = self.state @@ -543,7 +544,9 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any: return [] # GET /repos/{owner}/{repo}/contents/{path} - if m := re.match(r"^repos/([^/]+)/([^/]+)/contents/(.+?)(?:\?ref=(.+))?$", path): + if m := re.match( + r"^repos/([^/]+)/([^/]+)/contents/(.+?)(?:\?ref=(.+))?$", path + ): # Return a NotFoundError for the fake endpoint raise ghstack.github.NotFoundError() diff --git a/src/ghstack/merge_rules.py b/src/ghstack/merge_rules.py index edc5bba..06a928f 100644 --- a/src/ghstack/merge_rules.py +++ b/src/ghstack/merge_rules.py @@ -152,9 +152,7 @@ def get_pr_approvers(self, pr_number: int) -> Set[str]: def get_pr_check_statuses(self, pr_number: int) -> Dict[str, str]: """Get CI check statuses for a PR's head commit.""" # First get the PR to find the head SHA - pr_info = self.github.get( - f"repos/{self.owner}/{self.repo}/pulls/{pr_number}" - ) + pr_info = self.github.get(f"repos/{self.owner}/{self.repo}/pulls/{pr_number}") head_sha = pr_info.get("head", {}).get("sha", "") if not head_sha: return {} @@ -212,9 +210,7 @@ def find_matching_rule( return rule return None - def validate_pr( - self, pr_number: int, rules: List[MergeRule] - ) -> ValidationResult: + def validate_pr(self, pr_number: int, rules: List[MergeRule]) -> ValidationResult: """ Validate a PR against the provided merge rules.