diff --git a/src/cortex-agents/src/mention.rs b/src/cortex-agents/src/mention.rs index 81d9d71c..59bb9550 100644 --- a/src/cortex-agents/src/mention.rs +++ b/src/cortex-agents/src/mention.rs @@ -17,6 +17,46 @@ use regex::Regex; use std::sync::LazyLock; +/// Safely get the string slice up to the given byte position. +/// +/// Returns the slice `&text[..pos]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching backwards. +fn safe_slice_up_to(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return text; + } + if text.is_char_boundary(pos) { + return &text[..pos]; + } + // Find the nearest valid boundary by searching backwards + let mut valid_pos = pos; + while valid_pos > 0 && !text.is_char_boundary(valid_pos) { + valid_pos -= 1; + } + &text[..valid_pos] +} + +/// Safely get the string slice from the given byte position to the end. +/// +/// Returns the slice `&text[pos..]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching forwards. +fn safe_slice_from(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return ""; + } + if text.is_char_boundary(pos) { + return &text[pos..]; + } + // Find the nearest valid boundary by searching forwards + let mut valid_pos = pos; + while valid_pos < text.len() && !text.is_char_boundary(valid_pos) { + valid_pos += 1; + } + &text[valid_pos..] +} + /// A parsed agent mention from user input. #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentMention { @@ -108,10 +148,10 @@ pub fn extract_mention_and_text( ) -> Option<(AgentMention, String)> { let mention = find_first_valid_mention(text, valid_agents)?; - // Remove the mention from text + // Remove the mention from text, using safe slicing for UTF-8 boundaries let mut remaining = String::with_capacity(text.len()); - remaining.push_str(&text[..mention.start]); - remaining.push_str(&text[mention.end..]); + remaining.push_str(safe_slice_up_to(text, mention.start)); + remaining.push_str(safe_slice_from(text, mention.end)); // Trim and normalize whitespace let remaining = remaining.trim().to_string(); @@ -123,7 +163,8 @@ pub fn extract_mention_and_text( pub fn starts_with_mention(text: &str, valid_agents: &[&str]) -> bool { let text = text.trim(); if let Some(mention) = find_first_valid_mention(text, valid_agents) { - mention.start == 0 || text[..mention.start].trim().is_empty() + // Use safe slicing to handle UTF-8 boundaries + mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() } else { false } @@ -196,8 +237,8 @@ pub fn parse_message_for_agent(text: &str, valid_agents: &[&str]) -> ParsedAgent // Check if message starts with @agent if let Some((mention, remaining)) = extract_mention_and_text(text, valid_agents) { - // Only trigger if mention is at the start - if mention.start == 0 || text[..mention.start].trim().is_empty() { + // Only trigger if mention is at the start, using safe slicing for UTF-8 boundaries + if mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() { return ParsedAgentMessage::for_agent(mention.agent_name, remaining, text.to_string()); } } @@ -318,4 +359,99 @@ mod tests { assert_eq!(mentions[0].agent_name, "my-agent"); assert_eq!(mentions[1].agent_name, "my_agent"); } + + // UTF-8 boundary safety tests + #[test] + fn test_safe_slice_up_to_ascii() { + let text = "hello world"; + assert_eq!(safe_slice_up_to(text, 5), "hello"); + assert_eq!(safe_slice_up_to(text, 0), ""); + assert_eq!(safe_slice_up_to(text, 100), "hello world"); + } + + #[test] + fn test_safe_slice_up_to_multibyte() { + // "こんにちは" - each character is 3 bytes + let text = "こんにちは"; + assert_eq!(safe_slice_up_to(text, 3), "こ"); // Valid boundary + assert_eq!(safe_slice_up_to(text, 6), "こん"); // Valid boundary + // Position 4 is inside the second character, should return "こ" + assert_eq!(safe_slice_up_to(text, 4), "こ"); + assert_eq!(safe_slice_up_to(text, 5), "こ"); + } + + #[test] + fn test_safe_slice_from_multibyte() { + let text = "こんにちは"; + assert_eq!(safe_slice_from(text, 3), "んにちは"); // Valid boundary + // Position 4 is inside second character, should skip to position 6 + assert_eq!(safe_slice_from(text, 4), "にちは"); + assert_eq!(safe_slice_from(text, 5), "にちは"); + } + + #[test] + fn test_extract_mention_with_multibyte_prefix() { + let valid = vec!["general"]; + + // Multi-byte characters before mention + let result = extract_mention_and_text("日本語 @general search files", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // The prefix should be preserved without panicking + assert!(remaining.contains("search files")); + } + + #[test] + fn test_starts_with_mention_multibyte() { + let valid = vec!["general"]; + + // Whitespace with multi-byte characters should not cause panic + assert!(starts_with_mention(" @general task", &valid)); + + // Multi-byte characters before mention - should return false, not panic + assert!(!starts_with_mention("日本語 @general task", &valid)); + } + + #[test] + fn test_parse_message_for_agent_multibyte() { + let valid = vec!["general"]; + + // Multi-byte prefix - should not panic + let parsed = parse_message_for_agent("日本語 @general find files", &valid); + // Since mention is not at the start, should not invoke task + assert!(!parsed.should_invoke_task); + + // Multi-byte in the prompt (after mention) + let parsed = parse_message_for_agent("@general 日本語を検索", &valid); + assert!(parsed.should_invoke_task); + assert_eq!(parsed.agent, Some("general".to_string())); + assert_eq!(parsed.prompt, "日本語を検索"); + } + + #[test] + fn test_extract_mention_with_emoji() { + let valid = vec!["general"]; + + // Emojis are 4 bytes each + let result = extract_mention_and_text("🎉 @general celebrate", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + assert!(remaining.contains("celebrate")); + } + + #[test] + fn test_mixed_multibyte_and_ascii() { + let valid = vec!["general"]; + + // Mix of ASCII, CJK, and emoji + let text = "Hello 世界 🌍 @general search for 日本語"; + let result = extract_mention_and_text(text, &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // Should not panic and produce valid output + assert!(!remaining.is_empty()); + } } diff --git a/src/cortex-cli/src/import_cmd.rs b/src/cortex-cli/src/import_cmd.rs index 696d93ae..38b25f86 100644 --- a/src/cortex-cli/src/import_cmd.rs +++ b/src/cortex-cli/src/import_cmd.rs @@ -357,31 +357,47 @@ fn validate_export_messages(messages: &[ExportMessage]) -> Result<()> { for (idx, message) in messages.iter().enumerate() { // Check for base64-encoded image data in content // Common pattern: "data:image/png;base64,..." or "data:image/jpeg;base64,..." - if let Some(data_uri_start) = message.content.find("data:image/") - && let Some(base64_marker) = message.content[data_uri_start..].find(";base64,") - { - let base64_start = data_uri_start + base64_marker + 8; // 8 = len(";base64,") - let remaining = &message.content[base64_start..]; - - // Find end of base64 data (could end with quote, whitespace, or end of string) - let base64_end = remaining - .find(['"', '\'', ' ', '\n', ')']) - .unwrap_or(remaining.len()); - let base64_data = &remaining[..base64_end]; - - // Validate the base64 data - if !base64_data.is_empty() { - let engine = base64::engine::general_purpose::STANDARD; - if let Err(e) = engine.decode(base64_data) { - bail!( - "Invalid base64 encoding in message {} (role: '{}'): {}\n\ - The image data starting at position {} has invalid base64 encoding.\n\ - Please ensure all embedded images use valid base64 encoding.", - idx + 1, - message.role, - e, - data_uri_start - ); + if let Some(data_uri_start) = message.content.find("data:image/") { + // Use safe slicing with .get() to avoid panics on multi-byte UTF-8 boundaries + let content_after_start = match message.content.get(data_uri_start..) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this message + }; + + if let Some(base64_marker) = content_after_start.find(";base64,") { + let base64_start = data_uri_start + base64_marker + 8; // 8 = len(";base64,") + + // Safe slicing for the remaining content after base64 marker + let remaining = match message.content.get(base64_start..) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this message + }; + + // Find end of base64 data (could end with quote, whitespace, or end of string) + let base64_end = remaining + .find(['"', '\'', ' ', '\n', ')']) + .unwrap_or(remaining.len()); + + // Safe slicing for the base64 data + let base64_data = match remaining.get(..base64_end) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this message + }; + + // Validate the base64 data + if !base64_data.is_empty() { + let engine = base64::engine::general_purpose::STANDARD; + if let Err(e) = engine.decode(base64_data) { + bail!( + "Invalid base64 encoding in message {} (role: '{}'): {}\n\ + The image data starting at position {} has invalid base64 encoding.\n\ + Please ensure all embedded images use valid base64 encoding.", + idx + 1, + message.role, + e, + data_uri_start + ); + } } } } @@ -395,13 +411,24 @@ fn validate_export_messages(messages: &[ExportMessage]) -> Result<()> { // Try to find and validate any base64 in the arguments for (pos, _) in args_str.match_indices(";base64,") { let base64_start = pos + 8; - let remaining = &args_str[base64_start..]; + + // Safe slicing for the remaining content after base64 marker + let remaining = match args_str.get(base64_start..) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this occurrence + }; + let base64_end = remaining .find(|c: char| { c == '"' || c == '\'' || c == ' ' || c == '\n' || c == ')' }) .unwrap_or(remaining.len()); - let base64_data = &remaining[..base64_end]; + + // Safe slicing for the base64 data + let base64_data = match remaining.get(..base64_end) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this occurrence + }; if !base64_data.is_empty() { let engine = base64::engine::general_purpose::STANDARD; diff --git a/src/cortex-cli/src/lock_cmd.rs b/src/cortex-cli/src/lock_cmd.rs index dc652cad..1caa3d3c 100644 --- a/src/cortex-cli/src/lock_cmd.rs +++ b/src/cortex-cli/src/lock_cmd.rs @@ -114,6 +114,15 @@ fn validate_session_id(session_id: &str) -> Result<()> { ) } +/// Safely get a string prefix by character count, not byte count. +/// This avoids panics on multi-byte UTF-8 characters. +fn safe_char_prefix(s: &str, max_chars: usize) -> &str { + match s.char_indices().nth(max_chars) { + Some((byte_idx, _)) => &s[..byte_idx], + None => s, // String has fewer than max_chars characters + } +} + /// Get the lock file path. fn get_lock_file_path() -> PathBuf { dirs::home_dir() @@ -156,7 +165,7 @@ pub fn is_session_locked(session_id: &str) -> bool { match load_lock_file() { Ok(lock_file) => lock_file.locked_sessions.iter().any(|entry| { entry.session_id == session_id - || session_id.starts_with(&entry.session_id[..8.min(entry.session_id.len())]) + || session_id.starts_with(safe_char_prefix(&entry.session_id, 8)) }), Err(_) => false, } @@ -308,7 +317,7 @@ async fn run_list(args: LockListArgs) -> Result<()> { println!("{}", "-".repeat(60)); for entry in &lock_file.locked_sessions { - let short_id = &entry.session_id[..8.min(entry.session_id.len())]; + let short_id = safe_char_prefix(&entry.session_id, 8); println!(" {} - locked at {}", short_id, entry.locked_at); if let Some(ref reason) = entry.reason { println!(" Reason: {}", reason); @@ -332,7 +341,7 @@ async fn run_check(args: LockCheckArgs) -> Result<()> { e.session_id == args.session_id || args .session_id - .starts_with(&e.session_id[..8.min(e.session_id.len())]) + .starts_with(safe_char_prefix(&e.session_id, 8)) }); if is_locked { @@ -342,7 +351,7 @@ async fn run_check(args: LockCheckArgs) -> Result<()> { e.session_id == args.session_id || args .session_id - .starts_with(&e.session_id[..8.min(e.session_id.len())]) + .starts_with(safe_char_prefix(&e.session_id, 8)) }) && let Some(ref reason) = entry.reason { println!("Reason: {}", reason); @@ -508,4 +517,39 @@ mod tests { let path_str = path.to_string_lossy(); assert!(path_str.contains(".cortex")); } + + #[test] + fn test_safe_char_prefix_ascii() { + // ASCII strings should work correctly + assert_eq!(safe_char_prefix("abcdefghij", 8), "abcdefgh"); + assert_eq!(safe_char_prefix("abc", 8), "abc"); + assert_eq!(safe_char_prefix("", 8), ""); + assert_eq!(safe_char_prefix("12345678", 8), "12345678"); + } + + #[test] + fn test_safe_char_prefix_utf8_multibyte() { + // Multi-byte UTF-8 characters should not panic + // Each emoji is 4 bytes, so 8 chars = 32 bytes + let emoji_id = "🔥🎉🚀💡🌟✨🎯🔮extra"; + assert_eq!(safe_char_prefix(emoji_id, 8), "🔥🎉🚀💡🌟✨🎯🔮"); + + // Mixed ASCII and multi-byte + let mixed = "ab🔥cd🎉ef"; + assert_eq!(safe_char_prefix(mixed, 4), "ab🔥c"); + assert_eq!(safe_char_prefix(mixed, 8), "ab🔥cd🎉ef"); + + // Chinese characters (3 bytes each) + let chinese = "中文测试会话标识符"; + assert_eq!(safe_char_prefix(chinese, 4), "中文测试"); + } + + #[test] + fn test_safe_char_prefix_boundary() { + // Edge cases + assert_eq!(safe_char_prefix("a", 0), ""); + assert_eq!(safe_char_prefix("a", 1), "a"); + assert_eq!(safe_char_prefix("🔥", 1), "🔥"); + assert_eq!(safe_char_prefix("🔥", 0), ""); + } } diff --git a/src/cortex-cli/src/utils/notification.rs b/src/cortex-cli/src/utils/notification.rs index 4656e223..8edd2c93 100644 --- a/src/cortex-cli/src/utils/notification.rs +++ b/src/cortex-cli/src/utils/notification.rs @@ -63,7 +63,14 @@ pub fn send_task_notification(session_id: &str, success: bool) -> Result<()> { "Cortex Task Failed" }; - let short_id = &session_id[..8.min(session_id.len())]; + // Use safe UTF-8 slicing - find the last valid char boundary at or before position 8 + let short_id = session_id + .char_indices() + .take_while(|(idx, _)| *idx < 8) + .map(|(idx, ch)| idx + ch.len_utf8()) + .last() + .and_then(|end| session_id.get(..end)) + .unwrap_or(session_id); let body = format!("Session: {}", short_id); let urgency = if success { diff --git a/src/cortex-resume/src/resume_picker.rs b/src/cortex-resume/src/resume_picker.rs index 9bf8832a..7b0ee9a6 100644 --- a/src/cortex-resume/src/resume_picker.rs +++ b/src/cortex-resume/src/resume_picker.rs @@ -153,12 +153,15 @@ fn format_relative_time(time: &chrono::DateTime) -> String { } } -/// Truncate string to fit width. +/// Truncate string to fit width, handling multi-byte UTF-8 safely. fn truncate_string(s: &str, width: usize) -> String { - if s.len() <= width { + // Count actual character width, not byte length + let char_count = s.chars().count(); + if char_count <= width { s.to_string() } else if width > 3 { - format!("{}...", &s[..width - 3]) + let truncated: String = s.chars().take(width - 3).collect(); + format!("{}...", truncated) } else { s.chars().take(width).collect() } @@ -176,4 +179,40 @@ mod tests { let hour_ago = now - chrono::Duration::hours(2); assert_eq!(format_relative_time(&hour_ago), "2h ago"); } + + #[test] + fn test_truncate_string_ascii() { + // Short string, no truncation needed + assert_eq!(truncate_string("hello", 10), "hello"); + + // Exact fit + assert_eq!(truncate_string("hello", 5), "hello"); + + // Needs truncation + assert_eq!(truncate_string("hello world", 8), "hello..."); + + // Very short width + assert_eq!(truncate_string("hello", 3), "hel"); + assert_eq!(truncate_string("hello", 2), "he"); + } + + #[test] + fn test_truncate_string_utf8() { + // UTF-8 multi-byte characters (Japanese) + let japanese = "こんにちは世界"; // 7 chars + assert_eq!(truncate_string(japanese, 10), japanese); // No truncation + assert_eq!(truncate_string(japanese, 7), japanese); // Exact fit + assert_eq!(truncate_string(japanese, 6), "こんに..."); // Truncated (3 chars + ...) + + // UTF-8 with emoji + let emoji = "Hello 🌍🌎🌏"; // 9 chars: H,e,l,l,o, ,🌍,🌎,🌏 + assert_eq!(truncate_string(emoji, 20), emoji); // No truncation + assert_eq!(truncate_string(emoji, 9), emoji); // Exact fit (9 chars) + assert_eq!(truncate_string(emoji, 8), "Hello..."); // Truncated (5 chars + ...) + + // Mixed UTF-8 and ASCII + let mixed = "路径/path/文件"; // 11 chars + assert_eq!(truncate_string(mixed, 20), mixed); // No truncation + assert_eq!(truncate_string(mixed, 8), "路径/pa..."); // Truncated + } }