diff --git a/Cargo.lock b/Cargo.lock index a2db172..e60ca8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5239,7 +5239,3 @@ name = "zmij" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02aae0f83f69aafc94776e879363e9771d7ecbffe2c7fbb6c14c5e00dfe88439" - -[[patch.unused]] -name = "ratatui-core" -version = "0.1.0-beta.0" diff --git a/src/prompts.rs b/src/prompts.rs index 7be3ee5..e4ec4a0 100644 --- a/src/prompts.rs +++ b/src/prompts.rs @@ -16,6 +16,9 @@ const ESH_SCRIPT: &str = include_str!("../lib/esh/esh"); /// Filename for custom system prompt additions pub const SYSTEM_MD_FILENAME: &str = "SYSTEM.md"; +/// Filename for correction memory +pub const CORRECTIONS_FILENAME: &str = "corrections.md"; + /// Welcome message shown when the application starts pub const WELCOME_MESSAGE: &str = "Welcome to Codey! I'm your AI coding assistant. How can I help you today?"; @@ -36,6 +39,7 @@ You have access to the following tools: - `spawn_agent`: Spawn a sub-agent to handle a subtask - `list_agents` / `get_agent`: Check status and retrieve results from sub-agents - `list_background_tasks` / `get_background_task`: Check on background tool executions +- `record_correction`: Record a correction when a command fails and you find a better approach ## Guidelines @@ -70,6 +74,12 @@ For long-running operations, you can execute tools in the background by adding ` Use `list_background_tasks` to check status and `get_background_task` to retrieve results when complete. You will be notified with a message when background tasks finish. +### Recording Corrections +When a shell command or approach fails and you find a better way to accomplish the goal, use `record_correction` to save this knowledge. This helps avoid repeating the same mistakes in future sessions. For example: +- A command that doesn't exist on this system but has an alternative +- A path that was wrong but you found the correct one +- A syntax that didn't work but another did + ### General - Be concise but thorough - Explain what you're doing before executing tools @@ -118,6 +128,7 @@ You have read-only access to: /// 1. The base system prompt (static) /// 2. User SYSTEM.md from ~/.config/codey/ (optional, dynamic) /// 3. Project SYSTEM.md from .codey/ (optional, dynamic) +/// 4. Project corrections.md from .codey/ (optional, contains learned corrections) /// /// SYSTEM.md files are processed through [esh](https://github.com/jirutka/esh), /// allowing embedded shell commands using `<%= command %>` syntax. @@ -125,6 +136,7 @@ You have read-only access to: pub struct SystemPrompt { user_path: Option, project_path: PathBuf, + corrections_path: PathBuf, } impl SystemPrompt { @@ -132,10 +144,12 @@ impl SystemPrompt { pub fn new() -> Self { let user_path = Config::config_dir().map(|d| d.join(SYSTEM_MD_FILENAME)); let project_path = Path::new(CODEY_DIR).join(SYSTEM_MD_FILENAME); + let corrections_path = Path::new(CODEY_DIR).join(CORRECTIONS_FILENAME); Self { user_path, project_path, + corrections_path, } } @@ -146,6 +160,7 @@ impl SystemPrompt { /// - Base system prompt /// - User SYSTEM.md content (if exists) /// - Project SYSTEM.md content (if exists) + /// - Project corrections.md content (if exists) pub fn build(&self) -> String { let mut prompt = SYSTEM_PROMPT.to_string(); @@ -163,9 +178,27 @@ impl SystemPrompt { prompt.push_str(&content); } + // Append corrections.md if it exists (learned corrections from previous sessions) + if let Some(content) = self.load_corrections() { + prompt.push_str("\n\n## Learned Corrections\n\n"); + prompt.push_str("The following corrections were learned from previous sessions. "); + prompt.push_str("Use this knowledge to avoid repeating the same mistakes:\n\n"); + prompt.push_str(&content); + } + prompt } + /// Load corrections from the corrections.md file. + fn load_corrections(&self) -> Option { + if !self.corrections_path.exists() { + return None; + } + fs::read_to_string(&self.corrections_path) + .ok() + .filter(|s| !s.is_empty()) + } + /// Load and process a SYSTEM.md file through esh, falling back to raw content. fn load_system_md(&self, path: &Path) -> Option { if !path.exists() { diff --git a/src/tools/handlers.rs b/src/tools/handlers.rs index 8e613f4..c017eec 100644 --- a/src/tools/handlers.rs +++ b/src/tools/handlers.rs @@ -443,3 +443,92 @@ impl EffectHandler for FetchHtml { } } } + +// ============================================================================= +// Correction memory handlers +// ============================================================================= + +use crate::config::CODEY_DIR; +use chrono::Local; + +/// Filename for storing corrections +pub const CORRECTIONS_FILENAME: &str = "corrections.md"; + +/// Append a correction to the corrections file +pub struct AppendCorrection { + pub goal: String, + pub failed_attempt: String, + pub successful_approach: String, +} + +#[async_trait::async_trait] +impl EffectHandler for AppendCorrection { + async fn call(self: Box) -> Step { + let codey_dir = PathBuf::from(CODEY_DIR); + + // Create .codey directory if it doesn't exist + if !codey_dir.exists() { + if let Err(e) = fs::create_dir_all(&codey_dir) { + return Step::Error(format!( + "Failed to create {} directory: {}", + CODEY_DIR, e + )); + } + } + + let corrections_path = codey_dir.join(CORRECTIONS_FILENAME); + let timestamp = Local::now().format("%Y-%m-%d %H:%M"); + + // Format the correction entry + let entry = format!( + "\n## Correction ({timestamp})\n\n\ + **Goal:** {}\n\n\ + **Failed approach:** `{}`\n\n\ + **Successful approach:** `{}`\n\n\ + ---\n", + self.goal, self.failed_attempt, self.successful_approach + ); + + // Check if file exists to add header for new files + let content = if corrections_path.exists() { + entry + } else { + format!( + "# Corrections\n\n\ + This file contains corrections learned during previous sessions.\n\n\ + ---\n{}", + entry + ) + }; + + // Append to the file + use std::io::Write; + let mut file = match fs::OpenOptions::new() + .create(true) + .append(true) + .open(&corrections_path) + { + Ok(f) => f, + Err(e) => { + return Step::Error(format!( + "Failed to open {}: {}", + corrections_path.display(), + e + )) + } + }; + + if let Err(e) = file.write_all(content.as_bytes()) { + return Step::Error(format!( + "Failed to write correction to {}: {}", + corrections_path.display(), + e + )); + } + + Step::Output(format!( + "Correction recorded in {}", + corrections_path.display() + )) + } +} diff --git a/src/tools/impls/mod.rs b/src/tools/impls/mod.rs index 0738c98..7cad7b2 100644 --- a/src/tools/impls/mod.rs +++ b/src/tools/impls/mod.rs @@ -5,6 +5,7 @@ mod fetch_html; mod fetch_url; mod open_file; mod read_file; +mod record_correction; mod shell; mod spawn_agent; mod web_search; @@ -20,6 +21,7 @@ pub use fetch_html::FetchHtmlTool; pub use fetch_url::FetchUrlTool; pub use open_file::OpenFileTool; pub use read_file::ReadFileTool; +pub use record_correction::RecordCorrectionTool; pub use shell::ShellTool; pub use spawn_agent::{init_agent_context, update_agent_oauth, SpawnAgentTool}; pub use web_search::WebSearchTool; diff --git a/src/tools/impls/record_correction.rs b/src/tools/impls/record_correction.rs new file mode 100644 index 0000000..158a967 --- /dev/null +++ b/src/tools/impls/record_correction.rs @@ -0,0 +1,204 @@ +//! Record correction tool +//! +//! Allows the agent to record corrections when a shell command fails and +//! a subsequent approach succeeds. These corrections are stored in +//! `.codey/corrections.md` and loaded into the system prompt to help +//! the agent avoid repeating the same mistakes. + +use super::{handlers, Tool, ToolPipeline}; +use crate::impl_tool_block; +use crate::transcript::{ + render_agent_label, render_prefix, render_result, Block, BlockType, Status, ToolBlock, +}; +use ratatui::{ + style::{Color, Style}, + text::{Line, Span}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +/// Record correction display block +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RecordCorrectionBlock { + pub call_id: String, + pub tool_name: String, + pub params: serde_json::Value, + pub status: Status, + pub text: String, + #[serde(default)] + pub background: bool, + /// Agent label for sub-agent tools + #[serde(default, skip_serializing_if = "Option::is_none")] + pub agent_label: Option, +} + +impl RecordCorrectionBlock { + pub fn new( + call_id: impl Into, + tool_name: impl Into, + params: serde_json::Value, + background: bool, + ) -> Self { + Self { + call_id: call_id.into(), + tool_name: tool_name.into(), + params, + status: Status::Pending, + text: String::new(), + background, + agent_label: None, + } + } + + pub fn from_params( + call_id: &str, + tool_name: &str, + params: serde_json::Value, + background: bool, + ) -> Option { + let _: RecordCorrectionParams = serde_json::from_value(params.clone()).ok()?; + Some(Self::new(call_id, tool_name, params, background)) + } +} + +#[typetag::serde] +impl Block for RecordCorrectionBlock { + impl_tool_block!(BlockType::Tool); + + fn render(&self, _width: u16) -> Vec> { + let mut lines = Vec::new(); + + let goal = self + .params["goal"] + .as_str() + .unwrap_or("") + .chars() + .take(40) + .collect::(); + + // Format: [agent_label] record_correction(goal preview...) + lines.push(Line::from(vec![ + self.render_status(), + render_agent_label(self.agent_label.as_deref()), + render_prefix(self.background), + Span::styled("record_correction", Style::default().fg(Color::Magenta)), + Span::styled("(", Style::default().fg(Color::DarkGray)), + Span::styled( + if goal.len() == 40 { + format!("{}...", goal) + } else { + goal + }, + Style::default().fg(Color::Green), + ), + Span::styled(")", Style::default().fg(Color::DarkGray)), + ])); + + if !self.text.is_empty() { + lines.extend(render_result(&self.text, 3)); + } + + lines + } + + fn call_id(&self) -> Option<&str> { + Some(&self.call_id) + } + + fn tool_name(&self) -> Option<&str> { + Some(&self.tool_name) + } + + fn params(&self) -> Option<&serde_json::Value> { + Some(&self.params) + } + + fn set_agent_label(&mut self, label: String) { + self.agent_label = Some(label); + } + + fn agent_label(&self) -> Option<&str> { + self.agent_label.as_deref() + } +} + +/// Tool for recording corrections when shell commands fail and a better approach is found +pub struct RecordCorrectionTool; + +#[derive(Debug, Deserialize)] +struct RecordCorrectionParams { + /// What the agent was trying to accomplish + goal: String, + /// The command or approach that failed + failed_attempt: String, + /// The command or approach that succeeded + successful_approach: String, +} + +impl RecordCorrectionTool { + pub const NAME: &'static str = "mcp_record_correction"; +} + +impl Tool for RecordCorrectionTool { + fn name(&self) -> &'static str { + Self::NAME + } + + fn description(&self) -> &'static str { + "Record a correction when a shell command or approach fails and you find a better way. \ + This helps avoid repeating the same mistakes in future sessions." + } + + fn schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "goal": { + "type": "string", + "description": "Brief description of what you were trying to accomplish (1-2 sentences)" + }, + "failed_attempt": { + "type": "string", + "description": "The command or approach that didn't work" + }, + "successful_approach": { + "type": "string", + "description": "The command or approach that worked instead" + } + }, + "required": ["goal", "failed_attempt", "successful_approach"] + }) + } + + fn compose(&self, params: serde_json::Value) -> ToolPipeline { + let parsed: Result = serde_json::from_value(params.clone()); + let params = match parsed { + Ok(p) => p, + Err(e) => { + return ToolPipeline::error(format!("Invalid params: {}", e)); + } + }; + + ToolPipeline::new() + .then(handlers::AppendCorrection { + goal: params.goal, + failed_attempt: params.failed_attempt, + successful_approach: params.successful_approach, + }) + } + + fn create_block( + &self, + call_id: &str, + params: serde_json::Value, + background: bool, + ) -> Box { + if let Some(block) = + RecordCorrectionBlock::from_params(call_id, self.name(), params.clone(), background) + { + Box::new(block) + } else { + Box::new(ToolBlock::new(call_id, self.name(), params, background)) + } + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 99494d2..2b830ca 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -32,6 +32,7 @@ pub mod names { pub const GET_BACKGROUND_TASK: &str = "mcp_get_background_task"; pub const LIST_AGENTS: &str = "mcp_list_agents"; pub const GET_AGENT: &str = "mcp_get_agent"; + pub const RECORD_CORRECTION: &str = "mcp_record_correction"; } use std::collections::HashMap; @@ -43,7 +44,7 @@ pub use exec::{ToolCall, ToolDecision, ToolEvent, ToolExecutor}; pub use impls::{ init_agent_context, update_agent_oauth, EditFileTool, FetchHtmlTool, FetchUrlTool, GetAgentTool, GetBackgroundTaskTool, ListAgentsTool, ListBackgroundTasksTool, OpenFileTool, - ReadFileTool, ShellTool, SpawnAgentTool, WebSearchTool, WriteFileTool, + ReadFileTool, RecordCorrectionTool, ShellTool, SpawnAgentTool, WebSearchTool, WriteFileTool, }; #[cfg(feature = "cli")] pub use browser::init_browser_context; @@ -153,6 +154,7 @@ impl ToolRegistry { registry.register(Arc::new(GetBackgroundTaskTool)); registry.register(Arc::new(ListAgentsTool)); registry.register(Arc::new(GetAgentTool)); + registry.register(Arc::new(RecordCorrectionTool)); registry }