Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions src/millstone/runtime/merge_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
loc_threshold: int,
max_retries: int,
tasklist: str = "docs/tasklist.md",
skip_tasklist_mark: bool = False,
):
self.repo_dir = Path(repo_dir)
self.integration_worktree = Path(integration_worktree)
Expand All @@ -109,6 +110,7 @@ def __init__(
self.policy = policy or {}
self.loc_threshold = int(loc_threshold)
self.max_retries = int(max_retries)
self.skip_tasklist_mark = skip_tasklist_mark

# Bind tasklist operations to the integration checkout.
self.tasklist_manager = TasklistManager(
Expand Down Expand Up @@ -232,35 +234,37 @@ def integrate_eval_and_land(
)

# Mark task complete in the integration checkout.
with self.tasklist_lock:
task_already_complete = False
ok = self.tasklist_manager.mark_task_complete_by_id(task_id, taskmap)
if not ok:
completion_state = self.tasklist_manager.task_completion_by_id(
task_id, taskmap
)
if completion_state is True:
task_already_complete = True
else:
self._reset_hard(base_head)
return IntegrationResult(
success=False,
status="land_fail",
error="task_id_not_found_or_already_complete",
# Skipped for MCP providers — they handle completion externally.
if not self.skip_tasklist_mark:
with self.tasklist_lock:
task_already_complete = False
ok = self.tasklist_manager.mark_task_complete_by_id(task_id, taskmap)
if not ok:
completion_state = self.tasklist_manager.task_completion_by_id(
task_id, taskmap
)
self._git("add", self.tasklist_manager.tasklist)
has_staged_changes = (
self._git("diff", "--cached", "--quiet", check=False).returncode != 0
)
if has_staged_changes:
msg = f"millstone: mark task {task_id} complete"
if task_already_complete:
msg = f"millstone: sync task {task_id} tasklist updates"
self._git(
"commit",
"-m",
msg,
if completion_state is True:
task_already_complete = True
else:
self._reset_hard(base_head)
return IntegrationResult(
success=False,
status="land_fail",
error="task_id_not_found_or_already_complete",
)
self._git("add", self.tasklist_manager.tasklist)
has_staged_changes = (
self._git("diff", "--cached", "--quiet", check=False).returncode != 0
)
if has_staged_changes:
msg = f"millstone: mark task {task_id} complete"
if task_already_complete:
msg = f"millstone: sync task {task_id} tasklist updates"
self._git(
"commit",
"-m",
msg,
)

# Land: update base branch ref via local push.
push = subprocess.run(
Expand Down
139 changes: 134 additions & 5 deletions src/millstone/runtime/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

from millstone.artifact_providers.mcp import MCPTasklistProvider
from millstone.artifacts.eval_manager import EvalManager
from millstone.artifacts.models import TaskStatus
from millstone.runtime.locks import AdvisoryLock
from millstone.runtime.merge_pipeline import MergePipeline
from millstone.runtime.parallel_state import ParallelState
Expand Down Expand Up @@ -361,6 +363,108 @@ def _recover_state(self) -> None:
merge_queue=saved_state.get("merge_queue", []) or [],
)

def _is_mcp_provider(self) -> bool:
"""Check whether the configured tasklist provider is MCP-backed."""
provider = self.orch._outer_loop_manager.tasklist_provider
return isinstance(provider, MCPTasklistProvider)

def _fetch_tasks_from_provider(self) -> list[dict]:
"""Fetch tasks from the MCP provider, returning the scheduler-compatible format.

Returns the same ``{task_id, checked, title, raw_text, index}`` dicts that
``TasklistManager.extract_all_task_ids()`` produces so that the rest of the
parallel pipeline works unchanged.
"""
provider = self.orch._outer_loop_manager.tasklist_provider
if not isinstance(provider, MCPTasklistProvider):
raise TypeError("_fetch_tasks_from_provider requires an MCPTasklistProvider")
if provider._agent_callback is None:
provider.set_agent_callback(lambda p, **k: self.orch.run_agent(p, role="author", **k))
provider.invalidate_cache()
items = provider.list_tasks()
results: list[dict] = []
for index, item in enumerate(items):
checked = item.status not in (TaskStatus.todo, TaskStatus.in_progress)
results.append(
{
"task_id": item.task_id,
"checked": checked,
"title": item.title,
"raw_text": "",
"index": index,
}
)
return results

def _fetch_task_body(self, task_id: str) -> str:
"""Fetch full task body from MCP provider via get_task().

Returns a formatted text block with title, context, criteria, tests,
and risk — suitable for passing as ``--task`` to a worker subprocess.

Raises ``RuntimeError`` if the provider returns ``None`` so that callers
can surface the failure instead of silently falling back to title-only.
"""
provider = self.orch._outer_loop_manager.tasklist_provider
if not isinstance(provider, MCPTasklistProvider):
return ""
item = provider.get_task(task_id)
if item is None:
raise RuntimeError(f"MCP get_task('{task_id}') returned None")
parts = [item.title]
if item.context:
parts.append(f" - Context: {item.context}")
if item.criteria:
parts.append(f" - Criteria: {item.criteria}")
if item.tests:
parts.append(f" - Tests: {item.tests}")
if item.risk:
parts.append(f" - Risk: {item.risk}")
return "\n".join(parts)

def _analyze_tasks_mcp(self, task_dicts: list[dict]) -> tuple[list[dict], list[dict]]:
"""Build enriched-task list for MCP tasks (no dependency graph).

Fetches full task body from the MCP provider so that worker subprocesses
receive meaningful task descriptions via ``--task``.

Raises ``RuntimeError`` if any task body cannot be fetched.
"""
enriched: list[dict] = []
for task in task_dicts:
task_id = task["task_id"]
body = self._fetch_task_body(task_id)
risk = None
if body:
# Extract risk from fetched body if present.
for line in body.splitlines():
stripped = line.strip()
if stripped.lower().startswith("- risk:"):
risk = stripped.split(":", 1)[1].strip() or None
break
enriched.append(
{
"task_id": task_id,
"title": task.get("title", ""),
"group": None,
"file_refs": [],
"risk": risk,
"raw_text": body or task.get("title", ""),
}
)
return enriched, []

def _mark_mcp_task_done(self, task_id: str) -> None:
"""Mark a task as done via the MCP provider after a successful merge.

Raises on failure so the caller can treat it as a merge failure,
preventing the task from being silently left open on the remote.
"""
provider = self.orch._outer_loop_manager.tasklist_provider
if not isinstance(provider, MCPTasklistProvider):
return
provider.update_task_status(task_id, TaskStatus.done)

def _make_eval_manager(self, repo_dir: Path, work_dir: Path) -> EvalManager:
work_dir.mkdir(parents=True, exist_ok=True)
return EvalManager(
Expand Down Expand Up @@ -583,7 +687,10 @@ def _dry_run(self) -> int:
base_ref = self.orch.base_ref or base_branch
base_ref_sha = self._rev_parse(base_ref)

tasks = self.orch._tasklist_manager.extract_all_task_ids()
if self._is_mcp_provider():
tasks = self._fetch_tasks_from_provider()
else:
tasks = self.orch._tasklist_manager.extract_all_task_ids()
max_tasks = max(0, int(self.orch.max_tasks))
pending = [t for t in tasks if not t["checked"]][:max_tasks]

Expand Down Expand Up @@ -622,6 +729,7 @@ def run(self) -> int:
self.orch.parallel_integration_branch,
base_ref_sha,
)
use_mcp = self._is_mcp_provider()
merge_pipeline = MergePipeline(
repo_dir=self.orch.repo_dir,
integration_worktree=integration_wt,
Expand All @@ -634,9 +742,12 @@ def run(self) -> int:
loc_threshold=self.orch.loc_threshold,
max_retries=self.orch.merge_max_retries,
tasklist=self.orch.tasklist,
skip_tasklist_mark=use_mcp,
)

tasks = self.orch._tasklist_manager.extract_all_task_ids()
if use_mcp:
tasks = self._fetch_tasks_from_provider()
else:
tasks = self.orch._tasklist_manager.extract_all_task_ids()
taskmap = {t["task_id"]: {"index": t["index"]} for t in tasks}
self.parallel_state.save_taskmap(taskmap)

Expand All @@ -652,7 +763,10 @@ def run(self) -> int:
poll_interval = 0.1

try:
enriched_tasks, dependencies = self._analyze_tasks(pending)
if use_mcp:
enriched_tasks, dependencies = self._analyze_tasks_mcp(pending)
else:
enriched_tasks, dependencies = self._analyze_tasks(pending)
scheduler = TaskScheduler(
concurrency=max(1, int(self.orch.parallel_concurrency)),
high_risk_concurrency=max(1, int(self.orch.high_risk_concurrency)),
Expand Down Expand Up @@ -790,6 +904,21 @@ def run(self) -> int:
)

if merge_res.success:
if use_mcp:
try:
self._mark_mcp_task_done(task_id)
except Exception as exc:
reason = f"mcp_status_update_failed: {exc}"
scheduler.mark_failed(task_id, reason)
task_records[task_id] = {
"status": "failed",
"error": reason,
"completed_at": time.time(),
}
failures = True
in_flight_worktrees.pop(task_id, None)
in_flight_started_at.pop(task_id, None)
continue
completed.add(task_id)
scheduler.mark_completed(task_id)
task_records[task_id] = {
Expand Down Expand Up @@ -846,7 +975,7 @@ def run(self) -> int:

if scheduler.has_remaining() and in_flight:
time.sleep(poll_interval)
except ValueError as e:
except (ValueError, RuntimeError) as e:
print(f"ERROR: {e}")
failures = True
finally:
Expand Down
Loading