diff --git a/src/cli.rs b/src/cli.rs index a35fecb..e304780 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,11 +1,15 @@ use std::env; use std::error::Error; +use std::path::PathBuf; use env_logger::Builder as LoggerBuilder; #[derive(Debug)] pub enum CliAction { - Run { log_level: Option }, + Run { + log_level: Option, + workdir: Option, + }, Help, Version, } @@ -15,6 +19,7 @@ where I: Iterator, { let mut log_level = None; + let mut workdir = None; let mut iter = args.peekable(); while let Some(arg) = iter.next() { @@ -29,6 +34,16 @@ where .next() .ok_or_else(|| "--log-level requires a value".to_string())?; log_level = Some(value); + } else if let Some(path) = arg.strip_prefix("--workdir=") { + if path.is_empty() { + return Err("--workdir requires a value".to_string()); + } + workdir = Some(PathBuf::from(path)); + } else if arg == "--workdir" { + let value = iter + .next() + .ok_or_else(|| "--workdir requires a value".to_string())?; + workdir = Some(PathBuf::from(value)); } else { return Err(format!("Unknown argument: {arg}")); } @@ -36,11 +51,11 @@ where } } - Ok(CliAction::Run { log_level }) + Ok(CliAction::Run { log_level, workdir }) } pub fn print_usage() { - println!("Usage: codex-tools-mcp [OPTIONS]\n\nOptions:\n --log-level Override default log level (info)\n -V, --version Print version information\n -h, --help Print this help message"); + println!("Usage: codex-tools-mcp [OPTIONS]\n\nOptions:\n --log-level Override default log level (info)\n --workdir Set process working directory before serving\n -V, --version Print version information\n -h, --help Print this help message"); } pub fn init_logging(log_level: Option) -> Result<(), Box> { diff --git a/src/main.rs b/src/main.rs index 5aae337..dc36412 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ mod tools; use cli::{init_logging, parse_cli, print_usage, CliAction}; use std::env; use std::error::Error; +use std::io; use std::process; fn main() { @@ -26,7 +27,18 @@ fn try_main() -> Result<(), Box> { println!("{}", cli::version_string()); Ok(()) } - CliAction::Run { log_level } => { + CliAction::Run { log_level, workdir } => { + if let Some(workdir) = workdir { + env::set_current_dir(&workdir).map_err(|err| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!( + "Failed to set working directory to {}: {err}", + workdir.display() + ), + ) + })?; + } init_logging(log_level)?; server::run_server()?; Ok(()) diff --git a/tests/integration.rs b/tests/integration.rs index 38334d6..a12a77c 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -54,3 +54,47 @@ fn applies_patch_and_creates_file() { let contents = std::fs::read_to_string(&hello_path).expect("hello.txt created"); assert_eq!(contents.trim(), "hello world!"); } + +#[test] +fn applies_patch_in_explicit_workdir() { + let launch_dir = tempdir().expect("create launch temp dir"); + let work_dir = tempdir().expect("create work temp dir"); + + let mut cmd = Command::cargo_bin("codex-tools-mcp").expect("binary exists"); + cmd.arg("--log-level").arg("error"); + cmd.arg("--workdir").arg(work_dir.path()); + cmd.current_dir(launch_dir.path()); + + let input = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","clientInfo":{"name":"test","version":"0"},"capabilities":{}}} +{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"apply_patch","arguments":{"input":"*** Begin Patch\n*** Add File: hello.txt\n+hello workdir!\n*** End Patch\n"}}} +"#; + let mut child = cmd + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .expect("spawn server"); + + if let Some(stdin) = child.stdin.as_mut() { + use std::io::Write; + let _ = stdin.write_all(input.as_bytes()); + } + + let output = child.wait_with_output().expect("collect output"); + assert!( + output.status.success(), + "process exited with failure.\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let hello_path = work_dir.path().join("hello.txt"); + let contents = std::fs::read_to_string(&hello_path).expect("hello.txt created in --workdir"); + assert_eq!(contents.trim(), "hello workdir!"); + + let wrong_path = launch_dir.path().join("hello.txt"); + assert!( + !wrong_path.exists(), + "hello.txt should not be created in process cwd" + ); +}