diff --git a/obli-transpiler-framework/backend/src/codegen.rs b/obli-transpiler-framework/backend/src/codegen.rs index 7ad3ab9..b0e95da 100644 --- a/obli-transpiler-framework/backend/src/codegen.rs +++ b/obli-transpiler-framework/backend/src/codegen.rs @@ -10,6 +10,57 @@ use crate::error::Error; use crate::oir::*; use std::fmt::Write; +/// Validate that an identifier is safe for code generation +/// +/// Returns an error if the identifier contains characters that could +/// enable code injection attacks. +fn validate_identifier(name: &str) -> Result<(), Error> { + if name.is_empty() { + return Err(Error::codegen("empty identifier")); + } + + // First character must be letter or underscore + let first = name.chars().next().unwrap(); + if !first.is_ascii_alphabetic() && first != '_' { + return Err(Error::codegen(format!( + "invalid identifier '{}': must start with letter or underscore", + name + ))); + } + + // Rest must be alphanumeric or underscore + for c in name.chars() { + if !c.is_ascii_alphanumeric() && c != '_' { + return Err(Error::codegen(format!( + "invalid identifier '{}': contains forbidden character '{}'", + name, c + ))); + } + } + + // Check for Rust reserved keywords that could cause issues + const DANGEROUS_NAMES: &[&str] = &[ + "unsafe", "asm", "extern", "mod", "crate", "self", "super", + "macro_rules", "include", "include_str", "include_bytes", + ]; + if DANGEROUS_NAMES.contains(&name) { + return Err(Error::codegen(format!( + "identifier '{}' is reserved and cannot be used", + name + ))); + } + + Ok(()) +} + +/// Sanitize an identifier for safe code generation +/// +/// Validates and returns the identifier, or returns an error. +fn sanitize_ident(name: &str) -> Result<&str, Error> { + validate_identifier(name)?; + Ok(name) +} + /// Code generator state pub struct CodeGenerator { indent: usize, @@ -264,7 +315,10 @@ impl CodeGenerator { match expr { Expr::Lit(lit) => self.emit_literal(lit)?, - Expr::Var(name) => write!(self.output, "{}", name)?, + Expr::Var(name) => { + let safe_name = sanitize_ident(name)?; + write!(self.output, "{}", safe_name)?; + } Expr::Binop(op, lhs, rhs) => { write!(self.output, "(")?; @@ -280,7 +334,8 @@ impl CodeGenerator { } Expr::Call(name, args) => { - write!(self.output, "{}(", name)?; + let safe_name = sanitize_ident(name)?; + write!(self.output, "{}(", safe_name)?; for (i, arg) in args.iter().enumerate() { if i > 0 { write!(self.output, ", ")?; @@ -298,8 +353,9 @@ impl CodeGenerator { } Expr::Field(obj, field) => { + let safe_field = sanitize_ident(field)?; self.emit_expr(obj)?; - write!(self.output, ".{}", field)?; + write!(self.output, ".{}", safe_field)?; } Expr::Cmov(cond, then_val, else_val) => { @@ -324,12 +380,14 @@ impl CodeGenerator { } Expr::Struct(name, fields) => { - write!(self.output, "{} {{", name)?; + let safe_name = sanitize_ident(name)?; + write!(self.output, "{} {{", safe_name)?; for (i, (fname, fval)) in fields.iter().enumerate() { + let safe_fname = sanitize_ident(fname)?; if i > 0 { write!(self.output, ",")?; } - write!(self.output, " {}: ", fname)?; + write!(self.output, " {}: ", safe_fname)?; self.emit_expr(fval)?; } write!(self.output, " }}")?; diff --git a/obli-transpiler-framework/driver/src/pipeline.rs b/obli-transpiler-framework/driver/src/pipeline.rs index caf7e84..9d66a67 100644 --- a/obli-transpiler-framework/driver/src/pipeline.rs +++ b/obli-transpiler-framework/driver/src/pipeline.rs @@ -4,10 +4,74 @@ //! Compilation pipeline implementation use crate::error::Error; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::process::Command; use tempfile::TempDir; +/// Validate that a path is safe for use (no path traversal) +/// +/// Ensures the path doesn't contain ".." or other traversal attempts. +fn validate_path(path: &Path) -> Result<(), Error> { + // Check for path traversal attempts + for component in path.components() { + if let std::path::Component::ParentDir = component { + return Err(Error::InvalidInput(format!( + "path '{}' contains parent directory reference (..)", + path.display() + ))); + } + } + + // Check the path string for other dangerous patterns + let path_str = path.to_string_lossy(); + if path_str.contains('\0') { + return Err(Error::InvalidInput("path contains null byte".to_string())); + } + + Ok(()) +} + +/// Validate and normalize an input file path +fn validate_input_path(path: &Path) -> Result { + validate_path(path)?; + + // Verify the file exists + if !path.exists() { + return Err(Error::InputNotFound(path.display().to_string())); + } + + // Use canonicalize to resolve to absolute path + path.canonicalize() + .map_err(|e| Error::InvalidInput(format!("cannot resolve path '{}': {}", path.display(), e))) +} + +/// Validate an output path (doesn't need to exist, but must be safe) +fn validate_output_path(path: &Path) -> Result { + validate_path(path)?; + + // If parent exists, canonicalize parent and join filename + if let Some(parent) = path.parent() { + if parent.as_os_str().is_empty() { + // No parent means current directory - that's fine + Ok(path.to_path_buf()) + } else if parent.exists() { + let canonical_parent = parent.canonicalize() + .map_err(|e| Error::InvalidInput(format!( + "cannot resolve parent directory '{}': {}", + parent.display(), e + )))?; + Ok(canonical_parent.join(path.file_name().unwrap_or_default())) + } else { + Err(Error::InvalidInput(format!( + "parent directory '{}' does not exist", + parent.display() + ))) + } + } else { + Ok(path.to_path_buf()) + } +} + /// Configuration for compile command pub struct CompileConfig { pub input: PathBuf, @@ -80,16 +144,18 @@ fn find_backend() -> Result { /// Compile .obl to .rs pub fn compile(config: CompileConfig) -> Result<(), Error> { - if !config.input.exists() { - return Err(Error::InputNotFound(config.input.display().to_string())); - } + // Validate input path (security: prevent path traversal) + let input = validate_input_path(&config.input)?; let frontend = find_frontend()?; let backend = find_backend()?; - // Determine output paths - let oir_path = config.input.with_extension("oir.json"); - let rs_path = config.output.unwrap_or_else(|| config.input.with_extension("rs")); + // Determine and validate output paths + let oir_path = validate_output_path(&input.with_extension("oir.json"))?; + let rs_path = match config.output { + Some(ref p) => validate_output_path(p)?, + None => validate_output_path(&input.with_extension("rs"))?, + }; if config.verbose { eprintln!("Using frontend: {}", frontend.display()); diff --git a/obli-transpiler-framework/frontend/lib/typecheck.ml b/obli-transpiler-framework/frontend/lib/typecheck.ml index f62178a..5c1494d 100644 --- a/obli-transpiler-framework/frontend/lib/typecheck.ml +++ b/obli-transpiler-framework/frontend/lib/typecheck.ml @@ -410,11 +410,18 @@ and check_stmt state env stmt = | SAssign (lhs, rhs) -> let lhs_at = check_expr state env lhs in let rhs_at = check_expr state env rhs in + (* Check type compatibility *) if not (types_equal lhs_at.typ rhs_at.typ) then report state.diags (type_mismatch ~expected:(type_to_string lhs_at.typ) ~found:(type_to_string rhs_at.typ) - rhs.expr_loc) + rhs.expr_loc); + (* CRITICAL: Check security label - cannot assign high to low (information leak) *) + if not (security_leq rhs_at.security lhs_at.security) then + report state.diags (information_leak + ~from_label:(security_to_string rhs_at.security) + ~to_label:(security_to_string lhs_at.security) + stmt.stmt_loc) | SOramWrite (arr, idx, value) -> let arr_at = check_expr state env arr in diff --git a/obli-transpiler-framework/runtime/src/crypto.rs b/obli-transpiler-framework/runtime/src/crypto.rs index cbbf4d9..dec921b 100644 --- a/obli-transpiler-framework/runtime/src/crypto.rs +++ b/obli-transpiler-framework/runtime/src/crypto.rs @@ -11,6 +11,7 @@ use aes_gcm::{ }; use rand::RngCore; use sha2::{Digest, Sha256}; +use std::sync::atomic::{AtomicU64, Ordering}; use zeroize::{Zeroize, ZeroizeOnDrop}; /// Encryption key size (256 bits) @@ -22,6 +23,10 @@ pub const NONCE_SIZE: usize = 12; /// Authentication tag size pub const TAG_SIZE: usize = 16; +/// Maximum nonce counter before key rotation required +/// Using 2^48 to stay well below birthday bound for 96-bit nonces +pub const MAX_NONCE_COUNTER: u64 = 1u64 << 48; + /// A secret key that zeroizes on drop #[derive(Clone, Zeroize, ZeroizeOnDrop)] pub struct SecretKey([u8; KEY_SIZE]); @@ -45,10 +50,74 @@ impl SecretKey { } } -/// Encrypt a block of data using AES-256-GCM +/// Stateful encryptor with counter-based nonces +/// +/// Uses a monotonic counter for nonces to prevent reuse. +/// Each encryptor instance should be used with a single key. +pub struct Encryptor { + cipher: Aes256Gcm, + /// Random prefix for nonce (4 bytes) + nonce_prefix: [u8; 4], + /// Monotonic counter (8 bytes) - ensures unique nonces + counter: AtomicU64, +} + +impl Encryptor { + /// Create a new encryptor with the given key + pub fn new(key: &SecretKey) -> Self { + let cipher = Aes256Gcm::new(key.0.as_ref().into()); + let mut nonce_prefix = [0u8; 4]; + OsRng.fill_bytes(&mut nonce_prefix); + Encryptor { + cipher, + nonce_prefix, + counter: AtomicU64::new(0), + } + } + + /// Encrypt data with automatic nonce generation + /// + /// # Panics + /// Panics if nonce counter exceeds MAX_NONCE_COUNTER (requires key rotation) + pub fn encrypt(&self, plaintext: &[u8]) -> Result, CryptoError> { + let counter = self.counter.fetch_add(1, Ordering::SeqCst); + if counter >= MAX_NONCE_COUNTER { + return Err(CryptoError::NonceExhausted); + } + + let mut nonce_bytes = [0u8; NONCE_SIZE]; + nonce_bytes[..4].copy_from_slice(&self.nonce_prefix); + nonce_bytes[4..12].copy_from_slice(&counter.to_le_bytes()); + let nonce = Nonce::from_slice(&nonce_bytes); + + let ciphertext = self.cipher + .encrypt(nonce, plaintext) + .map_err(|_| CryptoError::EncryptionFailed)?; + + let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len()); + result.extend_from_slice(&nonce_bytes); + result.extend_from_slice(&ciphertext); + Ok(result) + } + + /// Decrypt data + pub fn decrypt(&self, ciphertext: &[u8]) -> Result, CryptoError> { + decrypt_with_cipher(&self.cipher, ciphertext) + } + + /// Get current nonce counter (for diagnostics) + pub fn nonce_count(&self) -> u64 { + self.counter.load(Ordering::SeqCst) + } +} + +/// Encrypt a block of data using AES-256-GCM (stateless, random nonce) +/// +/// WARNING: For high-volume encryption, use Encryptor instead to prevent +/// nonce reuse via birthday attack. This function is safe for occasional use. /// /// Returns ciphertext with nonce prepended. -pub fn encrypt(key: &SecretKey, plaintext: &[u8]) -> Vec { +pub fn encrypt(key: &SecretKey, plaintext: &[u8]) -> Result, CryptoError> { let cipher = Aes256Gcm::new(key.0.as_ref().into()); let mut nonce_bytes = [0u8; NONCE_SIZE]; @@ -57,23 +126,20 @@ pub fn encrypt(key: &SecretKey, plaintext: &[u8]) -> Vec { let ciphertext = cipher .encrypt(nonce, plaintext) - .expect("encryption should not fail"); + .map_err(|_| CryptoError::EncryptionFailed)?; let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len()); result.extend_from_slice(&nonce_bytes); result.extend_from_slice(&ciphertext); - result + Ok(result) } -/// Decrypt a block of data using AES-256-GCM -/// -/// Expects nonce prepended to ciphertext. -pub fn decrypt(key: &SecretKey, ciphertext: &[u8]) -> Result, CryptoError> { +/// Internal decrypt function using pre-created cipher +fn decrypt_with_cipher(cipher: &Aes256Gcm, ciphertext: &[u8]) -> Result, CryptoError> { if ciphertext.len() < NONCE_SIZE + TAG_SIZE { return Err(CryptoError::InvalidCiphertext); } - let cipher = Aes256Gcm::new(key.0.as_ref().into()); let nonce = Nonce::from_slice(&ciphertext[..NONCE_SIZE]); let ct = &ciphertext[NONCE_SIZE..]; @@ -82,6 +148,14 @@ pub fn decrypt(key: &SecretKey, ciphertext: &[u8]) -> Result, CryptoErro .map_err(|_| CryptoError::DecryptionFailed) } +/// Decrypt a block of data using AES-256-GCM +/// +/// Expects nonce prepended to ciphertext. +pub fn decrypt(key: &SecretKey, ciphertext: &[u8]) -> Result, CryptoError> { + let cipher = Aes256Gcm::new(key.0.as_ref().into()); + decrypt_with_cipher(&cipher, ciphertext) +} + /// Compute SHA-256 hash pub fn sha256(data: &[u8]) -> [u8; 32] { let mut hasher = Sha256::new(); @@ -116,7 +190,10 @@ pub fn prf(key: &SecretKey, input: u64) -> u64 { pub enum CryptoError { InvalidCiphertext, DecryptionFailed, + EncryptionFailed, InvalidKeyLength, + /// Nonce counter exhausted - key rotation required + NonceExhausted, } impl std::fmt::Display for CryptoError { @@ -124,7 +201,9 @@ impl std::fmt::Display for CryptoError { match self { CryptoError::InvalidCiphertext => write!(f, "invalid ciphertext"), CryptoError::DecryptionFailed => write!(f, "decryption failed"), + CryptoError::EncryptionFailed => write!(f, "encryption failed"), CryptoError::InvalidKeyLength => write!(f, "invalid key length"), + CryptoError::NonceExhausted => write!(f, "nonce counter exhausted - key rotation required"), } } } @@ -139,11 +218,31 @@ mod tests { fn test_encrypt_decrypt() { let key = SecretKey::generate(); let plaintext = b"hello, ORAM world!"; - let ciphertext = encrypt(&key, plaintext); + let ciphertext = encrypt(&key, plaintext).unwrap(); let decrypted = decrypt(&key, &ciphertext).unwrap(); assert_eq!(decrypted, plaintext); } + #[test] + fn test_encryptor_stateful() { + let key = SecretKey::generate(); + let encryptor = Encryptor::new(&key); + + let plaintext = b"test message"; + let ct1 = encryptor.encrypt(plaintext).unwrap(); + let ct2 = encryptor.encrypt(plaintext).unwrap(); + + // Same plaintext should produce different ciphertext (different nonces) + assert_ne!(ct1, ct2); + + // Both should decrypt correctly + assert_eq!(encryptor.decrypt(&ct1).unwrap(), plaintext); + assert_eq!(encryptor.decrypt(&ct2).unwrap(), plaintext); + + // Counter should have advanced + assert_eq!(encryptor.nonce_count(), 2); + } + #[test] fn test_decrypt_wrong_key() { let key1 = SecretKey::generate(); diff --git a/obli-transpiler-framework/runtime/src/oram/bucket.rs b/obli-transpiler-framework/runtime/src/oram/bucket.rs index dcd6794..09d4a96 100644 --- a/obli-transpiler-framework/runtime/src/oram/bucket.rs +++ b/obli-transpiler-framework/runtime/src/oram/bucket.rs @@ -7,7 +7,20 @@ use super::OramBlock; use crate::constant_time::{ct_lookup, ct_store}; -use subtle::{Choice, ConditionallySelectable}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; + +/// Constant-time equality check for u64 +/// Returns 0xFF if equal, 0x00 otherwise +#[inline] +fn ct_eq_u64(a: u64, b: u64) -> u8 { + // XOR gives 0 if equal, use subtle's ct_eq for the comparison + let diff = a ^ b; + // Check if all bits are zero in constant time + let is_zero = diff | diff.wrapping_neg(); + let high_bit = (is_zero >> 63) as u8; + // If diff was 0, high_bit is 0, so we return 1; else return 0 + high_bit ^ 1 +} /// Number of blocks per bucket (Z parameter in Path ORAM) pub const BUCKET_SIZE: usize = 4; @@ -86,26 +99,35 @@ impl Bucket { /// Read and remove entry with given address (constant-time) /// + /// This operation is constant-time: it always accesses all entries + /// and uses conditional selection to hide which entry matched. + /// /// Returns the data if found, None otherwise. /// The entry is marked as empty. pub fn read_and_remove(&mut self, addr: u64) -> Option where T: ConditionallySelectable + Clone, { - let mut found = false; let mut result = T::default(); + let mut found_mask = 0u8; + // Process ALL entries to maintain constant-time behavior for entry in &mut self.entries { - let matches = entry.addr == addr; - if matches { - found = true; - result = entry.data.clone(); - entry.addr = u64::MAX; - entry.data = T::default(); - } + // Constant-time comparison using XOR and OR reduction + let addr_match = ct_eq_u64(entry.addr, addr); + let choice = Choice::from(addr_match); + + // Conditionally select result (only updates if match) + result.conditional_assign(&entry.data, choice); + found_mask |= addr_match; + + // Conditionally clear entry + let empty_addr = u64::conditional_select(&entry.addr, &u64::MAX, choice); + entry.addr = empty_addr; + entry.data.conditional_assign(&T::default(), choice); } - if found { + if found_mask != 0 { Some(result) } else { None @@ -113,16 +135,29 @@ impl Bucket { } /// Read entry with given address without removing (constant-time) + /// + /// This operation is constant-time: it always accesses all entries + /// and uses conditional selection to hide which entry matched. pub fn read(&self, addr: u64) -> Option where T: ConditionallySelectable + Clone, { + let mut result = T::default(); + let mut found_mask = 0u8; + + // Process ALL entries to maintain constant-time behavior for entry in &self.entries { - if entry.addr == addr { - return Some(entry.data.clone()); - } + let addr_match = ct_eq_u64(entry.addr, addr); + let choice = Choice::from(addr_match); + result.conditional_assign(&entry.data, choice); + found_mask |= addr_match; + } + + if found_mask != 0 { + Some(result) + } else { + None } - None } /// Get entries as slice diff --git a/obli-transpiler-framework/runtime/src/oram/path.rs b/obli-transpiler-framework/runtime/src/oram/path.rs index 22e45c8..080dcc3 100644 --- a/obli-transpiler-framework/runtime/src/oram/path.rs +++ b/obli-transpiler-framework/runtime/src/oram/path.rs @@ -30,14 +30,31 @@ pub struct PathOram { capacity: u64, } +/// Maximum supported ORAM capacity (2^32 blocks) +/// Limited to prevent integer overflow in tree size calculations +pub const MAX_ORAM_CAPACITY: u64 = 1u64 << 32; + impl PathOram { /// Create a new Path ORAM with given capacity + /// + /// # Panics + /// Panics if capacity is 0 or exceeds MAX_ORAM_CAPACITY pub fn new(capacity: u64, key: SecretKey) -> Self { + // Validate capacity to prevent overflow + assert!(capacity > 0, "ORAM capacity must be > 0"); + assert!( + capacity <= MAX_ORAM_CAPACITY, + "ORAM capacity {} exceeds maximum {} (risk of integer overflow)", + capacity, + MAX_ORAM_CAPACITY + ); + // Calculate tree depth (ceil(log2(capacity))) let depth = (64 - capacity.leading_zeros()) as usize; let num_leaves = 1u64 << depth; // Total nodes in complete binary tree: 2^(depth+1) - 1 + // Safe: depth <= 32, so 2^33 - 1 fits in usize on 64-bit let num_nodes = (1usize << (depth + 1)) - 1; // Initialize empty tree @@ -57,10 +74,21 @@ impl PathOram { } /// Access (read or write) a block + /// + /// # Panics + /// Panics if addr >= capacity fn access(&mut self, addr: u64, op: AccessOp) -> T where T: Clone, { + // Bounds check to prevent corruption + assert!( + addr < self.capacity, + "ORAM address {} out of bounds (capacity: {})", + addr, + self.capacity + ); + // 1. Look up position and remap let (old_leaf, new_leaf) = self.position_map.get_and_remap(addr); diff --git a/obli-transpiler-framework/runtime/src/oram/stash.rs b/obli-transpiler-framework/runtime/src/oram/stash.rs index 516c738..7d96bd9 100644 --- a/obli-transpiler-framework/runtime/src/oram/stash.rs +++ b/obli-transpiler-framework/runtime/src/oram/stash.rs @@ -40,13 +40,23 @@ impl Stash { } /// Add a block to the stash + /// + /// # Panics + /// Panics if stash exceeds MAX_STASH_SIZE - this indicates ORAM security breach pub fn add(&mut self, addr: u64, leaf: u64, data: T) { - self.entries.push(StashEntry { addr, leaf, data }); - if self.entries.len() > MAX_STASH_SIZE { - // In production, this would be a security failure - // For now, just warn (the stash overflow bound proof ensures this is negligible) - log::warn!("Stash overflow: {} entries", self.entries.len()); + if self.entries.len() >= MAX_STASH_SIZE { + // CRITICAL: Stash overflow breaks ORAM security guarantees + // This should never happen with correct parameters (probability negligible) + // If it does, the only safe option is to abort to prevent information leak + panic!( + "FATAL: Stash overflow ({} >= {}) - ORAM security compromised! \ + This indicates either incorrect ORAM parameters or an attack. \ + Aborting to prevent information leakage.", + self.entries.len() + 1, + MAX_STASH_SIZE + ); } + self.entries.push(StashEntry { addr, leaf, data }); } /// Find and remove a block by address @@ -145,14 +155,20 @@ fn path_overlaps(leaf1: u64, leaf2: u64, depth: usize) -> bool { } /// Calculate the deepest level where two paths overlap +/// +/// Returns the deepest tree level (0 = root, depth-1 = leaf) where +/// the paths to leaf1 and leaf2 share a common ancestor. pub fn path_overlap_level(leaf1: u64, leaf2: u64, depth: usize) -> usize { - for level in (0..depth).rev() { + // Find the deepest level where paths still share a prefix + for level in 0..depth { let shift = depth - level - 1; - if (leaf1 >> shift) == (leaf2 >> shift) { - return level; + if (leaf1 >> shift) != (leaf2 >> shift) { + // Paths diverge at this level, so overlap is at parent (level - 1) + return level.saturating_sub(1); } } - 0 + // Paths completely overlap (same leaf) + depth.saturating_sub(1) } #[cfg(test)]