Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"typing-extensions>=3",
"click<9,>=8",
"flake8<8.0.0,>=7.0.0",
"pyyaml<7,>=6",
]
name = "ghstack"
version = "0.14.0"
Expand Down
39 changes: 39 additions & 0 deletions src/ghstack/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,42 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
Returns: parsed JSON response
"""
pass

# Merge rules related API methods

def get_pr_reviews(self, owner: str, repo: str, number: int) -> Any:
"""Get reviews for a pull request."""
return self.get(f"repos/{owner}/{repo}/pulls/{number}/reviews")

def get_pr_files(self, owner: str, repo: str, number: int) -> Any:
"""Get files changed in a pull request."""
return self.get(f"repos/{owner}/{repo}/pulls/{number}/files")

def get_check_runs(self, owner: str, repo: str, ref: str) -> Any:
"""Get check runs for a commit ref."""
return self.get(f"repos/{owner}/{repo}/commits/{ref}/check-runs")

def get_team_members(self, org: str, team_slug: str) -> Any:
"""Get members of a team."""
return self.get(f"orgs/{org}/teams/{team_slug}/members")

def get_file_contents(
self, owner: str, repo: str, path: str, ref: str = "HEAD"
) -> str:
"""
Get the contents of a file from the repository.
Returns the decoded file contents as a string.
"""
import base64

result = self.get(f"repos/{owner}/{repo}/contents/{path}?ref={ref}")
content = result.get("content", "")
encoding = result.get("encoding", "")
if encoding == "base64":
return base64.b64decode(content).decode("utf-8")
return content

def post_issue_comment(self, owner: str, repo: str, number: int, body: str) -> Any:
"""Post a comment on an issue or pull request."""
return self.post(f"repos/{owner}/{repo}/issues/{number}/comments", body=body)
86 changes: 86 additions & 0 deletions src/ghstack/github_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ def repository(self, info: GraphQLResolveInfo) -> Repository:
return github_state(info).repositories[self._repository]


@dataclass
class PullRequestReview:
user: str
state: str # APPROVED, CHANGES_REQUESTED, COMMENTED, etc.


@dataclass
class CheckRun:
name: str
status: str # queued, in_progress, completed
conclusion: Optional[str] # success, failure, neutral, cancelled, skipped, etc.


@dataclass
class PullRequest(Node):
baseRef: Optional[Ref]
Expand All @@ -274,6 +287,10 @@ class PullRequest(Node):
url: str
reviewers: List[str] = dataclasses.field(default_factory=list)
labels: List[str] = dataclasses.field(default_factory=list)
# Merge rules related fields
files: List[str] = dataclasses.field(default_factory=list)
reviews: List[PullRequestReview] = dataclasses.field(default_factory=list)
check_runs: List[CheckRun] = dataclasses.field(default_factory=list)

def repository(self, info: GraphQLResolveInfo) -> Repository:
return github_state(info).repositories[self._repository]
Expand Down Expand Up @@ -464,6 +481,75 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
# For now, pretend all branches are not protected
raise ghstack.github.NotFoundError()

# GET /repos/{owner}/{repo}/pulls/{number}/reviews
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)/reviews$", path):
state = self.state
repo = state.repository(m.group(1), m.group(2))
pr = state.pull_request(repo, GitHubNumber(int(m.group(3))))
return [
{"user": {"login": r.user}, "state": r.state} for r in pr.reviews
]

# GET /repos/{owner}/{repo}/pulls/{number}/files
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)/files$", path):
state = self.state
repo = state.repository(m.group(1), m.group(2))
pr = state.pull_request(repo, GitHubNumber(int(m.group(3))))
return [{"filename": f} for f in pr.files]

# GET /repos/{owner}/{repo}/pulls/{number}
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)$", path):
state = self.state
repo = state.repository(m.group(1), m.group(2))
pr = state.pull_request(repo, GitHubNumber(int(m.group(3))))
head_sha = ""
if pr.headRef:
head_sha = pr.headRef.target.oid
return {
"number": pr.number,
"title": pr.title,
"body": pr.body,
"head": {"sha": head_sha},
"base": {"ref": pr.baseRefName},
}

# GET /repos/{owner}/{repo}/commits/{ref}/check-runs
if m := re.match(
r"^repos/([^/]+)/([^/]+)/commits/([^/]+)/check-runs$", path
):
# For the fake endpoint, we need to find the PR by head SHA
# and return its check runs
state = self.state
ref = m.group(3)
# Search for PR with matching head ref
for pr in state.pull_requests.values():
if pr.headRef and pr.headRef.target.oid == ref:
return {
"total_count": len(pr.check_runs),
"check_runs": [
{
"name": c.name,
"status": c.status,
"conclusion": c.conclusion,
}
for c in pr.check_runs
],
}
# No matching PR found
return {"total_count": 0, "check_runs": []}

# GET /orgs/{org}/teams/{team_slug}/members
if m := re.match(r"^orgs/([^/]+)/teams/([^/]+)/members$", path):
# Return empty list for fake endpoint
return []

# GET /repos/{owner}/{repo}/contents/{path}
if m := re.match(
r"^repos/([^/]+)/([^/]+)/contents/(.+?)(?:\?ref=(.+))?$", path
):
# Return a NotFoundError for the fake endpoint
raise ghstack.github.NotFoundError()

elif method == "post":
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path):
return self._create_pull(
Expand Down
Loading