From 535de7bbbf2e12f7e08fa7eb87bc53db259f4a28 Mon Sep 17 00:00:00 2001 From: echobt Date: Wed, 4 Feb 2026 17:00:24 +0000 Subject: [PATCH] fix: consolidated bug fixes and security improvements This PR consolidates all bug fixes and security improvements from PRs #69-88 into a single cohesive change. ## Categories ### Security Fixes - Path traversal prevention in MCP and session storage - Shell injection prevention in restore scripts - Secure random temp files for external editor - TOCTOU race condition fixes ### TUI Improvements - Overflow prevention for u16 conversions - Cursor positioning fixes in selection lists - Unicode width handling for popups - Empty section handling in help browser ### Error Handling - Graceful semaphore and init failure handling - Improved error propagation in middleware - Better client access error handling - SystemTime operation safety ### Memory and Storage - Cache size limits to prevent unbounded growth - File lock cleanup for memory leak prevention - fsync after critical writes for durability - Bounded ToolResponseStore with automatic cleanup ### Protocol Robustness - Buffer size limits for StreamProcessor - ToolState transition validation - State machine documentation ### Numeric Safety - Saturating operations to prevent overflow/underflow - Safe UTF-8 string slicing throughout codebase ### Tools - Parameter alias support for backward compatibility - Handler name consistency fixes ## Files Modified Multiple files across cortex-tui, cortex-engine, cortex-exec, cortex-common, cortex-protocol, cortex-storage, cortex-mcp-server, and other crates. Closes #69, #70, #71, #73, #75, #80, #82, #87, #88 --- Cargo.lock | 1 + src/cortex-agents/src/mention.rs | 148 ++++- src/cortex-app-server/src/auth.rs | 6 +- src/cortex-app-server/src/config.rs | 10 + src/cortex-app-server/src/middleware.rs | 3 +- src/cortex-app-server/src/storage.rs | 3 - src/cortex-apply-patch/src/hunk.rs | 11 - src/cortex-cli/src/import_cmd.rs | 81 ++- src/cortex-cli/src/lock_cmd.rs | 52 +- src/cortex-cli/src/utils/notification.rs | 9 +- src/cortex-cli/src/utils/paths.rs | 65 ++- src/cortex-common/src/file_locking.rs | 35 ++ src/cortex-common/src/http_client.rs | 48 ++ src/cortex-compact/src/compactor.rs | 5 +- src/cortex-engine/src/async_utils.rs | 57 +- .../src/config/config_discovery.rs | 28 +- src/cortex-engine/src/git_info.rs | 16 +- src/cortex-engine/src/ratelimit.rs | 14 +- src/cortex-engine/src/streaming.rs | 29 +- src/cortex-engine/src/tokenizer.rs | 22 +- .../src/tools/handlers/file_ops.rs | 4 +- src/cortex-engine/src/tools/handlers/glob.rs | 1 + src/cortex-engine/src/tools/handlers/grep.rs | 1 + src/cortex-engine/src/tools/mod.rs | 6 + src/cortex-engine/src/tools/response_store.rs | 537 ++++++++++++++++++ src/cortex-engine/src/validation.rs | 134 ++++- src/cortex-exec/src/runner.rs | 41 +- src/cortex-mcp-client/src/transport.rs | 103 +--- src/cortex-mcp-server/src/server.rs | 15 +- src/cortex-plugins/src/registry.rs | 21 +- src/cortex-protocol/Cargo.toml | 1 + .../src/protocol/message_parts.rs | 48 ++ src/cortex-resume/src/resume_picker.rs | 45 +- src/cortex-resume/src/session_store.rs | 19 + src/cortex-shell-snapshot/src/snapshot.rs | 52 +- src/cortex-storage/src/sessions/storage.rs | 53 +- src/cortex-tui-components/src/dropdown.rs | 4 +- src/cortex-tui-components/src/scroll.rs | 6 +- .../src/selection_list.rs | 8 +- src/cortex-tui/Cargo.toml | 1 + src/cortex-tui/src/cards/commands.rs | 5 +- src/cortex-tui/src/cards/models.rs | 5 +- src/cortex-tui/src/cards/sessions.rs | 4 +- src/cortex-tui/src/external_editor.rs | 45 +- src/cortex-tui/src/interactive/renderer.rs | 8 +- src/cortex-tui/src/mcp_storage.rs | 82 ++- src/cortex-tui/src/session/storage.rs | 130 ++++- src/cortex-tui/src/widgets/autocomplete.rs | 19 +- .../src/widgets/help_browser/render.rs | 4 +- .../src/widgets/help_browser/state.rs | 9 +- .../src/widgets/help_browser/tests.rs | 22 +- src/cortex-tui/src/widgets/mention_popup.rs | 11 +- .../src/widgets/scrollable_dropdown.rs | 19 +- 53 files changed, 1804 insertions(+), 302 deletions(-) create mode 100644 src/cortex-engine/src/tools/response_store.rs diff --git a/Cargo.lock b/Cargo.lock index de5594c..130925f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1502,6 +1502,7 @@ dependencies = [ "serde_json", "strum_macros", "tempfile", + "tracing", "uuid", ] diff --git a/src/cortex-agents/src/mention.rs b/src/cortex-agents/src/mention.rs index 81d9d71..59bb955 100644 --- a/src/cortex-agents/src/mention.rs +++ b/src/cortex-agents/src/mention.rs @@ -17,6 +17,46 @@ use regex::Regex; use std::sync::LazyLock; +/// Safely get the string slice up to the given byte position. +/// +/// Returns the slice `&text[..pos]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching backwards. +fn safe_slice_up_to(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return text; + } + if text.is_char_boundary(pos) { + return &text[..pos]; + } + // Find the nearest valid boundary by searching backwards + let mut valid_pos = pos; + while valid_pos > 0 && !text.is_char_boundary(valid_pos) { + valid_pos -= 1; + } + &text[..valid_pos] +} + +/// Safely get the string slice from the given byte position to the end. +/// +/// Returns the slice `&text[pos..]` if `pos` is at a valid UTF-8 character boundary. +/// If `pos` is inside a multi-byte character, finds the nearest valid boundary +/// by searching forwards. +fn safe_slice_from(text: &str, pos: usize) -> &str { + if pos >= text.len() { + return ""; + } + if text.is_char_boundary(pos) { + return &text[pos..]; + } + // Find the nearest valid boundary by searching forwards + let mut valid_pos = pos; + while valid_pos < text.len() && !text.is_char_boundary(valid_pos) { + valid_pos += 1; + } + &text[valid_pos..] +} + /// A parsed agent mention from user input. #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentMention { @@ -108,10 +148,10 @@ pub fn extract_mention_and_text( ) -> Option<(AgentMention, String)> { let mention = find_first_valid_mention(text, valid_agents)?; - // Remove the mention from text + // Remove the mention from text, using safe slicing for UTF-8 boundaries let mut remaining = String::with_capacity(text.len()); - remaining.push_str(&text[..mention.start]); - remaining.push_str(&text[mention.end..]); + remaining.push_str(safe_slice_up_to(text, mention.start)); + remaining.push_str(safe_slice_from(text, mention.end)); // Trim and normalize whitespace let remaining = remaining.trim().to_string(); @@ -123,7 +163,8 @@ pub fn extract_mention_and_text( pub fn starts_with_mention(text: &str, valid_agents: &[&str]) -> bool { let text = text.trim(); if let Some(mention) = find_first_valid_mention(text, valid_agents) { - mention.start == 0 || text[..mention.start].trim().is_empty() + // Use safe slicing to handle UTF-8 boundaries + mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() } else { false } @@ -196,8 +237,8 @@ pub fn parse_message_for_agent(text: &str, valid_agents: &[&str]) -> ParsedAgent // Check if message starts with @agent if let Some((mention, remaining)) = extract_mention_and_text(text, valid_agents) { - // Only trigger if mention is at the start - if mention.start == 0 || text[..mention.start].trim().is_empty() { + // Only trigger if mention is at the start, using safe slicing for UTF-8 boundaries + if mention.start == 0 || safe_slice_up_to(text, mention.start).trim().is_empty() { return ParsedAgentMessage::for_agent(mention.agent_name, remaining, text.to_string()); } } @@ -318,4 +359,99 @@ mod tests { assert_eq!(mentions[0].agent_name, "my-agent"); assert_eq!(mentions[1].agent_name, "my_agent"); } + + // UTF-8 boundary safety tests + #[test] + fn test_safe_slice_up_to_ascii() { + let text = "hello world"; + assert_eq!(safe_slice_up_to(text, 5), "hello"); + assert_eq!(safe_slice_up_to(text, 0), ""); + assert_eq!(safe_slice_up_to(text, 100), "hello world"); + } + + #[test] + fn test_safe_slice_up_to_multibyte() { + // "こんにちは" - each character is 3 bytes + let text = "こんにちは"; + assert_eq!(safe_slice_up_to(text, 3), "こ"); // Valid boundary + assert_eq!(safe_slice_up_to(text, 6), "こん"); // Valid boundary + // Position 4 is inside the second character, should return "こ" + assert_eq!(safe_slice_up_to(text, 4), "こ"); + assert_eq!(safe_slice_up_to(text, 5), "こ"); + } + + #[test] + fn test_safe_slice_from_multibyte() { + let text = "こんにちは"; + assert_eq!(safe_slice_from(text, 3), "んにちは"); // Valid boundary + // Position 4 is inside second character, should skip to position 6 + assert_eq!(safe_slice_from(text, 4), "にちは"); + assert_eq!(safe_slice_from(text, 5), "にちは"); + } + + #[test] + fn test_extract_mention_with_multibyte_prefix() { + let valid = vec!["general"]; + + // Multi-byte characters before mention + let result = extract_mention_and_text("日本語 @general search files", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // The prefix should be preserved without panicking + assert!(remaining.contains("search files")); + } + + #[test] + fn test_starts_with_mention_multibyte() { + let valid = vec!["general"]; + + // Whitespace with multi-byte characters should not cause panic + assert!(starts_with_mention(" @general task", &valid)); + + // Multi-byte characters before mention - should return false, not panic + assert!(!starts_with_mention("日本語 @general task", &valid)); + } + + #[test] + fn test_parse_message_for_agent_multibyte() { + let valid = vec!["general"]; + + // Multi-byte prefix - should not panic + let parsed = parse_message_for_agent("日本語 @general find files", &valid); + // Since mention is not at the start, should not invoke task + assert!(!parsed.should_invoke_task); + + // Multi-byte in the prompt (after mention) + let parsed = parse_message_for_agent("@general 日本語を検索", &valid); + assert!(parsed.should_invoke_task); + assert_eq!(parsed.agent, Some("general".to_string())); + assert_eq!(parsed.prompt, "日本語を検索"); + } + + #[test] + fn test_extract_mention_with_emoji() { + let valid = vec!["general"]; + + // Emojis are 4 bytes each + let result = extract_mention_and_text("🎉 @general celebrate", &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + assert!(remaining.contains("celebrate")); + } + + #[test] + fn test_mixed_multibyte_and_ascii() { + let valid = vec!["general"]; + + // Mix of ASCII, CJK, and emoji + let text = "Hello 世界 🌍 @general search for 日本語"; + let result = extract_mention_and_text(text, &valid); + assert!(result.is_some()); + let (mention, remaining) = result.unwrap(); + assert_eq!(mention.agent_name, "general"); + // Should not panic and produce valid output + assert!(!remaining.is_empty()); + } } diff --git a/src/cortex-app-server/src/auth.rs b/src/cortex-app-server/src/auth.rs index 414f36f..4f240c3 100644 --- a/src/cortex-app-server/src/auth.rs +++ b/src/cortex-app-server/src/auth.rs @@ -45,7 +45,7 @@ impl Claims { pub fn new(user_id: impl Into, expiry_seconds: u64) -> Self { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_secs(); Self { @@ -75,7 +75,7 @@ impl Claims { pub fn is_expired(&self) -> bool { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_secs(); self.exp < now } @@ -187,7 +187,7 @@ impl AuthService { pub async fn cleanup_revoked_tokens(&self) { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .unwrap() + .unwrap_or_default() .as_secs(); let mut revoked = self.revoked_tokens.write().await; diff --git a/src/cortex-app-server/src/config.rs b/src/cortex-app-server/src/config.rs index 35ac75b..92be050 100644 --- a/src/cortex-app-server/src/config.rs +++ b/src/cortex-app-server/src/config.rs @@ -49,12 +49,18 @@ pub struct ServerConfig { pub max_body_size: usize, /// Request timeout in seconds (applies to full request lifecycle). + /// + /// See `cortex_common::http_client` module documentation for the complete + /// timeout hierarchy across Cortex services. #[serde(default = "default_request_timeout")] pub request_timeout: u64, /// Read timeout for individual chunks in seconds. /// Applies to chunked transfer encoding to prevent indefinite hangs /// when clients disconnect without sending the terminal chunk. + /// + /// See `cortex_common::http_client` module documentation for the complete + /// timeout hierarchy across Cortex services. #[serde(default = "default_read_timeout")] pub read_timeout: u64, @@ -71,12 +77,16 @@ pub struct ServerConfig { pub cors_origins: Vec, /// Graceful shutdown timeout in seconds. + /// + /// See `cortex_common::http_client` module documentation for the complete + /// timeout hierarchy across Cortex services. #[serde(default = "default_shutdown_timeout")] pub shutdown_timeout: u64, } fn default_shutdown_timeout() -> u64 { 30 // 30 seconds for graceful shutdown + // See cortex_common::http_client for timeout hierarchy documentation } fn default_listen_addr() -> String { diff --git a/src/cortex-app-server/src/middleware.rs b/src/cortex-app-server/src/middleware.rs index a997157..45d4406 100644 --- a/src/cortex-app-server/src/middleware.rs +++ b/src/cortex-app-server/src/middleware.rs @@ -40,7 +40,8 @@ pub async fn request_id_middleware(mut request: Request, next: Next) -> Response let mut response = next.run(request).await; response.headers_mut().insert( REQUEST_ID_HEADER, - HeaderValue::from_str(&request_id).unwrap(), + HeaderValue::from_str(&request_id) + .unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")), ); response diff --git a/src/cortex-app-server/src/storage.rs b/src/cortex-app-server/src/storage.rs index 6c5d44e..1aa617f 100644 --- a/src/cortex-app-server/src/storage.rs +++ b/src/cortex-app-server/src/storage.rs @@ -47,8 +47,6 @@ pub struct StoredToolCall { /// Session storage manager. pub struct SessionStorage { - #[allow(dead_code)] - base_dir: PathBuf, sessions_dir: PathBuf, history_dir: PathBuf, } @@ -66,7 +64,6 @@ impl SessionStorage { info!("Session storage initialized at {:?}", base_dir); Ok(Self { - base_dir, sessions_dir, history_dir, }) diff --git a/src/cortex-apply-patch/src/hunk.rs b/src/cortex-apply-patch/src/hunk.rs index ea67a97..ab5b1f1 100644 --- a/src/cortex-apply-patch/src/hunk.rs +++ b/src/cortex-apply-patch/src/hunk.rs @@ -250,9 +250,6 @@ pub struct SearchReplace { pub search: String, /// The text to replace with. pub replace: String, - /// Replace all occurrences (true) or just the first (false). - #[allow(dead_code)] - pub replace_all: bool, } impl SearchReplace { @@ -266,16 +263,8 @@ impl SearchReplace { path: path.into(), search: search.into(), replace: replace.into(), - replace_all: false, } } - - /// Set whether to replace all occurrences. - #[allow(dead_code)] - pub fn with_replace_all(mut self, replace_all: bool) -> Self { - self.replace_all = replace_all; - self - } } #[cfg(test)] diff --git a/src/cortex-cli/src/import_cmd.rs b/src/cortex-cli/src/import_cmd.rs index 696d93a..38b25f8 100644 --- a/src/cortex-cli/src/import_cmd.rs +++ b/src/cortex-cli/src/import_cmd.rs @@ -357,31 +357,47 @@ fn validate_export_messages(messages: &[ExportMessage]) -> Result<()> { for (idx, message) in messages.iter().enumerate() { // Check for base64-encoded image data in content // Common pattern: "data:image/png;base64,..." or "data:image/jpeg;base64,..." - if let Some(data_uri_start) = message.content.find("data:image/") - && let Some(base64_marker) = message.content[data_uri_start..].find(";base64,") - { - let base64_start = data_uri_start + base64_marker + 8; // 8 = len(";base64,") - let remaining = &message.content[base64_start..]; - - // Find end of base64 data (could end with quote, whitespace, or end of string) - let base64_end = remaining - .find(['"', '\'', ' ', '\n', ')']) - .unwrap_or(remaining.len()); - let base64_data = &remaining[..base64_end]; - - // Validate the base64 data - if !base64_data.is_empty() { - let engine = base64::engine::general_purpose::STANDARD; - if let Err(e) = engine.decode(base64_data) { - bail!( - "Invalid base64 encoding in message {} (role: '{}'): {}\n\ - The image data starting at position {} has invalid base64 encoding.\n\ - Please ensure all embedded images use valid base64 encoding.", - idx + 1, - message.role, - e, - data_uri_start - ); + if let Some(data_uri_start) = message.content.find("data:image/") { + // Use safe slicing with .get() to avoid panics on multi-byte UTF-8 boundaries + let content_after_start = match message.content.get(data_uri_start..) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this message + }; + + if let Some(base64_marker) = content_after_start.find(";base64,") { + let base64_start = data_uri_start + base64_marker + 8; // 8 = len(";base64,") + + // Safe slicing for the remaining content after base64 marker + let remaining = match message.content.get(base64_start..) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this message + }; + + // Find end of base64 data (could end with quote, whitespace, or end of string) + let base64_end = remaining + .find(['"', '\'', ' ', '\n', ')']) + .unwrap_or(remaining.len()); + + // Safe slicing for the base64 data + let base64_data = match remaining.get(..base64_end) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this message + }; + + // Validate the base64 data + if !base64_data.is_empty() { + let engine = base64::engine::general_purpose::STANDARD; + if let Err(e) = engine.decode(base64_data) { + bail!( + "Invalid base64 encoding in message {} (role: '{}'): {}\n\ + The image data starting at position {} has invalid base64 encoding.\n\ + Please ensure all embedded images use valid base64 encoding.", + idx + 1, + message.role, + e, + data_uri_start + ); + } } } } @@ -395,13 +411,24 @@ fn validate_export_messages(messages: &[ExportMessage]) -> Result<()> { // Try to find and validate any base64 in the arguments for (pos, _) in args_str.match_indices(";base64,") { let base64_start = pos + 8; - let remaining = &args_str[base64_start..]; + + // Safe slicing for the remaining content after base64 marker + let remaining = match args_str.get(base64_start..) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this occurrence + }; + let base64_end = remaining .find(|c: char| { c == '"' || c == '\'' || c == ' ' || c == '\n' || c == ')' }) .unwrap_or(remaining.len()); - let base64_data = &remaining[..base64_end]; + + // Safe slicing for the base64 data + let base64_data = match remaining.get(..base64_end) { + Some(s) => s, + None => continue, // Invalid byte offset, skip this occurrence + }; if !base64_data.is_empty() { let engine = base64::engine::general_purpose::STANDARD; diff --git a/src/cortex-cli/src/lock_cmd.rs b/src/cortex-cli/src/lock_cmd.rs index dc652ca..1caa3d3 100644 --- a/src/cortex-cli/src/lock_cmd.rs +++ b/src/cortex-cli/src/lock_cmd.rs @@ -114,6 +114,15 @@ fn validate_session_id(session_id: &str) -> Result<()> { ) } +/// Safely get a string prefix by character count, not byte count. +/// This avoids panics on multi-byte UTF-8 characters. +fn safe_char_prefix(s: &str, max_chars: usize) -> &str { + match s.char_indices().nth(max_chars) { + Some((byte_idx, _)) => &s[..byte_idx], + None => s, // String has fewer than max_chars characters + } +} + /// Get the lock file path. fn get_lock_file_path() -> PathBuf { dirs::home_dir() @@ -156,7 +165,7 @@ pub fn is_session_locked(session_id: &str) -> bool { match load_lock_file() { Ok(lock_file) => lock_file.locked_sessions.iter().any(|entry| { entry.session_id == session_id - || session_id.starts_with(&entry.session_id[..8.min(entry.session_id.len())]) + || session_id.starts_with(safe_char_prefix(&entry.session_id, 8)) }), Err(_) => false, } @@ -308,7 +317,7 @@ async fn run_list(args: LockListArgs) -> Result<()> { println!("{}", "-".repeat(60)); for entry in &lock_file.locked_sessions { - let short_id = &entry.session_id[..8.min(entry.session_id.len())]; + let short_id = safe_char_prefix(&entry.session_id, 8); println!(" {} - locked at {}", short_id, entry.locked_at); if let Some(ref reason) = entry.reason { println!(" Reason: {}", reason); @@ -332,7 +341,7 @@ async fn run_check(args: LockCheckArgs) -> Result<()> { e.session_id == args.session_id || args .session_id - .starts_with(&e.session_id[..8.min(e.session_id.len())]) + .starts_with(safe_char_prefix(&e.session_id, 8)) }); if is_locked { @@ -342,7 +351,7 @@ async fn run_check(args: LockCheckArgs) -> Result<()> { e.session_id == args.session_id || args .session_id - .starts_with(&e.session_id[..8.min(e.session_id.len())]) + .starts_with(safe_char_prefix(&e.session_id, 8)) }) && let Some(ref reason) = entry.reason { println!("Reason: {}", reason); @@ -508,4 +517,39 @@ mod tests { let path_str = path.to_string_lossy(); assert!(path_str.contains(".cortex")); } + + #[test] + fn test_safe_char_prefix_ascii() { + // ASCII strings should work correctly + assert_eq!(safe_char_prefix("abcdefghij", 8), "abcdefgh"); + assert_eq!(safe_char_prefix("abc", 8), "abc"); + assert_eq!(safe_char_prefix("", 8), ""); + assert_eq!(safe_char_prefix("12345678", 8), "12345678"); + } + + #[test] + fn test_safe_char_prefix_utf8_multibyte() { + // Multi-byte UTF-8 characters should not panic + // Each emoji is 4 bytes, so 8 chars = 32 bytes + let emoji_id = "🔥🎉🚀💡🌟✨🎯🔮extra"; + assert_eq!(safe_char_prefix(emoji_id, 8), "🔥🎉🚀💡🌟✨🎯🔮"); + + // Mixed ASCII and multi-byte + let mixed = "ab🔥cd🎉ef"; + assert_eq!(safe_char_prefix(mixed, 4), "ab🔥c"); + assert_eq!(safe_char_prefix(mixed, 8), "ab🔥cd🎉ef"); + + // Chinese characters (3 bytes each) + let chinese = "中文测试会话标识符"; + assert_eq!(safe_char_prefix(chinese, 4), "中文测试"); + } + + #[test] + fn test_safe_char_prefix_boundary() { + // Edge cases + assert_eq!(safe_char_prefix("a", 0), ""); + assert_eq!(safe_char_prefix("a", 1), "a"); + assert_eq!(safe_char_prefix("🔥", 1), "🔥"); + assert_eq!(safe_char_prefix("🔥", 0), ""); + } } diff --git a/src/cortex-cli/src/utils/notification.rs b/src/cortex-cli/src/utils/notification.rs index 4656e22..8edd2c9 100644 --- a/src/cortex-cli/src/utils/notification.rs +++ b/src/cortex-cli/src/utils/notification.rs @@ -63,7 +63,14 @@ pub fn send_task_notification(session_id: &str, success: bool) -> Result<()> { "Cortex Task Failed" }; - let short_id = &session_id[..8.min(session_id.len())]; + // Use safe UTF-8 slicing - find the last valid char boundary at or before position 8 + let short_id = session_id + .char_indices() + .take_while(|(idx, _)| *idx < 8) + .map(|(idx, ch)| idx + ch.len_utf8()) + .last() + .and_then(|end| session_id.get(..end)) + .unwrap_or(session_id); let body = format!("Session: {}", short_id); let urgency = if success { diff --git a/src/cortex-cli/src/utils/paths.rs b/src/cortex-cli/src/utils/paths.rs index 8cdf03e..c9654ff 100644 --- a/src/cortex-cli/src/utils/paths.rs +++ b/src/cortex-cli/src/utils/paths.rs @@ -34,10 +34,16 @@ pub fn get_cortex_home() -> PathBuf { /// // Returns: /home/user/documents/file.txt /// ``` pub fn expand_tilde(path: &str) -> String { - if path.starts_with("~/") - && let Some(home) = dirs::home_dir() - { - return home.join(&path[2..]).to_string_lossy().to_string(); + if path == "~" { + // Handle bare "~" - return home directory + if let Some(home) = dirs::home_dir() { + return home.to_string_lossy().to_string(); + } + } else if let Some(suffix) = path.strip_prefix("~/") { + // Handle "~/" prefix - expand to home directory + rest of path + if let Some(home) = dirs::home_dir() { + return home.join(suffix).to_string_lossy().to_string(); + } } path.to_string() } @@ -58,8 +64,12 @@ pub fn expand_tilde(path: &str) -> String { pub fn validate_path_safety(path: &Path, base_dir: Option<&Path>) -> Result<(), String> { let path_str = path.to_string_lossy(); - // Check for path traversal attempts - if path_str.contains("..") { + // Check for path traversal attempts by examining path components + // This correctly handles filenames containing ".." like "file..txt" + if path + .components() + .any(|c| matches!(c, std::path::Component::ParentDir)) + { return Err("Path contains traversal sequence '..'".to_string()); } @@ -257,8 +267,15 @@ mod tests { #[test] fn test_expand_tilde_with_tilde_only() { - // Test tilde alone - should remain unchanged (not "~/") - assert_eq!(expand_tilde("~"), "~"); + // Test bare "~" - should expand to home directory + let result = expand_tilde("~"); + if let Some(home) = dirs::home_dir() { + let expected = home.to_string_lossy().to_string(); + assert_eq!(result, expected); + } else { + // If no home dir, original is returned + assert_eq!(result, "~"); + } } #[test] @@ -320,20 +337,30 @@ mod tests { #[test] fn test_validate_path_safety_detects_various_traversal_patterns() { - // Different traversal patterns - let patterns = ["foo/../bar", "...", "foo/bar/../baz", "./foo/../../../etc"]; + // Patterns that ARE path traversal (contain ".." as a component) + let traversal_patterns = ["foo/../bar", "foo/bar/../baz", "./foo/../../../etc", ".."]; - for pattern in patterns { + for pattern in traversal_patterns { let path = Path::new(pattern); let result = validate_path_safety(path, None); - // Only patterns containing ".." should fail - if pattern.contains("..") { - assert!( - result.is_err(), - "Expected traversal detection for: {}", - pattern - ); - } + assert!( + result.is_err(), + "Expected traversal detection for: {}", + pattern + ); + } + + // Patterns that are NOT path traversal (contain ".." in filenames only) + let safe_patterns = ["file..txt", "..hidden", "test...file", "foo/bar..baz/file"]; + + for pattern in safe_patterns { + let path = Path::new(pattern); + let result = validate_path_safety(path, None); + assert!( + result.is_ok(), + "False positive: '{}' should not be detected as traversal", + pattern + ); } } diff --git a/src/cortex-common/src/file_locking.rs b/src/cortex-common/src/file_locking.rs index f9b78db..d2b4f73 100644 --- a/src/cortex-common/src/file_locking.rs +++ b/src/cortex-common/src/file_locking.rs @@ -557,6 +557,9 @@ pub async fn atomic_write_async( .map_err(|e| FileLockError::AtomicWriteFailed(format!("spawn_blocking failed: {}", e)))? } +/// Maximum number of lock entries before triggering cleanup. +const MAX_LOCK_ENTRIES: usize = 10_000; + /// A file lock manager for coordinating access across multiple operations. /// /// This is useful when you need to perform multiple operations on a file @@ -577,15 +580,47 @@ impl FileLockManager { /// /// This is in addition to the filesystem-level advisory lock and helps /// coordinate access within the same process. + /// + /// Automatically cleans up stale lock entries when the map grows too large. pub fn get_lock(&self, path: impl AsRef) -> Arc> { let path = path.as_ref().to_path_buf(); let mut locks = self.locks.lock().unwrap(); + + // Clean up stale entries if the map is getting large + if locks.len() >= MAX_LOCK_ENTRIES { + Self::cleanup_stale_entries(&mut locks); + } + locks .entry(path) .or_insert_with(|| Arc::new(std::sync::Mutex::new(()))) .clone() } + /// Remove lock entries that are no longer in use. + /// + /// An entry is considered stale when only the HashMap holds a reference + /// to it (strong_count == 1), meaning no caller is currently using the lock. + fn cleanup_stale_entries( + locks: &mut std::collections::HashMap>>, + ) { + locks.retain(|_, arc| Arc::strong_count(arc) > 1); + } + + /// Manually trigger cleanup of stale lock entries. + /// + /// This removes entries where no external reference exists (only the + /// manager holds the Arc). Useful for periodic maintenance. + pub fn cleanup(&self) { + let mut locks = self.locks.lock().unwrap(); + Self::cleanup_stale_entries(&mut locks); + } + + /// Returns the current number of lock entries in the manager. + pub fn lock_count(&self) -> usize { + self.locks.lock().unwrap().len() + } + /// Execute an operation with both process-local and file-system locks. pub fn with_lock(&self, path: impl AsRef, mode: LockMode, f: F) -> FileLockResult where diff --git a/src/cortex-common/src/http_client.rs b/src/cortex-common/src/http_client.rs index b181ac8..3b290ff 100644 --- a/src/cortex-common/src/http_client.rs +++ b/src/cortex-common/src/http_client.rs @@ -9,6 +9,54 @@ //! //! DNS caching is configured with reasonable TTL to allow failover and load //! balancer updates (#2177). +//! +//! # Timeout Configuration Guide +//! +//! This section documents the timeout hierarchy across the Cortex codebase. Use this +//! as a reference when configuring timeouts for new features or debugging timeout issues. +//! +//! ## Timeout Hierarchy +//! +//! | Use Case | Timeout | Constant/Location | Rationale | +//! |-----------------------------|---------|--------------------------------------------|-----------------------------------------| +//! | Health checks | 5s | `HEALTH_CHECK_TIMEOUT` (this module) | Quick validation of service status | +//! | Standard HTTP requests | 30s | `DEFAULT_TIMEOUT` (this module) | Normal API calls with reasonable margin | +//! | Per-chunk read (streaming) | 30s | `read_timeout` (cortex-app-server/config) | Individual chunk timeout during stream | +//! | Pool idle timeout | 60s | `POOL_IDLE_TIMEOUT` (this module) | DNS re-resolution for failover | +//! | LLM Request (non-streaming) | 120s | `DEFAULT_REQUEST_TIMEOUT_SECS` (cortex-exec/runner) | Model inference takes time | +//! | LLM Streaming total | 300s | `STREAMING_TIMEOUT` (this module) | Long-running streaming responses | +//! | Server request lifecycle | 300s | `request_timeout` (cortex-app-server/config) | Full HTTP request/response cycle | +//! | Entire exec session | 600s | `DEFAULT_TIMEOUT_SECS` (cortex-exec/runner) | Multi-turn conversation limit | +//! | Graceful shutdown | 30s | `shutdown_timeout` (cortex-app-server/config) | Time for cleanup on shutdown | +//! +//! ## Module-Specific Timeouts +//! +//! ### cortex-common (this module) +//! - `DEFAULT_TIMEOUT` (30s): Use for standard API calls. +//! - `STREAMING_TIMEOUT` (300s): Use for LLM streaming endpoints. +//! - `HEALTH_CHECK_TIMEOUT` (5s): Use for health/readiness checks. +//! - `POOL_IDLE_TIMEOUT` (60s): Connection pool cleanup for DNS freshness. +//! +//! ### cortex-exec (runner.rs) +//! - `DEFAULT_TIMEOUT_SECS` (600s): Maximum duration for entire exec session. +//! - `DEFAULT_REQUEST_TIMEOUT_SECS` (120s): Single LLM request timeout. +//! +//! ### cortex-app-server (config.rs) +//! - `request_timeout` (300s): Full request lifecycle timeout. +//! - `read_timeout` (30s): Per-chunk timeout for streaming reads. +//! - `shutdown_timeout` (30s): Graceful shutdown duration. +//! +//! ### cortex-engine (api_client.rs) +//! - Re-exports constants from this module for consistency. +//! +//! ## Recommendations +//! +//! When adding new timeout configurations: +//! 1. Use constants from this module when possible for consistency. +//! 2. Document any new timeout constants with their rationale. +//! 3. Consider the timeout hierarchy - inner timeouts should be shorter than outer ones. +//! 4. For LLM operations, use longer timeouts (120s-300s) to accommodate model inference. +//! 5. For health checks and quick validations, use short timeouts (5s-10s). use reqwest::Client; use std::time::Duration; diff --git a/src/cortex-compact/src/compactor.rs b/src/cortex-compact/src/compactor.rs index f5cbeb1..00fe23d 100644 --- a/src/cortex-compact/src/compactor.rs +++ b/src/cortex-compact/src/compactor.rs @@ -106,7 +106,10 @@ impl Compactor { }]; new_items.extend(items.into_iter().skip(preserved_start)); - let tokens_after = current_tokens - tokens_in_compacted + summary_tokens; + // Use saturating arithmetic to prevent underflow if tokens_in_compacted > current_tokens + let tokens_after = current_tokens + .saturating_sub(tokens_in_compacted) + .saturating_add(summary_tokens); let result = CompactionResult::success(summary, current_tokens, tokens_after, items_removed); diff --git a/src/cortex-engine/src/async_utils.rs b/src/cortex-engine/src/async_utils.rs index f7b0490..ed63a6f 100644 --- a/src/cortex-engine/src/async_utils.rs +++ b/src/cortex-engine/src/async_utils.rs @@ -147,13 +147,17 @@ impl ConcurrencyLimiter { } /// Execute with limit. - pub async fn execute(&self, f: F) -> T + /// + /// Returns an error if the semaphore is closed. + pub async fn execute(&self, f: F) -> Result where F: FnOnce() -> Fut, Fut: Future, { - let _permit = self.semaphore.acquire().await.unwrap(); - f().await + let _permit = self.semaphore.acquire().await.map_err(|_| { + CortexError::Internal("concurrency limiter semaphore closed unexpectedly".into()) + })?; + Ok(f().await) } /// Get available permits. @@ -178,26 +182,36 @@ impl AsyncOnce { } /// Get or initialize. - pub async fn get_or_init(&self, init: F) -> T + /// + /// Returns an error if the internal state is inconsistent (value missing after init flag set). + pub async fn get_or_init(&self, init: F) -> Result where F: FnOnce() -> Fut, Fut: Future, { // Fast path if *self.initialized.read().await { - return self.value.read().await.clone().unwrap(); + return self.value.read().await.clone().ok_or_else(|| { + CortexError::Internal( + "AsyncOnce: value missing despite initialized flag being set".into(), + ) + }); } // Slow path let mut initialized = self.initialized.write().await; if *initialized { - return self.value.read().await.clone().unwrap(); + return self.value.read().await.clone().ok_or_else(|| { + CortexError::Internal( + "AsyncOnce: value missing despite initialized flag being set".into(), + ) + }); } let value = init().await; *self.value.write().await = Some(value.clone()); *initialized = true; - value + Ok(value) } /// Check if initialized. @@ -399,7 +413,12 @@ impl AsyncCache { } /// Run futures concurrently with limit. -pub async fn concurrent(items: impl IntoIterator, limit: usize) -> Vec +/// +/// Returns an error if the semaphore is closed unexpectedly. +pub async fn concurrent( + items: impl IntoIterator, + limit: usize, +) -> Result> where F: FnOnce() -> Fut, Fut: Future, @@ -410,12 +429,17 @@ where for item in items { let sem = semaphore.clone(); handles.push(async move { - let _permit = sem.acquire().await.unwrap(); - item().await + let _permit = sem.acquire().await.map_err(|_| { + CortexError::Internal("concurrent execution semaphore closed unexpectedly".into()) + })?; + Ok(item().await) }); } - futures::future::join_all(handles).await + futures::future::join_all(handles) + .await + .into_iter() + .collect() } /// Select the first future to complete. @@ -503,7 +527,10 @@ mod tests { })); } - futures::future::join_all(handles).await; + let results: Vec<_> = futures::future::join_all(handles).await; + for result in results { + assert!(result.is_ok()); + } assert_eq!(*counter.lock().await, 5); } @@ -511,8 +538,8 @@ mod tests { async fn test_async_once() { let once: AsyncOnce = AsyncOnce::new(); - let v1 = once.get_or_init(|| async { 42 }).await; - let v2 = once.get_or_init(|| async { 100 }).await; + let v1 = once.get_or_init(|| async { 42 }).await.unwrap(); + let v2 = once.get_or_init(|| async { 100 }).await.unwrap(); assert_eq!(v1, 42); assert_eq!(v2, 42); @@ -560,7 +587,7 @@ mod tests { Box::new(|| Box::pin(async { 2 })), Box::new(|| Box::pin(async { 3 })), ]; - let results = concurrent(items, 2).await; + let results = concurrent(items, 2).await.unwrap(); assert_eq!(results.len(), 3); } diff --git a/src/cortex-engine/src/config/config_discovery.rs b/src/cortex-engine/src/config/config_discovery.rs index 86e3c64..7e5b97c 100644 --- a/src/cortex-engine/src/config/config_discovery.rs +++ b/src/cortex-engine/src/config/config_discovery.rs @@ -4,20 +4,36 @@ //! with caching support for performance in monorepo environments. use std::collections::HashMap; +use std::hash::Hash; use std::path::{Path, PathBuf}; use std::sync::{LazyLock, RwLock}; use tracing::{debug, trace}; +/// Maximum number of entries in each cache to prevent unbounded memory growth. +const MAX_CACHE_SIZE: usize = 1000; + /// Cache for discovered config paths. /// Key is the start directory, value is the found config path (or None). static CONFIG_CACHE: LazyLock>>> = - LazyLock::new(|| RwLock::new(HashMap::new())); + LazyLock::new(|| RwLock::new(HashMap::with_capacity(MAX_CACHE_SIZE))); /// Cache for project roots. /// Key is the start directory, value is the project root path. static PROJECT_ROOT_CACHE: LazyLock>>> = - LazyLock::new(|| RwLock::new(HashMap::new())); + LazyLock::new(|| RwLock::new(HashMap::with_capacity(MAX_CACHE_SIZE))); + +/// Insert a key-value pair into the cache with eviction when full. +/// When the cache reaches MAX_CACHE_SIZE, removes an arbitrary entry before inserting. +fn insert_with_eviction(cache: &mut HashMap, key: K, value: V) { + if cache.len() >= MAX_CACHE_SIZE { + // Remove first entry (simple eviction strategy) + if let Some(k) = cache.keys().next().cloned() { + cache.remove(&k); + } + } + cache.insert(key, value); +} /// Markers that indicate a project root directory. const PROJECT_ROOT_MARKERS: &[&str] = &[ @@ -57,9 +73,9 @@ pub fn find_up(start_dir: &Path, filename: &str) -> Option { let result = find_up_uncached(start_dir, filename); - // Store in cache + // Store in cache with eviction when full if let Ok(mut cache) = CONFIG_CACHE.write() { - cache.insert(cache_key, result.clone()); + insert_with_eviction(&mut cache, cache_key, result.clone()); } result @@ -169,9 +185,9 @@ pub fn find_project_root(start_dir: &Path) -> Option { let result = find_project_root_uncached(start_dir); - // Store in cache + // Store in cache with eviction when full if let Ok(mut cache) = PROJECT_ROOT_CACHE.write() { - cache.insert(start_dir.to_path_buf(), result.clone()); + insert_with_eviction(&mut cache, start_dir.to_path_buf(), result.clone()); } result diff --git a/src/cortex-engine/src/git_info.rs b/src/cortex-engine/src/git_info.rs index 84d29c2..475b5ac 100644 --- a/src/cortex-engine/src/git_info.rs +++ b/src/cortex-engine/src/git_info.rs @@ -67,9 +67,17 @@ impl GitInfo { let status = git_command(&root, &["status", "--porcelain"]).unwrap_or_default(); let is_dirty = !status.is_empty(); - // Count changes - let changes = status.lines().filter(|l| !l.starts_with("??")).count() as u32; - let untracked = status.lines().filter(|l| l.starts_with("??")).count() as u32; + // Count changes (saturating to u32::MAX to prevent truncation) + let changes = status + .lines() + .filter(|l| !l.starts_with("??")) + .count() + .min(u32::MAX as usize) as u32; + let untracked = status + .lines() + .filter(|l| l.starts_with("??")) + .count() + .min(u32::MAX as usize) as u32; // Get tags let tags = git_command(&root, &["tag", "--points-at", "HEAD"]) @@ -385,7 +393,7 @@ impl GitStash { let message = parts.get(2).unwrap_or(&"").to_string(); stashes.push(GitStash { - index: i as u32, + index: i.min(u32::MAX as usize) as u32, message, branch, }); diff --git a/src/cortex-engine/src/ratelimit.rs b/src/cortex-engine/src/ratelimit.rs index 5423512..e15ea4a 100644 --- a/src/cortex-engine/src/ratelimit.rs +++ b/src/cortex-engine/src/ratelimit.rs @@ -341,9 +341,13 @@ impl ConcurrencyLimiter { } /// Acquire a permit. - pub async fn acquire(&self) -> ConcurrencyPermit { - let permit = self.semaphore.clone().acquire_owned().await.unwrap(); - ConcurrencyPermit { _permit: permit } + /// + /// Returns an error if the semaphore is closed. + pub async fn acquire(&self) -> Result { + let permit = self.semaphore.clone().acquire_owned().await.map_err(|_| { + CortexError::Internal("concurrency limiter semaphore closed unexpectedly".into()) + })?; + Ok(ConcurrencyPermit { _permit: permit }) } /// Try to acquire a permit. @@ -595,8 +599,8 @@ mod tests { async fn test_concurrency_limiter() { let limiter = ConcurrencyLimiter::new(2); - let _p1 = limiter.acquire().await; - let _p2 = limiter.acquire().await; + let _p1 = limiter.acquire().await.unwrap(); + let _p2 = limiter.acquire().await.unwrap(); // Third should fail immediately assert!(limiter.try_acquire().is_none()); diff --git a/src/cortex-engine/src/streaming.rs b/src/cortex-engine/src/streaming.rs index 35bfcef..ef7135a 100644 --- a/src/cortex-engine/src/streaming.rs +++ b/src/cortex-engine/src/streaming.rs @@ -15,6 +15,10 @@ use futures::Stream; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; +/// Maximum number of events to buffer before dropping old ones. +/// Prevents unbounded memory growth if drain_events() is not called regularly. +const MAX_BUFFER_SIZE: usize = 10_000; + /// Token usage for streaming. #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct StreamTokenUsage { @@ -26,12 +30,25 @@ pub struct StreamTokenUsage { pub total_tokens: u32, } +/// Safely convert an i64 token count to u32 with saturation. +/// Negative values clamp to 0, values > u32::MAX clamp to u32::MAX. +#[inline] +fn saturating_i64_to_u32(value: i64) -> u32 { + if value <= 0 { + 0 + } else if value > u32::MAX as i64 { + u32::MAX + } else { + value as u32 + } +} + impl From for StreamTokenUsage { fn from(usage: crate::client::TokenUsage) -> Self { Self { - prompt_tokens: usage.input_tokens as u32, - completion_tokens: usage.output_tokens as u32, - total_tokens: usage.total_tokens as u32, + prompt_tokens: saturating_i64_to_u32(usage.input_tokens), + completion_tokens: saturating_i64_to_u32(usage.output_tokens), + total_tokens: saturating_i64_to_u32(usage.total_tokens), } } } @@ -213,7 +230,7 @@ impl StreamProcessor { Self { state: StreamState::Idle, content: StreamContent::new(), - buffer: VecDeque::new(), + buffer: VecDeque::with_capacity(1024), // Pre-allocate reasonable capacity start_time: None, first_token_time: None, last_event_time: None, @@ -284,6 +301,10 @@ impl StreamProcessor { } } + // Enforce buffer size limit to prevent unbounded memory growth + if self.buffer.len() >= MAX_BUFFER_SIZE { + self.buffer.pop_front(); + } self.buffer.push_back(event); } diff --git a/src/cortex-engine/src/tokenizer.rs b/src/cortex-engine/src/tokenizer.rs index 793f5e2..b8aeadc 100644 --- a/src/cortex-engine/src/tokenizer.rs +++ b/src/cortex-engine/src/tokenizer.rs @@ -3,9 +3,25 @@ //! Provides token counting and text tokenization for various models. use std::collections::HashMap; +use std::hash::Hash; use serde::{Deserialize, Serialize}; +/// Maximum number of entries in the token cache to prevent unbounded memory growth. +const MAX_CACHE_SIZE: usize = 1000; + +/// Insert a key-value pair into the cache with eviction when full. +/// When the cache reaches MAX_CACHE_SIZE, removes an arbitrary entry before inserting. +fn insert_with_eviction(cache: &mut HashMap, key: K, value: V) { + if cache.len() >= MAX_CACHE_SIZE { + // Remove first entry (simple eviction strategy) + if let Some(k) = cache.keys().next().cloned() { + cache.remove(&k); + } + } + cache.insert(key, value); +} + /// Tokenizer type. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -58,7 +74,7 @@ impl TokenizerType { pub struct TokenCounter { /// Tokenizer type. tokenizer: TokenizerType, - /// Cache. + /// Cache with bounded size to prevent unbounded memory growth. cache: HashMap, } @@ -67,7 +83,7 @@ impl TokenCounter { pub fn new(tokenizer: TokenizerType) -> Self { Self { tokenizer, - cache: HashMap::new(), + cache: HashMap::with_capacity(MAX_CACHE_SIZE), } } @@ -85,7 +101,7 @@ impl TokenCounter { } let count = self.count_uncached(text); - self.cache.insert(hash, count); + insert_with_eviction(&mut self.cache, hash, count); count } diff --git a/src/cortex-engine/src/tools/handlers/file_ops.rs b/src/cortex-engine/src/tools/handlers/file_ops.rs index ef9f044..430cb24 100644 --- a/src/cortex-engine/src/tools/handlers/file_ops.rs +++ b/src/cortex-engine/src/tools/handlers/file_ops.rs @@ -204,7 +204,7 @@ impl Default for WriteFileHandler { #[async_trait] impl ToolHandler for WriteFileHandler { fn name(&self) -> &str { - "Write" + "Create" } async fn execute(&self, arguments: Value, context: &ToolContext) -> Result { @@ -445,7 +445,7 @@ impl Default for SearchFilesHandler { #[async_trait] impl ToolHandler for SearchFilesHandler { fn name(&self) -> &str { - "search_files" + "SearchFiles" } async fn execute(&self, arguments: Value, context: &ToolContext) -> Result { diff --git a/src/cortex-engine/src/tools/handlers/glob.rs b/src/cortex-engine/src/tools/handlers/glob.rs index f743126..f9c486d 100644 --- a/src/cortex-engine/src/tools/handlers/glob.rs +++ b/src/cortex-engine/src/tools/handlers/glob.rs @@ -17,6 +17,7 @@ pub struct GlobHandler; #[derive(Debug, Deserialize)] struct GlobArgs { patterns: Vec, + #[serde(alias = "folder")] directory: Option, #[serde(default)] exclude_patterns: Vec, diff --git a/src/cortex-engine/src/tools/handlers/grep.rs b/src/cortex-engine/src/tools/handlers/grep.rs index 26d2561..ecef2d9 100644 --- a/src/cortex-engine/src/tools/handlers/grep.rs +++ b/src/cortex-engine/src/tools/handlers/grep.rs @@ -29,6 +29,7 @@ struct GrepArgs { glob_pattern: Option, #[serde(default = "default_output_mode")] output_mode: String, + #[serde(alias = "head_limit")] max_results: Option, #[serde(default)] multiline: bool, diff --git a/src/cortex-engine/src/tools/mod.rs b/src/cortex-engine/src/tools/mod.rs index 9f5b6e1..f693894 100644 --- a/src/cortex-engine/src/tools/mod.rs +++ b/src/cortex-engine/src/tools/mod.rs @@ -30,6 +30,7 @@ pub mod artifacts; pub mod context; pub mod handlers; pub mod registry; +pub mod response_store; pub mod router; pub mod spec; pub mod unified_executor; @@ -45,6 +46,11 @@ pub use artifacts::{ pub use context::ToolContext; pub use handlers::*; pub use registry::{PluginTool, ToolRegistry}; +pub use response_store::{ + CLEANUP_INTERVAL, DEFAULT_TTL, MAX_STORE_SIZE, StoreInfo, StoreStats, StoredResponse, + ToolResponseStore, ToolResponseStoreConfig, create_shared_store, + create_shared_store_with_config, +}; pub use router::ToolRouter; pub use spec::{ToolCall, ToolDefinition, ToolHandler, ToolResult}; pub use unified_executor::{ExecutorConfig, UnifiedToolExecutor}; diff --git a/src/cortex-engine/src/tools/response_store.rs b/src/cortex-engine/src/tools/response_store.rs new file mode 100644 index 0000000..9220c86 --- /dev/null +++ b/src/cortex-engine/src/tools/response_store.rs @@ -0,0 +1,537 @@ +//! Tool response storage with bounded capacity and automatic cleanup. +//! +//! This module provides a bounded storage for tool execution results that: +//! - Limits maximum number of stored responses to prevent unbounded memory growth +//! - Removes entries when they are consumed (read and take) +//! - Periodically cleans up stale entries based on TTL +//! +//! Fixes #5292 (unbounded growth) and #5293 (missing removal on read). + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::RwLock; +use tracing::debug; + +use crate::tools::spec::ToolResult; + +/// Maximum number of responses to store before eviction. +/// This prevents unbounded memory growth from accumulated tool responses. +pub const MAX_STORE_SIZE: usize = 500; + +/// Default time-to-live for stored responses (5 minutes). +pub const DEFAULT_TTL: Duration = Duration::from_secs(300); + +/// Interval for periodic cleanup of stale entries (1 minute). +pub const CLEANUP_INTERVAL: Duration = Duration::from_secs(60); + +/// A stored tool response with metadata. +#[derive(Debug, Clone)] +pub struct StoredResponse { + /// The tool execution result. + pub result: ToolResult, + /// Tool name that produced this result. + pub tool_name: String, + /// When the response was stored. + pub stored_at: Instant, + /// Whether this response has been read (but not yet consumed). + pub read: bool, +} + +impl StoredResponse { + /// Create a new stored response. + pub fn new(tool_name: impl Into, result: ToolResult) -> Self { + Self { + result, + tool_name: tool_name.into(), + stored_at: Instant::now(), + read: false, + } + } + + /// Check if the response has expired. + pub fn is_expired(&self, ttl: Duration) -> bool { + self.stored_at.elapsed() > ttl + } + + /// Get the age of this response. + pub fn age(&self) -> Duration { + self.stored_at.elapsed() + } +} + +/// Configuration for the tool response store. +#[derive(Debug, Clone)] +pub struct ToolResponseStoreConfig { + /// Maximum number of responses to store. + pub max_size: usize, + /// Time-to-live for stored responses. + pub ttl: Duration, + /// Whether to remove entries on read (peek vs consume). + pub remove_on_read: bool, +} + +impl Default for ToolResponseStoreConfig { + fn default() -> Self { + Self { + max_size: MAX_STORE_SIZE, + ttl: DEFAULT_TTL, + remove_on_read: true, + } + } +} + +impl ToolResponseStoreConfig { + /// Create a config with custom max size. + pub fn with_max_size(mut self, max_size: usize) -> Self { + self.max_size = max_size; + self + } + + /// Create a config with custom TTL. + pub fn with_ttl(mut self, ttl: Duration) -> Self { + self.ttl = ttl; + self + } + + /// Set whether to remove entries on read. + pub fn with_remove_on_read(mut self, remove: bool) -> Self { + self.remove_on_read = remove; + self + } +} + +/// Bounded storage for tool execution responses. +/// +/// This store prevents unbounded memory growth by: +/// 1. Enforcing a maximum number of stored responses +/// 2. Removing entries when they are consumed +/// 3. Periodically cleaning up stale entries +/// +/// # Thread Safety +/// +/// The store uses `RwLock` for interior mutability and is safe to share +/// across threads via `Arc`. +#[derive(Debug)] +pub struct ToolResponseStore { + /// Stored responses keyed by tool call ID. + responses: RwLock>, + /// Configuration. + config: ToolResponseStoreConfig, + /// Last cleanup time. + last_cleanup: RwLock, + /// Statistics. + stats: RwLock, +} + +impl ToolResponseStore { + /// Create a new tool response store with default configuration. + pub fn new() -> Self { + Self::with_config(ToolResponseStoreConfig::default()) + } + + /// Create a tool response store with custom configuration. + pub fn with_config(config: ToolResponseStoreConfig) -> Self { + Self { + responses: RwLock::new(HashMap::new()), + config, + last_cleanup: RwLock::new(Instant::now()), + stats: RwLock::new(StoreStats::default()), + } + } + + /// Store a tool response. + /// + /// If the store is at capacity, the oldest entry will be evicted. + /// Returns `true` if an entry was evicted to make room. + pub async fn store( + &self, + call_id: impl Into, + tool_name: impl Into, + result: ToolResult, + ) -> bool { + let call_id = call_id.into(); + let tool_name = tool_name.into(); + let mut evicted = false; + + // Perform periodic cleanup if needed + self.maybe_cleanup().await; + + let mut responses = self.responses.write().await; + + // Evict oldest entry if at capacity + if responses.len() >= self.config.max_size { + if let Some(oldest_key) = self.find_oldest_key(&responses) { + responses.remove(&oldest_key); + evicted = true; + debug!( + evicted_key = %oldest_key, + "Evicted oldest response to make room" + ); + } + } + + let response = StoredResponse::new(tool_name, result); + responses.insert(call_id.clone(), response); + + // Update stats + let mut stats = self.stats.write().await; + stats.total_stored += 1; + if evicted { + stats.evictions += 1; + } + + evicted + } + + /// Get a response without removing it (peek). + /// + /// Marks the response as read but does not consume it. + pub async fn get(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.get_mut(call_id) { + response.read = true; + let mut stats = self.stats.write().await; + stats.reads += 1; + Some(response.result.clone()) + } else { + None + } + } + + /// Take (consume) a response, removing it from the store. + /// + /// This is the primary method for retrieving responses as it ensures + /// entries are cleaned up after being consumed (#5293). + pub async fn take(&self, call_id: &str) -> Option { + let mut responses = self.responses.write().await; + + if let Some(response) = responses.remove(call_id) { + let mut stats = self.stats.write().await; + stats.takes += 1; + Some(response.result) + } else { + None + } + } + + /// Check if a response exists for the given call ID. + pub async fn contains(&self, call_id: &str) -> bool { + self.responses.read().await.contains_key(call_id) + } + + /// Get the current number of stored responses. + pub async fn len(&self) -> usize { + self.responses.read().await.len() + } + + /// Check if the store is empty. + pub async fn is_empty(&self) -> bool { + self.responses.read().await.is_empty() + } + + /// Remove all expired entries. + /// + /// Returns the number of entries removed. + pub async fn cleanup_expired(&self) -> usize { + let mut responses = self.responses.write().await; + let ttl = self.config.ttl; + let before = responses.len(); + + responses.retain(|_, v| !v.is_expired(ttl)); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up expired responses"); + let mut stats = self.stats.write().await; + stats.expired_cleanups += removed as u64; + } + + removed + } + + /// Remove all read entries that haven't been consumed. + /// + /// This is useful for cleaning up entries that were peeked but never taken. + pub async fn cleanup_read(&self) -> usize { + let mut responses = self.responses.write().await; + let before = responses.len(); + + responses.retain(|_, v| !v.read); + + let removed = before - responses.len(); + if removed > 0 { + debug!(removed, "Cleaned up read-but-not-consumed responses"); + } + + removed + } + + /// Clear all stored responses. + pub async fn clear(&self) { + self.responses.write().await.clear(); + } + + /// Get store statistics. + pub async fn stats(&self) -> StoreStats { + self.stats.read().await.clone() + } + + /// Get detailed store info including current size and config. + pub async fn info(&self) -> StoreInfo { + let responses = self.responses.read().await; + let stats = self.stats.read().await; + + StoreInfo { + current_size: responses.len(), + max_size: self.config.max_size, + ttl_secs: self.config.ttl.as_secs(), + oldest_age_secs: responses + .values() + .map(|r| r.age().as_secs()) + .max() + .unwrap_or(0), + stats: stats.clone(), + } + } + + // Internal helpers + + /// Find the key of the oldest entry. + fn find_oldest_key(&self, responses: &HashMap) -> Option { + responses + .iter() + .min_by_key(|(_, v)| v.stored_at) + .map(|(k, _)| k.clone()) + } + + /// Perform cleanup if enough time has passed since last cleanup. + async fn maybe_cleanup(&self) { + let should_cleanup = { + let last = self.last_cleanup.read().await; + last.elapsed() > CLEANUP_INTERVAL + }; + + if should_cleanup { + *self.last_cleanup.write().await = Instant::now(); + let removed = self.cleanup_expired().await; + if removed > 0 { + debug!(removed, "Periodic cleanup removed expired entries"); + } + } + } +} + +impl Default for ToolResponseStore { + fn default() -> Self { + Self::new() + } +} + +/// Statistics for the tool response store. +#[derive(Debug, Clone, Default)] +pub struct StoreStats { + /// Total responses stored. + pub total_stored: u64, + /// Number of get (peek) operations. + pub reads: u64, + /// Number of take (consume) operations. + pub takes: u64, + /// Number of evictions due to capacity limit. + pub evictions: u64, + /// Number of entries removed by TTL cleanup. + pub expired_cleanups: u64, +} + +/// Detailed store information. +#[derive(Debug, Clone)] +pub struct StoreInfo { + /// Current number of stored responses. + pub current_size: usize, + /// Maximum allowed size. + pub max_size: usize, + /// TTL in seconds. + pub ttl_secs: u64, + /// Age of oldest entry in seconds. + pub oldest_age_secs: u64, + /// Store statistics. + pub stats: StoreStats, +} + +/// Create a shared tool response store. +pub fn create_shared_store() -> Arc { + Arc::new(ToolResponseStore::new()) +} + +/// Create a shared tool response store with custom configuration. +pub fn create_shared_store_with_config(config: ToolResponseStoreConfig) -> Arc { + Arc::new(ToolResponseStore::with_config(config)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_store_and_take() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result.clone()).await; + + assert!(store.contains("call-1").await); + assert_eq!(store.len().await, 1); + + let taken = store.take("call-1").await; + assert!(taken.is_some()); + assert_eq!(taken.unwrap().output, "test output"); + + // After take, entry should be gone + assert!(!store.contains("call-1").await); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_store_and_get() { + let store = ToolResponseStore::new(); + + let result = ToolResult::success("test output"); + store.store("call-1", "Read", result).await; + + // Get should return result but not remove it + let got = store.get("call-1").await; + assert!(got.is_some()); + assert!(store.contains("call-1").await); + + // Second get should still work + let got2 = store.get("call-1").await; + assert!(got2.is_some()); + } + + #[tokio::test] + async fn test_capacity_eviction() { + let config = ToolResponseStoreConfig::default().with_max_size(3); + let store = ToolResponseStore::with_config(config); + + // Fill to capacity + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + store + .store("call-3", "Read", ToolResult::success("3")) + .await; + + assert_eq!(store.len().await, 3); + + // Add one more, should evict oldest + let evicted = store + .store("call-4", "Read", ToolResult::success("4")) + .await; + assert!(evicted); + assert_eq!(store.len().await, 3); + + // call-1 should be evicted (oldest) + assert!(!store.contains("call-1").await); + assert!(store.contains("call-4").await); + } + + #[tokio::test] + async fn test_expired_cleanup() { + let config = ToolResponseStoreConfig::default().with_ttl(Duration::from_millis(50)); + let store = ToolResponseStore::with_config(config); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + assert_eq!(store.len().await, 1); + + // Wait for expiration + tokio::time::sleep(Duration::from_millis(100)).await; + + let removed = store.cleanup_expired().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_cleanup_read() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + + // Read one entry + store.get("call-1").await; + + // Cleanup read entries + let removed = store.cleanup_read().await; + assert_eq!(removed, 1); + assert_eq!(store.len().await, 1); + assert!(!store.contains("call-1").await); + assert!(store.contains("call-2").await); + } + + #[tokio::test] + async fn test_stats() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store.get("call-1").await; + store.take("call-1").await; + + let stats = store.stats().await; + assert_eq!(stats.total_stored, 1); + assert_eq!(stats.reads, 1); + assert_eq!(stats.takes, 1); + } + + #[tokio::test] + async fn test_nonexistent_key() { + let store = ToolResponseStore::new(); + + assert!(store.get("nonexistent").await.is_none()); + assert!(store.take("nonexistent").await.is_none()); + assert!(!store.contains("nonexistent").await); + } + + #[tokio::test] + async fn test_clear() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + store + .store("call-2", "Read", ToolResult::success("2")) + .await; + + assert_eq!(store.len().await, 2); + + store.clear().await; + assert_eq!(store.len().await, 0); + } + + #[tokio::test] + async fn test_info() { + let store = ToolResponseStore::new(); + + store + .store("call-1", "Read", ToolResult::success("1")) + .await; + + let info = store.info().await; + assert_eq!(info.current_size, 1); + assert_eq!(info.max_size, MAX_STORE_SIZE); + } +} diff --git a/src/cortex-engine/src/validation.rs b/src/cortex-engine/src/validation.rs index a8afbec..5ff4d9e 100644 --- a/src/cortex-engine/src/validation.rs +++ b/src/cortex-engine/src/validation.rs @@ -269,6 +269,33 @@ pub struct CommandValidator { pub allow_shell_operators: bool, } +/// Normalize a command string for consistent validation. +/// +/// This function handles bypass attempts such as: +/// - Extra whitespace: "rm -rf" → "rm -rf" +/// - Quoted parts: "'rm' -rf" → "rm -rf" +/// - Path variants: "/bin/rm -rf" → "rm -rf" +fn normalize_command(cmd: &str) -> String { + cmd.split_whitespace() + .enumerate() + .map(|(idx, part)| { + // Remove surrounding quotes (single and double) + let unquoted = part.trim_matches(|c| c == '\'' || c == '"'); + + // For the first part (command), extract basename to handle path variants + if idx == 0 { + Path::new(unquoted) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or(unquoted) + } else { + unquoted + } + }) + .collect::>() + .join(" ") +} + impl CommandValidator { /// Create a new validator. pub fn new() -> Self { @@ -332,9 +359,12 @@ impl CommandValidator { )); } - // Check allowed list + // Normalize the command for consistent validation + let normalized = normalize_command(command); + + // Check allowed list using normalized command if let Some(ref allowed) = self.allowed { - let cmd = command.split_whitespace().next().unwrap_or(""); + let cmd = normalized.split_whitespace().next().unwrap_or(""); if !allowed.contains(cmd) { result.add_error(ValidationError::new( "command", @@ -343,9 +373,10 @@ impl CommandValidator { } } - // Check blocked commands + // Check blocked commands against normalized form for blocked in &self.blocked { - if command.contains(blocked) { + let normalized_blocked = normalize_command(blocked); + if normalized.contains(&normalized_blocked) { result.add_error(ValidationError::new( "command", "Command contains blocked pattern", @@ -354,9 +385,9 @@ impl CommandValidator { } } - // Check blocked patterns + // Check blocked patterns against both original and normalized for pattern in &self.blocked_patterns { - if command.contains(pattern) { + if command.contains(pattern) || normalized.contains(pattern) { result.add_error(ValidationError::new( "command", "Command contains dangerous pattern", @@ -700,6 +731,97 @@ mod tests { assert!(result.valid); } + #[test] + fn test_command_validation_whitespace_bypass() { + let validator = CommandValidator::new(); + + // Extra whitespace should not bypass validation + let result = validator.validate("rm -rf /"); + assert!( + !result.valid, + "Extra whitespace should not bypass blocked command" + ); + + let result = validator.validate("rm -rf /"); + assert!( + !result.valid, + "Multiple spaces should not bypass blocked command" + ); + } + + #[test] + fn test_command_validation_quote_bypass() { + let validator = CommandValidator::new(); + + // Quoted commands should not bypass validation + let result = validator.validate("'rm' -rf /"); + assert!( + !result.valid, + "Single quotes should not bypass blocked command" + ); + + let result = validator.validate("\"rm\" -rf /"); + assert!( + !result.valid, + "Double quotes should not bypass blocked command" + ); + + let result = validator.validate("'rm' '-rf' '/'"); + assert!( + !result.valid, + "Fully quoted command should not bypass blocked command" + ); + } + + #[test] + fn test_command_validation_path_bypass() { + let validator = CommandValidator::new(); + + // Path variants should not bypass validation + let result = validator.validate("/bin/rm -rf /"); + assert!( + !result.valid, + "Absolute path should not bypass blocked command" + ); + + let result = validator.validate("/usr/bin/rm -rf /"); + assert!(!result.valid, "Full path should not bypass blocked command"); + + let result = validator.validate("./rm -rf /"); + assert!( + !result.valid, + "Relative path should not bypass blocked command" + ); + } + + #[test] + fn test_command_validation_combined_bypass() { + let validator = CommandValidator::new(); + + // Combined bypass attempts + let result = validator.validate("'/bin/rm' -rf /"); + assert!( + !result.valid, + "Combined path and whitespace should not bypass" + ); + + let result = validator.validate("\"/usr/bin/rm\" '-rf' '/'"); + assert!( + !result.valid, + "Combined quotes, path, and whitespace should not bypass" + ); + } + + #[test] + fn test_normalize_command() { + // Test the normalize function directly + assert_eq!(normalize_command("rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("'rm' -rf /"), "rm -rf /"); + assert_eq!(normalize_command("/bin/rm -rf /"), "rm -rf /"); + assert_eq!(normalize_command("'/usr/bin/rm' '-rf' '/'"), "rm -rf /"); + } + #[test] fn test_url_validation() { let validator = UrlValidator::new(); diff --git a/src/cortex-exec/src/runner.rs b/src/cortex-exec/src/runner.rs index e831324..fdd33f9 100644 --- a/src/cortex-exec/src/runner.rs +++ b/src/cortex-exec/src/runner.rs @@ -27,11 +27,24 @@ use cortex_protocol::ConversationId; use crate::output::{OutputFormat, OutputWriter}; /// Default timeout for the entire execution (10 minutes). +/// +/// This is the maximum duration for a multi-turn exec session. +/// See `cortex_common::http_client` module documentation for the complete +/// timeout hierarchy across Cortex services. const DEFAULT_TIMEOUT_SECS: u64 = 600; /// Default timeout for a single LLM request (2 minutes). +/// +/// Allows sufficient time for model inference while preventing indefinite hangs. +/// See `cortex_common::http_client` module documentation for the complete +/// timeout hierarchy across Cortex services. const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 120; +/// Per-chunk timeout during streaming responses. +/// Prevents indefinite hangs when connections stall mid-stream. +/// See cortex_common::http_client for timeout hierarchy documentation. +const STREAMING_CHUNK_TIMEOUT_SECS: u64 = 30; + /// Maximum retries for transient errors. const MAX_RETRIES: usize = 3; @@ -187,7 +200,10 @@ impl ExecRunner { self.client = Some(client); } - Ok(self.client.as_ref().unwrap().as_ref()) + self.client + .as_ref() + .map(|c| c.as_ref()) + .ok_or_else(|| CortexError::Internal("LLM client not initialized".to_string())) } /// Get filtered tool definitions based on options. @@ -555,7 +571,28 @@ impl ExecRunner { let mut partial_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); - while let Some(event) = stream.next().await { + loop { + // Apply per-chunk timeout to prevent indefinite hangs when connections stall + let event = match tokio::time::timeout( + Duration::from_secs(STREAMING_CHUNK_TIMEOUT_SECS), + stream.next(), + ) + .await + { + Ok(Some(event)) => event, + Ok(None) => break, // Stream ended normally + Err(_) => { + tracing::warn!( + "Stream chunk timeout after {}s", + STREAMING_CHUNK_TIMEOUT_SECS + ); + return Err(CortexError::Provider(format!( + "Streaming timeout: no response chunk received within {}s", + STREAMING_CHUNK_TIMEOUT_SECS + ))); + } + }; + match event? { ResponseEvent::Delta(delta) => { if self.options.streaming { diff --git a/src/cortex-mcp-client/src/transport.rs b/src/cortex-mcp-client/src/transport.rs index 22152cf..0ee141d 100644 --- a/src/cortex-mcp-client/src/transport.rs +++ b/src/cortex-mcp-client/src/transport.rs @@ -20,8 +20,7 @@ use cortex_mcp_types::{ use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::process::{Child, Command}; use tokio::sync::{Mutex, RwLock}; -use tokio::time::sleep; -use tracing::{debug, error, info, warn}; +use tracing::{debug, info, warn}; // ============================================================================ // Transport Trait @@ -199,61 +198,6 @@ impl StdioTransport { Ok(()) } - /// Reconnect with exponential backoff. - /// - /// Properly cleans up existing connections before each attempt to prevent - /// file descriptor leaks (#2198). - #[allow(dead_code)] - async fn reconnect(&self) -> Result<()> { - if !self.reconnect_config.enabled { - return Err(anyhow!("Reconnection disabled")); - } - - let mut attempt = 0; - let mut delay = self.reconnect_config.initial_delay; - - while attempt < self.reconnect_config.max_attempts { - attempt += 1; - info!( - attempt, - max = self.reconnect_config.max_attempts, - "Attempting reconnection" - ); - - // Clean up any existing connection before attempting reconnect - // This prevents file descriptor leaks on repeated failures (#2198) - { - let mut process_guard = self.process.lock().await; - if let Some(mut child) = process_guard.take() { - // Kill the process and wait for it to clean up - let _ = child.kill().await; - // Wait a short time for resources to be released - drop(child); - } - self.connected.store(false, Ordering::SeqCst); - } - - // Clear any stale pending responses - self.pending_responses.write().await.clear(); - - match self.connect().await { - Ok(()) => { - info!("Reconnection successful"); - return Ok(()); - } - Err(e) => { - error!(error = %e, attempt, "Reconnection failed"); - if attempt < self.reconnect_config.max_attempts { - sleep(delay).await; - delay = (delay * 2).min(self.reconnect_config.max_delay); - } - } - } - } - - Err(anyhow!("Failed to reconnect after {} attempts", attempt)) - } - /// Send a request and wait for response. async fn send_request(&self, request: JsonRpcRequest) -> Result { // Ensure connected @@ -516,51 +460,6 @@ impl HttpTransport { fn next_request_id(&self) -> RequestId { RequestId::Number(self.request_id.fetch_add(1, Ordering::SeqCst) as i64) } - - /// Test connection. - #[allow(dead_code)] - async fn test_connection(&self) -> Result<()> { - let request = JsonRpcRequest::new(self.next_request_id(), methods::PING); - self.send_request(request).await?; - Ok(()) - } - - /// Reconnect with exponential backoff. - #[allow(dead_code)] - async fn reconnect(&self) -> Result<()> { - if !self.reconnect_config.enabled { - return Err(anyhow!("Reconnection disabled")); - } - - let mut attempt = 0; - let mut delay = self.reconnect_config.initial_delay; - - while attempt < self.reconnect_config.max_attempts { - attempt += 1; - info!( - attempt, - max = self.reconnect_config.max_attempts, - "Attempting HTTP reconnection" - ); - - match self.test_connection().await { - Ok(()) => { - info!("HTTP reconnection successful"); - self.connected.store(true, Ordering::SeqCst); - return Ok(()); - } - Err(e) => { - error!(error = %e, attempt, "HTTP reconnection failed"); - if attempt < self.reconnect_config.max_attempts { - sleep(delay).await; - delay = (delay * 2).min(self.reconnect_config.max_delay); - } - } - } - } - - Err(anyhow!("Failed to reconnect after {} attempts", attempt)) - } } #[async_trait] diff --git a/src/cortex-mcp-server/src/server.rs b/src/cortex-mcp-server/src/server.rs index 96fb8d8..266b9e4 100644 --- a/src/cortex-mcp-server/src/server.rs +++ b/src/cortex-mcp-server/src/server.rs @@ -222,14 +222,17 @@ impl McpServer { } async fn handle_initialize(&self, params: Option) -> Result { - // Check state - let current_state = *self.state.read().await; - if current_state != ServerState::Uninitialized { - return Err(JsonRpcError::invalid_request("Server already initialized")); + // Atomic check-and-transition: hold write lock during entire state check and modification + // to prevent TOCTOU race conditions where multiple concurrent initialize requests + // could both pass the uninitialized check before either sets the state + { + let mut state_guard = self.state.write().await; + if *state_guard != ServerState::Uninitialized { + return Err(JsonRpcError::invalid_request("Server already initialized")); + } + *state_guard = ServerState::Initializing; } - *self.state.write().await = ServerState::Initializing; - // Parse params let init_params: InitializeParams = params .map(serde_json::from_value) diff --git a/src/cortex-plugins/src/registry.rs b/src/cortex-plugins/src/registry.rs index 79f961e..f9bb235 100644 --- a/src/cortex-plugins/src/registry.rs +++ b/src/cortex-plugins/src/registry.rs @@ -674,18 +674,21 @@ impl PluginRegistry { let info = plugin.info().clone(); let id = info.id.clone(); - { - let plugins = self.plugins.read().await; - if plugins.contains_key(&id) { - return Err(PluginError::AlreadyExists(id)); - } - } - + // Use entry API to atomically check-and-insert within a single write lock + // to prevent TOCTOU race conditions where multiple concurrent registrations + // could both pass the contains_key check before either inserts let handle = PluginHandle::new(plugin); - { let mut plugins = self.plugins.write().await; - plugins.insert(id.clone(), handle); + use std::collections::hash_map::Entry; + match plugins.entry(id.clone()) { + Entry::Occupied(_) => { + return Err(PluginError::AlreadyExists(id)); + } + Entry::Vacant(entry) => { + entry.insert(handle); + } + } } { diff --git a/src/cortex-protocol/Cargo.toml b/src/cortex-protocol/Cargo.toml index f162dff..489e953 100644 --- a/src/cortex-protocol/Cargo.toml +++ b/src/cortex-protocol/Cargo.toml @@ -20,6 +20,7 @@ uuid = { workspace = true, features = ["serde", "v4"] } chrono = { workspace = true } strum_macros = "0.27" base64 = { workspace = true } +tracing = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/src/cortex-protocol/src/protocol/message_parts.rs b/src/cortex-protocol/src/protocol/message_parts.rs index 67d238f..e929676 100644 --- a/src/cortex-protocol/src/protocol/message_parts.rs +++ b/src/cortex-protocol/src/protocol/message_parts.rs @@ -182,6 +182,44 @@ pub enum ToolState { }, } +impl ToolState { + /// Check if transitioning to the given state is valid. + /// + /// Valid transitions: + /// - Pending -> Running, Completed, Error + /// - Running -> Completed, Error + /// - Completed -> (terminal, no transitions) + /// - Error -> (terminal, no transitions) + /// + /// State machine: + /// ```text + /// Pending -> Running -> Completed + /// | | + /// | +-> Error + /// +-> Completed + /// +-> Error + /// ``` + pub fn can_transition_to(&self, target: &ToolState) -> bool { + match (self, target) { + // From Pending, can go to any non-Pending state + (ToolState::Pending { .. }, ToolState::Running { .. }) => true, + (ToolState::Pending { .. }, ToolState::Completed { .. }) => true, + (ToolState::Pending { .. }, ToolState::Error { .. }) => true, + + // From Running, can go to Completed or Error + (ToolState::Running { .. }, ToolState::Completed { .. }) => true, + (ToolState::Running { .. }, ToolState::Error { .. }) => true, + + // Terminal states cannot transition + (ToolState::Completed { .. }, _) => false, + (ToolState::Error { .. }, _) => false, + + // Any other transition is invalid + _ => false, + } + } +} + /// Subtask execution status. #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -552,6 +590,8 @@ impl MessageWithParts { } /// Update a tool state by call ID. + /// + /// Logs a warning if the state transition is invalid (e.g., from a terminal state). pub fn update_tool_state(&mut self, call_id: &str, new_state: ToolState) -> bool { for part in &mut self.parts { if let MessagePart::Tool { @@ -561,6 +601,14 @@ impl MessageWithParts { } = &mut part.part { if cid == call_id { + if !state.can_transition_to(&new_state) { + tracing::warn!( + "Invalid ToolState transition from {:?} to {:?} for call_id {}", + state, + new_state, + call_id + ); + } *state = new_state; return true; } diff --git a/src/cortex-resume/src/resume_picker.rs b/src/cortex-resume/src/resume_picker.rs index 9bf8832..7b0ee9a 100644 --- a/src/cortex-resume/src/resume_picker.rs +++ b/src/cortex-resume/src/resume_picker.rs @@ -153,12 +153,15 @@ fn format_relative_time(time: &chrono::DateTime) -> String { } } -/// Truncate string to fit width. +/// Truncate string to fit width, handling multi-byte UTF-8 safely. fn truncate_string(s: &str, width: usize) -> String { - if s.len() <= width { + // Count actual character width, not byte length + let char_count = s.chars().count(); + if char_count <= width { s.to_string() } else if width > 3 { - format!("{}...", &s[..width - 3]) + let truncated: String = s.chars().take(width - 3).collect(); + format!("{}...", truncated) } else { s.chars().take(width).collect() } @@ -176,4 +179,40 @@ mod tests { let hour_ago = now - chrono::Duration::hours(2); assert_eq!(format_relative_time(&hour_ago), "2h ago"); } + + #[test] + fn test_truncate_string_ascii() { + // Short string, no truncation needed + assert_eq!(truncate_string("hello", 10), "hello"); + + // Exact fit + assert_eq!(truncate_string("hello", 5), "hello"); + + // Needs truncation + assert_eq!(truncate_string("hello world", 8), "hello..."); + + // Very short width + assert_eq!(truncate_string("hello", 3), "hel"); + assert_eq!(truncate_string("hello", 2), "he"); + } + + #[test] + fn test_truncate_string_utf8() { + // UTF-8 multi-byte characters (Japanese) + let japanese = "こんにちは世界"; // 7 chars + assert_eq!(truncate_string(japanese, 10), japanese); // No truncation + assert_eq!(truncate_string(japanese, 7), japanese); // Exact fit + assert_eq!(truncate_string(japanese, 6), "こんに..."); // Truncated (3 chars + ...) + + // UTF-8 with emoji + let emoji = "Hello 🌍🌎🌏"; // 9 chars: H,e,l,l,o, ,🌍,🌎,🌏 + assert_eq!(truncate_string(emoji, 20), emoji); // No truncation + assert_eq!(truncate_string(emoji, 9), emoji); // Exact fit (9 chars) + assert_eq!(truncate_string(emoji, 8), "Hello..."); // Truncated (5 chars + ...) + + // Mixed UTF-8 and ASCII + let mixed = "路径/path/文件"; // 11 chars + assert_eq!(truncate_string(mixed, 20), mixed); // No truncation + assert_eq!(truncate_string(mixed, 8), "路径/pa..."); // Truncated + } } diff --git a/src/cortex-resume/src/session_store.rs b/src/cortex-resume/src/session_store.rs index 48ed5b9..04a1f9f 100644 --- a/src/cortex-resume/src/session_store.rs +++ b/src/cortex-resume/src/session_store.rs @@ -13,15 +13,34 @@ use tokio::fs; use tokio::sync::{Mutex as AsyncMutex, RwLock}; use tracing::{debug, info}; +/// Maximum number of lock entries before triggering cleanup. +const MAX_LOCK_ENTRIES: usize = 10_000; + /// Global file lock manager for session store operations. /// Prevents concurrent modifications to the same file within the process. static FILE_LOCKS: once_cell::sync::Lazy>>>> = once_cell::sync::Lazy::new(|| std::sync::Mutex::new(HashMap::new())); +/// Remove lock entries that are no longer in use. +/// +/// An entry is considered stale when only the HashMap holds a reference +/// to it (strong_count == 1), meaning no caller is currently using the lock. +fn cleanup_stale_file_locks(locks: &mut HashMap>>) { + locks.retain(|_, arc| Arc::strong_count(arc) > 1); +} + /// Acquire an async lock for a specific file path. +/// +/// Automatically cleans up stale lock entries when the map grows too large. fn get_file_lock(path: &Path) -> Arc> { let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); let mut locks = FILE_LOCKS.lock().unwrap(); + + // Clean up stale entries if the map is getting large + if locks.len() >= MAX_LOCK_ENTRIES { + cleanup_stale_file_locks(&mut locks); + } + locks .entry(canonical) .or_insert_with(|| Arc::new(AsyncMutex::new(()))) diff --git a/src/cortex-shell-snapshot/src/snapshot.rs b/src/cortex-shell-snapshot/src/snapshot.rs index 904ed8a..ab38912 100644 --- a/src/cortex-shell-snapshot/src/snapshot.rs +++ b/src/cortex-shell-snapshot/src/snapshot.rs @@ -115,12 +115,13 @@ impl ShellSnapshot { } /// Generate a restore script that sources this snapshot. + /// + /// The path is properly escaped to prevent shell injection attacks. + /// Paths containing single quotes are escaped using shell-safe quoting. pub fn restore_script(&self) -> String { let header = scripts::restore_header(self.metadata.shell_type); - format!( - "{header}\n# Source snapshot\nsource '{}'\n", - self.path.display() - ) + let escaped_path = shell_escape_path(&self.path); + format!("{header}\n# Source snapshot\nsource {escaped_path}\n") } /// Save the snapshot to disk. @@ -197,6 +198,27 @@ impl Drop for ShellSnapshot { } } +/// Escape a path for safe use in shell commands. +/// +/// This function handles paths containing single quotes by using the +/// shell-safe escaping technique: 'path'"'"'with'"'"'quotes' +/// +/// For paths without single quotes, simple single-quoting is used. +fn shell_escape_path(path: &Path) -> String { + let path_str = path.display().to_string(); + + if !path_str.contains('\'') { + // Simple case: no single quotes, just wrap in single quotes + format!("'{}'", path_str) + } else { + // Complex case: escape single quotes using '"'"' technique + // This closes the single-quoted string, adds a double-quoted single quote, + // and reopens the single-quoted string + let escaped = path_str.replace('\'', "'\"'\"'"); + format!("'{}'", escaped) + } +} + #[cfg(test)] mod tests { use super::*; @@ -221,4 +243,26 @@ mod tests { "snapshot_12345678-1234-1234-1234-123456789012.zsh" ); } + + #[test] + fn test_shell_escape_path_simple() { + let path = Path::new("/tmp/test/snapshot.sh"); + let escaped = shell_escape_path(path); + assert_eq!(escaped, "'/tmp/test/snapshot.sh'"); + } + + #[test] + fn test_shell_escape_path_with_single_quotes() { + let path = Path::new("/tmp/test's/snap'shot.sh"); + let escaped = shell_escape_path(path); + // Single quotes should be escaped using '"'"' technique + assert_eq!(escaped, "'/tmp/test'\"'\"'s/snap'\"'\"'shot.sh'"); + } + + #[test] + fn test_shell_escape_path_spaces() { + let path = Path::new("/tmp/test path/snapshot.sh"); + let escaped = shell_escape_path(path); + assert_eq!(escaped, "'/tmp/test path/snapshot.sh'"); + } } diff --git a/src/cortex-storage/src/sessions/storage.rs b/src/cortex-storage/src/sessions/storage.rs index dd750b0..7f3d38c 100644 --- a/src/cortex-storage/src/sessions/storage.rs +++ b/src/cortex-storage/src/sessions/storage.rs @@ -124,20 +124,67 @@ impl SessionStorage { } /// Save a session to disk. + /// + /// This function ensures data durability by calling sync_all() (fsync) + /// after writing to prevent data loss on crash or forceful termination. pub async fn save_session(&self, session: &StoredSession) -> Result<()> { let path = self.paths.session_path(&session.id); let content = serde_json::to_string_pretty(session)?; - fs::write(&path, content).await?; + + // Write content to file + let file = fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&path) + .await?; + + use tokio::io::AsyncWriteExt; + let mut file = file; + file.write_all(content.as_bytes()).await?; + file.flush().await?; + + // Ensure data is durably written to disk (fsync) to prevent data loss on crash + file.sync_all().await?; + + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + { + if let Some(parent) = path.parent() { + if let Ok(dir) = fs::File::open(parent).await { + let _ = dir.sync_all().await; + } + } + } + debug!(session_id = %session.id, "Session saved"); Ok(()) } /// Save a session synchronously. + /// + /// This function ensures data durability by calling sync_all() (fsync) + /// after writing to prevent data loss on crash or forceful termination. pub fn save_session_sync(&self, session: &StoredSession) -> Result<()> { let path = self.paths.session_path(&session.id); let file = std::fs::File::create(&path)?; - let writer = BufWriter::new(file); - serde_json::to_writer_pretty(writer, session)?; + let mut writer = BufWriter::new(file); + serde_json::to_writer_pretty(&mut writer, session)?; + writer.flush()?; + + // Ensure data is durably written to disk (fsync) to prevent data loss on crash + writer.get_ref().sync_all()?; + + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + { + if let Some(parent) = path.parent() { + if let Ok(dir) = std::fs::File::open(parent) { + let _ = dir.sync_all(); + } + } + } + debug!(session_id = %session.id, "Session saved"); Ok(()) } diff --git a/src/cortex-tui-components/src/dropdown.rs b/src/cortex-tui-components/src/dropdown.rs index fbae13f..3a2560d 100644 --- a/src/cortex-tui-components/src/dropdown.rs +++ b/src/cortex-tui-components/src/dropdown.rs @@ -99,7 +99,7 @@ impl DropdownState { /// Select the next item. pub fn select_next(&mut self) { - if self.items.is_empty() { + if self.items.is_empty() || self.max_visible == 0 { return; } self.selected = (self.selected + 1) % self.items.len(); @@ -108,7 +108,7 @@ impl DropdownState { /// Select the previous item. pub fn select_prev(&mut self) { - if self.items.is_empty() { + if self.items.is_empty() || self.max_visible == 0 { return; } self.selected = if self.selected == 0 { diff --git a/src/cortex-tui-components/src/scroll.rs b/src/cortex-tui-components/src/scroll.rs index b80145e..38497b8 100644 --- a/src/cortex-tui-components/src/scroll.rs +++ b/src/cortex-tui-components/src/scroll.rs @@ -119,12 +119,16 @@ impl ScrollState { /// /// Adjusts offset if necessary to make the item visible. pub fn ensure_visible(&mut self, index: usize) { + // Guard against zero visible items to prevent underflow + if self.visible == 0 { + return; + } if index < self.offset { // Item is above visible area - scroll up self.offset = index; } else if index >= self.offset + self.visible { // Item is below visible area - scroll down - self.offset = index.saturating_sub(self.visible - 1); + self.offset = index.saturating_sub(self.visible.saturating_sub(1)); } self.clamp_offset(); } diff --git a/src/cortex-tui-components/src/selection_list.rs b/src/cortex-tui-components/src/selection_list.rs index f25ab1b..bd08658 100644 --- a/src/cortex-tui-components/src/selection_list.rs +++ b/src/cortex-tui-components/src/selection_list.rs @@ -572,7 +572,7 @@ impl SelectionList { && let Some(reason) = &item.disabled_reason { let reason_str = format!(" {}", reason); - let reason_x = x + width - reason_str.len() as u16 - 1; + let reason_x = x.saturating_add(width.saturating_sub(reason_str.len() as u16 + 1)); if reason_x > col + 2 { buf.set_string( reason_x, @@ -651,7 +651,11 @@ impl SelectionList { buf.set_string(x + 2, area.y, &display_text, text_style); - let cursor_x = x + 2 + self.search_query.len() as u16; + // Use character count for cursor position, and account for truncation + let query_char_count = self.search_query.chars().count(); + let display_char_count = display_text.chars().count(); + let cursor_offset = display_char_count.min(query_char_count) as u16; + let cursor_x = x + 2 + cursor_offset; if cursor_x < area.right().saturating_sub(1) { buf[(cursor_x, area.y)].set_bg(self.colors.accent); buf[(cursor_x, area.y)].set_fg(self.colors.void); diff --git a/src/cortex-tui/Cargo.toml b/src/cortex-tui/Cargo.toml index b75ee4b..9874145 100644 --- a/src/cortex-tui/Cargo.toml +++ b/src/cortex-tui/Cargo.toml @@ -65,6 +65,7 @@ walkdir = { workspace = true } # External editor which = { workspace = true } +tempfile = { workspace = true } # Audio notifications rodio = { workspace = true } diff --git a/src/cortex-tui/src/cards/commands.rs b/src/cortex-tui/src/cards/commands.rs index b777c5a..25226b4 100644 --- a/src/cortex-tui/src/cards/commands.rs +++ b/src/cortex-tui/src/cards/commands.rs @@ -225,8 +225,9 @@ impl CardView for CommandsCard { fn desired_height(&self, max_height: u16, _width: u16) -> u16 { // Base height for list items + search bar + some padding - let command_count = self.commands.len() as u16; - let content_height = command_count + 2; // +2 for search bar and padding + // Use saturating conversion to prevent overflow when count > u16::MAX + let command_count = u16::try_from(self.commands.len()).unwrap_or(u16::MAX); + let content_height = command_count.saturating_add(2); // +2 for search bar and padding // Clamp between min 5 and max 14, respecting max_height content_height.clamp(5, 14).min(max_height) diff --git a/src/cortex-tui/src/cards/models.rs b/src/cortex-tui/src/cards/models.rs index a5d0e48..a7abf75 100644 --- a/src/cortex-tui/src/cards/models.rs +++ b/src/cortex-tui/src/cards/models.rs @@ -147,8 +147,9 @@ impl CardView for ModelsCard { fn desired_height(&self, max_height: u16, _width: u16) -> u16 { // Base height for list items + search bar + some padding - let model_count = self.models.len() as u16; - let content_height = model_count + 2; // +2 for search bar and padding + // Use saturating conversion to prevent overflow when count > u16::MAX + let model_count = u16::try_from(self.models.len()).unwrap_or(u16::MAX); + let content_height = model_count.saturating_add(2); // +2 for search bar and padding // Clamp between min 5 and max 12, respecting max_height content_height.clamp(5, 12).min(max_height) diff --git a/src/cortex-tui/src/cards/sessions.rs b/src/cortex-tui/src/cards/sessions.rs index 76c67a0..b856f91 100644 --- a/src/cortex-tui/src/cards/sessions.rs +++ b/src/cortex-tui/src/cards/sessions.rs @@ -207,7 +207,9 @@ impl CardView for SessionsCard { fn desired_height(&self, max_height: u16, _width: u16) -> u16 { // Base height: sessions + header + search bar + padding - let content_height = self.sessions.len() as u16 + 3; + // Use saturating conversion to prevent overflow when count > u16::MAX + let session_count = u16::try_from(self.sessions.len()).unwrap_or(u16::MAX); + let content_height = session_count.saturating_add(3); let min_height = 5; let max_desired = 15; content_height diff --git a/src/cortex-tui/src/external_editor.rs b/src/cortex-tui/src/external_editor.rs index 7aa545b..7b94574 100644 --- a/src/cortex-tui/src/external_editor.rs +++ b/src/cortex-tui/src/external_editor.rs @@ -162,17 +162,25 @@ pub async fn open_external_editor(initial_content: &str) -> Result = editor_cmd.split_whitespace().collect(); let (editor, args) = match parts.split_first() { @@ -219,17 +227,25 @@ pub fn open_external_editor_sync(initial_content: &str) -> Result = editor_cmd.split_whitespace().collect(); let (editor, args) = match parts.split_first() { @@ -264,12 +280,13 @@ pub fn open_external_editor_sync(initial_content: &str) -> Result PathBuf { let temp_dir = std::env::temp_dir(); - temp_dir.join(format!("cortex_prompt_{}.md", std::process::id())) + temp_dir.join("cortex_prompt_XXXXXXXXXXXXXXXX.md") } // ============================================================ diff --git a/src/cortex-tui/src/interactive/renderer.rs b/src/cortex-tui/src/interactive/renderer.rs index d10598a..763989a 100644 --- a/src/cortex-tui/src/interactive/renderer.rs +++ b/src/cortex-tui/src/interactive/renderer.rs @@ -109,7 +109,13 @@ impl<'a> InteractiveWidget<'a> { let hints_height = 1; let border_height = 2; - (items_count as u16) + header_height + search_height + hints_height + border_height + // Use saturating conversion to prevent overflow when items_count exceeds u16::MAX + let items_height = u16::try_from(items_count).unwrap_or(u16::MAX); + items_height + .saturating_add(header_height) + .saturating_add(search_height) + .saturating_add(hints_height) + .saturating_add(border_height) } } diff --git a/src/cortex-tui/src/mcp_storage.rs b/src/cortex-tui/src/mcp_storage.rs index 19b247b..a67ff1d 100644 --- a/src/cortex-tui/src/mcp_storage.rs +++ b/src/cortex-tui/src/mcp_storage.rs @@ -15,6 +15,37 @@ use anyhow::{Context, Result}; use cortex_common::AppDirs; use serde::{Deserialize, Serialize}; +// ============================================================ +// SECURITY HELPERS +// ============================================================ + +/// Sanitize a server name to prevent path traversal attacks. +/// +/// Only allows alphanumeric characters, hyphens, and underscores. +/// Any other characters (including path separators) are replaced with underscores. +fn sanitize_server_name(name: &str) -> String { + name.chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' || c == '_' { + c + } else { + '_' + } + }) + .collect() +} + +/// Validate a server name for safe filesystem use. +/// +/// Returns true if the name contains only safe characters. +#[allow(dead_code)] +pub fn validate_server_name(name: &str) -> bool { + !name.is_empty() + && name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') +} + /// MCP transport type #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] @@ -158,8 +189,11 @@ impl McpStorage { } /// Get the path to a server's config file + /// + /// The server name is sanitized to prevent path traversal attacks. fn server_path(&self, name: &str) -> PathBuf { - self.mcps_dir.join(format!("{}.json", name)) + let sanitized_name = sanitize_server_name(name); + self.mcps_dir.join(format!("{}.json", sanitized_name)) } /// Save an MCP server configuration @@ -372,4 +406,50 @@ mod tests { let result = storage.load_server("nonexistent").unwrap(); assert!(result.is_none()); } + + #[test] + fn test_sanitize_server_name() { + // Normal names stay the same + assert_eq!(sanitize_server_name("my-server"), "my-server"); + assert_eq!(sanitize_server_name("server_123"), "server_123"); + + // Path traversal attempts get sanitized - verify no path separators remain + // and that result ends with "etc" (exact underscore count may vary by platform) + let sanitized = sanitize_server_name("../../../etc"); + assert!(!sanitized.contains('/')); + assert!(!sanitized.contains('\\')); + assert!(!sanitized.contains("..")); + assert!(sanitized.ends_with("etc")); + + assert_eq!(sanitize_server_name("test/subdir"), "test_subdir"); + assert_eq!(sanitize_server_name("test\\windows"), "test_windows"); + } + + #[test] + fn test_validate_server_name() { + // Valid names + assert!(validate_server_name("my-server")); + assert!(validate_server_name("server_123")); + assert!(validate_server_name("ABC")); + + // Invalid names + assert!(!validate_server_name("../../../etc")); + assert!(!validate_server_name("test/subdir")); + assert!(!validate_server_name("")); + assert!(!validate_server_name("name with spaces")); + } + + #[test] + fn test_server_path_traversal() { + let (storage, tmp) = test_storage(); + let base_dir = tmp.path().to_path_buf(); + + // Attempt path traversal + let malicious_name = "../../../etc/passwd"; + let result_path = storage.server_path(malicious_name); + + // The result should still be under mcps_dir + assert!(result_path.starts_with(base_dir.join("mcps"))); + assert!(!result_path.to_string_lossy().contains("..")); + } } diff --git a/src/cortex-tui/src/session/storage.rs b/src/cortex-tui/src/session/storage.rs index 7e1621e..b5a04de 100644 --- a/src/cortex-tui/src/session/storage.rs +++ b/src/cortex-tui/src/session/storage.rs @@ -21,6 +21,37 @@ const META_FILE: &str = "meta.json"; /// History file name. const HISTORY_FILE: &str = "history.jsonl"; +// ============================================================ +// SECURITY HELPERS +// ============================================================ + +/// Sanitize a session ID to prevent path traversal attacks. +/// +/// Only allows alphanumeric characters, hyphens, and underscores. +/// Any other characters (including path separators) are replaced with underscores. +fn sanitize_session_id(session_id: &str) -> String { + session_id + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' || c == '_' { + c + } else { + '_' + } + }) + .collect() +} + +/// Validate a session ID for safe filesystem use. +/// +/// Returns true if the session_id contains only safe characters. +pub fn validate_session_id(session_id: &str) -> bool { + !session_id.is_empty() + && session_id + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') +} + // ============================================================ // SESSION STORAGE // ============================================================ @@ -49,8 +80,21 @@ impl SessionStorage { } /// Gets the directory for a specific session. + /// + /// # Security + /// + /// The session_id is validated to prevent path traversal attacks. + /// Only alphanumeric characters, hyphens, and underscores are allowed. + /// + /// # Panics + /// + /// This function does not panic but will return an invalid path if + /// the session_id contains disallowed characters. Use `validate_session_id` + /// before calling this function for untrusted input. pub fn session_dir(&self, session_id: &str) -> PathBuf { - self.base_dir.join(session_id) + // Sanitize session_id to prevent path traversal + let sanitized_id = sanitize_session_id(session_id); + self.base_dir.join(&sanitized_id) } /// Gets the metadata file path for a session. @@ -87,6 +131,9 @@ impl SessionStorage { // ======================================================================== /// Saves session metadata. + /// + /// Uses atomic write (temp file + rename) with fsync for durability. + /// This prevents data loss on crash or forceful termination. pub fn save_meta(&self, meta: &SessionMeta) -> Result<()> { self.ensure_session_dir(&meta.id)?; @@ -94,13 +141,35 @@ impl SessionStorage { let content = serde_json::to_string_pretty(meta).context("Failed to serialize session metadata")?; - // Atomic write: write to temp file then rename + // Atomic write: write to temp file, fsync, then rename let temp_path = path.with_extension("json.tmp"); - fs::write(&temp_path, &content) + + // Write and sync temp file + let file = File::create(&temp_path) + .with_context(|| format!("Failed to create temp metadata file: {:?}", temp_path))?; + let mut writer = BufWriter::new(file); + writer + .write_all(content.as_bytes()) .with_context(|| format!("Failed to write temp metadata file: {:?}", temp_path))?; + writer.flush()?; + + // Ensure data is durably written to disk (fsync) before rename + writer.get_ref().sync_all().with_context(|| { + format!("Failed to sync temp metadata file to disk: {:?}", temp_path) + })?; + + // Rename temp file to final path fs::rename(&temp_path, &path) .with_context(|| format!("Failed to rename metadata file: {:?}", path))?; + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + if let Some(parent) = path.parent() + && let Ok(dir) = File::open(parent) + { + let _ = dir.sync_all(); + } + Ok(()) } @@ -212,6 +281,14 @@ impl SessionStorage { fs::rename(&temp_path, &path) .with_context(|| format!("Failed to rename history file: {:?}", path))?; + // Sync parent directory on Unix for crash safety (ensures directory entry is persisted) + #[cfg(unix)] + if let Some(parent) = path.parent() + && let Ok(dir) = File::open(parent) + { + let _ = dir.sync_all(); + } + Ok(()) } @@ -429,4 +506,51 @@ mod tests { let loaded = storage.load_meta(&session_id).unwrap(); assert!(loaded.archived); } + + #[test] + fn test_validate_session_id() { + // Valid IDs + assert!(validate_session_id("abc-123")); + assert!(validate_session_id("test_session")); + assert!(validate_session_id("ABC123")); + + // Invalid IDs - path traversal attempts + assert!(!validate_session_id("../../../etc")); + assert!(!validate_session_id("..")); + assert!(!validate_session_id("test/../passwd")); + assert!(!validate_session_id("test/subdir")); + assert!(!validate_session_id("")); + } + + #[test] + fn test_sanitize_session_id() { + // Normal ID stays the same + assert_eq!(sanitize_session_id("abc-123"), "abc-123"); + assert_eq!(sanitize_session_id("test_session"), "test_session"); + + // Path traversal gets sanitized - verify no path separators remain + // and that result ends with "etc" (exact underscore count may vary by platform) + let sanitized = sanitize_session_id("../../../etc"); + assert!(!sanitized.contains('/')); + assert!(!sanitized.contains('\\')); + assert!(!sanitized.contains("..")); + assert!(sanitized.ends_with("etc")); + + assert_eq!(sanitize_session_id("test/subdir"), "test_subdir"); + assert_eq!(sanitize_session_id("test\x00evil"), "test_evil"); + } + + #[test] + fn test_session_dir_path_traversal() { + let (storage, temp) = create_test_storage(); + let base_dir = temp.path().to_path_buf(); + + // Attempt path traversal - should be sanitized + let malicious_id = "../../../etc/passwd"; + let result_path = storage.session_dir(malicious_id); + + // The result should still be under base_dir, not escaping it + assert!(result_path.starts_with(&base_dir)); + assert!(!result_path.to_string_lossy().contains("..")); + } } diff --git a/src/cortex-tui/src/widgets/autocomplete.rs b/src/cortex-tui/src/widgets/autocomplete.rs index 77ecce3..3eefe69 100644 --- a/src/cortex-tui/src/widgets/autocomplete.rs +++ b/src/cortex-tui/src/widgets/autocomplete.rs @@ -77,10 +77,11 @@ impl<'a> AutocompletePopup<'a> { let item_count = self.state.visible_items().len() as u16; let height = item_count * ITEM_HEIGHT + 2; // +2 for borders - // Calculate width based on content + // Calculate width based on visible/filtered items only (not all items) + // This prevents the popup from being too wide when the filtered list is smaller let content_width = self .state - .items + .visible_items() .iter() .map(|item| { let icon_width = if item.icon != '\0' { 2 } else { 0 }; @@ -204,11 +205,19 @@ impl Widget for AutocompletePopup<'_> { let (width, height) = self.calculate_dimensions(); - // Position the popup above the input area - // We assume `area` is positioned where the popup should appear + // Position the popup above the input area if there's room, otherwise below + // This prevents the popup from going off-screen at the top + let y = if area.y >= height { + // Enough room above - position popup above the input + area.y.saturating_sub(height) + } else { + // Not enough room above - position popup below the input + area.bottom() + }; + let popup_area = Rect { x: area.x, - y: area.y.saturating_sub(height), + y, width: width.min(area.width), height, }; diff --git a/src/cortex-tui/src/widgets/help_browser/render.rs b/src/cortex-tui/src/widgets/help_browser/render.rs index ae0d035..4f11d99 100644 --- a/src/cortex-tui/src/widgets/help_browser/render.rs +++ b/src/cortex-tui/src/widgets/help_browser/render.rs @@ -152,7 +152,9 @@ impl<'a> HelpBrowser<'a> { /// Renders the content pane. fn render_content(&self, area: Rect, buf: &mut Buffer) { - let section = self.state.current_section(); + let Some(section) = self.state.current_section() else { + return; + }; let mut y = area.y; let scroll = self.state.content_scroll; let mut line_idx = 0; diff --git a/src/cortex-tui/src/widgets/help_browser/state.rs b/src/cortex-tui/src/widgets/help_browser/state.rs index ebfe52c..da95936 100644 --- a/src/cortex-tui/src/widgets/help_browser/state.rs +++ b/src/cortex-tui/src/widgets/help_browser/state.rs @@ -148,8 +148,13 @@ impl HelpBrowserState { } /// Returns the currently selected section. - pub fn current_section(&self) -> &HelpSection { - &self.sections[self.selected_section] + /// + /// Returns `None` if the sections vector is empty. + pub fn current_section(&self) -> Option<&HelpSection> { + if self.sections.is_empty() { + return None; + } + self.sections.get(self.selected_section) } /// Handles character input for search. diff --git a/src/cortex-tui/src/widgets/help_browser/tests.rs b/src/cortex-tui/src/widgets/help_browser/tests.rs index ed8a583..8772517 100644 --- a/src/cortex-tui/src/widgets/help_browser/tests.rs +++ b/src/cortex-tui/src/widgets/help_browser/tests.rs @@ -43,7 +43,10 @@ mod tests { #[test] fn test_help_browser_state_with_topic() { let state = HelpBrowserState::new().with_topic(Some("keyboard")); - assert_eq!(state.current_section().id, "keyboard"); + assert_eq!( + state.current_section().expect("should have section").id, + "keyboard" + ); } #[test] @@ -220,10 +223,16 @@ mod tests { #[test] fn test_current_section() { let mut state = HelpBrowserState::new(); - assert_eq!(state.current_section().id, "getting-started"); + assert_eq!( + state.current_section().expect("should have section").id, + "getting-started" + ); state.select_next(); - assert_eq!(state.current_section().id, "keyboard"); + assert_eq!( + state.current_section().expect("should have section").id, + "keyboard" + ); } #[test] @@ -232,4 +241,11 @@ mod tests { assert!(!state.sections.is_empty()); assert_eq!(state.selected_section, 0); } + + #[test] + fn test_current_section_empty_sections() { + let mut state = HelpBrowserState::new(); + state.sections.clear(); + assert!(state.current_section().is_none()); + } } diff --git a/src/cortex-tui/src/widgets/mention_popup.rs b/src/cortex-tui/src/widgets/mention_popup.rs index 7c68640..65fb12c 100644 --- a/src/cortex-tui/src/widgets/mention_popup.rs +++ b/src/cortex-tui/src/widgets/mention_popup.rs @@ -85,12 +85,12 @@ impl<'a> MentionPopup<'a> { let item_count = self.state.visible_results().len() as u16; let height = (item_count + 2).min(MAX_HEIGHT + 2); // +2 for borders - // Calculate width based on content + // Calculate width based on content (use chars().count() for Unicode support) let content_width = self .state .results() .iter() - .map(|p| p.to_string_lossy().len()) + .map(|p| p.to_string_lossy().chars().count()) .max() .unwrap_or(20) as u16; @@ -195,8 +195,11 @@ impl Widget for MentionPopup<'_> { let (width, height) = self.calculate_dimensions(area); - // Position the popup - let popup_area = if self.above { + // Position the popup - check if it fits above, otherwise render below + let fits_above = area.y >= height; + let render_above = self.above && fits_above; + + let popup_area = if render_above { Rect::new(area.x, area.y.saturating_sub(height), width, height) } else { Rect::new( diff --git a/src/cortex-tui/src/widgets/scrollable_dropdown.rs b/src/cortex-tui/src/widgets/scrollable_dropdown.rs index e4c07fd..7f7f49b 100644 --- a/src/cortex-tui/src/widgets/scrollable_dropdown.rs +++ b/src/cortex-tui/src/widgets/scrollable_dropdown.rs @@ -254,9 +254,12 @@ impl<'a> ScrollableDropdown<'a> { /// Returns visible items slice. fn visible_items(&self) -> &[DropdownItem] { - let start = self.scroll_offset; + if self.max_visible == 0 || self.items.is_empty() { + return &[]; + } + let start = self.scroll_offset.min(self.items.len()); let end = (start + self.max_visible).min(self.items.len()); - &self.items[start..end] + self.items.get(start..end).unwrap_or(&[]) } /// Renders a single item. @@ -459,7 +462,7 @@ pub fn calculate_scroll_offset( max_visible: usize, total_items: usize, ) -> usize { - if total_items <= max_visible { + if max_visible == 0 || total_items <= max_visible { return 0; } @@ -468,7 +471,7 @@ pub fn calculate_scroll_offset( selected } else if selected >= current_offset + max_visible { // Selected item is below visible area - scroll down - selected.saturating_sub(max_visible - 1) + selected.saturating_sub(max_visible.saturating_sub(1)) } else { // Selected item is visible - no change needed current_offset @@ -482,7 +485,7 @@ pub fn select_prev( max_visible: usize, total_items: usize, ) -> (usize, usize) { - if total_items == 0 { + if total_items == 0 || max_visible == 0 { return (0, 0); } @@ -511,7 +514,7 @@ pub fn select_next( max_visible: usize, total_items: usize, ) -> (usize, usize) { - if total_items == 0 { + if total_items == 0 || max_visible == 0 { return (0, 0); } @@ -521,8 +524,8 @@ pub fn select_next( // Wrapped to start 0 } else if new_selected >= scroll_offset + max_visible { - // Need to scroll down - new_selected.saturating_sub(max_visible - 1) + // Need to scroll down - use saturating_sub to prevent underflow + new_selected.saturating_sub(max_visible.saturating_sub(1)) } else { scroll_offset };