From ba45ced79c92e1b0e2bc58fed4fb04530c563fe5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Jan 2026 17:22:26 +0000 Subject: [PATCH 1/2] more defensive --- conftest.py | 2 +- src/ghstack/clean.py | 344 ++++++++++++++++++++++++++++++++++++ src/ghstack/cli.py | 43 +++++ src/ghstack/github_fake.py | 24 ++- src/ghstack/test_prelude.py | 36 +++- test/clean/basic.py.test | 51 ++++++ 6 files changed, 491 insertions(+), 9 deletions(-) create mode 100644 src/ghstack/clean.py create mode 100644 test/clean/basic.py.test diff --git a/conftest.py b/conftest.py index 4e5f519..9dc7d24 100644 --- a/conftest.py +++ b/conftest.py @@ -23,7 +23,7 @@ def pytest_collect_file(file_path: pathlib.Path, parent): class Script(pytest.File): def collect(self): yield ScriptItem.from_parent(self, name="default", direct=False) - if self.path.parent.name in ["submit", "unlink"]: + if self.path.parent.name in ["submit", "unlink", "clean"]: yield ScriptItem.from_parent(self, name="direct", direct=True) diff --git a/src/ghstack/clean.py b/src/ghstack/clean.py new file mode 100644 index 0000000..db828e6 --- /dev/null +++ b/src/ghstack/clean.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 + +import logging +import re +from typing import Dict, List, Optional, Tuple + +import click + +import ghstack.diff +import ghstack.github +import ghstack.github_utils +import ghstack.shell +from ghstack.types import GhNumber + + +# Regex to match ghstack branch names: gh/{username}/{number}/{kind} +RE_GHSTACK_BRANCH = re.compile(r"^(?:refs/(?:heads|remotes/[^/]+)/)?gh/([^/]+)/([0-9]+)/(.+)$") + + +def parse_ghstack_branch(ref: str) -> Optional[Tuple[str, GhNumber, str]]: + """ + Parse a ghstack branch reference. + + Returns (username, ghnum, kind) if it's a valid ghstack branch, None otherwise. + """ + m = RE_GHSTACK_BRANCH.match(ref) + if m: + return (m.group(1), GhNumber(m.group(2)), m.group(3)) + return None + + +def get_pr_number_for_ghnum( + sh: ghstack.shell.Shell, + remote_name: str, + username: str, + ghnum: GhNumber, + github_url: str, +) -> Optional[int]: + """ + Get the GitHub PR number associated with a ghstack ghnum by reading the orig branch. + + Returns the PR number if found, None otherwise. + """ + orig_ref = f"{remote_name}/gh/{username}/{ghnum}/orig" + + try: + # Try to get the commit message from the orig branch + commit_msg = sh.git("log", "-1", "--format=%B", orig_ref) + except RuntimeError: + # Branch doesn't exist or can't be read + return None + + # Use ghstack's own PullRequestResolved.search() to find the PR + # This handles all formats: "Pull Request resolved:", "Pull-Request-resolved:", "Pull-Request:" + # as well as the legacy "gh-metadata:" format + pr_resolved = ghstack.diff.PullRequestResolved.search(commit_msg, github_url) + if pr_resolved is not None: + return int(pr_resolved.number) + + return None + + +def find_pr_by_head_ref( + github: ghstack.github.GitHubEndpoint, + repo_owner: str, + repo_name: str, + head_ref: str, +) -> Optional[Tuple[int, bool]]: + """ + Find a PR by its head ref name. + + Returns (pr_number, is_closed) if found, None if no PR exists for this head ref. + Raises on API errors (fail loudly). + """ + # Query for PRs with this head ref + result = github.graphql( + """ + query ($owner: String!, $name: String!, $headRefName: String!) { + repository(name: $name, owner: $owner) { + pullRequests(headRefName: $headRefName, first: 1) { + nodes { + number + closed + } + } + } + } + """, + owner=repo_owner, + name=repo_name, + headRefName=head_ref, + ) + prs = result["data"]["repository"]["pullRequests"]["nodes"] + if not prs: + return None + pr = prs[0] + return (pr["number"], pr["closed"]) + + +def check_pr_closed( + github: ghstack.github.GitHubEndpoint, + repo_owner: str, + repo_name: str, + pr_number: int, +) -> bool: + """ + Check if a PR is closed. + + Returns True if the PR is closed, False if it's open. + Raises if the PR doesn't exist or on API errors (fail loudly). + """ + result = github.graphql( + """ + query ($owner: String!, $name: String!, $number: Int!) { + repository(name: $name, owner: $owner) { + pullRequest(number: $number) { + closed + } + } + } + """, + owner=repo_owner, + name=repo_name, + number=pr_number, + ) + pr = result["data"]["repository"]["pullRequest"] + if pr is None: + # PR doesn't exist - treat as closed (it was deleted) + return True + return pr["closed"] + + +def main( + *, + github: ghstack.github.GitHubEndpoint, + sh: ghstack.shell.Shell, + github_url: str, + remote_name: str, + dry_run: bool = False, + clean_local: bool = False, + username: Optional[str] = None, + repo_owner: Optional[str] = None, + repo_name: Optional[str] = None, + force: bool = False, +) -> List[str]: + """ + Clean up orphan ghstack branches. + + An orphan branch is a ghstack-managed branch whose associated PR has been + closed (either merged or manually closed). + + Args: + github: GitHub API endpoint + sh: Shell for executing git commands + github_url: GitHub URL (e.g., 'github.com') + remote_name: Name of the remote (e.g., 'origin') + dry_run: If True, list branches without deleting + clean_local: If True, also prune local tracking branches + username: If provided, only clean branches for this user + repo_owner: Repository owner (inferred from remote if not provided) + repo_name: Repository name (inferred from remote if not provided) + force: If True, skip confirmation prompt + + Returns: + List of branch names that were deleted (or would be deleted in dry-run mode) + """ + # Get repo info if not provided + if repo_owner is None or repo_name is None: + repo_info = ghstack.github_utils.get_github_repo_info( + github=github, + sh=sh, + repo_owner=repo_owner, + repo_name=repo_name, + github_url=github_url, + remote_name=remote_name, + ) + repo_owner = repo_info["name_with_owner"]["owner"] + repo_name = repo_info["name_with_owner"]["name"] + + # Fetch latest state from remote + logging.info(f"Fetching from {remote_name}...") + sh.git("fetch", "--prune", remote_name) + + # List all ghstack branches on the remote + refs_output = sh.git( + "for-each-ref", + f"refs/remotes/{remote_name}/gh/", + "--format=%(refname)", + ) + + if not refs_output.strip(): + logging.info("No ghstack branches found.") + return [] + + refs = refs_output.strip().split("\n") + + # Group branches by (username, ghnum) + branches_by_ghnum: Dict[Tuple[str, GhNumber], List[str]] = {} + + for ref in refs: + parsed = parse_ghstack_branch(ref) + if parsed is None: + continue + + branch_username, ghnum, kind = parsed + + # Filter by username if specified + if username is not None and branch_username != username: + continue + + key = (branch_username, ghnum) + if key not in branches_by_ghnum: + branches_by_ghnum[key] = [] + + # Extract just the branch name without refs/remotes/{remote}/ + branch_name = f"gh/{branch_username}/{ghnum}/{kind}" + branches_by_ghnum[key].append(branch_name) + + if not branches_by_ghnum: + if username: + logging.info(f"No ghstack branches found for user '{username}'.") + else: + logging.info("No ghstack branches found.") + return [] + + logging.info(f"Found {len(branches_by_ghnum)} ghstack PR(s) to check...") + + # Check which PRs are closed + # Cache pr_number -> is_closed to handle multiple ghnums mapping to same PR + pr_closed_cache: Dict[int, bool] = {} + orphan_branches: List[str] = [] + + for (branch_username, ghnum), branches in branches_by_ghnum.items(): + # Get the PR number from the orig branch + pr_number = get_pr_number_for_ghnum( + sh, remote_name, branch_username, ghnum, github_url + ) + + if pr_number is None: + # Can't determine PR number from orig branch (missing or corrupted) + # Try to find PR by querying GitHub for the head ref + head_ref = f"gh/{branch_username}/{ghnum}/head" + logging.info( + f"Missing orig branch for gh/{branch_username}/{ghnum}, " + f"querying GitHub by head ref..." + ) + pr_info = find_pr_by_head_ref(github, repo_owner, repo_name, head_ref) + + if pr_info is None: + # No PR exists for this head ref - truly orphan + logging.info( + f"No PR found for gh/{branch_username}/{ghnum}, treating as orphan" + ) + orphan_branches.extend(branches) + continue + + pr_number, is_closed = pr_info + pr_closed_cache[pr_number] = is_closed + if is_closed: + logging.info(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is closed") + orphan_branches.extend(branches) + else: + logging.debug(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is still open") + continue + + # Check cache first (handles multiple ghnums mapping to same PR) + if pr_number in pr_closed_cache: + is_closed = pr_closed_cache[pr_number] + else: + # Query GitHub for PR status (raises on API error - fail loudly) + is_closed = check_pr_closed(github, repo_owner, repo_name, pr_number) + pr_closed_cache[pr_number] = is_closed + + if is_closed: + logging.info(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is closed") + orphan_branches.extend(branches) + else: + logging.debug(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is still open") + + if not orphan_branches: + logging.info("No orphan branches found.") + return [] + + # Sort branches for consistent output + orphan_branches.sort() + + # Display branches to be deleted + click.echo("\nOrphan branches that would be deleted:") + for branch in orphan_branches: + click.echo(f" {branch}") + click.echo(f"\nTotal: {len(orphan_branches)} branch(es)") + + if dry_run: + click.echo("\nRun without --dry-run to delete these branches.") + return orphan_branches + + # Confirm before deleting (unless --force is specified) + if not force: + click.echo("\n" + "=" * 60) + click.echo("WARNING: THIS OPERATION IS IRREVERSIBLE!") + click.echo("These branches will be permanently deleted from the remote.") + click.echo("=" * 60) + response = click.prompt( + "\nType 'delete' to confirm deletion", + default="", + show_default=False, + ) + if response.strip().lower() != "delete": + click.echo("Aborted. No branches were deleted.") + return [] + + # Delete branches on remote + click.echo(f"\nDeleting {len(orphan_branches)} orphan branch(es)...") + + # Delete in batches to avoid command line length limits + batch_size = 50 + deleted_branches: List[str] = [] + + for i in range(0, len(orphan_branches), batch_size): + batch = orphan_branches[i:i + batch_size] + try: + sh.git("push", remote_name, "--delete", *batch) + deleted_branches.extend(batch) + for branch in batch: + click.echo(f" Deleted: {branch}") + except RuntimeError as e: + logging.warning(f"Failed to delete some branches: {e}") + # Try deleting individually to identify which ones failed + for branch in batch: + try: + sh.git("push", remote_name, "--delete", branch) + deleted_branches.append(branch) + click.echo(f" Deleted: {branch}") + except RuntimeError: + logging.warning(f" Failed to delete: {branch}") + + # Optionally prune local tracking branches + if clean_local: + logging.info("Pruning local tracking branches...") + sh.git("fetch", "--prune", remote_name) + + click.echo(f"\nSuccessfully deleted {len(deleted_branches)} branch(es).") + + return deleted_branches diff --git a/src/ghstack/cli.py b/src/ghstack/cli.py index 1a7f4f7..f3db3b4 100644 --- a/src/ghstack/cli.py +++ b/src/ghstack/cli.py @@ -10,6 +10,7 @@ import ghstack.checkout import ghstack.cherry_pick import ghstack.circleci_real +import ghstack.clean import ghstack.config import ghstack.github_real import ghstack.land @@ -140,6 +141,48 @@ def checkout(same_base: bool, pull_request: str) -> None: ) +@main.command("clean") +@click.option( + "--dry-run", + is_flag=True, + help="List orphan branches without deleting them", +) +@click.option( + "--force", + "-f", + is_flag=True, + help="Skip confirmation prompt before deleting", +) +@click.option( + "--local", + is_flag=True, + help="Also prune local tracking branches after cleaning remote", +) +@click.option( + "--user", + default=None, + help="Only clean branches for a specific GitHub username", +) +def clean(dry_run: bool, force: bool, local: bool, user: Optional[str]) -> None: + """ + Clean up orphan ghstack branches + + Identifies and deletes ghstack-managed branches whose associated PRs + have been closed (either merged or manually closed). + """ + with cli_context() as (shell, config, github): + ghstack.clean.main( + github=github, + sh=shell, + github_url=config.github_url, + remote_name=config.remote_name, + dry_run=dry_run, + clean_local=local, + username=user, + force=force, + ) + + @main.command("cherry-pick") @click.option( "--stack", diff --git a/src/ghstack/github_fake.py b/src/ghstack/github_fake.py index 622b6af..198318f 100644 --- a/src/ghstack/github_fake.py +++ b/src/ghstack/github_fake.py @@ -206,15 +206,25 @@ def pullRequest( ) -> "PullRequest": return github_state(info).pull_request(self, number) - def pullRequests(self, info: GraphQLResolveInfo) -> "PullRequestConnection": - return PullRequestConnection( - nodes=list( - filter( - lambda pr: self == pr.repository(info), - github_state(info).pull_requests.values(), - ) + def pullRequests( + self, + info: GraphQLResolveInfo, + headRefName: Optional[str] = None, + first: Optional[int] = None, + ) -> "PullRequestConnection": + prs = list( + filter( + lambda pr: self == pr.repository(info), + github_state(info).pull_requests.values(), ) ) + # Filter by headRefName if specified + if headRefName is not None: + prs = [pr for pr in prs if pr.headRefName == headRefName] + # Limit results if first is specified + if first is not None: + prs = prs[:first] + return PullRequestConnection(nodes=prs) # TODO: This should take which repository the ref is in # This only works if you have upstream_sh diff --git a/src/ghstack/test_prelude.py b/src/ghstack/test_prelude.py index 567fadb..ca98343 100644 --- a/src/ghstack/test_prelude.py +++ b/src/ghstack/test_prelude.py @@ -14,7 +14,7 @@ import ghstack.checkout import ghstack.cherry_pick - +import ghstack.clean import ghstack.github import ghstack.github_fake import ghstack.github_utils @@ -32,6 +32,7 @@ "gh_submit", "gh_land", "gh_unlink", + "gh_clean", "gh_cherry_pick", "gh_checkout", "GitCommitHash", @@ -53,6 +54,7 @@ "get_github", "get_pr_reviewers", "get_pr_labels", + "close_pr", "tick", "captured_output", ] @@ -270,6 +272,38 @@ def gh_checkout(pull_request: str, same_base: bool = False) -> None: ) +def gh_clean( + dry_run: bool = False, + clean_local: bool = False, + username: Optional[str] = None, + force: bool = True, # Default to True in tests to skip confirmation prompt +) -> List[str]: + """Clean up orphan ghstack branches.""" + self = CTX + return ghstack.clean.main( + github=self.github, + sh=self.sh, + github_url="github.com", + remote_name="origin", + dry_run=dry_run, + clean_local=clean_local, + username=username, + repo_owner="pytorch", + repo_name="pytorch", + force=force, + ) + + +def close_pr(pr_number: int) -> None: + """Close a PR (for testing orphan branch cleanup).""" + self = CTX + repo = self.github.state.repository("pytorch", "pytorch") + pr = self.github.state.pull_request( + repo, ghstack.github_fake.GitHubNumber(pr_number) + ) + pr.closed = True + + def write_file_and_add(filename: str, contents: str) -> None: self = CTX with self.sh.open(filename, "w") as f: diff --git a/test/clean/basic.py.test b/test/clean/basic.py.test new file mode 100644 index 0000000..fedc2b6 --- /dev/null +++ b/test/clean/basic.py.test @@ -0,0 +1,51 @@ +from ghstack.test_prelude import * + +init_test() + +# Create a stack with two commits +commit("A") +commit("B") +gh_submit("Initial") + +# Verify branches exist before cleaning +upstream_sh = get_upstream_sh() +refs_before = upstream_sh.git("for-each-ref", "refs/heads/gh/", "--format=%(refname)") +assert "gh/ezyang/1/" in refs_before, f"Expected branches for PR 1, got: {refs_before}" +assert "gh/ezyang/2/" in refs_before, f"Expected branches for PR 2, got: {refs_before}" + +# Close PR #500 (first PR in the stack) +close_pr(500) + +# Run clean in dry-run mode first +with captured_output() as (out, err): + deleted_dry = gh_clean(dry_run=True) + +# Branches should still exist after dry run +refs_after_dry = upstream_sh.git("for-each-ref", "refs/heads/gh/", "--format=%(refname)") +assert "gh/ezyang/1/" in refs_after_dry, "Branches should still exist after dry run" + +# Verify dry run found the right branches +if is_direct(): + # Direct mode has head, orig, next + assert len(deleted_dry) >= 2, f"Expected at least 2 branches, got: {deleted_dry}" +else: + # Non-direct mode has head, base, orig + assert len(deleted_dry) >= 2, f"Expected at least 2 branches, got: {deleted_dry}" + +# Now actually clean +with captured_output() as (out, err): + deleted = gh_clean(dry_run=False) + +# Verify the closed PR branches were deleted +refs_after = upstream_sh.git("for-each-ref", "refs/heads/gh/", "--format=%(refname)") +assert "gh/ezyang/1/" not in refs_after, f"Expected branches for PR 1 to be deleted, got: {refs_after}" + +# Verify the open PR branches still exist +assert "gh/ezyang/2/" in refs_after, f"Expected branches for PR 2 to still exist, got: {refs_after}" + +# Clean again - should find nothing to clean +with captured_output() as (out, err): + deleted_again = gh_clean(dry_run=False) +assert len(deleted_again) == 0, f"Expected no branches to delete, got: {deleted_again}" + +ok() From 918e0a59a262b7e72c80aacbe4cf238d8e70933a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 20 Jan 2026 17:29:49 +0000 Subject: [PATCH 2/2] linter --- src/ghstack/clean.py | 18 +++++++++++++----- test/clean/basic.py.test | 12 +++++++++--- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/ghstack/clean.py b/src/ghstack/clean.py index db828e6..4013203 100644 --- a/src/ghstack/clean.py +++ b/src/ghstack/clean.py @@ -14,7 +14,9 @@ # Regex to match ghstack branch names: gh/{username}/{number}/{kind} -RE_GHSTACK_BRANCH = re.compile(r"^(?:refs/(?:heads|remotes/[^/]+)/)?gh/([^/]+)/([0-9]+)/(.+)$") +RE_GHSTACK_BRANCH = re.compile( + r"^(?:refs/(?:heads|remotes/[^/]+)/)?gh/([^/]+)/([0-9]+)/(.+)$" +) def parse_ghstack_branch(ref: str) -> Optional[Tuple[str, GhNumber, str]]: @@ -257,10 +259,14 @@ def main( pr_number, is_closed = pr_info pr_closed_cache[pr_number] = is_closed if is_closed: - logging.info(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is closed") + logging.info( + f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is closed" + ) orphan_branches.extend(branches) else: - logging.debug(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is still open") + logging.debug( + f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is still open" + ) continue # Check cache first (handles multiple ghnums mapping to same PR) @@ -275,7 +281,9 @@ def main( logging.info(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is closed") orphan_branches.extend(branches) else: - logging.debug(f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is still open") + logging.debug( + f"PR #{pr_number} (gh/{branch_username}/{ghnum}) is still open" + ) if not orphan_branches: logging.info("No orphan branches found.") @@ -317,7 +325,7 @@ def main( deleted_branches: List[str] = [] for i in range(0, len(orphan_branches), batch_size): - batch = orphan_branches[i:i + batch_size] + batch = orphan_branches[i : i + batch_size] try: sh.git("push", remote_name, "--delete", *batch) deleted_branches.extend(batch) diff --git a/test/clean/basic.py.test b/test/clean/basic.py.test index fedc2b6..3a7056f 100644 --- a/test/clean/basic.py.test +++ b/test/clean/basic.py.test @@ -21,7 +21,9 @@ with captured_output() as (out, err): deleted_dry = gh_clean(dry_run=True) # Branches should still exist after dry run -refs_after_dry = upstream_sh.git("for-each-ref", "refs/heads/gh/", "--format=%(refname)") +refs_after_dry = upstream_sh.git( + "for-each-ref", "refs/heads/gh/", "--format=%(refname)" +) assert "gh/ezyang/1/" in refs_after_dry, "Branches should still exist after dry run" # Verify dry run found the right branches @@ -38,10 +40,14 @@ with captured_output() as (out, err): # Verify the closed PR branches were deleted refs_after = upstream_sh.git("for-each-ref", "refs/heads/gh/", "--format=%(refname)") -assert "gh/ezyang/1/" not in refs_after, f"Expected branches for PR 1 to be deleted, got: {refs_after}" +assert ( + "gh/ezyang/1/" not in refs_after +), f"Expected branches for PR 1 to be deleted, got: {refs_after}" # Verify the open PR branches still exist -assert "gh/ezyang/2/" in refs_after, f"Expected branches for PR 2 to still exist, got: {refs_after}" +assert ( + "gh/ezyang/2/" in refs_after +), f"Expected branches for PR 2 to still exist, got: {refs_after}" # Clean again - should find nothing to clean with captured_output() as (out, err):