Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .nodes import (
check_changes,
check_pr,
clone_and_branch,
create_pr,
create_pr_post,
Expand Down Expand Up @@ -38,6 +39,13 @@ def route_after_plan(state: AgentState) -> str:
return "implement"


def route_after_clone(state: AgentState) -> str:
"""Route to PR inspection only when running against a provided branch."""
if state.get("should_check_pr"):
return "check_pr"
return "setup"


_NON_CODE_EXTENSIONS = {
".md", ".txt", ".rst", # docs
".yml", ".yaml", ".toml", ".ini", ".cfg", ".conf", # config
Expand Down Expand Up @@ -106,6 +114,7 @@ def build_graph(checkpointer: BaseCheckpointSaver | None = None) -> CompiledStat

# Core infrastructure nodes
graph.add_node("clone_and_branch", clone_and_branch)
graph.add_node("check_pr", check_pr)
graph.add_node("setup", setup)
graph.add_node("reason", reason)
graph.add_node("tools", tool_node)
Expand All @@ -130,7 +139,8 @@ def build_graph(checkpointer: BaseCheckpointSaver | None = None) -> CompiledStat

# Edges
graph.add_edge(START, "clone_and_branch")
graph.add_edge("clone_and_branch", "setup")
graph.add_conditional_edges("clone_and_branch", route_after_clone, {"check_pr": "check_pr", "setup": "setup"})
graph.add_edge("check_pr", "setup")
graph.add_edge("setup", "plan")

# Reasoning loop for each phase
Expand Down
1 change: 1 addition & 0 deletions src/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def run(self) -> str:
"task": self.task,
"settings": self.settings,
"repo_config": self.repo_config,
"checkpointer": checkpointer,
}
if self.branch:
initial_state["branch"] = self.branch
Expand Down
147 changes: 146 additions & 1 deletion src/agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import re
from pathlib import Path
from typing import Any

from langchain_core.messages import HumanMessage, RemoveMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -72,10 +73,100 @@ def _get_metadata(state: AgentState, config: RunnableConfig) -> tuple[str, str]:
# Get model for current prompt type if set, else default
prompt_type = state.get("prompt_type", "")
model_name = settings.llm.get_model(prompt_type)
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
thread_id = state.get("restored_thread_id") or config.get("configurable", {}).get("thread_id", "unknown")
return model_name, thread_id


def _parse_thread_id_from_text(text: str) -> str | None:
"""Extract thread_id from a line like 'thread_id: <value>'."""
for raw_line in text.splitlines():
line = raw_line.strip()
if line.startswith("thread_id:"):
thread_id = line.partition(":")[2].strip()
if thread_id:
return thread_id
return None


def _extract_checkpoint_values(checkpoint_tuple: Any) -> dict[str, Any]:
"""Best-effort extraction of channel values from a checkpoint tuple."""
if not checkpoint_tuple:
return {}

if isinstance(checkpoint_tuple, dict):
checkpoint = checkpoint_tuple.get("checkpoint", checkpoint_tuple)
else:
checkpoint = getattr(checkpoint_tuple, "checkpoint", None)

if isinstance(checkpoint, dict):
values = checkpoint.get("channel_values")
if isinstance(values, dict):
return values

if isinstance(checkpoint_tuple, dict):
values = checkpoint_tuple.get("channel_values")
else:
values = getattr(checkpoint_tuple, "channel_values", None)
if isinstance(values, dict):
return values

return {}


async def _restore_checkpoint_state(checkpointer: Any, thread_id: str) -> dict[str, Any]:
"""Best-effort checkpoint restore for a specific thread_id."""
if not checkpointer or not thread_id:
return {}

cfg = {"configurable": {"thread_id": thread_id}}

checkpoint_tuple = None
aget_tuple = getattr(checkpointer, "aget_tuple", None)
if callable(aget_tuple):
checkpoint_tuple = await aget_tuple(cfg)
else:
get_tuple = getattr(checkpointer, "get_tuple", None)
if callable(get_tuple):
maybe = get_tuple(cfg)
if asyncio.iscoroutine(maybe):
checkpoint_tuple = await maybe
else:
checkpoint_tuple = maybe

return _extract_checkpoint_values(checkpoint_tuple)


def _summarize_review_comments(comments: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Normalize review comments to a compact state-friendly structure."""
normalized: list[dict[str, Any]] = []
for comment in comments:
user = comment.get("user") or {}
normalized.append(
{
"author": user.get("login", "unknown"),
"body": comment.get("body", ""),
"path": comment.get("path", ""),
"line": comment.get("line"),
"created_at": comment.get("created_at", ""),
"url": comment.get("html_url", ""),
}
)
return normalized


def _derive_review_status(pr: dict[str, Any], reviews: list[dict[str, Any]]) -> str:
"""Return review status with fallback to the latest review state."""
decision = pr.get("reviewDecision")
if decision:
return str(decision)

if not reviews:
return ""

latest_state = str(reviews[-1].get("state", "")).strip()
return latest_state


async def _run_checks(repo_config, repo_dir: Path) -> CommandResult:
"""Run test suite and linter in parallel, returning combined result."""
coros = []
Expand Down Expand Up @@ -180,11 +271,65 @@ async def clone_and_branch(state: AgentState, config: RunnableConfig) -> dict:
"git": git,
"repo_dir": git.repo_dir,
"active_branch": active_branch,
"should_check_pr": bool(branch_input),
"readme_preamble": readme_preamble,
"messages": [SystemMessage(content=prompts.SYSTEM_PROMPT)],
"llm": create_llm(settings.llm).bind_tools(ALL_TOOLS),
"attempt": 0,
"max_attempts": settings.agent.max_fix_attempts,
"checkpointer": state.get("checkpointer"),
}


async def check_pr(state: AgentState, config: RunnableConfig) -> dict:
"""Load PR metadata for branch runs and try checkpoint restore from PR thread_id."""
git = state["git"]
branch = state["active_branch"]

try:
pr = git.fetch_open_pr_for_branch(branch)
except Exception as exc: # pragma: no cover - defensive for network/gh failures
logger.warning("Failed to fetch PR for branch %s: %s", branch, exc)
return {}

if not pr:
logger.info("No open PR found for branch %s", branch)
return {}

pr_body = pr.get("body", "") or ""
thread_id = _parse_thread_id_from_text(pr_body)

restored_values: dict[str, Any] = {}
checkpoint_restore_ok = False
if thread_id:
try:
restored_values = await _restore_checkpoint_state(state.get("checkpointer"), thread_id)
checkpoint_restore_ok = bool(restored_values)
except Exception as exc: # pragma: no cover - defensive for backend differences
logger.warning("Failed to restore checkpoint for thread_id %s: %s", thread_id, exc)

reviews: list[dict[str, Any]] = []
try:
reviews = git.fetch_pr_reviews(pr["number"])
except Exception as exc: # pragma: no cover - defensive for network/gh failures
logger.warning("Failed to fetch PR reviews for #%s: %s", pr["number"], exc)

review_comments: list[dict[str, Any]] = []
try:
review_comments = _summarize_review_comments(git.fetch_pr_review_comments(pr["number"]))
except Exception as exc: # pragma: no cover - defensive for network/gh failures
logger.warning("Failed to fetch PR review comments for #%s: %s", pr["number"], exc)

return {
"pr_url": pr.get("url", ""),
"pr_number": pr.get("number", 0),
"pr_title": pr.get("title", ""),
"pr_body": pr_body,
"restored_thread_id": thread_id or "",
"checkpoint_restore_ok": checkpoint_restore_ok,
"restored_checkpoint_state": restored_values,
"pr_review_status": _derive_review_status(pr, reviews),
"pr_review_comments": review_comments,
}


Expand Down
11 changes: 11 additions & 0 deletions src/agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class AgentState(TypedDict, total=False):
git: Annotated[GitRepo, UntrackedValue]
repo_dir: Annotated[Path, UntrackedValue]
active_branch: Annotated[str, UntrackedValue]
checkpointer: Annotated[Any, UntrackedValue]
should_check_pr: bool

# Control flow
prompt_type: str
Expand All @@ -51,6 +53,15 @@ class AgentState(TypedDict, total=False):

# PR
pr_ready: bool
pr_url: str
pr_number: int
pr_title: str
pr_body: str
pr_review_status: str
pr_review_comments: list[dict[str, Any]]
restored_thread_id: str
checkpoint_restore_ok: bool
restored_checkpoint_state: dict[str, Any]

# Token tracking
token_usage: dict[str, int]
Expand Down
53 changes: 53 additions & 0 deletions src/git_ops/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import tempfile
from pathlib import Path
from typing import Any

from ..agent.executor import list_files as list_files_from_executor
from ..config.settings import GitConfig, RepoConfig
Expand Down Expand Up @@ -160,6 +161,58 @@ def fetch_last_commit(self, branch: str) -> tuple[str, str]:
except subprocess.CalledProcessError as e:
raise RuntimeError(f"gh api failed: {e.stderr.strip() or str(e)}")

def fetch_open_pr_for_branch(self, branch: str) -> dict[str, Any] | None:
"""Fetch open PR metadata for a branch, if one exists."""
nwo = self.extract_github_nwo()
if not nwo:
raise RuntimeError(f"Cannot extract owner/repo from URL: {self.repo.url}")

try:
result = self._gh(
"pr",
"list",
"--repo",
nwo,
"--state",
"open",
"--head",
branch,
"--json",
"number,title,body,url,reviewDecision",
"--limit",
"1",
)
data = json.loads(result.stdout)
if not data:
return None
return data[0]
except subprocess.CalledProcessError as e:
raise RuntimeError(f"gh pr list failed: {e.stderr.strip() or str(e)}")

def fetch_pr_reviews(self, pr_number: int) -> list[dict[str, Any]]:
"""Fetch PR reviews from GitHub."""
nwo = self.extract_github_nwo()
if not nwo:
raise RuntimeError(f"Cannot extract owner/repo from URL: {self.repo.url}")

try:
result = self._gh("api", f"repos/{nwo}/pulls/{pr_number}/reviews")
return json.loads(result.stdout)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"gh api reviews failed: {e.stderr.strip() or str(e)}")

def fetch_pr_review_comments(self, pr_number: int) -> list[dict[str, Any]]:
"""Fetch PR review comments from GitHub."""
nwo = self.extract_github_nwo()
if not nwo:
raise RuntimeError(f"Cannot extract owner/repo from URL: {self.repo.url}")

try:
result = self._gh("api", f"repos/{nwo}/pulls/{pr_number}/comments")
return json.loads(result.stdout)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"gh api review comments failed: {e.stderr.strip() or str(e)}")

def _setup_credential_helper(self) -> None:
"""Configure a git credential helper that supplies the token without embedding it in the remote URL."""
token = self._resolve_token()
Expand Down