From 28a86448b7a321e0769ed7247e092cb7f0994c80 Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Wed, 25 Feb 2026 09:04:27 +0100 Subject: [PATCH] refactor(ci): validate GitHub event payloads with Pydantic models Replace untyped dict manipulation of GitHub webhook payloads with Pydantic models, parsing once at the boundary in get_github_event() and passing typed objects throughout. This removes ~40 lines of manual isinstance/get checks and catches schema violations early. Co-Authored-By: Claude Opus 4.6 Change-Id: I06f75cbee3238d0418560159c2888c70560a1e46 Claude-Session-Id: fbc3e1a8-2cec-43d2-8105-298313058e7a --- mergify_cli/ci/detector.py | 18 +-- mergify_cli/ci/git_refs/detector.py | 136 +++++++----------- mergify_cli/ci/github_event.py | 36 +++++ mergify_cli/ci/queue/metadata.py | 19 ++- .../ci/git_refs/test_git_refs_detector.py | 31 ++++ mergify_cli/tests/ci/push_event.json | 46 ++++++ mergify_cli/tests/ci/queue/test_metadata.py | 41 ++++++ mergify_cli/tests/ci/test_cli.py | 2 + mergify_cli/tests/ci/test_github_event.py | 75 ++++++++++ mergify_cli/tests/test_utils.py | 3 +- mergify_cli/utils.py | 9 +- 11 files changed, 310 insertions(+), 106 deletions(-) create mode 100644 mergify_cli/ci/github_event.py create mode 100644 mergify_cli/tests/ci/push_event.json create mode 100644 mergify_cli/tests/ci/test_github_event.py diff --git a/mergify_cli/ci/detector.py b/mergify_cli/ci/detector.py index a96d8b68..575ef8b8 100644 --- a/mergify_cli/ci/detector.py +++ b/mergify_cli/ci/detector.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import os import pathlib import re @@ -84,11 +83,13 @@ def get_head_ref_name() -> str | None: def get_github_actions_head_sha() -> str | None: if os.getenv("GITHUB_EVENT_NAME") == "pull_request": - # NOTE(leo): we want the head sha of pull request - event_raw_path = os.getenv("GITHUB_EVENT_PATH") - if event_raw_path and ((event_path := pathlib.Path(event_raw_path)).is_file()): - event = json.loads(event_path.read_bytes()) - return str(event["pull_request"]["head"]["sha"]) + try: + _, event = utils.get_github_event() + except utils.GitHubEventNotFoundError: + pass + else: + if event.pull_request and event.pull_request.head: + return event.pull_request.head.sha return os.getenv("GITHUB_SHA") @@ -193,10 +194,9 @@ def get_github_pull_request_number() -> int | None: _, event = utils.get_github_event() except utils.GitHubEventNotFoundError: return None - pr = event.get("pull_request") - if not isinstance(pr, dict): + if event.pull_request is None: return None - return typing.cast("int", pr["number"]) + return event.pull_request.number case _: return None diff --git a/mergify_cli/ci/git_refs/detector.py b/mergify_cli/ci/git_refs/detector.py index deb29a10..07070190 100644 --- a/mergify_cli/ci/git_refs/detector.py +++ b/mergify_cli/ci/git_refs/detector.py @@ -10,6 +10,10 @@ from mergify_cli.ci.scopes import exceptions +if typing.TYPE_CHECKING: + from mergify_cli.ci import github_event + + GITHUB_ACTIONS_BASE_OUTPUT_NAME = "base" GITHUB_ACTIONS_HEAD_OUTPUT_NAME = "head" @@ -18,55 +22,6 @@ class BaseNotFoundError(exceptions.ScopesError): pass -def _detect_base_from_merge_queue_payload(ev: dict[str, typing.Any]) -> str | None: - content = queue_metadata.extract_from_event(ev) - if content: - return content["checking_base_sha"] - return None - - -def _detect_head_from_event(ev: dict[str, typing.Any]) -> str | None: - pr = ev.get("pull_request") - if isinstance(pr, dict): - sha = pr.get("head", {}).get("sha") - if isinstance(sha, str) and sha: - return sha - - return None - - -def _detect_base_from_event(ev: dict[str, typing.Any]) -> str | None: - pr = ev.get("pull_request") - if isinstance(pr, dict): - sha = pr.get("base", {}).get("sha") - if isinstance(sha, str) and sha: - return sha - return None - - -def _detect_default_branch_from_event(ev: dict[str, typing.Any]) -> str | None: - repo = ev.get("repository") - if isinstance(repo, dict): - sha = repo.get("default_branch") - if isinstance(sha, str) and sha: - return sha - return None - - -def _detect_head_from_push_event(ev: dict[str, typing.Any]) -> str | None: - sha = ev.get("after") - if isinstance(sha, str) and sha: - return sha - return None - - -def _detect_base_from_push_event(ev: dict[str, typing.Any]) -> str | None: - sha = ev.get("before") - if isinstance(sha, str) and sha: - return sha - return None - - ReferencesSource = typing.Literal[ "manual", "merge_queue", @@ -92,46 +47,63 @@ def maybe_write_to_github_outputs(self) -> None: fh.write(f"{GITHUB_ACTIONS_HEAD_OUTPUT_NAME}={self.head}\n") +def _detect_from_pull_request_event( + ev: github_event.GitHubEvent, +) -> References | None: + head = "HEAD" + if ev.pull_request and ev.pull_request.head: + head = ev.pull_request.head.sha + + # 0) merge-queue PR override + content = queue_metadata.extract_from_event(ev) + if content: + return References(content["checking_base_sha"], head, "merge_queue") + + # 1) standard event payload + if ev.pull_request and ev.pull_request.base: + return References(ev.pull_request.base.sha, head, "github_event_pull_request") + + # 2) repository default branch fallback + if ev.repository and ev.repository.default_branch: + return References( + ev.repository.default_branch, + head, + "github_event_pull_request", + ) + + return None + + +def _detect_from_push_event(ev: github_event.GitHubEvent) -> References | None: + head_sha = ev.after or "HEAD" + if ev.before: + return References(ev.before, head_sha, "github_event_push") + + if ev.repository and ev.repository.default_branch: + return References(ev.repository.default_branch, "HEAD", "github_event_push") + + return None + + def detect() -> References: try: event_name, event = utils.get_github_event() except utils.GitHubEventNotFoundError: # fallback to last commit return References("HEAD^", "HEAD", "fallback_last_commit") + + if event_name in queue_metadata.PULL_REQUEST_EVENTS: + result = _detect_from_pull_request_event(event) + if result: + return result + + elif event_name == "push": + result = _detect_from_push_event(event) + if result: + return result + else: - if event_name in queue_metadata.PULL_REQUEST_EVENTS: - head = _detect_head_from_event(event) or "HEAD" - # 0) merge-queue PR override - mq_sha = _detect_base_from_merge_queue_payload(event) - if mq_sha: - return References(mq_sha, head, "merge_queue") - - # 1) standard event payload - event_sha = _detect_base_from_event(event) - if event_sha: - return References(event_sha, head, "github_event_pull_request") - - # 2) standard event payload - event_sha = _detect_default_branch_from_event(event) - if event_sha: - return References( - event_sha, - head, - "github_event_pull_request", - ) - - elif event_name == "push": - head_sha = _detect_head_from_push_event(event) or "HEAD" - base_sha = _detect_base_from_push_event(event) - if base_sha: - return References(base_sha, head_sha, "github_event_push") - - event_sha = _detect_default_branch_from_event(event) - if event_sha: - return References(event_sha, "HEAD", "github_event_push") - - else: - return References(None, "HEAD", "github_event_other") + return References(None, "HEAD", "github_event_other") msg = "Could not detect base SHA. Provide GITHUB_EVENT_NAME / GITHUB_EVENT_PATH." raise BaseNotFoundError(msg) diff --git a/mergify_cli/ci/github_event.py b/mergify_cli/ci/github_event.py new file mode 100644 index 00000000..29c1f006 --- /dev/null +++ b/mergify_cli/ci/github_event.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import pydantic + + +class GitRef(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore") + + sha: str + ref: str | None = None + + +class PullRequest(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore") + + number: int + title: str = "" + body: str | None = None + base: GitRef | None = None + head: GitRef | None = None + + +class Repository(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore") + + default_branch: str | None = None + + +class GitHubEvent(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore") + + pull_request: PullRequest | None = None + repository: Repository | None = None + # push events + before: str | None = None + after: str | None = None diff --git a/mergify_cli/ci/queue/metadata.py b/mergify_cli/ci/queue/metadata.py index bec9650b..1a7215b6 100644 --- a/mergify_cli/ci/queue/metadata.py +++ b/mergify_cli/ci/queue/metadata.py @@ -8,6 +8,10 @@ from mergify_cli import utils +if typing.TYPE_CHECKING: + from mergify_cli.ci import github_event + + class MergeQueuePullRequest(typing.TypedDict): number: int @@ -46,23 +50,18 @@ def _yaml_docs_from_fenced_blocks(body: str) -> MergeQueueMetadata | None: return None -def extract_from_event(ev: dict[str, typing.Any]) -> MergeQueueMetadata | None: - pr = ev.get("pull_request") - if not isinstance(pr, dict): - return None - title = pr.get("title") or "" - if not isinstance(title, str): +def extract_from_event(ev: github_event.GitHubEvent) -> MergeQueueMetadata | None: + if ev.pull_request is None: return None - if not title.startswith("merge queue: "): + if not ev.pull_request.title.startswith("merge queue: "): return None - body = pr.get("body") - if not body: + if not ev.pull_request.body: click.echo( "WARNING: MQ pull request without body, skipping metadata extraction", err=True, ) return None - ref = _yaml_docs_from_fenced_blocks(body) + ref = _yaml_docs_from_fenced_blocks(ev.pull_request.body) if ref is None: click.echo( "WARNING: MQ pull request body without Mergify metadata, skipping metadata extraction", diff --git a/mergify_cli/tests/ci/git_refs/test_git_refs_detector.py b/mergify_cli/tests/ci/git_refs/test_git_refs_detector.py index 93fba93d..7982fd55 100644 --- a/mergify_cli/tests/ci/git_refs/test_git_refs_detector.py +++ b/mergify_cli/tests/ci/git_refs/test_git_refs_detector.py @@ -84,6 +84,7 @@ def test_detect_base_from_pull_request_event_path( ) -> None: event_data = { "pull_request": { + "number": 1, "base": {"sha": "abc123"}, "head": {"sha": "xyz987"}, }, @@ -109,6 +110,7 @@ def test_detect_base_merge_queue_override( ) -> None: event_data = { "pull_request": { + "number": 1, "title": "merge queue: embarking #1 together", "body": "```yaml\nchecking_base_sha: xyz789\n```", "base": {"sha": "abc123"}, @@ -143,6 +145,35 @@ def test_detect_base_no_info( detector.detect() +def test_detect_no_github_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("GITHUB_EVENT_NAME", raising=False) + monkeypatch.delenv("GITHUB_EVENT_PATH", raising=False) + + result = detector.detect() + + assert result == detector.References("HEAD^", "HEAD", "fallback_last_commit") + + +def test_detect_push_event_no_info( + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + event_data: dict[str, str] = {} + event_file = tmp_path / "event.json" + event_file.write_text(json.dumps(event_data)) + + monkeypatch.setenv("GITHUB_EVENT_NAME", "push") + monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) + + with pytest.raises( + detector.BaseNotFoundError, + match="Could not detect base SHA", + ): + detector.detect() + + def test_detect_unhandled_event( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, diff --git a/mergify_cli/tests/ci/push_event.json b/mergify_cli/tests/ci/push_event.json new file mode 100644 index 00000000..f25dd231 --- /dev/null +++ b/mergify_cli/tests/ci/push_event.json @@ -0,0 +1,46 @@ +{ + "ref": "refs/heads/main", + "before": "773db6b5c5f77d0c70c75e6dacef1684cb03495f", + "after": "10068d193546082d802676bb310a570d0898e061", + "created": false, + "deleted": false, + "forced": false, + "pusher": { + "name": "mergify[bot]", + "email": "37929162+mergify[bot]@users.noreply.github.com" + }, + "repository": { + "id": 368096773, + "node_id": "MDEwOlJlcG9zaXRvcnkzNjgwOTY3NzM=", + "name": "mergify-cli", + "full_name": "Mergifyio/mergify-cli", + "private": false, + "owner": { + "login": "Mergifyio", + "id": 37838584, + "node_id": "MDEyOk9yZ2FuaXphdGlvbjM3ODM4NTg0", + "type": "Organization", + "site_admin": false + }, + "html_url": "https://github.com/Mergifyio/mergify-cli", + "description": "Mergify CLI tool", + "fork": false, + "url": "https://api.github.com/repos/Mergifyio/mergify-cli", + "default_branch": "main", + "visibility": "public" + }, + "head_commit": { + "id": "10068d193546082d802676bb310a570d0898e061", + "message": "chore(deps): update dependency uv to v0.10.6", + "timestamp": "2026-02-25T07:34:33Z", + "author": { + "name": "renovate[bot]", + "email": "29139614+renovate[bot]@users.noreply.github.com" + } + }, + "sender": { + "login": "mergify[bot]", + "id": 37929162, + "type": "Bot" + } +} diff --git a/mergify_cli/tests/ci/queue/test_metadata.py b/mergify_cli/tests/ci/queue/test_metadata.py index 83e17b11..e4de582a 100644 --- a/mergify_cli/tests/ci/queue/test_metadata.py +++ b/mergify_cli/tests/ci/queue/test_metadata.py @@ -57,6 +57,7 @@ def test_detect_merge_queue( ) -> None: event_data = { "pull_request": { + "number": 10, "title": "merge queue: embarking #1 and #2 together", "body": "```yaml\n---\nchecking_base_sha: xyz789\npull_requests:\n - number: 1\n - number: 2\nprevious_failed_batches:\n - draft_pr_number: 5\n checked_pull_requests:\n - 1\n - 3\n...\n```", "base": {"sha": "abc123"}, @@ -78,12 +79,52 @@ def test_detect_merge_queue( ] +def test_detect_merge_queue_no_body( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + event_data = { + "pull_request": { + "number": 10, + "title": "merge queue: embarking #1 together", + }, + } + event_file = tmp_path / "event.json" + event_file.write_text(json.dumps(event_data)) + + monkeypatch.setenv("GITHUB_EVENT_NAME", "pull_request") + monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) + + assert metadata.detect() is None + + +def test_detect_merge_queue_body_without_yaml( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + event_data = { + "pull_request": { + "number": 10, + "title": "merge queue: embarking #1 together", + "body": "No yaml metadata here", + }, + } + event_file = tmp_path / "event.json" + event_file.write_text(json.dumps(event_data)) + + monkeypatch.setenv("GITHUB_EVENT_NAME", "pull_request") + monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) + + assert metadata.detect() is None + + def test_detect_not_merge_queue( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, ) -> None: event_data = { "pull_request": { + "number": 5, "title": "feat: add something", "body": "Some description", "base": {"sha": "abc123"}, diff --git a/mergify_cli/tests/ci/test_cli.py b/mergify_cli/tests/ci/test_cli.py index 1e74e385..bc256dcc 100644 --- a/mergify_cli/tests/ci/test_cli.py +++ b/mergify_cli/tests/ci/test_cli.py @@ -521,6 +521,7 @@ def test_queue_info( ) -> None: event_data = { "pull_request": { + "number": 10, "title": "merge queue: embarking #1 and #2 together", "body": "```yaml\n---\nchecking_base_sha: xyz789\npull_requests:\n - number: 1\n - number: 2\nprevious_failed_batches: []\n...\n```", "base": {"sha": "abc123"}, @@ -557,6 +558,7 @@ def test_queue_info_not_merge_queue( ) -> None: event_data = { "pull_request": { + "number": 5, "title": "feat: add something", "body": "Some description", "base": {"sha": "abc123"}, diff --git a/mergify_cli/tests/ci/test_github_event.py b/mergify_cli/tests/ci/test_github_event.py new file mode 100644 index 00000000..90806a66 --- /dev/null +++ b/mergify_cli/tests/ci/test_github_event.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import json +import pathlib + +import pydantic +import pytest + +from mergify_cli.ci.github_event import GitHubEvent + + +PULL_REQUEST_EVENT = pathlib.Path(__file__).parent / "pull_request.json" +PUSH_EVENT = pathlib.Path(__file__).parent / "push_event.json" + + +def test_parse_real_pull_request_event() -> None: + raw = json.loads(PULL_REQUEST_EVENT.read_bytes()) + event = GitHubEvent.model_validate(raw) + + assert event.pull_request is not None + assert event.pull_request.number == 2 + assert event.pull_request.title == "Update the README with new information." + assert event.pull_request.body is not None + + assert event.pull_request.head is not None + assert event.pull_request.head.sha == "ec26c3e57ca3a959ca5aad62de7213c562f8c821" + assert event.pull_request.head.ref == "changes" + + assert event.pull_request.base is not None + assert event.pull_request.base.sha == "f95f852bd8fca8fcc58a9a2d6c842781e32a215e" + assert event.pull_request.base.ref == "master" + + assert event.repository is not None + assert event.repository.default_branch == "master" + + # push-event fields should be None + assert event.before is None + assert event.after is None + + +def test_parse_real_push_event() -> None: + raw = json.loads(PUSH_EVENT.read_bytes()) + event = GitHubEvent.model_validate(raw) + + assert event.before == "773db6b5c5f77d0c70c75e6dacef1684cb03495f" + assert event.after == "10068d193546082d802676bb310a570d0898e061" + assert event.pull_request is None + assert event.repository is not None + assert event.repository.default_branch == "main" + + +def test_parse_empty_event() -> None: + event = GitHubEvent.model_validate({}) + + assert event.pull_request is None + assert event.repository is None + assert event.before is None + assert event.after is None + + +def test_parse_minimal_pull_request() -> None: + raw = {"pull_request": {"number": 42}} + event = GitHubEvent.model_validate(raw) + + assert event.pull_request is not None + assert event.pull_request.number == 42 + assert not event.pull_request.title + assert event.pull_request.body is None + assert event.pull_request.base is None + assert event.pull_request.head is None + + +def test_parse_pull_request_missing_number() -> None: + with pytest.raises(pydantic.ValidationError, match="number"): + GitHubEvent.model_validate({"pull_request": {"title": "oops"}}) diff --git a/mergify_cli/tests/test_utils.py b/mergify_cli/tests/test_utils.py index 98a929b1..7b78f9c7 100644 --- a/mergify_cli/tests/test_utils.py +++ b/mergify_cli/tests/test_utils.py @@ -127,7 +127,8 @@ def test_get_github_event_success( monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) name, event = utils.get_github_event() assert name == "pull_request" - assert event == event_data + assert event.pull_request is not None + assert event.pull_request.number == 123 def test_get_github_event_not_found(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/mergify_cli/utils.py b/mergify_cli/utils.py index cebb4350..c7cece38 100644 --- a/mergify_cli/utils.py +++ b/mergify_cli/utils.py @@ -18,7 +18,6 @@ import asyncio import dataclasses import functools -import json import os import pathlib import typing @@ -28,6 +27,7 @@ from mergify_cli import VERSION from mergify_cli import console +from mergify_cli.ci import github_event if typing.TYPE_CHECKING: @@ -312,15 +312,16 @@ class GitHubEventNotFoundError(Exception): pass -def get_github_event() -> tuple[str, typing.Any]: +def get_github_event() -> tuple[str, github_event.GitHubEvent]: event_name = os.environ.get("GITHUB_EVENT_NAME") if not event_name: raise GitHubEventNotFoundError event_path = os.environ.get("GITHUB_EVENT_PATH") if event_path and pathlib.Path(event_path).is_file(): try: - with pathlib.Path(event_path).open("r", encoding="utf-8") as f: - return event_name, json.load(f) + return event_name, github_event.GitHubEvent.model_validate_json( + pathlib.Path(event_path).read_text(encoding="utf-8"), + ) except FileNotFoundError: pass raise GitHubEventNotFoundError