From 05eb1c72476ec3bcfe3f9b92dfa29c81407a9849 Mon Sep 17 00:00:00 2001 From: Abhinav Rao <39939017+AetherPrior@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:32:53 +0000 Subject: [PATCH] Add handling for Parquet files and better test-case capturing for Java --- executor_docker/docker/.gitignore | 2 + .../docker/juliet-java-env/Dockerfile | 4 +- .../juliet-java-env/compile-and-test-patch.sh | 93 +++++++++-- .../juliet-java-env/compile-and-test.sh | 93 +++++++++-- .../docker/juliet-java-env/dataset/.gitkeep | 1 + executor_docker/server/server_utils.py | 154 ++++++++++++++++-- 6 files changed, 310 insertions(+), 37 deletions(-) create mode 100644 executor_docker/docker/.gitignore create mode 100644 executor_docker/docker/juliet-java-env/dataset/.gitkeep diff --git a/executor_docker/docker/.gitignore b/executor_docker/docker/.gitignore new file mode 100644 index 0000000..1d42889 --- /dev/null +++ b/executor_docker/docker/.gitignore @@ -0,0 +1,2 @@ +juliet-java-env/dataset/* +!juliet-java-env/dataset/.gitkeep \ No newline at end of file diff --git a/executor_docker/docker/juliet-java-env/Dockerfile b/executor_docker/docker/juliet-java-env/Dockerfile index 31e468a..63d6e93 100644 --- a/executor_docker/docker/juliet-java-env/Dockerfile +++ b/executor_docker/docker/juliet-java-env/Dockerfile @@ -1,4 +1,4 @@ -FROM openjdk:17-jdk-slim +FROM eclipse-temurin:17-jdk-jammy # install necessary packages RUN apt-get update && apt-get install -y \ @@ -38,7 +38,7 @@ RUN chmod +x /usr/local/bin/*.sh # environment variables ENV MAVEN_OPTS="-Xmx512m" ENV JAVA_TOOL_OPTIONS="-Xmx512m" -ENV JAVA_HOME=/usr/local/openjdk-17 +ENV JAVA_HOME=/opt/java/openjdk ENV PATH="$JAVA_HOME/bin:$PATH" RUN mkdir -p /tmp/java-eval diff --git a/executor_docker/docker/juliet-java-env/compile-and-test-patch.sh b/executor_docker/docker/juliet-java-env/compile-and-test-patch.sh index f2eb67f..e4c6fae 100644 --- a/executor_docker/docker/juliet-java-env/compile-and-test-patch.sh +++ b/executor_docker/docker/juliet-java-env/compile-and-test-patch.sh @@ -6,7 +6,7 @@ set -e # Set JAVA_HOME environment variable -export JAVA_HOME=/usr/local/openjdk-17 +export JAVA_HOME=${JAVA_HOME:-/opt/java/openjdk} export PATH="$JAVA_HOME/bin:$PATH" # Set Maven local repository to a temporary location @@ -119,16 +119,85 @@ content = re.sub(r'(private\s+[\w<>\[\],\s]+\s+\w+\s*\([^)]*\))(?!\s*throws)(\s* # 5. Handle any other public void methods in test files (catch-all) content = re.sub(r'(public\s+void\s+\w+\s*\([^)]*\))(?!\s*throws)(\s*\{)', r'\1 throws Throwable\2', content) -# Handle lambda expressions that call methods throwing Throwable -# Wrap lambda bodies in try-catch blocks if they contain method calls -if 'captureStdOut(() -> {' in content: - # Find all lambda expressions and wrap their content in try-catch - content = re.sub( - r'(captureStdOut\(\(\) -> \{)(.*?)(\}\))', - r'\1\n try {\2\n } catch (Throwable t) {\n throw new RuntimeException(t);\n }\3', - content, - flags=re.DOTALL - ) +# Handle captureStdOut lambdas that may call methods throwing Throwable +def wrap_capture_stdout_lambdas(src: str) -> str: + marker = 'captureStdOut(() ->' + out = [] + cursor = 0 + + while True: + idx = src.find(marker, cursor) + if idx == -1: + out.append(src[cursor:]) + break + + out.append(src[cursor:idx]) + + start_paren = idx + len('captureStdOut') + if start_paren >= len(src) or src[start_paren] != '(': + out.append(src[idx:idx + len(marker)]) + cursor = idx + len(marker) + continue + + depth = 1 + pos = start_paren + 1 + while pos < len(src) and depth > 0: + if src[pos] == '(': + depth += 1 + elif src[pos] == ')': + depth -= 1 + pos += 1 + + if depth != 0: + out.append(src[idx:]) + break + + close_paren = pos - 1 + inside = src[start_paren + 1:close_paren] + m = re.match(r'\s*\(\s*\)\s*->\s*(.*)\s*$', inside, flags=re.DOTALL) + if not m: + out.append(src[idx:close_paren + 1]) + cursor = close_paren + 1 + continue + + rhs = m.group(1).strip() + + if 'catch (Throwable t)' in rhs: + out.append(src[idx:close_paren + 1]) + cursor = close_paren + 1 + continue + + if rhs.startswith('{') and rhs.endswith('}'): + body = rhs[1:-1].strip('\n') + wrapped = ( + 'captureStdOut(() -> {\n' + ' try {\n' + f'{body}\n' + ' } catch (Throwable t) {\n' + ' throw new RuntimeException(t);\n' + ' }\n' + ' })' + ) + else: + statement = rhs + if not statement.endswith(';'): + statement += ';' + wrapped = ( + 'captureStdOut(() -> {\n' + ' try {\n' + f' {statement}\n' + ' } catch (Throwable t) {\n' + ' throw new RuntimeException(t);\n' + ' }\n' + ' })' + ) + + out.append(wrapped) + cursor = close_paren + 1 + + return ''.join(out) + +content = wrap_capture_stdout_lambdas(content) # Fix static method calls - if test is calling ClassName.methodName(), create instance instead # Extract class name from the test file name instead of template file @@ -178,8 +247,10 @@ echo "✓ Test compilation successful" # Run tests echo "=== Running tests ===" +set +e TEST_OUTPUT=$(timeout $TIMEOUT_DURATION mvn test 2>&1) TEST_EXIT_CODE=$? +set -e # Output the test results for debugging echo "$TEST_OUTPUT" diff --git a/executor_docker/docker/juliet-java-env/compile-and-test.sh b/executor_docker/docker/juliet-java-env/compile-and-test.sh index 65a5be5..f968031 100644 --- a/executor_docker/docker/juliet-java-env/compile-and-test.sh +++ b/executor_docker/docker/juliet-java-env/compile-and-test.sh @@ -6,7 +6,7 @@ set -e # Set JAVA_HOME environment variable -export JAVA_HOME=/usr/local/openjdk-17 +export JAVA_HOME=${JAVA_HOME:-/opt/java/openjdk} export PATH="$JAVA_HOME/bin:$PATH" # Set Maven local repository to a temporary location @@ -150,16 +150,85 @@ content = re.sub(r'(private\s+[\w<>\[\],\s]+\s+\w+\s*\([^)]*\))(?!\s*throws)(\s* # 5. Handle any other public void methods in test files (catch-all) content = re.sub(r'(public\s+void\s+\w+\s*\([^)]*\))(?!\s*throws)(\s*\{)', r'\1 throws Throwable\2', content) -# Handle lambda expressions that call methods throwing Throwable -# Wrap lambda bodies in try-catch blocks if they contain method calls -if 'captureStdOut(() -> {' in content: - # Find all lambda expressions and wrap their content in try-catch - content = re.sub( - r'(captureStdOut\(\(\) -> \{)(.*?)(\}\))', - r'\1\n try {\2\n } catch (Throwable t) {\n throw new RuntimeException(t);\n }\3', - content, - flags=re.DOTALL - ) +# Handle captureStdOut lambdas that may call methods throwing Throwable +def wrap_capture_stdout_lambdas(src: str) -> str: + marker = 'captureStdOut(() ->' + out = [] + cursor = 0 + + while True: + idx = src.find(marker, cursor) + if idx == -1: + out.append(src[cursor:]) + break + + out.append(src[cursor:idx]) + + start_paren = idx + len('captureStdOut') + if start_paren >= len(src) or src[start_paren] != '(': + out.append(src[idx:idx + len(marker)]) + cursor = idx + len(marker) + continue + + depth = 1 + pos = start_paren + 1 + while pos < len(src) and depth > 0: + if src[pos] == '(': + depth += 1 + elif src[pos] == ')': + depth -= 1 + pos += 1 + + if depth != 0: + out.append(src[idx:]) + break + + close_paren = pos - 1 + inside = src[start_paren + 1:close_paren] + m = re.match(r'\s*\(\s*\)\s*->\s*(.*)\s*$', inside, flags=re.DOTALL) + if not m: + out.append(src[idx:close_paren + 1]) + cursor = close_paren + 1 + continue + + rhs = m.group(1).strip() + + if 'catch (Throwable t)' in rhs: + out.append(src[idx:close_paren + 1]) + cursor = close_paren + 1 + continue + + if rhs.startswith('{') and rhs.endswith('}'): + body = rhs[1:-1].strip('\n') + wrapped = ( + 'captureStdOut(() -> {\n' + ' try {\n' + f'{body}\n' + ' } catch (Throwable t) {\n' + ' throw new RuntimeException(t);\n' + ' }\n' + ' })' + ) + else: + statement = rhs + if not statement.endswith(';'): + statement += ';' + wrapped = ( + 'captureStdOut(() -> {\n' + ' try {\n' + f' {statement}\n' + ' } catch (Throwable t) {\n' + ' throw new RuntimeException(t);\n' + ' }\n' + ' })' + ) + + out.append(wrapped) + cursor = close_paren + 1 + + return ''.join(out) + +content = wrap_capture_stdout_lambdas(content) # Fix static method calls - if test is calling ClassName.methodName(), create instance instead # Extract class name from the test file name instead of template file @@ -209,8 +278,10 @@ echo "✓ Test compilation successful" # Run tests echo "=== Running tests ===" +set +e TEST_OUTPUT=$(timeout $TIMEOUT_DURATION mvn test 2>&1) TEST_EXIT_CODE=$? +set -e # Output the test results for debugging echo "$TEST_OUTPUT" diff --git a/executor_docker/docker/juliet-java-env/dataset/.gitkeep b/executor_docker/docker/juliet-java-env/dataset/.gitkeep new file mode 100644 index 0000000..78f5a3f --- /dev/null +++ b/executor_docker/docker/juliet-java-env/dataset/.gitkeep @@ -0,0 +1 @@ +# dataset files go in this folder \ No newline at end of file diff --git a/executor_docker/server/server_utils.py b/executor_docker/server/server_utils.py index c6479d3..4219892 100644 --- a/executor_docker/server/server_utils.py +++ b/executor_docker/server/server_utils.py @@ -1,11 +1,14 @@ import base64 +import json from enum import IntEnum +from functools import lru_cache from pathlib import Path from typing import Literal from uuid import uuid4 import docker import requests +from datasets import load_dataset from docker.errors import DockerException from fastapi import HTTPException @@ -27,6 +30,60 @@ class CustomExitCode(IntEnum): CUSTOM_ERROR_MESSAGES = {CustomExitCode.Timeout: "Timeout waiting for the program"} +JULIET_JAVA_DATASET_PATH = ( + Path(__file__).resolve().parents[1] / "docker" / "juliet-java-env" / "dataset" +) + + +@lru_cache(maxsize=1) +def _load_juliet_java_secure_dataset(): + return load_dataset(str(JULIET_JAVA_DATASET_PATH), split="java_secure_coding") + + +@lru_cache(maxsize=1) +def _load_juliet_java_patch_dataset(): + return load_dataset(str(JULIET_JAVA_DATASET_PATH), split="java_patch_generation") + + +@lru_cache(maxsize=1) +def _build_juliet_java_secure_index(): + dataset = _load_juliet_java_secure_dataset() + return {dataset[i]["id"]: i for i in range(len(dataset))} + + +@lru_cache(maxsize=1) +def _build_juliet_java_patch_index(): + dataset = _load_juliet_java_patch_dataset() + return {dataset[i]["id"]: i for i in range(len(dataset))} + + +def _extract_unit_test(meta_data): + if isinstance(meta_data, str): + meta_data = json.loads(meta_data) + if isinstance(meta_data, dict): + return meta_data.get("unit_test", "") + return "" + + +def _get_juliet_java_secure_task_artifacts(task_id: str): + index = _build_juliet_java_secure_index().get(task_id) + if index is None: + raise FileNotFoundError(f"Task id not found in java_secure_coding split: {task_id}") + row = _load_juliet_java_secure_dataset()[index] + template_code = row.get("context", "") + unit_test = _extract_unit_test(row.get("meta_data", "")) + return template_code, unit_test + + +def _get_juliet_java_patch_task_test(task_id: str): + index = _build_juliet_java_patch_index().get(task_id) + if index is None: + raise FileNotFoundError(f"Task id not found in java_patch_generation split: {task_id}") + row = _load_juliet_java_patch_dataset()[index] + unit_test = _extract_unit_test(row.get("meta_data", "")) + return unit_test + + def _post_process_result(res: dict): if res["exit_code"] == CustomExitCode.Timeout: res["output"] = CUSTOM_ERROR_MESSAGES[CustomExitCode.Timeout] @@ -107,6 +164,14 @@ def run_juliet_java_container( solution_code = solution_path.read_text() print(f"[DEBUG] Solution code length: {len(solution_code)} characters") + use_legacy_dataset_files = True + try: + template_code, unit_test_code = _get_juliet_java_secure_task_artifacts(task_id) + if template_code and unit_test_code: + use_legacy_dataset_files = False + except Exception as e: + print(f"[DEBUG] Failed loading parquet task artifacts, trying legacy files: {e}") + # Use the same Docker testing logic as our original implementation client = docker.from_env() container = None @@ -117,12 +182,28 @@ def run_juliet_java_container( "ascii" ) - # Run Docker command using juliet-java-local image - cmd = [ - "bash", - "-c", - f"echo '{encoded_solution}' | base64 -d > /workspace/solution.java && cd /workspace && /usr/local/bin/compile-and-test.sh {masked_file} {test_file} solution.java", - ] + if use_legacy_dataset_files: + cmd = [ + "bash", + "-c", + f"echo '{encoded_solution}' | base64 -d > /workspace/solution.java && cd /workspace && /usr/local/bin/compile-and-test.sh {masked_file} {test_file} solution.java", + ] + else: + encoded_template = base64.b64encode(template_code.encode("utf-8")).decode("ascii") + encoded_test = base64.b64encode(unit_test_code.encode("utf-8")).decode("ascii") + cmd = [ + "bash", + "-c", + " && ".join( + [ + f"echo '{encoded_template}' | base64 -d > /workspace/template.java", + f"echo '{encoded_test}' | base64 -d > /workspace/test.java", + f"echo '{encoded_solution}' | base64 -d > /workspace/solution.java", + "cd /workspace", + "/usr/local/bin/compile-and-test.sh template.java test.java solution.java", + ] + ), + ] container = client.containers.run( image=image, @@ -140,13 +221,23 @@ def run_juliet_java_container( docker_output = b"".join(out) except requests.exceptions.ReadTimeout: + print("[DEBUG] Java test timed out while waiting for Docker API response") + return CustomExitCode.Timeout, b"Timeout waiting for Java test" + except requests.exceptions.RequestException as e: + # Docker SDK sometimes wraps API read timeouts as generic request exceptions. + if "Read timed out" in str(e): + print(f"[DEBUG] Java test request timeout: {e}") + return CustomExitCode.Timeout, b"Timeout waiting for Java test" raise HTTPException( - status_code=500, detail="Timeout waiting for Java test" + status_code=500, detail=f"Unexpected Java test request error: {e}" ) from None except DockerException as e: print(f"[DEBUG] Docker error: {str(e)}") return CustomExitCode.Timeout, str(e).encode("utf-8") except Exception as e: + if "Read timed out" in str(e): + print(f"[DEBUG] Java test timeout from generic exception: {e}") + return CustomExitCode.Timeout, b"Timeout waiting for Java test" raise HTTPException( status_code=500, detail=f"Unexpected Java test error: {e}" ) from None @@ -214,6 +305,14 @@ def run_juliet_java_patch_container( patched_code = solution_path.read_text() print(f"[DEBUG] Patched code length: {len(patched_code)} characters") + use_legacy_dataset_files = True + try: + unit_test_code = _get_juliet_java_patch_task_test(task_id) + if unit_test_code: + use_legacy_dataset_files = False + except Exception as e: + print(f"[DEBUG] Failed loading parquet patch artifacts, trying legacy files: {e}") + # Use Docker to test the complete patched file client = docker.from_env() container = None @@ -224,12 +323,26 @@ def run_juliet_java_patch_container( "ascii" ) - # Run Docker command using compile-and-test-patch.sh script - cmd = [ - "bash", - "-c", - f"echo '{encoded_patched}' | base64 -d > /workspace/patched.java && cd /workspace && /usr/local/bin/compile-and-test-patch.sh {test_file} patched.java", - ] + if use_legacy_dataset_files: + cmd = [ + "bash", + "-c", + f"echo '{encoded_patched}' | base64 -d > /workspace/patched.java && cd /workspace && /usr/local/bin/compile-and-test-patch.sh {test_file} patched.java", + ] + else: + encoded_test = base64.b64encode(unit_test_code.encode("utf-8")).decode("ascii") + cmd = [ + "bash", + "-c", + " && ".join( + [ + f"echo '{encoded_test}' | base64 -d > /workspace/test.java", + f"echo '{encoded_patched}' | base64 -d > /workspace/patched.java", + "cd /workspace", + "/usr/local/bin/compile-and-test-patch.sh test.java patched.java", + ] + ), + ] container = client.containers.run( image=image, @@ -247,9 +360,24 @@ def run_juliet_java_patch_container( return exit_code, b"".join(outputs) + except requests.exceptions.ReadTimeout: + print("[DEBUG] Java patch test timed out while waiting for Docker API response") + return CustomExitCode.Timeout, b"Timeout waiting for Java patch test" + except requests.exceptions.RequestException as e: + if "Read timed out" in str(e): + print(f"[DEBUG] Java patch test request timeout: {e}") + return CustomExitCode.Timeout, b"Timeout waiting for Java patch test" + print(f"[DEBUG] Java patch request error: {e}") + return 1, f"Unexpected Java patch request error: {e}".encode("utf-8") except DockerException as e: print(f"[DEBUG] Docker error in patch container: {str(e)}") return CustomExitCode.Timeout, str(e).encode("utf-8") + except Exception as e: + if "Read timed out" in str(e): + print(f"[DEBUG] Java patch timeout from generic exception: {e}") + return CustomExitCode.Timeout, b"Timeout waiting for Java patch test" + print(f"[DEBUG] Unexpected Java patch test error: {e}") + return 1, f"Unexpected Java patch test error: {e}".encode("utf-8") finally: # Clean up the container if it exists if container: