diff --git a/gptdiff/applydiff.py b/gptdiff/applydiff.py index 1d5bf62..090aa71 100644 --- a/gptdiff/applydiff.py +++ b/gptdiff/applydiff.py @@ -9,6 +9,21 @@ import hashlib from collections import defaultdict + +def _strip_diff_fence(diff_text: str) -> str: + """Remove wrapping triple backticks from a diff block.""" + stripped = diff_text.strip() + if stripped.startswith("```") and stripped.endswith("```"): + lines = stripped.splitlines() + if lines[0].startswith("```"): + lines = lines[1:] + if lines and lines[-1].startswith("```"): + lines = lines[:-1] + if lines and lines[0].strip().lower() == "diff": + lines = lines[1:] + return "\n".join(lines) + return diff_text + def apply_diff(project_dir, diff_text): """ Applies a unified diff (as generated by git diff) to the files in project_dir @@ -112,6 +127,7 @@ def apply_patch_to_file(file_path, patch): return True # Parse the diff into per-file patches. + diff_text = _strip_diff_fence(diff_text) file_patches = parse_diff_per_file(diff_text) if not file_patches: print("No file patches found in diff.") @@ -192,6 +208,8 @@ def parse_diff_per_file(diff_text): Uses 'b/' prefix detection from git diffs to determine target paths This doesn't work all the time and needs to be revised with stronger models """ + diff_text = _strip_diff_fence(diff_text) + def dedup_diffs(diffs): groups = defaultdict(list) for key, value in diffs: diff --git a/tests/test_multidiff.py b/tests/test_multidiff.py index a35a46b..96f2731 100644 --- a/tests/test_multidiff.py +++ b/tests/test_multidiff.py @@ -41,9 +41,15 @@ def test_fail_diff_through_call_llm(monkeypatch): def dummy_call_llm(api_key, base_url, model, messages, max_tokens, budget_tokens, temperature): return DummyResponse(diff_str, prompt_tokens=10, completion_tokens=20, total_tokens=30) - # Patch call_llm in the gptdiff module with our dummy function. + # Patch call_llm and tiktoken.get_encoding to avoid network access. monkeypatch.setattr("gptdiff.gptdiff.call_llm", dummy_call_llm) + class DummyEnc: + def encode(self, text): + return [0] * len(text) + + monkeypatch.setattr("tiktoken.get_encoding", lambda name: DummyEnc()) + # generate_diff calls call_llm_for_diff internally, which now uses our dummy_call_llm. result = generate_diff("dummy environment", "dummy goal", model="test-model") diff --git a/tests/test_parse_diff_per_file.py b/tests/test_parse_diff_per_file.py index 3d41305..b69ce76 100644 --- a/tests/test_parse_diff_per_file.py +++ b/tests/test_parse_diff_per_file.py @@ -169,5 +169,22 @@ def test_parse_diff_per_file_unconventional_header(): assert "+++ game.js" in patch, "Expected patch to include '+++ game.js'" assert "+let player" in patch, "Expected patch to include added lines" + +def test_parse_diff_with_code_fence(): + diff_text = """```diff +diff --git a/file.txt b/file.txt +--- a/file.txt ++++ b/file.txt +@@ -1 +1 @@ +-old ++new +```""" + + result = parse_diff_per_file(diff_text) + assert len(result) == 1 + file_path, patch = result[0] + assert file_path == "file.txt" + assert "+new" in patch + if __name__ == '__main__': unittest.main()