From b411063c56ce583ff3a5d1b5d0788465fb9c8554 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 1 Mar 2026 21:51:52 +0800 Subject: [PATCH] Add check_pr branch state for PR/thread restore Entire-Checkpoint: c4dfc620f91d --- src/agent/graph.py | 12 +++- src/agent/loop.py | 1 + src/agent/nodes.py | 147 +++++++++++++++++++++++++++++++++++++++++++- src/agent/state.py | 11 ++++ src/git_ops/repo.py | 53 ++++++++++++++++ 5 files changed, 222 insertions(+), 2 deletions(-) diff --git a/src/agent/graph.py b/src/agent/graph.py index 7497ff3..06aefa9 100644 --- a/src/agent/graph.py +++ b/src/agent/graph.py @@ -9,6 +9,7 @@ from .nodes import ( check_changes, + check_pr, clone_and_branch, create_pr, create_pr_post, @@ -38,6 +39,13 @@ def route_after_plan(state: AgentState) -> str: return "implement" +def route_after_clone(state: AgentState) -> str: + """Route to PR inspection only when running against a provided branch.""" + if state.get("should_check_pr"): + return "check_pr" + return "setup" + + _NON_CODE_EXTENSIONS = { ".md", ".txt", ".rst", # docs ".yml", ".yaml", ".toml", ".ini", ".cfg", ".conf", # config @@ -106,6 +114,7 @@ def build_graph(checkpointer: BaseCheckpointSaver | None = None) -> CompiledStat # Core infrastructure nodes graph.add_node("clone_and_branch", clone_and_branch) + graph.add_node("check_pr", check_pr) graph.add_node("setup", setup) graph.add_node("reason", reason) graph.add_node("tools", tool_node) @@ -130,7 +139,8 @@ def build_graph(checkpointer: BaseCheckpointSaver | None = None) -> CompiledStat # Edges graph.add_edge(START, "clone_and_branch") - graph.add_edge("clone_and_branch", "setup") + graph.add_conditional_edges("clone_and_branch", route_after_clone, {"check_pr": "check_pr", "setup": "setup"}) + graph.add_edge("check_pr", "setup") graph.add_edge("setup", "plan") # Reasoning loop for each phase diff --git a/src/agent/loop.py b/src/agent/loop.py index d1dcf78..6626607 100644 --- a/src/agent/loop.py +++ b/src/agent/loop.py @@ -58,6 +58,7 @@ async def run(self) -> str: "task": self.task, "settings": self.settings, "repo_config": self.repo_config, + "checkpointer": checkpointer, } if self.branch: initial_state["branch"] = self.branch diff --git a/src/agent/nodes.py b/src/agent/nodes.py index 34ee0bb..ed570bb 100644 --- a/src/agent/nodes.py +++ b/src/agent/nodes.py @@ -8,6 +8,7 @@ import os import re from pathlib import Path +from typing import Any from langchain_core.messages import HumanMessage, RemoveMessage, SystemMessage from langchain_core.runnables import RunnableConfig @@ -72,10 +73,100 @@ def _get_metadata(state: AgentState, config: RunnableConfig) -> tuple[str, str]: # Get model for current prompt type if set, else default prompt_type = state.get("prompt_type", "") model_name = settings.llm.get_model(prompt_type) - thread_id = config.get("configurable", {}).get("thread_id", "unknown") + thread_id = state.get("restored_thread_id") or config.get("configurable", {}).get("thread_id", "unknown") return model_name, thread_id +def _parse_thread_id_from_text(text: str) -> str | None: + """Extract thread_id from a line like 'thread_id: '.""" + for raw_line in text.splitlines(): + line = raw_line.strip() + if line.startswith("thread_id:"): + thread_id = line.partition(":")[2].strip() + if thread_id: + return thread_id + return None + + +def _extract_checkpoint_values(checkpoint_tuple: Any) -> dict[str, Any]: + """Best-effort extraction of channel values from a checkpoint tuple.""" + if not checkpoint_tuple: + return {} + + if isinstance(checkpoint_tuple, dict): + checkpoint = checkpoint_tuple.get("checkpoint", checkpoint_tuple) + else: + checkpoint = getattr(checkpoint_tuple, "checkpoint", None) + + if isinstance(checkpoint, dict): + values = checkpoint.get("channel_values") + if isinstance(values, dict): + return values + + if isinstance(checkpoint_tuple, dict): + values = checkpoint_tuple.get("channel_values") + else: + values = getattr(checkpoint_tuple, "channel_values", None) + if isinstance(values, dict): + return values + + return {} + + +async def _restore_checkpoint_state(checkpointer: Any, thread_id: str) -> dict[str, Any]: + """Best-effort checkpoint restore for a specific thread_id.""" + if not checkpointer or not thread_id: + return {} + + cfg = {"configurable": {"thread_id": thread_id}} + + checkpoint_tuple = None + aget_tuple = getattr(checkpointer, "aget_tuple", None) + if callable(aget_tuple): + checkpoint_tuple = await aget_tuple(cfg) + else: + get_tuple = getattr(checkpointer, "get_tuple", None) + if callable(get_tuple): + maybe = get_tuple(cfg) + if asyncio.iscoroutine(maybe): + checkpoint_tuple = await maybe + else: + checkpoint_tuple = maybe + + return _extract_checkpoint_values(checkpoint_tuple) + + +def _summarize_review_comments(comments: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Normalize review comments to a compact state-friendly structure.""" + normalized: list[dict[str, Any]] = [] + for comment in comments: + user = comment.get("user") or {} + normalized.append( + { + "author": user.get("login", "unknown"), + "body": comment.get("body", ""), + "path": comment.get("path", ""), + "line": comment.get("line"), + "created_at": comment.get("created_at", ""), + "url": comment.get("html_url", ""), + } + ) + return normalized + + +def _derive_review_status(pr: dict[str, Any], reviews: list[dict[str, Any]]) -> str: + """Return review status with fallback to the latest review state.""" + decision = pr.get("reviewDecision") + if decision: + return str(decision) + + if not reviews: + return "" + + latest_state = str(reviews[-1].get("state", "")).strip() + return latest_state + + async def _run_checks(repo_config, repo_dir: Path) -> CommandResult: """Run test suite and linter in parallel, returning combined result.""" coros = [] @@ -180,11 +271,65 @@ async def clone_and_branch(state: AgentState, config: RunnableConfig) -> dict: "git": git, "repo_dir": git.repo_dir, "active_branch": active_branch, + "should_check_pr": bool(branch_input), "readme_preamble": readme_preamble, "messages": [SystemMessage(content=prompts.SYSTEM_PROMPT)], "llm": create_llm(settings.llm).bind_tools(ALL_TOOLS), "attempt": 0, "max_attempts": settings.agent.max_fix_attempts, + "checkpointer": state.get("checkpointer"), + } + + +async def check_pr(state: AgentState, config: RunnableConfig) -> dict: + """Load PR metadata for branch runs and try checkpoint restore from PR thread_id.""" + git = state["git"] + branch = state["active_branch"] + + try: + pr = git.fetch_open_pr_for_branch(branch) + except Exception as exc: # pragma: no cover - defensive for network/gh failures + logger.warning("Failed to fetch PR for branch %s: %s", branch, exc) + return {} + + if not pr: + logger.info("No open PR found for branch %s", branch) + return {} + + pr_body = pr.get("body", "") or "" + thread_id = _parse_thread_id_from_text(pr_body) + + restored_values: dict[str, Any] = {} + checkpoint_restore_ok = False + if thread_id: + try: + restored_values = await _restore_checkpoint_state(state.get("checkpointer"), thread_id) + checkpoint_restore_ok = bool(restored_values) + except Exception as exc: # pragma: no cover - defensive for backend differences + logger.warning("Failed to restore checkpoint for thread_id %s: %s", thread_id, exc) + + reviews: list[dict[str, Any]] = [] + try: + reviews = git.fetch_pr_reviews(pr["number"]) + except Exception as exc: # pragma: no cover - defensive for network/gh failures + logger.warning("Failed to fetch PR reviews for #%s: %s", pr["number"], exc) + + review_comments: list[dict[str, Any]] = [] + try: + review_comments = _summarize_review_comments(git.fetch_pr_review_comments(pr["number"])) + except Exception as exc: # pragma: no cover - defensive for network/gh failures + logger.warning("Failed to fetch PR review comments for #%s: %s", pr["number"], exc) + + return { + "pr_url": pr.get("url", ""), + "pr_number": pr.get("number", 0), + "pr_title": pr.get("title", ""), + "pr_body": pr_body, + "restored_thread_id": thread_id or "", + "checkpoint_restore_ok": checkpoint_restore_ok, + "restored_checkpoint_state": restored_values, + "pr_review_status": _derive_review_status(pr, reviews), + "pr_review_comments": review_comments, } diff --git a/src/agent/state.py b/src/agent/state.py index a497f49..c1ec6ac 100644 --- a/src/agent/state.py +++ b/src/agent/state.py @@ -30,6 +30,8 @@ class AgentState(TypedDict, total=False): git: Annotated[GitRepo, UntrackedValue] repo_dir: Annotated[Path, UntrackedValue] active_branch: Annotated[str, UntrackedValue] + checkpointer: Annotated[Any, UntrackedValue] + should_check_pr: bool # Control flow prompt_type: str @@ -51,6 +53,15 @@ class AgentState(TypedDict, total=False): # PR pr_ready: bool + pr_url: str + pr_number: int + pr_title: str + pr_body: str + pr_review_status: str + pr_review_comments: list[dict[str, Any]] + restored_thread_id: str + checkpoint_restore_ok: bool + restored_checkpoint_state: dict[str, Any] # Token tracking token_usage: dict[str, int] diff --git a/src/git_ops/repo.py b/src/git_ops/repo.py index e04cffc..13a8197 100644 --- a/src/git_ops/repo.py +++ b/src/git_ops/repo.py @@ -10,6 +10,7 @@ import subprocess import tempfile from pathlib import Path +from typing import Any from ..agent.executor import list_files as list_files_from_executor from ..config.settings import GitConfig, RepoConfig @@ -160,6 +161,58 @@ def fetch_last_commit(self, branch: str) -> tuple[str, str]: except subprocess.CalledProcessError as e: raise RuntimeError(f"gh api failed: {e.stderr.strip() or str(e)}") + def fetch_open_pr_for_branch(self, branch: str) -> dict[str, Any] | None: + """Fetch open PR metadata for a branch, if one exists.""" + nwo = self.extract_github_nwo() + if not nwo: + raise RuntimeError(f"Cannot extract owner/repo from URL: {self.repo.url}") + + try: + result = self._gh( + "pr", + "list", + "--repo", + nwo, + "--state", + "open", + "--head", + branch, + "--json", + "number,title,body,url,reviewDecision", + "--limit", + "1", + ) + data = json.loads(result.stdout) + if not data: + return None + return data[0] + except subprocess.CalledProcessError as e: + raise RuntimeError(f"gh pr list failed: {e.stderr.strip() or str(e)}") + + def fetch_pr_reviews(self, pr_number: int) -> list[dict[str, Any]]: + """Fetch PR reviews from GitHub.""" + nwo = self.extract_github_nwo() + if not nwo: + raise RuntimeError(f"Cannot extract owner/repo from URL: {self.repo.url}") + + try: + result = self._gh("api", f"repos/{nwo}/pulls/{pr_number}/reviews") + return json.loads(result.stdout) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"gh api reviews failed: {e.stderr.strip() or str(e)}") + + def fetch_pr_review_comments(self, pr_number: int) -> list[dict[str, Any]]: + """Fetch PR review comments from GitHub.""" + nwo = self.extract_github_nwo() + if not nwo: + raise RuntimeError(f"Cannot extract owner/repo from URL: {self.repo.url}") + + try: + result = self._gh("api", f"repos/{nwo}/pulls/{pr_number}/comments") + return json.loads(result.stdout) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"gh api review comments failed: {e.stderr.strip() or str(e)}") + def _setup_credential_helper(self) -> None: """Configure a git credential helper that supplies the token without embedding it in the remote URL.""" token = self._resolve_token()