diff --git a/src/millstone/runtime/merge_pipeline.py b/src/millstone/runtime/merge_pipeline.py index 52d23a8..69a43ac 100644 --- a/src/millstone/runtime/merge_pipeline.py +++ b/src/millstone/runtime/merge_pipeline.py @@ -98,6 +98,7 @@ def __init__( loc_threshold: int, max_retries: int, tasklist: str = "docs/tasklist.md", + skip_tasklist_mark: bool = False, ): self.repo_dir = Path(repo_dir) self.integration_worktree = Path(integration_worktree) @@ -109,6 +110,7 @@ def __init__( self.policy = policy or {} self.loc_threshold = int(loc_threshold) self.max_retries = int(max_retries) + self.skip_tasklist_mark = skip_tasklist_mark # Bind tasklist operations to the integration checkout. self.tasklist_manager = TasklistManager( @@ -232,35 +234,37 @@ def integrate_eval_and_land( ) # Mark task complete in the integration checkout. - with self.tasklist_lock: - task_already_complete = False - ok = self.tasklist_manager.mark_task_complete_by_id(task_id, taskmap) - if not ok: - completion_state = self.tasklist_manager.task_completion_by_id( - task_id, taskmap - ) - if completion_state is True: - task_already_complete = True - else: - self._reset_hard(base_head) - return IntegrationResult( - success=False, - status="land_fail", - error="task_id_not_found_or_already_complete", + # Skipped for MCP providers — they handle completion externally. + if not self.skip_tasklist_mark: + with self.tasklist_lock: + task_already_complete = False + ok = self.tasklist_manager.mark_task_complete_by_id(task_id, taskmap) + if not ok: + completion_state = self.tasklist_manager.task_completion_by_id( + task_id, taskmap ) - self._git("add", self.tasklist_manager.tasklist) - has_staged_changes = ( - self._git("diff", "--cached", "--quiet", check=False).returncode != 0 - ) - if has_staged_changes: - msg = f"millstone: mark task {task_id} complete" - if task_already_complete: - msg = f"millstone: sync task {task_id} tasklist updates" - self._git( - "commit", - "-m", - msg, + if completion_state is True: + task_already_complete = True + else: + self._reset_hard(base_head) + return IntegrationResult( + success=False, + status="land_fail", + error="task_id_not_found_or_already_complete", + ) + self._git("add", self.tasklist_manager.tasklist) + has_staged_changes = ( + self._git("diff", "--cached", "--quiet", check=False).returncode != 0 ) + if has_staged_changes: + msg = f"millstone: mark task {task_id} complete" + if task_already_complete: + msg = f"millstone: sync task {task_id} tasklist updates" + self._git( + "commit", + "-m", + msg, + ) # Land: update base branch ref via local push. push = subprocess.run( diff --git a/src/millstone/runtime/parallel.py b/src/millstone/runtime/parallel.py index ab738eb..5533aea 100644 --- a/src/millstone/runtime/parallel.py +++ b/src/millstone/runtime/parallel.py @@ -12,7 +12,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from millstone.artifact_providers.mcp import MCPTasklistProvider from millstone.artifacts.eval_manager import EvalManager +from millstone.artifacts.models import TaskStatus from millstone.runtime.locks import AdvisoryLock from millstone.runtime.merge_pipeline import MergePipeline from millstone.runtime.parallel_state import ParallelState @@ -361,6 +363,108 @@ def _recover_state(self) -> None: merge_queue=saved_state.get("merge_queue", []) or [], ) + def _is_mcp_provider(self) -> bool: + """Check whether the configured tasklist provider is MCP-backed.""" + provider = self.orch._outer_loop_manager.tasklist_provider + return isinstance(provider, MCPTasklistProvider) + + def _fetch_tasks_from_provider(self) -> list[dict]: + """Fetch tasks from the MCP provider, returning the scheduler-compatible format. + + Returns the same ``{task_id, checked, title, raw_text, index}`` dicts that + ``TasklistManager.extract_all_task_ids()`` produces so that the rest of the + parallel pipeline works unchanged. + """ + provider = self.orch._outer_loop_manager.tasklist_provider + if not isinstance(provider, MCPTasklistProvider): + raise TypeError("_fetch_tasks_from_provider requires an MCPTasklistProvider") + if provider._agent_callback is None: + provider.set_agent_callback(lambda p, **k: self.orch.run_agent(p, role="author", **k)) + provider.invalidate_cache() + items = provider.list_tasks() + results: list[dict] = [] + for index, item in enumerate(items): + checked = item.status not in (TaskStatus.todo, TaskStatus.in_progress) + results.append( + { + "task_id": item.task_id, + "checked": checked, + "title": item.title, + "raw_text": "", + "index": index, + } + ) + return results + + def _fetch_task_body(self, task_id: str) -> str: + """Fetch full task body from MCP provider via get_task(). + + Returns a formatted text block with title, context, criteria, tests, + and risk — suitable for passing as ``--task`` to a worker subprocess. + + Raises ``RuntimeError`` if the provider returns ``None`` so that callers + can surface the failure instead of silently falling back to title-only. + """ + provider = self.orch._outer_loop_manager.tasklist_provider + if not isinstance(provider, MCPTasklistProvider): + return "" + item = provider.get_task(task_id) + if item is None: + raise RuntimeError(f"MCP get_task('{task_id}') returned None") + parts = [item.title] + if item.context: + parts.append(f" - Context: {item.context}") + if item.criteria: + parts.append(f" - Criteria: {item.criteria}") + if item.tests: + parts.append(f" - Tests: {item.tests}") + if item.risk: + parts.append(f" - Risk: {item.risk}") + return "\n".join(parts) + + def _analyze_tasks_mcp(self, task_dicts: list[dict]) -> tuple[list[dict], list[dict]]: + """Build enriched-task list for MCP tasks (no dependency graph). + + Fetches full task body from the MCP provider so that worker subprocesses + receive meaningful task descriptions via ``--task``. + + Raises ``RuntimeError`` if any task body cannot be fetched. + """ + enriched: list[dict] = [] + for task in task_dicts: + task_id = task["task_id"] + body = self._fetch_task_body(task_id) + risk = None + if body: + # Extract risk from fetched body if present. + for line in body.splitlines(): + stripped = line.strip() + if stripped.lower().startswith("- risk:"): + risk = stripped.split(":", 1)[1].strip() or None + break + enriched.append( + { + "task_id": task_id, + "title": task.get("title", ""), + "group": None, + "file_refs": [], + "risk": risk, + "raw_text": body or task.get("title", ""), + } + ) + return enriched, [] + + def _mark_mcp_task_done(self, task_id: str) -> None: + """Mark a task as done via the MCP provider after a successful merge. + + Raises on failure so the caller can treat it as a merge failure, + preventing the task from being silently left open on the remote. + """ + provider = self.orch._outer_loop_manager.tasklist_provider + if not isinstance(provider, MCPTasklistProvider): + return + provider.update_task_status(task_id, TaskStatus.done) + def _make_eval_manager(self, repo_dir: Path, work_dir: Path) -> EvalManager: work_dir.mkdir(parents=True, exist_ok=True) return EvalManager( @@ -583,7 +687,10 @@ def _dry_run(self) -> int: base_ref = self.orch.base_ref or base_branch base_ref_sha = self._rev_parse(base_ref) - tasks = self.orch._tasklist_manager.extract_all_task_ids() + if self._is_mcp_provider(): + tasks = self._fetch_tasks_from_provider() + else: + tasks = self.orch._tasklist_manager.extract_all_task_ids() max_tasks = max(0, int(self.orch.max_tasks)) pending = [t for t in tasks if not t["checked"]][:max_tasks] @@ -622,6 +729,7 @@ def run(self) -> int: self.orch.parallel_integration_branch, base_ref_sha, ) + use_mcp = self._is_mcp_provider() merge_pipeline = MergePipeline( repo_dir=self.orch.repo_dir, integration_worktree=integration_wt, @@ -634,9 +742,12 @@ def run(self) -> int: loc_threshold=self.orch.loc_threshold, max_retries=self.orch.merge_max_retries, tasklist=self.orch.tasklist, + skip_tasklist_mark=use_mcp, ) - - tasks = self.orch._tasklist_manager.extract_all_task_ids() + if use_mcp: + tasks = self._fetch_tasks_from_provider() + else: + tasks = self.orch._tasklist_manager.extract_all_task_ids() taskmap = {t["task_id"]: {"index": t["index"]} for t in tasks} self.parallel_state.save_taskmap(taskmap) @@ -652,7 +763,10 @@ def run(self) -> int: poll_interval = 0.1 try: - enriched_tasks, dependencies = self._analyze_tasks(pending) + if use_mcp: + enriched_tasks, dependencies = self._analyze_tasks_mcp(pending) + else: + enriched_tasks, dependencies = self._analyze_tasks(pending) scheduler = TaskScheduler( concurrency=max(1, int(self.orch.parallel_concurrency)), high_risk_concurrency=max(1, int(self.orch.high_risk_concurrency)), @@ -790,6 +904,21 @@ def run(self) -> int: ) if merge_res.success: + if use_mcp: + try: + self._mark_mcp_task_done(task_id) + except Exception as exc: + reason = f"mcp_status_update_failed: {exc}" + scheduler.mark_failed(task_id, reason) + task_records[task_id] = { + "status": "failed", + "error": reason, + "completed_at": time.time(), + } + failures = True + in_flight_worktrees.pop(task_id, None) + in_flight_started_at.pop(task_id, None) + continue completed.add(task_id) scheduler.mark_completed(task_id) task_records[task_id] = { @@ -846,7 +975,7 @@ def run(self) -> int: if scheduler.has_remaining() and in_flight: time.sleep(poll_interval) - except ValueError as e: + except (ValueError, RuntimeError) as e: print(f"ERROR: {e}") failures = True finally: diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 2d9a36b..14b2df2 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1446,3 +1446,356 @@ def spawn_worker(task_id: str, _task_text: str, _worktree_path: Path): assert state is not None assert state["task_records"]["t1"]["status"] == "failed" assert "worker failed" in state["task_records"]["t1"]["error"] + + +class TestMCPProviderWorktrees: + """Tests for MCP provider compatibility with worktree parallel mode.""" + + def test_fetch_tasks_from_mcp_provider(self, temp_repo): + """_fetch_tasks_from_provider converts TasklistItems to scheduler format.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + from millstone.artifacts.models import TasklistItem, TaskStatus + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + orch = Orchestrator( + parallel_enabled=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + ) + + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider._agent_callback = lambda p, **k: "" + mock_provider.list_tasks.return_value = [ + TasklistItem(task_id="issue-1", title="First task", status=TaskStatus.todo), + TasklistItem(task_id="issue-2", title="Second task", status=TaskStatus.done), + TasklistItem(task_id="issue-3", title="Third task", status=TaskStatus.todo), + ] + orch._outer_loop_manager.tasklist_provider = mock_provider + + po = ParallelOrchestrator(orch) + assert po._is_mcp_provider() + + tasks = po._fetch_tasks_from_provider() + assert len(tasks) == 3 + assert tasks[0] == { + "task_id": "issue-1", + "checked": False, + "title": "First task", + "raw_text": "", + "index": 0, + } + assert tasks[1]["checked"] is True # done -> checked + assert tasks[2]["checked"] is False # todo -> not checked + + def test_analyze_tasks_mcp_returns_no_dependencies(self, temp_repo): + """_analyze_tasks_mcp fetches full task body and returns enriched tasks + with empty dependency list.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + from millstone.artifacts.models import TasklistItem, TaskStatus + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + orch = Orchestrator( + parallel_enabled=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + ) + + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider.get_task.side_effect = [ + TasklistItem( + task_id="t1", + title="Task one", + status=TaskStatus.todo, + context="Do thing one", + risk="high", + ), + TasklistItem( + task_id="t2", + title="Task two", + status=TaskStatus.todo, + criteria="Must work", + ), + ] + orch._outer_loop_manager.tasklist_provider = mock_provider + + po = ParallelOrchestrator(orch) + + task_dicts = [ + {"task_id": "t1", "title": "Task one", "raw_text": "", "index": 0, "checked": False}, + {"task_id": "t2", "title": "Task two", "raw_text": "", "index": 1, "checked": False}, + ] + enriched, deps = po._analyze_tasks_mcp(task_dicts) + assert len(enriched) == 2 + assert deps == [] + assert enriched[0]["task_id"] == "t1" + assert enriched[0]["group"] is None + assert enriched[0]["file_refs"] == [] + # raw_text now contains full task body from get_task() + assert "Task one" in enriched[0]["raw_text"] + assert "Do thing one" in enriched[0]["raw_text"] + assert enriched[0]["risk"] == "high" + # Second task has criteria but no risk + assert "Must work" in enriched[1]["raw_text"] + assert enriched[1]["risk"] is None + assert mock_provider.get_task.call_count == 2 + + def test_dry_run_with_mcp_provider(self, temp_repo, capsys): + """--dry-run with MCP provider fetches tasks from remote backend.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + from millstone.artifacts.models import TasklistItem, TaskStatus + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + orch = Orchestrator( + parallel_enabled=True, + dry_run=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + ) + + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider._agent_callback = lambda p, **k: "" + mock_provider.list_tasks.return_value = [ + TasklistItem(task_id="gh-10", title="MCP task A", status=TaskStatus.todo), + TasklistItem(task_id="gh-11", title="MCP task B", status=TaskStatus.todo), + ] + orch._outer_loop_manager.tasklist_provider = mock_provider + + po = ParallelOrchestrator(orch) + rc = po.run() + assert rc == 0 + + captured = capsys.readouterr() + assert "tasks_pending: 2" in captured.out + assert "gh-10" in captured.out + assert "gh-11" in captured.out + mock_provider.list_tasks.assert_called_once() + + def test_run_with_mcp_provider(self, temp_repo): + """Full run() with MCP provider: fetches task body, runs real merge pipeline, + and marks task done via MCP provider.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + from millstone.artifacts.models import TasklistItem, TaskStatus + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + + # Detach HEAD so base_branch is not checked out in a worktree. + subprocess.run( + ["git", "checkout", "--detach"], cwd=temp_repo, capture_output=True, check=True + ) + + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider._agent_callback = lambda p, **k: "" + mock_provider.list_tasks.return_value = [ + TasklistItem(task_id="mcp-1", title="Remote task", status=TaskStatus.todo), + ] + # get_task() returns full task details for worker body + mock_provider.get_task.return_value = TasklistItem( + task_id="mcp-1", + title="Remote task", + status=TaskStatus.todo, + context="Implement the remote feature", + criteria="All tests pass", + tests="test_remote.py", + risk="low", + ) + + orch = Orchestrator( + parallel_enabled=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + merge_strategy="cherry-pick", + worktree_cleanup="always", + loc_threshold=1000, + ) + orch._outer_loop_manager.tasklist_provider = mock_provider + + dispatched_tasks: list[str] = [] + dispatched_texts: list[str] = [] + + def worker_runner(task_id: str, task_text: str, worktree_path: Path) -> dict: + dispatched_tasks.append(task_id) + dispatched_texts.append(task_text) + (worktree_path / "output.txt").write_text(f"{task_id}\n") + subprocess.run(["git", "add", "."], cwd=worktree_path, capture_output=True, check=True) + subprocess.run( + ["git", "commit", "-m", f"task {task_id}"], + cwd=worktree_path, + capture_output=True, + check=True, + ) + return { + "status": "success", + "commit_sha": _rev_parse(worktree_path, "HEAD"), + "risk": "low", + } + + po = ParallelOrchestrator( + orch, + worker_runner=worker_runner, + eval_manager_factory=_eval_factory(True), + ) + + rc = po.run() + assert rc == 0 + assert dispatched_tasks == ["mcp-1"] + mock_provider.list_tasks.assert_called_once() + + # Verify worker received full task body (not empty string) + assert len(dispatched_texts) == 1 + task_text = dispatched_texts[0] + assert "Remote task" in task_text + assert "Implement the remote feature" in task_text + assert "All tests pass" in task_text + mock_provider.get_task.assert_called_once_with("mcp-1") + + # Verify MCP provider was called to mark task done + mock_provider.update_task_status.assert_called_once_with("mcp-1", TaskStatus.done) + + def test_run_mcp_status_update_failure(self, temp_repo): + """When update_task_status() fails after merge, run treats it as a failure + to prevent the task being left open remotely and re-executed.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + from millstone.artifacts.models import TasklistItem, TaskStatus + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + subprocess.run( + ["git", "checkout", "--detach"], cwd=temp_repo, capture_output=True, check=True + ) + + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider._agent_callback = lambda p, **k: "" + mock_provider.list_tasks.return_value = [ + TasklistItem(task_id="mcp-1", title="Remote task", status=TaskStatus.todo), + ] + mock_provider.get_task.return_value = TasklistItem( + task_id="mcp-1", + title="Remote task", + status=TaskStatus.todo, + context="ctx", + criteria="crit", + tests="t.py", + risk="low", + ) + mock_provider.update_task_status.side_effect = RuntimeError("MCP API unreachable") + + orch = Orchestrator( + parallel_enabled=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + merge_strategy="cherry-pick", + worktree_cleanup="always", + loc_threshold=1000, + ) + orch._outer_loop_manager.tasklist_provider = mock_provider + + def worker_runner(task_id: str, task_text: str, worktree_path: Path) -> dict: + (worktree_path / "output.txt").write_text(f"{task_id}\n") + subprocess.run(["git", "add", "."], cwd=worktree_path, capture_output=True, check=True) + subprocess.run( + ["git", "commit", "-m", f"task {task_id}"], + cwd=worktree_path, + capture_output=True, + check=True, + ) + return { + "status": "success", + "commit_sha": _rev_parse(worktree_path, "HEAD"), + "risk": "low", + } + + po = ParallelOrchestrator( + orch, + worker_runner=worker_runner, + eval_manager_factory=_eval_factory(True), + ) + + rc = po.run() + assert rc == 1, "Run must fail when MCP status update fails" + mock_provider.update_task_status.assert_called_once_with("mcp-1", TaskStatus.done) + + def test_fetch_task_body_raises_on_none(self, temp_repo): + """_fetch_task_body raises RuntimeError when get_task() returns None.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider._agent_callback = lambda p, **k: "" + mock_provider.get_task.return_value = None + + orch = Orchestrator( + parallel_enabled=True, + dry_run=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + ) + orch._outer_loop_manager.tasklist_provider = mock_provider + po = ParallelOrchestrator(orch) + + with pytest.raises(RuntimeError, match="returned None"): + po._fetch_task_body("missing-task") + + def test_run_mcp_get_task_failure_aborts(self, temp_repo): + """When get_task() returns None during analysis, run() fails + instead of silently falling back to title-only worker input.""" + from unittest.mock import MagicMock + + from millstone.artifact_providers.mcp import MCPTasklistProvider + from millstone.artifacts.models import TasklistItem, TaskStatus + + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + subprocess.run( + ["git", "checkout", "--detach"], cwd=temp_repo, capture_output=True, check=True + ) + + mock_provider = MagicMock(spec=MCPTasklistProvider) + mock_provider._agent_callback = lambda p, **k: "" + mock_provider.list_tasks.return_value = [ + TasklistItem(task_id="mcp-1", title="Task one", status=TaskStatus.todo), + ] + mock_provider.get_task.return_value = None # Simulates provider read failure + + orch = Orchestrator( + parallel_enabled=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + merge_strategy="cherry-pick", + worktree_cleanup="always", + loc_threshold=1000, + ) + orch._outer_loop_manager.tasklist_provider = mock_provider + + po = ParallelOrchestrator( + orch, + worker_runner=lambda *_: {"status": "success"}, + eval_manager_factory=_eval_factory(True), + ) + + rc = po.run() + assert rc == 1, "Run must fail when get_task() returns None" + + def test_file_provider_unchanged(self, temp_repo): + """File-based provider path is completely unaffected by MCP changes.""" + base_branch = _git(temp_repo, "rev-parse", "--abbrev-ref", "HEAD").strip() + orch = Orchestrator( + parallel_enabled=True, + dry_run=True, + base_branch=base_branch, + repo_dir=str(temp_repo), + ) + po = ParallelOrchestrator(orch) + assert not po._is_mcp_provider() + # Should still work with file-based extraction (dry run doesn't need tasks) + rc = po.run() + assert rc == 0