diff --git a/docs/specs/architecture/compiler-architecture.adoc b/docs/specs/architecture/compiler-architecture.adoc new file mode 100644 index 0000000..8c91039 --- /dev/null +++ b/docs/specs/architecture/compiler-architecture.adoc @@ -0,0 +1,491 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + += Oblibeny Compiler Architecture: OCaml Frontend + Rust Backend +:author: Oblibeny Project +:revdate: 2024 +:toc: left +:toclevels: 4 +:sectnums: +:stem: latexmath + +== Overview + +Oblibeny uses a split-compiler architecture: + +* **OCaml Frontend**: Parsing, type checking, security analysis, IR generation +* **Rust Backend**: Code generation, optimization, runtime, ORAM implementation + +This leverages OCaml's strengths in symbolic manipulation and Rust's strengths +in systems programming and performance. + +== Architecture Diagram + +[source] +---- + OBLIBENY COMPILER ARCHITECTURE + ══════════════════════════════ + +┌─────────────────────────────────────────────────────────────────────────────┐ +│ OCaml Frontend │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ Source Code (.obl) │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │ │ +│ │ │ Lexer │───▶│ Parser │───▶│ Abstract Syntax Tree │ │ │ +│ │ │ (ocamllex) │ │ (Menhir) │ │ (AST) │ │ │ +│ │ └─────────────┘ └─────────────┘ └───────────┬─────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Type Checker │ │ │ +│ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ │ │ +│ │ │ │ Base Types │ │ Security │ │ Obliviousness │ │ │ │ +│ │ │ │ Checker │ │ Levels │ │ Analysis │ │ │ │ +│ │ │ └─────────────┘ └─────────────┘ └─────────────────────┘ │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Typed AST (TAST) │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Security & Obliviousness Checker │ │ │ +│ │ │ • Information flow analysis │ │ │ +│ │ │ • Access pattern leak detection │ │ │ +│ │ │ • Transformation suggestions │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ IR Generator │ │ │ +│ │ │ • Lower TAST to OIR (Oblivious IR) │ │ │ +│ │ │ • Insert ORAM operations │ │ │ +│ │ │ • Mark oblivious regions │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ └───────────────────────────────┼───────────────────────────────────────┘ │ +│ │ │ +└──────────────────────────────────┼───────────────────────────────────────────┘ + │ + │ OIR (MessagePack/JSON) + │ +┌──────────────────────────────────┼───────────────────────────────────────────┐ +│ ▼ │ +│ Rust Backend │ +│ ┌───────────────────────────────────────────────────────────────────────┐ │ +│ │ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ IR Deserializer │ │ │ +│ │ │ • Parse OIR from OCaml frontend │ │ │ +│ │ │ • Validate IR structure │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Optimizer │ │ │ +│ │ │ • Batch ORAM accesses │ │ │ +│ │ │ • Dead code elimination │ │ │ +│ │ │ • Inline oblivious primitives │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ │ ▼ │ │ +│ │ ┌─────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Code Generator │ │ │ +│ │ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ │ +│ │ │ │ Rust │ │ RISC-V │ │ WASM │ │ C │ │ │ │ +│ │ │ │ Output │ │ Output │ │ Output │ │ Output │ │ │ │ +│ │ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │ │ +│ │ └───────────────────────────┬─────────────────────────────────┘ │ │ +│ │ │ │ │ +│ └───────────────────────────────┼────────────────────────────────────────┘ │ +│ │ │ +│ ┌───────────────────────────────┼────────────────────────────────────────┐ │ +│ │ ▼ │ │ +│ │ ORAM Runtime │ │ +│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ +│ │ │ Path ORAM │ │ Stash │ │ Position │ │ Crypto │ │ │ +│ │ │ Core │ │ Manager │ │ Map │ │ Layer │ │ │ +│ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ │ │ +│ │ │ │ +│ │ ┌─────────────────────────────────────────────────────────────────┐ │ │ +│ │ │ Oblivious Standard Library │ │ │ +│ │ │ OArray │ OMap │ OVec │ OSort │ OSearch │ OQueue │ │ │ +│ │ └─────────────────────────────────────────────────────────────────┘ │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────────────────────────┘ │ +│ │ +└───────────────────────────────────────────────────────────────────────────────┘ +---- + +== Component Responsibilities + +=== OCaml Frontend + +[cols="1,3"] +|=== +| Component | Responsibility + +| **Lexer** +| Tokenize source code, handle string literals, comments + +| **Parser** +| Build AST from token stream, handle operator precedence + +| **Type Checker** +| Hindley-Milner type inference + security level inference + +| **Security Analyzer** +| Information flow analysis, detect implicit flows + +| **Obliviousness Checker** +| Identify non-oblivious memory accesses + +| **IR Generator** +| Lower typed AST to OIR, insert ORAM calls +|=== + +=== Rust Backend + +[cols="1,3"] +|=== +| Component | Responsibility + +| **IR Parser** +| Deserialize OIR from OCaml frontend + +| **Optimizer** +| ORAM-aware optimizations (batching, caching hints) + +| **Code Generator** +| Emit target code (Rust, RISC-V, WASM, C) + +| **ORAM Runtime** +| Path ORAM implementation, position map, stash + +| **Stdlib** +| Oblivious data structures (OArray, OMap, etc.) + +| **Crypto** +| AES-GCM, SHA-256, BLAKE3, Merkle trees +|=== + +== Communication Protocol + +=== IR Serialization + +The OCaml frontend produces OIR (Oblivious Intermediate Representation) +serialized as MessagePack (binary) or JSON (debug). + +[source,json] +---- +{ + "version": "1.0.0", + "module": "main", + "functions": [ + { + "name": "lookup", + "params": [ + {"name": "arr", "type": {"oarray": "i64"}, "security": "low"}, + {"name": "idx", "type": "i64", "security": "high"} + ], + "return_type": {"type": "i64", "security": "high"}, + "body": [ + { + "kind": "oram_read", + "array": {"var": "arr"}, + "index": {"var": "idx"}, + "result": "tmp0" + }, + { + "kind": "return", + "value": {"var": "tmp0"} + } + ] + } + ] +} +---- + +=== Build Integration + +[source,bash] +---- +# Full compilation pipeline +obli-frontend source.obl -o source.oir # OCaml +obli-backend source.oir -o source.rs # Rust +rustc source.rs -L obli-runtime -o binary # Standard Rust +---- + +Or unified: +[source,bash] +---- +oblic source.obl -o binary # Driver invokes both +---- + +== Directory Structure + +[source] +---- +oblibeny/ +├── obli-transpiler-framework/ +│ ├── frontend/ # OCaml +│ │ ├── dune-project +│ │ ├── dune +│ │ ├── bin/ +│ │ │ └── main.ml # CLI entry point +│ │ ├── lib/ +│ │ │ ├── lexer.mll # ocamllex +│ │ │ ├── parser.mly # Menhir +│ │ │ ├── ast.ml # AST types +│ │ │ ├── types.ml # Type definitions +│ │ │ ├── typecheck.ml # Type checker +│ │ │ ├── security.ml # Security analysis +│ │ │ ├── oblivious.ml # Obliviousness checker +│ │ │ ├── ir.ml # OIR types +│ │ │ ├── lower.ml # AST → OIR +│ │ │ └── emit.ml # OIR serialization +│ │ └── test/ +│ │ └── *.ml +│ │ +│ ├── backend/ # Rust +│ │ ├── Cargo.toml +│ │ ├── src/ +│ │ │ ├── main.rs # CLI entry point +│ │ │ ├── ir/ +│ │ │ │ ├── mod.rs +│ │ │ │ ├── parse.rs # OIR deserializer +│ │ │ │ └── types.rs # OIR types (mirror OCaml) +│ │ │ ├── opt/ +│ │ │ │ ├── mod.rs +│ │ │ │ ├── batch.rs # ORAM batching +│ │ │ │ └── inline.rs # Primitive inlining +│ │ │ ├── codegen/ +│ │ │ │ ├── mod.rs +│ │ │ │ ├── rust.rs # → Rust output +│ │ │ │ ├── riscv.rs # → RISC-V output +│ │ │ │ └── wasm.rs # → WASM output +│ │ │ └── lib.rs +│ │ └── tests/ +│ │ +│ ├── runtime/ # Rust runtime library +│ │ ├── Cargo.toml +│ │ └── src/ +│ │ ├── lib.rs +│ │ ├── oram/ +│ │ │ ├── mod.rs +│ │ │ ├── path.rs # Path ORAM +│ │ │ ├── position.rs # Position map +│ │ │ ├── stash.rs # Stash management +│ │ │ └── bucket.rs # Bucket operations +│ │ ├── crypto/ +│ │ │ ├── mod.rs +│ │ │ ├── aead.rs # AES-GCM +│ │ │ ├── hash.rs # SHA-256, BLAKE3 +│ │ │ └── merkle.rs # Merkle tree +│ │ └── collections/ +│ │ ├── mod.rs +│ │ ├── oarray.rs # Oblivious array +│ │ ├── omap.rs # Oblivious map +│ │ ├── ovec.rs # Oblivious vector +│ │ └── osort.rs # Oblivious sorting +│ │ +│ └── driver/ # Unified CLI (Rust) +│ ├── Cargo.toml +│ └── src/ +│ └── main.rs # Invokes frontend + backend +│ +├── obli-riscv-dev-kit/ # (separate submodule) +├── obli-fs/ # (separate submodule) +└── docs/ +---- + +== Language Specification Preview + +=== Source Language Syntax (.obl files) + +[source] +---- +// Type declarations with security annotations +type SecretIndex = int @high +type PublicData = int @low + +// Oblivious array type +type Database = oarray + +// Function with security-typed parameters +fn lookup(db: Database, idx: SecretIndex) -> PublicData @high { + // Compiler automatically uses ORAM for this access + // because idx has @high security level + db[idx] +} + +// Explicit oblivious block +fn process(data: array, secret: bool @high) -> int { + oblivious { + // All memory accesses in this block are oblivious + if secret { + data[0] + } else { + data[1] + } + } +} + +// Oblivious conditional (no branching leak) +fn oselect(cond: bool @high, a: T, b: T) -> T @high { + cmov(cond, a, b) // Compiles to constant-time select +} +---- + +=== Type System + +[source] +---- +Types τ ::= + | int | bool | unit (* base types *) + | τ₁ → τ₂ (* functions *) + | τ₁ × τ₂ (* tuples *) + | array<τ> (* regular array *) + | oarray<τ> (* oblivious array *) + | ref<τ> (* mutable reference *) + | oref<τ> (* oblivious reference *) + +Security ℓ ::= + | @low (* public *) + | @high (* secret *) + | @ℓ₁ ⊔ ℓ₂ (* join *) + +Labeled Types σ ::= τ @ℓ +---- + +== Build System + +=== Prerequisites + +[source,bash] +---- +# OCaml toolchain +opam install dune menhir ppx_deriving yojson msgpck + +# Rust toolchain +rustup default stable +cargo install cargo-watch +---- + +=== Build Commands + +[source,bash] +---- +# Build everything +just build + +# Build frontend only +cd obli-transpiler-framework/frontend && dune build + +# Build backend only +cd obli-transpiler-framework/backend && cargo build --release + +# Run tests +just test + +# Format code +just fmt +---- + +=== Justfile + +[source,just] +---- +# Build all components +build: + cd obli-transpiler-framework/frontend && dune build + cd obli-transpiler-framework/backend && cargo build --release + cd obli-transpiler-framework/runtime && cargo build --release + cd obli-transpiler-framework/driver && cargo build --release + +# Run all tests +test: + cd obli-transpiler-framework/frontend && dune test + cd obli-transpiler-framework/backend && cargo test + cd obli-transpiler-framework/runtime && cargo test + +# Format all code +fmt: + cd obli-transpiler-framework/frontend && dune fmt + cd obli-transpiler-framework/backend && cargo fmt + cd obli-transpiler-framework/runtime && cargo fmt + +# Clean build artifacts +clean: + cd obli-transpiler-framework/frontend && dune clean + cd obli-transpiler-framework/backend && cargo clean + cd obli-transpiler-framework/runtime && cargo clean +---- + +== Testing Strategy + +=== Unit Tests + +* OCaml: Each module has `_test.ml` companion +* Rust: Inline `#[cfg(test)]` modules + +=== Integration Tests + +[source] +---- +tests/ +├── compile/ # Source → IR → Binary +│ ├── basic.obl +│ ├── security.obl +│ └── oblivious.obl +├── runtime/ # ORAM correctness +│ ├── path_oram.rs +│ └── stash.rs +└── security/ # Side-channel tests + ├── timing.rs + └── pattern.rs +---- + +=== Property-Based Tests + +[source,rust] +---- +#[test] +fn prop_oram_correctness() { + proptest!(|(ops: Vec)| { + let mut oram = PathOram::new(1024); + let mut reference = HashMap::new(); + + for op in ops { + match op { + OramOp::Write(k, v) => { + oram.write(k, v); + reference.insert(k, v); + } + OramOp::Read(k) => { + assert_eq!(oram.read(k), reference.get(&k).copied()); + } + } + } + }); +} +---- + +== Next Steps + +1. **Phase 1**: Implement minimal OCaml frontend (lexer, parser, basic types) +2. **Phase 2**: Implement Rust ORAM runtime +3. **Phase 3**: Connect via OIR format +4. **Phase 4**: Add security type system +5. **Phase 5**: Add optimizations + +== References + +* Real World OCaml: https://dev.realworldocaml.org/ +* Menhir Manual: http://gallium.inria.fr/~fpottier/menhir/ +* Rust Book: https://doc.rust-lang.org/book/ diff --git a/docs/specs/backend/rust-backend.adoc b/docs/specs/backend/rust-backend.adoc new file mode 100644 index 0000000..d3734b7 --- /dev/null +++ b/docs/specs/backend/rust-backend.adoc @@ -0,0 +1,1123 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + += Rust Backend Specification +:author: Oblibeny Project +:revdate: 2024 +:toc: left +:toclevels: 4 +:sectnums: +:stem: latexmath + +== Overview + +The Rust backend is responsible for: + +1. Parsing OIR from the OCaml frontend +2. Optimizing ORAM operations +3. Generating target code (Rust, RISC-V, WASM) +4. Providing the ORAM runtime library + +== Project Structure + +[source] +---- +backend/ +├── Cargo.toml +├── src/ +│ ├── main.rs # CLI entry point +│ ├── lib.rs # Library root +│ ├── ir/ +│ │ ├── mod.rs +│ │ ├── types.rs # OIR type definitions +│ │ ├── parse.rs # JSON/MessagePack parser +│ │ └── validate.rs # IR validation +│ ├── opt/ +│ │ ├── mod.rs +│ │ ├── batch.rs # ORAM access batching +│ │ ├── inline.rs # Primitive inlining +│ │ ├── dce.rs # Dead code elimination +│ │ └── const_prop.rs # Constant propagation +│ ├── codegen/ +│ │ ├── mod.rs +│ │ ├── rust.rs # Rust code generation +│ │ ├── riscv.rs # RISC-V assembly +│ │ └── wasm.rs # WebAssembly +│ └── error.rs # Error types +└── tests/ + ├── ir_tests.rs + ├── opt_tests.rs + └── codegen_tests.rs + +runtime/ +├── Cargo.toml +├── src/ +│ ├── lib.rs +│ ├── oram/ +│ │ ├── mod.rs +│ │ ├── path.rs # Path ORAM implementation +│ │ ├── position.rs # Position map +│ │ ├── stash.rs # Stash management +│ │ ├── bucket.rs # Bucket operations +│ │ └── tree.rs # Tree structure +│ ├── crypto/ +│ │ ├── mod.rs +│ │ ├── aead.rs # AES-256-GCM +│ │ ├── hash.rs # SHA-256, BLAKE3 +│ │ ├── merkle.rs # Merkle tree +│ │ └── random.rs # Secure RNG +│ ├── collections/ +│ │ ├── mod.rs +│ │ ├── oarray.rs # Oblivious array +│ │ ├── omap.rs # Oblivious map +│ │ ├── ovec.rs # Oblivious vector +│ │ └── osort.rs # Oblivious sorting +│ ├── primitives/ +│ │ ├── mod.rs +│ │ ├── cmov.rs # Constant-time select +│ │ ├── cswap.rs # Constant-time swap +│ │ └── cmp.rs # Constant-time compare +│ └── storage/ +│ ├── mod.rs +│ ├── memory.rs # In-memory backend +│ ├── file.rs # File-based backend +│ └── remote.rs # Network backend +└── benches/ + ├── oram_bench.rs + └── crypto_bench.rs +---- + +== Cargo Configuration + +=== backend/Cargo.toml + +[source,toml] +---- +[package] +name = "obli-backend" +version = "0.1.0" +edition = "2021" +license = "MIT OR Palimpsest-0.8" +description = "Oblibeny compiler backend" + +[dependencies] +obli-runtime = { path = "../runtime" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +rmp-serde = "1.1" +thiserror = "1.0" +clap = { version = "4.0", features = ["derive"] } +tracing = "0.1" +tracing-subscriber = "0.3" + +[dev-dependencies] +proptest = "1.0" +criterion = "0.5" +tempfile = "3.0" + +[[bin]] +name = "obli-backend" +path = "src/main.rs" +---- + +=== runtime/Cargo.toml + +[source,toml] +---- +[package] +name = "obli-runtime" +version = "0.1.0" +edition = "2021" +license = "MIT OR Palimpsest-0.8" +description = "Oblibeny ORAM runtime library" + +[dependencies] +aes-gcm = "0.10" +sha2 = "0.10" +blake3 = "1.5" +rand = "0.8" +rand_chacha = "0.3" +subtle = "2.5" # Constant-time operations +zeroize = "1.7" # Secure memory wiping +parking_lot = "0.12" # Better mutexes +thiserror = "1.0" +tracing = "0.1" + +[dev-dependencies] +proptest = "1.0" +criterion = "0.5" + +[features] +default = ["std"] +std = [] +no_std = [] # For embedded/WASM + +[[bench]] +name = "oram_bench" +harness = false +---- + +== IR Parser + +[source,rust] +---- +// ir/parse.rs + +use crate::ir::types::*; +use serde::de::DeserializeOwned; +use std::io::Read; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ParseError { + #[error("JSON parse error: {0}")] + Json(#[from] serde_json::Error), + + #[error("MessagePack parse error: {0}")] + MsgPack(#[from] rmp_serde::decode::Error), + + #[error("Invalid magic bytes")] + InvalidMagic, + + #[error("Unsupported version: {0}")] + UnsupportedVersion(String), + + #[error("Checksum mismatch")] + ChecksumMismatch, + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +pub type Result = std::result::Result; + +const MAGIC: &[u8; 4] = b"OIR\0"; +const SUPPORTED_VERSIONS: &[&str] = &["1.0.0"]; + +pub fn parse_json(input: &str) -> Result { + let module: Module = serde_json::from_str(input)?; + validate_version(&module.version)?; + Ok(module) +} + +pub fn parse_msgpack(mut reader: R) -> Result { + // Read and verify magic + let mut magic = [0u8; 4]; + reader.read_exact(&mut magic)?; + if &magic != MAGIC { + return Err(ParseError::InvalidMagic); + } + + // Read version + let mut version = [0u8; 4]; + reader.read_exact(&mut version)?; + let version = u32::from_le_bytes(version); + + // Read payload length + let mut length = [0u8; 8]; + reader.read_exact(&mut length)?; + let length = u64::from_le_bytes(length) as usize; + + // Read payload + let mut payload = vec![0u8; length]; + reader.read_exact(&mut payload)?; + + // Read and verify checksum + let mut checksum = [0u8; 32]; + reader.read_exact(&mut checksum)?; + let computed = blake3::hash(&payload); + if computed.as_bytes() != &checksum { + return Err(ParseError::ChecksumMismatch); + } + + // Deserialize + let module: Module = rmp_serde::from_slice(&payload)?; + validate_version(&module.version)?; + Ok(module) +} + +fn validate_version(version: &str) -> Result<()> { + if SUPPORTED_VERSIONS.contains(&version) { + Ok(()) + } else { + Err(ParseError::UnsupportedVersion(version.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_simple_json() { + let json = r#"{ + "version": "1.0.0", + "name": "test", + "imports": [], + "types": [], + "globals": [], + "functions": [], + "entry": null, + "metadata": { + "compiler_version": "0.1.0", + "timestamp": "2024-01-01T00:00:00Z", + "options": {} + } + }"#; + + let module = parse_json(json).unwrap(); + assert_eq!(module.name, "test"); + } +} +---- + +== ORAM Runtime + +=== Path ORAM Implementation + +[source,rust] +---- +// oram/path.rs + +use crate::crypto::{Aead, Rng}; +use crate::oram::{Bucket, PositionMap, Stash}; +use std::marker::PhantomData; +use subtle::ConstantTimeEq; +use zeroize::Zeroize; + +/// Path ORAM configuration +#[derive(Debug, Clone)] +pub struct Config { + /// Number of data blocks + pub block_count: usize, + /// Size of each block in bytes + pub block_size: usize, + /// Number of blocks per bucket + pub bucket_size: usize, + /// Maximum stash size before error + pub max_stash_size: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + block_count: 1 << 20, // 1M blocks + block_size: 4096, // 4KB + bucket_size: 4, // Z = 4 + max_stash_size: 256, // O(λ) + } + } +} + +/// Block identifier +pub type BlockId = u64; + +/// Leaf identifier (position in tree) +pub type LeafId = u64; + +/// Encrypted block with metadata +#[derive(Clone, Zeroize)] +pub struct EncryptedBlock { + /// Block ID (encrypted) + id: BlockId, + /// Leaf assignment (encrypted) + leaf: LeafId, + /// Encrypted data + data: Vec, + /// Authentication tag + tag: [u8; 16], +} + +/// Path ORAM implementation +pub struct PathOram { + config: Config, + /// Tree storage backend + storage: S, + /// Client-side position map + position_map: PositionMap, + /// Client-side stash + stash: Stash, + /// Encryption key + key: [u8; 32], + /// Access counter (for nonce derivation) + access_counter: u64, + /// Random number generator + rng: Rng, +} + +impl PathOram { + /// Create new Path ORAM instance + pub fn new(config: Config, storage: S) -> Self { + let mut rng = Rng::new(); + let key = rng.gen_key(); + + let tree_height = (config.block_count as f64).log2().ceil() as usize; + let leaf_count = 1 << tree_height; + + // Initialize position map with random positions + let position_map = PositionMap::new_random(config.block_count, leaf_count, &mut rng); + + Self { + config, + storage, + position_map, + stash: Stash::new(), + key, + access_counter: 0, + rng, + } + } + + /// Read a block obliviously + pub fn read(&mut self, block_id: BlockId) -> Result, OramError> { + self.access(Operation::Read, block_id, None) + } + + /// Write a block obliviously + pub fn write(&mut self, block_id: BlockId, data: Vec) -> Result, OramError> { + if data.len() != self.config.block_size { + return Err(OramError::InvalidBlockSize); + } + self.access(Operation::Write, block_id, Some(data)) + } + + /// Core access operation + fn access( + &mut self, + op: Operation, + block_id: BlockId, + data: Option>, + ) -> Result, OramError> { + // Step 1: Get old position and assign new random position + let old_leaf = self.position_map.get(block_id); + let new_leaf = self.rng.gen_leaf(self.leaf_count()); + self.position_map.set(block_id, new_leaf); + + // Step 2: Read entire path from root to leaf into stash + let path = self.read_path(old_leaf)?; + for bucket in path { + for block in bucket.blocks() { + if !block.is_dummy() { + let decrypted = self.decrypt_block(&block)?; + self.stash.insert(decrypted.id, decrypted); + } + } + } + + // Step 3: Perform the actual operation + let result = match op { + Operation::Read => { + self.stash + .get(block_id) + .map(|b| b.data.clone()) + .ok_or(OramError::BlockNotFound)? + } + Operation::Write => { + let old_data = self.stash + .get(block_id) + .map(|b| b.data.clone()) + .unwrap_or_else(|| vec![0u8; self.config.block_size]); + + self.stash.insert(block_id, Block { + id: block_id, + leaf: new_leaf, + data: data.unwrap(), + }); + + old_data + } + }; + + // Step 4: Eviction - write blocks back along path + self.evict_path(old_leaf)?; + + // Check stash overflow + if self.stash.len() > self.config.max_stash_size { + return Err(OramError::StashOverflow); + } + + self.access_counter += 1; + Ok(result) + } + + /// Read all buckets along path to leaf + fn read_path(&self, leaf: LeafId) -> Result, OramError> { + let height = self.tree_height(); + let mut path = Vec::with_capacity(height + 1); + + for level in 0..=height { + let node = self.path_node(leaf, level); + let bucket = self.storage.read_bucket(node)?; + path.push(bucket); + } + + Ok(path) + } + + /// Evict blocks from stash back to path + fn evict_path(&mut self, leaf: LeafId) -> Result<(), OramError> { + let height = self.tree_height(); + + // Process from leaves to root + for level in (0..=height).rev() { + let node = self.path_node(leaf, level); + + // Collect blocks that can go to this bucket + let mut bucket_blocks = Vec::with_capacity(self.config.bucket_size); + + // Find blocks whose path includes this node + let to_evict: Vec = self.stash + .iter() + .filter(|(_, block)| self.can_reside_at(block.leaf, node)) + .take(self.config.bucket_size) + .map(|(id, _)| *id) + .collect(); + + for id in to_evict { + if let Some(block) = self.stash.remove(&id) { + bucket_blocks.push(block); + } + } + + // Pad with dummy blocks + while bucket_blocks.len() < self.config.bucket_size { + bucket_blocks.push(Block::dummy(self.config.block_size)); + } + + // Encrypt and write bucket + let encrypted: Vec = bucket_blocks + .into_iter() + .map(|b| self.encrypt_block(&b)) + .collect(); + + self.storage.write_bucket(node, Bucket::new(encrypted))?; + } + + Ok(()) + } + + /// Check if a block with given leaf can reside at node + fn can_reside_at(&self, leaf: LeafId, node: u64) -> bool { + let height = self.tree_height(); + let level = self.node_level(node); + + // Node is on path from root to leaf + self.path_node(leaf, level) == node + } + + /// Get node index at given level on path to leaf + fn path_node(&self, leaf: LeafId, level: usize) -> u64 { + let height = self.tree_height(); + // At level 0 (root), node = 0 + // At level height (leaves), node = 2^height - 1 + leaf + let leaf_start = (1u64 << height) - 1; + let leaf_node = leaf_start + leaf; + + // Traverse up from leaf to find node at level + leaf_node >> (height - level) + } + + fn tree_height(&self) -> usize { + (self.config.block_count as f64).log2().ceil() as usize + } + + fn leaf_count(&self) -> u64 { + 1 << self.tree_height() + } + + fn node_level(&self, node: u64) -> usize { + // Level 0 = root, Level H = leaves + ((node + 1) as f64).log2().floor() as usize + } + + fn encrypt_block(&self, block: &Block) -> EncryptedBlock { + let nonce = self.derive_nonce(block.id, self.access_counter); + let (ciphertext, tag) = Aead::encrypt(&self.key, &nonce, &block.data); + EncryptedBlock { + id: block.id, + leaf: block.leaf, + data: ciphertext, + tag, + } + } + + fn decrypt_block(&self, block: &EncryptedBlock) -> Result { + // Try decryption (timing-safe) + let nonce = self.derive_nonce(block.id, self.access_counter); + let data = Aead::decrypt(&self.key, &nonce, &block.data, &block.tag) + .map_err(|_| OramError::DecryptionFailed)?; + + Ok(Block { + id: block.id, + leaf: block.leaf, + data, + }) + } + + fn derive_nonce(&self, block_id: BlockId, counter: u64) -> [u8; 12] { + let mut nonce = [0u8; 12]; + nonce[0..8].copy_from_slice(&block_id.to_le_bytes()); + nonce[8..12].copy_from_slice(&(counter as u32).to_le_bytes()); + nonce + } +} + +impl Drop for PathOram { + fn drop(&mut self) { + self.key.zeroize(); + self.stash.zeroize(); + } +} +---- + +=== Constant-Time Primitives + +[source,rust] +---- +// primitives/cmov.rs + +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; + +/// Constant-time conditional move +/// +/// Returns `a` if `condition` is true, `b` otherwise. +/// Timing is independent of `condition`. +#[inline] +pub fn cmov(condition: bool, a: T, b: T) -> T { + T::conditional_select(&b, &a, Choice::from(condition as u8)) +} + +/// Constant-time conditional swap +/// +/// If `condition` is true, swaps `a` and `b`. +/// Timing is independent of `condition`. +#[inline] +pub fn cswap(condition: bool, a: &mut T, b: &mut T) { + let choice = Choice::from(condition as u8); + T::conditional_swap(a, b, choice); +} + +/// Constant-time equality comparison +#[inline] +pub fn ct_eq(a: &T, b: &T) -> bool { + a.ct_eq(b).into() +} + +/// Constant-time less-than comparison for u64 +#[inline] +pub fn ct_lt(a: u64, b: u64) -> bool { + // a < b iff (a - b) has high bit set (considering underflow) + let diff = a.wrapping_sub(b); + (diff >> 63) == 1 +} + +/// Constant-time array lookup +/// +/// Returns `arr[index]` but accesses all elements to hide the index. +pub fn ct_lookup(arr: &[T], index: usize) -> T { + let mut result = T::default(); + + for (i, item) in arr.iter().enumerate() { + let is_target = ct_eq_usize(i, index); + result = cmov(is_target, *item, result); + } + + result +} + +/// Constant-time array write +/// +/// Sets `arr[index] = value` but touches all elements to hide the index. +pub fn ct_write(arr: &mut [T], index: usize, value: T) { + for (i, item) in arr.iter_mut().enumerate() { + let is_target = ct_eq_usize(i, index); + *item = cmov(is_target, value, *item); + } +} + +fn ct_eq_usize(a: usize, b: usize) -> bool { + (a ^ b) == 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cmov() { + assert_eq!(cmov(true, 10u64, 20u64), 10); + assert_eq!(cmov(false, 10u64, 20u64), 20); + } + + #[test] + fn test_ct_lookup() { + let arr = [10, 20, 30, 40, 50]; + assert_eq!(ct_lookup(&arr, 0), 10); + assert_eq!(ct_lookup(&arr, 2), 30); + assert_eq!(ct_lookup(&arr, 4), 50); + } +} +---- + +== Code Generator + +=== Rust Code Generation + +[source,rust] +---- +// codegen/rust.rs + +use crate::ir::*; +use std::fmt::Write; + +pub struct RustCodegen { + output: String, + indent: usize, +} + +impl RustCodegen { + pub fn new() -> Self { + Self { + output: String::new(), + indent: 0, + } + } + + pub fn generate(&mut self, module: &Module) -> String { + // Header + self.emit_line("// Generated by Oblibeny compiler"); + self.emit_line("// DO NOT EDIT"); + self.emit_line(""); + self.emit_line("use obli_runtime::prelude::*;"); + self.emit_line(""); + + // Generate each function + for func in &module.functions { + self.generate_function(func); + self.emit_line(""); + } + + std::mem::take(&mut self.output) + } + + fn generate_function(&mut self, func: &Function) { + // Function signature + let params: Vec = func.params.iter() + .map(|p| format!("{}: {}", p.name, self.type_to_rust(&p.ty.ty))) + .collect(); + + let ret_type = self.type_to_rust(&func.return_type.ty); + + self.emit_line(&format!( + "pub fn {}({}) -> {} {{", + func.name, + params.join(", "), + ret_type + )); + + self.indent += 1; + + // Locals + for local in &func.locals { + self.emit_line(&format!( + "let mut {}: {};", + local.name, + self.type_to_rust(&local.ty.ty) + )); + } + + if !func.locals.is_empty() { + self.emit_line(""); + } + + // Body + for instr in &func.body { + self.generate_instr(instr); + } + + self.indent -= 1; + self.emit_line("}"); + } + + fn generate_instr(&mut self, instr: &Instr) { + match instr { + Instr::Let { name, value } => { + self.emit_line(&format!( + "let {} = {};", + name, + self.expr_to_rust(value) + )); + } + + Instr::Assign { target, value } => { + self.emit_line(&format!( + "{} = {};", + self.lvalue_to_rust(target), + self.expr_to_rust(value) + )); + } + + Instr::OramRead { array, index, result } => { + self.emit_line(&format!( + "{} = {}.oram_read({});", + result, + self.expr_to_rust(array), + self.expr_to_rust(index) + )); + } + + Instr::OramWrite { array, index, value } => { + self.emit_line(&format!( + "{}.oram_write({}, {});", + self.expr_to_rust(array), + self.expr_to_rust(index), + self.expr_to_rust(value) + )); + } + + Instr::Cmov { cond, true_val, false_val, result } => { + self.emit_line(&format!( + "{} = cmov({}, {}, {});", + result, + self.expr_to_rust(cond), + self.expr_to_rust(true_val), + self.expr_to_rust(false_val) + )); + } + + Instr::OIf { cond, then_, else_ } => { + // Oblivious if: execute both branches, select result + self.emit_line("{"); + self.indent += 1; + + self.emit_line(&format!( + "let __cond = {};", + self.expr_to_rust(cond) + )); + + // Execute "then" branch + self.emit_line("let __then_result = {"); + self.indent += 1; + for i in then_ { + self.generate_instr(i); + } + self.indent -= 1; + self.emit_line("};"); + + // Execute "else" branch + self.emit_line("let __else_result = {"); + self.indent += 1; + for i in else_ { + self.generate_instr(i); + } + self.indent -= 1; + self.emit_line("};"); + + // Select based on condition (constant-time) + self.emit_line("cmov(__cond, __then_result, __else_result)"); + + self.indent -= 1; + self.emit_line("}"); + } + + Instr::Return(Some(value)) => { + self.emit_line(&format!("return {};", self.expr_to_rust(value))); + } + + Instr::Return(None) => { + self.emit_line("return;"); + } + + _ => { + self.emit_line("// TODO: unimplemented instruction"); + } + } + } + + fn expr_to_rust(&self, expr: &Expr) -> String { + match expr { + Expr::Unit => "()".to_string(), + Expr::Bool(b) => b.to_string(), + Expr::Int { value, ty } => { + format!("{}_{}", value, self.prim_to_rust(ty)) + } + Expr::Float { value, ty } => { + format!("{}_{}", value, self.prim_to_rust(ty)) + } + Expr::Var(name) => name.clone(), + Expr::Global(name) => format!("GLOBAL_{}", name.to_uppercase()), + Expr::Add(a, b) => format!("({} + {})", self.expr_to_rust(a), self.expr_to_rust(b)), + Expr::Sub(a, b) => format!("({} - {})", self.expr_to_rust(a), self.expr_to_rust(b)), + Expr::Mul(a, b) => format!("({} * {})", self.expr_to_rust(a), self.expr_to_rust(b)), + Expr::Div(a, b) => format!("({} / {})", self.expr_to_rust(a), self.expr_to_rust(b)), + Expr::Eq(a, b) => format!("({} == {})", self.expr_to_rust(a), self.expr_to_rust(b)), + Expr::Lt(a, b) => format!("({} < {})", self.expr_to_rust(a), self.expr_to_rust(b)), + Expr::Call { func, args } => { + let args_str: Vec = args.iter().map(|a| self.expr_to_rust(a)).collect(); + format!("{}({})", func, args_str.join(", ")) + } + _ => "/* unimplemented */".to_string(), + } + } + + fn lvalue_to_rust(&self, lv: &LValue) -> String { + match lv { + LValue::Var(name) => name.clone(), + LValue::Index { array, index } => { + format!("{}[{}]", self.expr_to_rust(array), self.expr_to_rust(index)) + } + LValue::Field { strct, field } => { + format!("{}.{}", self.expr_to_rust(strct), field) + } + } + } + + fn type_to_rust(&self, ty: &Type) -> String { + match ty { + Type::Prim(p) => self.prim_to_rust(p), + Type::Array { elem, size } => { + match size { + Some(n) => format!("[{}; {}]", self.type_to_rust(elem), n), + None => format!("Vec<{}>", self.type_to_rust(elem)), + } + } + Type::OArray { elem, .. } => { + format!("OArray<{}>", self.type_to_rust(elem)) + } + Type::Tuple(elems) => { + let parts: Vec = elems.iter().map(|e| self.type_to_rust(e)).collect(); + format!("({})", parts.join(", ")) + } + Type::Func { params, ret } => { + let params_str: Vec = params.iter().map(|p| self.type_to_rust(p)).collect(); + format!("fn({}) -> {}", params_str.join(", "), self.type_to_rust(ret)) + } + Type::Named(name) => name.clone(), + _ => "/* unknown type */".to_string(), + } + } + + fn prim_to_rust(&self, prim: &PrimType) -> &'static str { + match prim { + PrimType::Unit => "()", + PrimType::Bool => "bool", + PrimType::I8 => "i8", + PrimType::I16 => "i16", + PrimType::I32 => "i32", + PrimType::I64 => "i64", + PrimType::U8 => "u8", + PrimType::U16 => "u16", + PrimType::U32 => "u32", + PrimType::U64 => "u64", + PrimType::F32 => "f32", + PrimType::F64 => "f64", + } + } + + fn emit_line(&mut self, line: &str) { + for _ in 0..self.indent { + self.output.push_str(" "); + } + self.output.push_str(line); + self.output.push('\n'); + } +} +---- + +== CLI Entry Point + +[source,rust] +---- +// main.rs + +use clap::Parser; +use obli_backend::{codegen, ir, opt}; +use std::path::PathBuf; +use tracing::info; + +#[derive(Parser)] +#[command(name = "obli-backend")] +#[command(about = "Oblibeny compiler backend")] +struct Cli { + /// Input OIR file + input: PathBuf, + + /// Output file + #[arg(short, long, default_value = "out.rs")] + output: PathBuf, + + /// Output format + #[arg(short, long, value_enum, default_value = "rust")] + format: OutputFormat, + + /// Optimization level + #[arg(short = 'O', long, default_value = "1")] + opt_level: u8, + + /// Enable debug output + #[arg(short, long)] + debug: bool, +} + +#[derive(Clone, Copy, clap::ValueEnum)] +enum OutputFormat { + Rust, + RiscV, + Wasm, + C, +} + +fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + + tracing_subscriber::fmt() + .with_max_level(if cli.debug { + tracing::Level::DEBUG + } else { + tracing::Level::INFO + }) + .init(); + + // Parse input + info!("Parsing {:?}", cli.input); + let input = std::fs::read_to_string(&cli.input)?; + let module = ir::parse_json(&input)?; + + // Validate + info!("Validating IR"); + ir::validate(&module)?; + + // Optimize + let module = if cli.opt_level > 0 { + info!("Optimizing (level {})", cli.opt_level); + opt::optimize(module, cli.opt_level) + } else { + module + }; + + // Generate code + info!("Generating {:?} code", cli.format); + let output = match cli.format { + OutputFormat::Rust => codegen::rust::generate(&module), + OutputFormat::RiscV => codegen::riscv::generate(&module), + OutputFormat::Wasm => codegen::wasm::generate(&module), + OutputFormat::C => codegen::c::generate(&module), + }; + + // Write output + info!("Writing {:?}", cli.output); + std::fs::write(&cli.output, output)?; + + info!("Done"); + Ok(()) +} +---- + +== Testing + +[source,rust] +---- +// tests/integration_tests.rs + +use obli_runtime::collections::OArray; +use obli_runtime::oram::PathOram; +use obli_runtime::primitives::cmov; + +#[test] +fn test_oarray_basic() { + let mut arr: OArray = OArray::new(1024); + + // Write some values + arr.oram_write(0, 100); + arr.oram_write(1, 200); + arr.oram_write(100, 12345); + + // Read them back + assert_eq!(arr.oram_read(0), 100); + assert_eq!(arr.oram_read(1), 200); + assert_eq!(arr.oram_read(100), 12345); +} + +#[test] +fn test_cmov() { + let secret = true; + let a = 10i64; + let b = 20i64; + + let result = cmov(secret, a, b); + assert_eq!(result, 10); + + let result = cmov(!secret, a, b); + assert_eq!(result, 20); +} + +#[test] +fn test_path_oram_correctness() { + use proptest::prelude::*; + + proptest!(|(ops in prop::collection::vec(any::<(bool, u64, i64)>(), 1..100))| { + let mut oram = PathOram::new(Default::default()); + let mut reference = std::collections::HashMap::new(); + + for (is_write, key, value) in ops { + let key = key % 1000; // Limit key space + if is_write { + oram.write(key, value.to_le_bytes().to_vec()).unwrap(); + reference.insert(key, value); + } else { + let oram_val = oram.read(key).ok(); + let ref_val = reference.get(&key).map(|v| v.to_le_bytes().to_vec()); + assert_eq!(oram_val, ref_val); + } + } + }); +} +---- + +== Benchmarks + +[source,rust] +---- +// benches/oram_bench.rs + +use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; +use obli_runtime::oram::PathOram; + +fn bench_oram_access(c: &mut Criterion) { + let mut group = c.benchmark_group("ORAM Access"); + + for size in [1024, 4096, 16384, 65536] { + group.bench_with_input( + BenchmarkId::new("read", size), + &size, + |b, &size| { + let mut oram = PathOram::new(size, 4096); + let data = vec![0u8; 4096]; + oram.write(0, data.clone()).unwrap(); + + b.iter(|| { + black_box(oram.read(black_box(0)).unwrap()) + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("write", size), + &size, + |b, &size| { + let mut oram = PathOram::new(size, 4096); + let data = vec![42u8; 4096]; + + b.iter(|| { + oram.write(black_box(0), black_box(data.clone())).unwrap() + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_oram_access); +criterion_main!(benches); +---- diff --git a/docs/specs/frontend/ocaml-frontend.adoc b/docs/specs/frontend/ocaml-frontend.adoc new file mode 100644 index 0000000..ee029ec --- /dev/null +++ b/docs/specs/frontend/ocaml-frontend.adoc @@ -0,0 +1,959 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + += OCaml Frontend Specification +:author: Oblibeny Project +:revdate: 2024 +:toc: left +:toclevels: 4 +:sectnums: +:stem: latexmath + +== Overview + +The OCaml frontend is responsible for: + +1. Parsing source code to AST +2. Type checking with security annotations +3. Obliviousness analysis +4. Lowering to OIR + +== Project Structure + +[source] +---- +frontend/ +├── dune-project +├── dune +├── bin/ +│ └── main.ml # CLI entry point +├── lib/ +│ ├── dune +│ ├── syntax/ +│ │ ├── lexer.mll # Lexer (ocamllex) +│ │ ├── parser.mly # Parser (Menhir) +│ │ ├── ast.ml # AST definition +│ │ └── location.ml # Source locations +│ ├── typing/ +│ │ ├── types.ml # Type definitions +│ │ ├── env.ml # Typing environment +│ │ ├── infer.ml # Type inference +│ │ ├── security.ml # Security level checking +│ │ └── tast.ml # Typed AST +│ ├── analysis/ +│ │ ├── oblivious.ml # Obliviousness checker +│ │ ├── flow.ml # Information flow analysis +│ │ └── escape.ml # Escape analysis +│ ├── ir/ +│ │ ├── oir.ml # OIR types +│ │ ├── lower.ml # TAST → OIR +│ │ └── emit.ml # OIR serialization +│ ├── driver/ +│ │ ├── config.ml # Compiler configuration +│ │ ├── errors.ml # Error handling +│ │ └── pipeline.ml # Compilation pipeline +│ └── oblc.ml # Library entry point +└── test/ + ├── lexer_test.ml + ├── parser_test.ml + ├── typing_test.ml + └── integration/ + └── *.obl +---- + +== Dune Configuration + +=== dune-project + +[source,dune] +---- +(lang dune 3.0) +(name obli-frontend) +(version 0.1.0) + +(generate_opam_files true) + +(package + (name obli-frontend) + (synopsis "Oblibeny language frontend") + (description "OCaml frontend for the Oblibeny oblivious computing language") + (depends + (ocaml (>= 4.14.0)) + (dune (>= 3.0)) + (menhir (>= 20220210)) + (ppx_deriving (>= 5.2)) + (yojson (>= 2.0)) + (msgpck (>= 1.7)) + (cmdliner (>= 1.1)) + (alcotest (and (>= 1.6) :with-test)))) +---- + +=== lib/dune + +[source,dune] +---- +(library + (name oblc) + (public_name obli-frontend) + (libraries str yojson msgpck) + (preprocess (pps ppx_deriving.show ppx_deriving.eq ppx_deriving.ord))) + +(ocamllex syntax/lexer) + +(menhir + (modules syntax/parser) + (flags --explain --table)) +---- + +== Source Language Grammar + +=== Lexical Structure + +[source,ocaml] +---- +(* lexer.mll *) + +{ +open Parser + +exception Lexer_error of string * Lexing.position + +let keywords = Hashtbl.create 32 +let () = List.iter (fun (k, v) -> Hashtbl.add keywords k v) [ + "fn", FN; + "let", LET; + "mut", MUT; + "if", IF; + "else", ELSE; + "while", WHILE; + "for", FOR; + "in", IN; + "return", RETURN; + "break", BREAK; + "continue", CONTINUE; + "true", TRUE; + "false", FALSE; + "type", TYPE; + "struct", STRUCT; + "enum", ENUM; + "impl", IMPL; + "pub", PUB; + "oblivious", OBLIVIOUS; + "oarray", OARRAY; + "oref", OREF; + "cmov", CMOV; + "oswap", OSWAP; +] +} + +let digit = ['0'-'9'] +let alpha = ['a'-'z' 'A'-'Z' '_'] +let alnum = alpha | digit +let ident = alpha alnum* + +let int_lit = digit+ | "0x" ['0'-'9' 'a'-'f' 'A'-'F']+ +let float_lit = digit+ '.' digit* (['e' 'E'] ['+' '-']? digit+)? + +let whitespace = [' ' '\t']+ +let newline = '\r' | '\n' | "\r\n" + +rule token = parse + | whitespace { token lexbuf } + | newline { Lexing.new_line lexbuf; token lexbuf } + | "//" { line_comment lexbuf } + | "/*" { block_comment 0 lexbuf } + + (* Operators *) + | '+' { PLUS } + | '-' { MINUS } + | '*' { STAR } + | '/' { SLASH } + | '%' { PERCENT } + | '&' { AMP } + | '|' { PIPE } + | '^' { CARET } + | '~' { TILDE } + | "<<" { LSHIFT } + | ">>" { RSHIFT } + | "==" { EQEQ } + | "!=" { BANGEQ } + | '<' { LT } + | "<=" { LTEQ } + | '>' { GT } + | ">=" { GTEQ } + | "&&" { AMPAMP } + | "||" { PIPEPIPE } + | '!' { BANG } + | '=' { EQ } + | "->" { ARROW } + | "=>" { FATARROW } + | '@' { AT } + + (* Delimiters *) + | '(' { LPAREN } + | ')' { RPAREN } + | '[' { LBRACKET } + | ']' { RBRACKET } + | '{' { LBRACE } + | '}' { RBRACE } + | ',' { COMMA } + | ':' { COLON } + | ';' { SEMI } + | '.' { DOT } + + (* Literals *) + | int_lit as n { INT (Int64.of_string n) } + | float_lit as f { FLOAT (float_of_string f) } + | '"' { string (Buffer.create 32) lexbuf } + + (* Identifiers and keywords *) + | ident as id { + try Hashtbl.find keywords id + with Not_found -> IDENT id + } + + | eof { EOF } + | _ as c { raise (Lexer_error (Printf.sprintf "Unexpected character: %c" c, + lexbuf.Lexing.lex_curr_p)) } + +and line_comment = parse + | newline { Lexing.new_line lexbuf; token lexbuf } + | eof { EOF } + | _ { line_comment lexbuf } + +and block_comment depth = parse + | "*/" { if depth = 0 then token lexbuf else block_comment (depth - 1) lexbuf } + | "/*" { block_comment (depth + 1) lexbuf } + | newline { Lexing.new_line lexbuf; block_comment depth lexbuf } + | eof { raise (Lexer_error ("Unterminated block comment", lexbuf.Lexing.lex_curr_p)) } + | _ { block_comment depth lexbuf } + +and string buf = parse + | '"' { STRING (Buffer.contents buf) } + | "\\n" { Buffer.add_char buf '\n'; string buf lexbuf } + | "\\t" { Buffer.add_char buf '\t'; string buf lexbuf } + | "\\\\" { Buffer.add_char buf '\\'; string buf lexbuf } + | "\\"" { Buffer.add_char buf '"'; string buf lexbuf } + | [^ '"' '\\']+ as s { Buffer.add_string buf s; string buf lexbuf } + | eof { raise (Lexer_error ("Unterminated string", lexbuf.Lexing.lex_curr_p)) } +---- + +=== Parser Grammar + +[source,ocaml] +---- +(* parser.mly *) + +%{ +open Ast + +let make_loc startpos endpos = Location.make startpos endpos +%} + +%token IDENT STRING +%token INT +%token FLOAT +%token TRUE FALSE +%token FN LET MUT IF ELSE WHILE FOR IN RETURN BREAK CONTINUE +%token TYPE STRUCT ENUM IMPL PUB +%token OBLIVIOUS OARRAY OREF CMOV OSWAP +%token PLUS MINUS STAR SLASH PERCENT +%token AMP PIPE CARET TILDE LSHIFT RSHIFT +%token EQEQ BANGEQ LT LTEQ GT GTEQ AMPAMP PIPEPIPE BANG +%token EQ ARROW FATARROW AT +%token LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE +%token COMMA COLON SEMI DOT +%token EOF + +%left PIPEPIPE +%left AMPAMP +%left PIPE +%left CARET +%left AMP +%left EQEQ BANGEQ +%left LT LTEQ GT GTEQ +%left LSHIFT RSHIFT +%left PLUS MINUS +%left STAR SLASH PERCENT +%right BANG TILDE +%left DOT LBRACKET + +%start program + +%% + +program: + | items = list(item) EOF { items } + +item: + | fn_def { $1 } + | type_def { $1 } + | struct_def { $1 } + +fn_def: + | PUB? FN name = IDENT + LPAREN params = separated_list(COMMA, param) RPAREN + ret = option(preceded(ARROW, typ_annot)) + body = block + { Item_fn { + name; + params; + return_type = ret; + body; + is_public = Option.is_some $1; + loc = make_loc $startpos $endpos + } + } + +param: + | name = IDENT COLON ty = typ_annot + { { param_name = name; param_type = ty; param_loc = make_loc $startpos $endpos } } + +typ_annot: + | ty = typ sec = option(preceded(AT, security_level)) + { { ty; security = Option.value sec ~default:Security_infer } } + +typ: + | IDENT { Ty_named $1 } + | LBRACKET ty = typ RBRACKET { Ty_array ty } + | OARRAY LT ty = typ GT { Ty_oarray ty } + | OREF LT ty = typ GT { Ty_oref ty } + | LPAREN tys = separated_list(COMMA, typ) RPAREN { Ty_tuple tys } + | ty = typ ARROW ret = typ { Ty_fn ([ty], ret) } + +security_level: + | IDENT { + match $1 with + | "low" -> Security_low + | "high" -> Security_high + | s -> Security_named s + } + +block: + | LBRACE stmts = list(stmt) RBRACE { stmts } + +stmt: + | LET MUT? name = IDENT ty = option(preceded(COLON, typ_annot)) EQ value = expr SEMI + { Stmt_let { name; ty; value; is_mut = Option.is_some $2; loc = make_loc $startpos $endpos } } + | lhs = lvalue EQ rhs = expr SEMI + { Stmt_assign { lhs; rhs; loc = make_loc $startpos $endpos } } + | IF cond = expr then_branch = block else_branch = option(preceded(ELSE, else_block)) + { Stmt_if { cond; then_branch; else_branch; loc = make_loc $startpos $endpos } } + | OBLIVIOUS IF cond = expr then_branch = block else_branch = option(preceded(ELSE, else_block)) + { Stmt_oif { cond; then_branch; else_branch; loc = make_loc $startpos $endpos } } + | WHILE cond = expr body = block + { Stmt_while { cond; body; loc = make_loc $startpos $endpos } } + | RETURN value = option(expr) SEMI + { Stmt_return { value; loc = make_loc $startpos $endpos } } + | BREAK SEMI { Stmt_break (make_loc $startpos $endpos) } + | CONTINUE SEMI { Stmt_continue (make_loc $startpos $endpos) } + | e = expr SEMI { Stmt_expr { expr = e; loc = make_loc $startpos $endpos } } + +else_block: + | block { $1 } + | IF cond = expr then_branch = block else_branch = option(preceded(ELSE, else_block)) + { [Stmt_if { cond; then_branch; else_branch; loc = make_loc $startpos $endpos }] } + +lvalue: + | IDENT { Lv_var $1 } + | lv = lvalue LBRACKET idx = expr RBRACKET { Lv_index (lv, idx) } + | lv = lvalue DOT field = IDENT { Lv_field (lv, field) } + +expr: + | primary_expr { $1 } + | e1 = expr PLUS e2 = expr { Expr_binop (Op_add, e1, e2, make_loc $startpos $endpos) } + | e1 = expr MINUS e2 = expr { Expr_binop (Op_sub, e1, e2, make_loc $startpos $endpos) } + | e1 = expr STAR e2 = expr { Expr_binop (Op_mul, e1, e2, make_loc $startpos $endpos) } + | e1 = expr SLASH e2 = expr { Expr_binop (Op_div, e1, e2, make_loc $startpos $endpos) } + | e1 = expr PERCENT e2 = expr { Expr_binop (Op_mod, e1, e2, make_loc $startpos $endpos) } + | e1 = expr EQEQ e2 = expr { Expr_binop (Op_eq, e1, e2, make_loc $startpos $endpos) } + | e1 = expr BANGEQ e2 = expr { Expr_binop (Op_ne, e1, e2, make_loc $startpos $endpos) } + | e1 = expr LT e2 = expr { Expr_binop (Op_lt, e1, e2, make_loc $startpos $endpos) } + | e1 = expr LTEQ e2 = expr { Expr_binop (Op_le, e1, e2, make_loc $startpos $endpos) } + | e1 = expr GT e2 = expr { Expr_binop (Op_gt, e1, e2, make_loc $startpos $endpos) } + | e1 = expr GTEQ e2 = expr { Expr_binop (Op_ge, e1, e2, make_loc $startpos $endpos) } + | e1 = expr AMPAMP e2 = expr { Expr_binop (Op_and, e1, e2, make_loc $startpos $endpos) } + | e1 = expr PIPEPIPE e2 = expr { Expr_binop (Op_or, e1, e2, make_loc $startpos $endpos) } + | BANG e = expr { Expr_unop (Op_not, e, make_loc $startpos $endpos) } + | MINUS e = expr %prec BANG { Expr_unop (Op_neg, e, make_loc $startpos $endpos) } + | e = expr LBRACKET idx = expr RBRACKET + { Expr_index (e, idx, make_loc $startpos $endpos) } + | CMOV LPAREN cond = expr COMMA e1 = expr COMMA e2 = expr RPAREN + { Expr_cmov (cond, e1, e2, make_loc $startpos $endpos) } + | name = IDENT LPAREN args = separated_list(COMMA, expr) RPAREN + { Expr_call (name, args, make_loc $startpos $endpos) } + +primary_expr: + | INT { Expr_int ($1, make_loc $startpos $endpos) } + | FLOAT { Expr_float ($1, make_loc $startpos $endpos) } + | TRUE { Expr_bool (true, make_loc $startpos $endpos) } + | FALSE { Expr_bool (false, make_loc $startpos $endpos) } + | STRING { Expr_string ($1, make_loc $startpos $endpos) } + | IDENT { Expr_var ($1, make_loc $startpos $endpos) } + | LPAREN e = expr RPAREN { e } + | LPAREN es = separated_list(COMMA, expr) RPAREN { Expr_tuple (es, make_loc $startpos $endpos) } +---- + +== AST Definition + +[source,ocaml] +---- +(* ast.ml *) + +type location = Location.t + +type security = + | Security_low + | Security_high + | Security_named of string + | Security_infer (* To be inferred *) + [@@deriving show, eq] + +type typ = + | Ty_named of string + | Ty_array of typ + | Ty_oarray of typ (* Oblivious array *) + | Ty_oref of typ (* Oblivious reference *) + | Ty_tuple of typ list + | Ty_fn of typ list * typ + [@@deriving show, eq] + +type typ_annot = { + ty: typ; + security: security; +} [@@deriving show, eq] + +type binop = + | Op_add | Op_sub | Op_mul | Op_div | Op_mod + | Op_band | Op_bor | Op_bxor | Op_shl | Op_shr + | Op_eq | Op_ne | Op_lt | Op_le | Op_gt | Op_ge + | Op_and | Op_or + [@@deriving show, eq] + +type unop = + | Op_neg | Op_bnot | Op_not + [@@deriving show, eq] + +type expr = + | Expr_int of int64 * location + | Expr_float of float * location + | Expr_bool of bool * location + | Expr_string of string * location + | Expr_var of string * location + | Expr_tuple of expr list * location + | Expr_binop of binop * expr * expr * location + | Expr_unop of unop * expr * location + | Expr_index of expr * expr * location + | Expr_field of expr * string * location + | Expr_call of string * expr list * location + | Expr_cmov of expr * expr * expr * location + | Expr_if of expr * expr * expr * location + [@@deriving show] + +type lvalue = + | Lv_var of string + | Lv_index of lvalue * expr + | Lv_field of lvalue * string + [@@deriving show] + +type stmt = + | Stmt_let of { name: string; ty: typ_annot option; value: expr; + is_mut: bool; loc: location } + | Stmt_assign of { lhs: lvalue; rhs: expr; loc: location } + | Stmt_if of { cond: expr; then_branch: stmt list; + else_branch: stmt list option; loc: location } + | Stmt_oif of { cond: expr; then_branch: stmt list; + else_branch: stmt list option; loc: location } + | Stmt_while of { cond: expr; body: stmt list; loc: location } + | Stmt_return of { value: expr option; loc: location } + | Stmt_break of location + | Stmt_continue of location + | Stmt_expr of { expr: expr; loc: location } + [@@deriving show] + +type param = { + param_name: string; + param_type: typ_annot; + param_loc: location; +} [@@deriving show] + +type fn_def = { + name: string; + params: param list; + return_type: typ_annot option; + body: stmt list; + is_public: bool; + loc: location; +} [@@deriving show] + +type item = + | Item_fn of fn_def + | Item_type of { name: string; def: typ; loc: location } + | Item_struct of { name: string; fields: (string * typ_annot) list; loc: location } + [@@deriving show] + +type program = item list [@@deriving show] +---- + +== Type Checker + +[source,ocaml] +---- +(* typing/infer.ml - Type inference with security levels *) + +open Types +open Tast + +type env = { + vars: (string * typed_type) list; + fns: (string * fn_sig) list; + types: (string * typ) list; +} + +let empty_env = { vars = []; fns = []; types = [] } + +let lookup_var env name = + List.assoc_opt name env.vars + +let extend_var env name ty = + { env with vars = (name, ty) :: env.vars } + +(* Security level operations *) +let join_security s1 s2 = + match s1, s2 with + | Low, Low -> Low + | _, _ -> High + +let check_flow ~from ~to_ = + match from, to_ with + | High, Low -> Error "Cannot flow high-security value to low-security location" + | _ -> Ok () + +(* Type inference *) +let rec infer_expr env expr = + match expr with + | Ast.Expr_int (n, loc) -> + Ok { texpr = TExpr_int n; ty = Prim I64; sec = Low; loc } + + | Ast.Expr_bool (b, loc) -> + Ok { texpr = TExpr_bool b; ty = Prim Bool; sec = Low; loc } + + | Ast.Expr_var (name, loc) -> + (match lookup_var env name with + | Some tty -> Ok { texpr = TExpr_var name; ty = tty.ty; sec = tty.sec; loc } + | None -> Error (Printf.sprintf "Unbound variable: %s" name)) + + | Ast.Expr_index (arr, idx, loc) -> + let* tarr = infer_expr env arr in + let* tidx = infer_expr env idx in + (match tarr.ty with + | Array elem_ty -> + (* Regular array: access pattern leaks if index is high *) + if tidx.sec = High then + Error "Array index with high-security index leaks access pattern. Use oarray." + else + Ok { texpr = TExpr_index (tarr, tidx); ty = elem_ty; + sec = tarr.sec; loc } + | OArray elem_ty -> + (* Oblivious array: safe with any index *) + Ok { texpr = TExpr_oindex (tarr, tidx); ty = elem_ty; + sec = join_security tarr.sec tidx.sec; loc } + | _ -> Error "Cannot index non-array type") + + | Ast.Expr_cmov (cond, e1, e2, loc) -> + let* tcond = infer_expr env cond in + let* te1 = infer_expr env e1 in + let* te2 = infer_expr env e2 in + if te1.ty <> te2.ty then + Error "cmov branches must have same type" + else + Ok { texpr = TExpr_cmov (tcond, te1, te2); + ty = te1.ty; + sec = join_security tcond.sec (join_security te1.sec te2.sec); + loc } + + | Ast.Expr_binop (op, e1, e2, loc) -> + let* te1 = infer_expr env e1 in + let* te2 = infer_expr env e2 in + let ty = infer_binop_type op te1.ty te2.ty in + Ok { texpr = TExpr_binop (op, te1, te2); + ty; + sec = join_security te1.sec te2.sec; + loc } + + | _ -> Error "Not yet implemented" +---- + +== Obliviousness Checker + +[source,ocaml] +---- +(* analysis/oblivious.ml *) + +open Tast + +type access_pattern = + | Constant (* Same address every time *) + | Public (* Depends only on public data *) + | Secret (* Depends on secret data - LEAK! *) + | Oblivious (* Using ORAM, safe *) + +type violation = { + loc: Location.t; + kind: violation_kind; + suggestion: string; +} + +and violation_kind = + | Secret_array_index of string (* arr[secret] without ORAM *) + | Secret_branch of string (* if (secret) without oblivious *) + | Secret_loop_bound of string (* while (secret) *) + +let check_obliviousness (prog : typed_program) : violation list = + let violations = ref [] in + + let report v = violations := v :: !violations in + + let rec check_expr expr = + match expr.texpr with + | TExpr_index (arr, idx) when idx.sec = High -> + report { + loc = expr.loc; + kind = Secret_array_index "index"; + suggestion = "Use oarray instead of array, or declassify the index" + } + + | TExpr_oindex (_, _) -> + () (* Oblivious access, safe *) + + | TExpr_binop (_, e1, e2) -> + check_expr e1; + check_expr e2 + + | _ -> () + in + + let rec check_stmt stmt = + match stmt with + | TStmt_if { cond; then_branch; else_branch; _ } when cond.sec = High -> + report { + loc = cond.loc; + kind = Secret_branch "condition"; + suggestion = "Use 'oblivious if' or cmov for secret-dependent branching" + }; + List.iter check_stmt then_branch; + Option.iter (List.iter check_stmt) else_branch + + | TStmt_oif { then_branch; else_branch; _ } -> + (* Oblivious if is safe, but check children *) + List.iter check_stmt then_branch; + Option.iter (List.iter check_stmt) else_branch + + | TStmt_while { cond; body; _ } when cond.sec = High -> + report { + loc = cond.loc; + kind = Secret_loop_bound "condition"; + suggestion = "Loop bounds must not depend on secrets (would leak iteration count)" + } + + | TStmt_let { value; _ } -> + check_expr value + + | TStmt_assign { rhs; _ } -> + check_expr rhs + + | TStmt_expr { expr; _ } -> + check_expr expr + + | _ -> () + in + + List.iter (fun fn -> + List.iter check_stmt fn.tbody + ) prog.functions; + + List.rev !violations +---- + +== IR Lowering + +[source,ocaml] +---- +(* ir/lower.ml - Lower TAST to OIR *) + +open Tast +open Oir + +let fresh_var = + let counter = ref 0 in + fun () -> + incr counter; + Printf.sprintf "tmp%d" !counter + +let rec lower_expr (expr : typed_expr) : Oir.expr * Oir.instr list = + match expr.texpr with + | TExpr_int n -> + (EInt (n, to_oir_prim expr.ty), []) + + | TExpr_bool b -> + (EBool b, []) + + | TExpr_var name -> + (EVar name, []) + + | TExpr_oindex (arr, idx) -> + let arr_e, arr_instrs = lower_expr arr in + let idx_e, idx_instrs = lower_expr idx in + let result = fresh_var () in + let instr = IOramRead { array = arr_e; index = idx_e; result } in + (EVar result, arr_instrs @ idx_instrs @ [instr]) + + | TExpr_cmov (cond, e1, e2) -> + let cond_e, cond_instrs = lower_expr cond in + let e1_e, e1_instrs = lower_expr e1 in + let e2_e, e2_instrs = lower_expr e2 in + let result = fresh_var () in + let instr = ICmov { cond = cond_e; true_val = e1_e; false_val = e2_e; result } in + (EVar result, cond_instrs @ e1_instrs @ e2_instrs @ [instr]) + + | TExpr_binop (op, e1, e2) -> + let e1_e, e1_instrs = lower_expr e1 in + let e2_e, e2_instrs = lower_expr e2 in + let oir_op = lower_binop op in + (EBinop (oir_op, e1_e, e2_e), e1_instrs @ e2_instrs) + + | _ -> failwith "Not yet implemented" + +let rec lower_stmt (stmt : typed_stmt) : Oir.instr list = + match stmt with + | TStmt_let { name; value; _ } -> + let value_e, instrs = lower_expr value in + instrs @ [ILet (name, value_e)] + + | TStmt_assign { lhs; rhs; _ } -> + let rhs_e, instrs = lower_expr rhs in + let lhs_lv = lower_lvalue lhs in + (match lhs with + | TLv_oindex (arr, idx) -> + let arr_e, arr_instrs = lower_expr arr in + let idx_e, idx_instrs = lower_expr idx in + arr_instrs @ idx_instrs @ instrs @ + [IOramWrite { array = arr_e; index = idx_e; value = rhs_e }] + | _ -> + instrs @ [IAssign (lhs_lv, rhs_e)]) + + | TStmt_oif { cond; then_branch; else_branch; _ } -> + let cond_e, cond_instrs = lower_expr cond in + let then_instrs = List.concat_map lower_stmt then_branch in + let else_instrs = match else_branch with + | Some stmts -> List.concat_map lower_stmt stmts + | None -> [] + in + cond_instrs @ [IOIf { cond = cond_e; then_ = then_instrs; else_ = else_instrs }] + + | TStmt_return { value; _ } -> + (match value with + | Some e -> + let e_e, instrs = lower_expr e in + instrs @ [IReturn (Some e_e)] + | None -> + [IReturn None]) + + | _ -> failwith "Not yet implemented" + +let lower_function (fn : typed_fn) : Oir.func = + let body = List.concat_map lower_stmt fn.tbody in + { + name = fn.tname; + params = List.map (fun p -> (p.tparam_name, lower_typed_type p.tparam_type)) fn.tparams; + return_type = lower_typed_type fn.treturn_type; + locals = []; (* Collected during lowering *) + body; + attributes = if fn.is_oblivious then [Oblivious] else []; + } + +let lower_program (prog : typed_program) : Oir.module_ = + { + version = "1.0.0"; + name = prog.name; + imports = []; + types = []; + globals = []; + functions = List.map lower_function prog.functions; + entry = prog.entry; + metadata = { + source_file = Some prog.source_file; + source_map = []; + compiler_version = "0.1.0"; + timestamp = ""; + options = []; + }; + } +---- + +== CLI Entry Point + +[source,ocaml] +---- +(* bin/main.ml *) + +open Cmdliner + +let compile input_file output_file debug = + try + (* Read source *) + let source = In_channel.with_open_bin input_file In_channel.input_all in + + (* Parse *) + let lexbuf = Lexing.from_string source in + lexbuf.lex_curr_p <- { lexbuf.lex_curr_p with pos_fname = input_file }; + let ast = Parser.program Lexer.token lexbuf in + + if debug then + Printf.eprintf "AST:\n%s\n" (Ast.show_program ast); + + (* Type check *) + let tast = Oblc.Typing.check_program ast in + + (* Obliviousness check *) + let violations = Oblc.Analysis.Oblivious.check_obliviousness tast in + List.iter (fun v -> + Printf.eprintf "Warning: %s at %s\n Suggestion: %s\n" + (Oblc.Analysis.Oblivious.show_violation_kind v.kind) + (Location.show v.loc) + v.suggestion + ) violations; + + (* Lower to IR *) + let ir = Oblc.Ir.Lower.lower_program tast in + + (* Emit *) + let json = Oblc.Ir.Emit.to_json ir in + Out_channel.with_open_bin output_file (fun oc -> + Out_channel.output_string oc (Yojson.Safe.pretty_to_string json) + ); + + `Ok () + with + | Lexer.Lexer_error (msg, pos) -> + Printf.eprintf "Lexer error at %s:%d:%d: %s\n" + pos.pos_fname pos.pos_lnum (pos.pos_cnum - pos.pos_bol) msg; + `Error (false, "Lexer error") + | Parser.Error -> + Printf.eprintf "Parse error\n"; + `Error (false, "Parse error") + | Failure msg -> + Printf.eprintf "Error: %s\n" msg; + `Error (false, msg) + +let input_file = + let doc = "Input source file (.obl)" in + Arg.(required & pos 0 (some file) None & info [] ~docv:"INPUT" ~doc) + +let output_file = + let doc = "Output IR file (.oir.json)" in + Arg.(value & opt string "out.oir.json" & info ["o"; "output"] ~docv:"OUTPUT" ~doc) + +let debug = + let doc = "Enable debug output" in + Arg.(value & flag & info ["d"; "debug"] ~doc) + +let cmd = + let doc = "Compile Oblibeny source to OIR" in + let info = Cmd.info "obli-frontend" ~version:"0.1.0" ~doc in + Cmd.v info Term.(ret (const compile $ input_file $ output_file $ debug)) + +let () = exit (Cmd.eval cmd) +---- + +== Testing + +[source,ocaml] +---- +(* test/typing_test.ml *) + +open Alcotest +open Oblc + +let test_simple_function () = + let src = {| + fn add(x: i64, y: i64) -> i64 { + x + y + } + |} in + let ast = parse_string src in + let tast = Typing.check_program ast in + check int "one function" 1 (List.length tast.functions) + +let test_security_inference () = + let src = {| + fn lookup(arr: oarray, idx: i64 @high) -> i64 @high { + arr[idx] + } + |} in + let ast = parse_string src in + let tast = Typing.check_program ast in + let fn = List.hd tast.functions in + check bool "return is high" true (fn.treturn_type.sec = High) + +let test_oblivious_violation () = + let src = {| + fn bad(arr: [i64], idx: i64 @high) -> i64 { + arr[idx] + } + |} in + let ast = parse_string src in + let tast = Typing.check_program ast in + let violations = Analysis.Oblivious.check_obliviousness tast in + check bool "has violation" true (List.length violations > 0) + +let () = + run "Frontend" [ + "typing", [ + test_case "simple function" `Quick test_simple_function; + test_case "security inference" `Quick test_security_inference; + ]; + "oblivious", [ + test_case "violation detection" `Quick test_oblivious_violation; + ]; + ] +---- + +== Error Messages + +=== Example Error Output + +[source] +---- +$ obli-frontend bad.obl -o out.oir.json + +Error at bad.obl:5:10 + | +5 | arr[secret_idx] + | ^^^^^^^^^^ + | +Array access with secret index leaks access pattern. + +Suggestion: Use oarray instead of array: + + Before: arr: [i64] + After: arr: oarray + +Or declassify the index if the leak is acceptable: + + arr[declassify(secret_idx)] +---- + +== Dependencies + +[source] +---- +opam install \ + dune \ + menhir \ + ppx_deriving \ + yojson \ + msgpck \ + cmdliner \ + alcotest +---- diff --git a/docs/specs/ir/oir-specification.adoc b/docs/specs/ir/oir-specification.adoc new file mode 100644 index 0000000..750c767 --- /dev/null +++ b/docs/specs/ir/oir-specification.adoc @@ -0,0 +1,680 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + += OIR: Oblivious Intermediate Representation Specification +:author: Oblibeny Project +:revdate: 2024 +:toc: left +:toclevels: 4 +:sectnums: +:stem: latexmath + +== Overview + +OIR (Oblivious Intermediate Representation) is the boundary format between +the OCaml frontend and Rust backend. It is a typed, security-annotated IR +designed for oblivious program transformation. + +== Design Principles + +1. **Self-contained**: All information needed for code generation +2. **Typed**: Every expression carries its type and security level +3. **Explicit**: ORAM operations are explicit, not implicit +4. **Serializable**: MessagePack (binary) or JSON (debug) +5. **Versionable**: Schema versioning for compatibility + +== File Format + +=== Binary Format (Production) + +MessagePack serialization with the following structure: + +[source] +---- +OIR File: + magic: [0x4F, 0x49, 0x52, 0x00] # "OIR\0" + version: u32 # Schema version + length: u64 # Payload length + payload: msgpack(Module) # MessagePack-encoded module + checksum: [u8; 32] # BLAKE3 hash of payload +---- + +=== Text Format (Debug) + +JSON with `.oir.json` extension for debugging. + +== Schema + +=== Module (Top-Level) + +[source] +---- +Module = { + version: string, # "1.0.0" + name: string, # Module name + imports: [Import], # External dependencies + types: [TypeDef], # Type definitions + globals: [Global], # Global variables + functions: [Function], # Function definitions + entry: string?, # Entry point function name + metadata: Metadata # Debug info, source maps +} +---- + +=== Types + +[source] +---- +Type = + | { "prim": PrimType } + | { "array": Type, "size": int? } + | { "oarray": Type, "size": int? } # Oblivious array + | { "ref": Type } + | { "oref": Type } # Oblivious reference + | { "tuple": [Type] } + | { "func": { "params": [Type], "ret": Type } } + | { "named": string } # Reference to TypeDef + +PrimType = "unit" | "bool" | "i8" | "i16" | "i32" | "i64" + | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" + +SecurityLevel = "low" | "high" | { "join": [SecurityLevel] } + +TypedType = { + type: Type, + security: SecurityLevel +} +---- + +=== Functions + +[source] +---- +Function = { + name: string, + params: [Param], + return_type: TypedType, + locals: [Local], + body: [Instruction], + attributes: [Attribute] +} + +Param = { + name: string, + type: TypedType +} + +Local = { + name: string, + type: TypedType +} + +Attribute = "inline" | "noinline" | "oblivious" | "constant_time" +---- + +=== Instructions + +[source] +---- +Instruction = + (* Variables *) + | { "let": { "name": string, "value": Expr } } + | { "assign": { "target": LValue, "value": Expr } } + + (* Control flow *) + | { "if": { "cond": Expr, "then": [Instruction], "else": [Instruction] } } + | { "oif": { "cond": Expr, "then": [Instruction], "else": [Instruction] } } # Oblivious if + | { "loop": { "body": [Instruction] } } + | { "break": null } + | { "continue": null } + | { "return": Expr? } + + (* Oblivious operations *) + | { "oram_read": { "array": Expr, "index": Expr, "result": string } } + | { "oram_write": { "array": Expr, "index": Expr, "value": Expr } } + | { "cmov": { "cond": Expr, "true_val": Expr, "false_val": Expr, "result": string } } + | { "oswap": { "cond": Expr, "a": LValue, "b": LValue } } # Oblivious swap + + (* Memory *) + | { "alloc": { "name": string, "type": Type, "size": Expr? } } + | { "oalloc": { "name": string, "type": Type, "size": Expr? } } # Oblivious alloc + | { "free": { "target": string } } + + (* Function calls *) + | { "call": { "func": string, "args": [Expr], "result": string? } } + + (* Debugging *) + | { "assert": { "cond": Expr, "msg": string } } + | { "debug": { "msg": string, "values": [Expr] } } + +LValue = + | { "var": string } + | { "index": { "array": Expr, "index": Expr } } + | { "field": { "struct": Expr, "field": string } } +---- + +=== Expressions + +[source] +---- +Expr = + (* Literals *) + | { "unit": null } + | { "bool": bool } + | { "int": { "value": int, "type": PrimType } } + | { "float": { "value": float, "type": PrimType } } + + (* Variables *) + | { "var": string } + | { "global": string } + + (* Arithmetic *) + | { "add": [Expr, Expr] } + | { "sub": [Expr, Expr] } + | { "mul": [Expr, Expr] } + | { "div": [Expr, Expr] } + | { "mod": [Expr, Expr] } + | { "neg": Expr } + + (* Bitwise *) + | { "band": [Expr, Expr] } + | { "bor": [Expr, Expr] } + | { "bxor": [Expr, Expr] } + | { "bnot": Expr } + | { "shl": [Expr, Expr] } + | { "shr": [Expr, Expr] } + + (* Comparison *) + | { "eq": [Expr, Expr] } + | { "ne": [Expr, Expr] } + | { "lt": [Expr, Expr] } + | { "le": [Expr, Expr] } + | { "gt": [Expr, Expr] } + | { "ge": [Expr, Expr] } + + (* Logical *) + | { "and": [Expr, Expr] } + | { "or": [Expr, Expr] } + | { "not": Expr } + + (* Memory access (non-oblivious) *) + | { "load": { "ptr": Expr } } + | { "index": { "array": Expr, "index": Expr } } + + (* Oblivious access (results from oram_read stored in var) *) + | { "oload": { "oref": Expr } } + + (* Type operations *) + | { "cast": { "value": Expr, "to": Type } } + | { "sizeof": Type } + + (* Tuples/Structs *) + | { "tuple": [Expr] } + | { "field": { "tuple": Expr, "index": int } } + | { "struct": { "type": string, "fields": { string: Expr } } } + + (* Function *) + | { "call": { "func": string, "args": [Expr] } } + + (* Security *) + | { "classify": { "value": Expr, "to": SecurityLevel } } # Raise security + | { "declassify": { "value": Expr } } # Lower security (unsafe!) +---- + +=== Metadata + +[source] +---- +Metadata = { + source_file: string?, + source_map: [SourceMapping]?, + compiler_version: string, + timestamp: string, + options: { string: string } +} + +SourceMapping = { + ir_range: [int, int], # Instruction range in IR + source_range: { # Position in source + file: string, + start_line: int, + start_col: int, + end_line: int, + end_col: int + } +} +---- + +== Example + +=== Source Code + +[source] +---- +fn secret_lookup(db: oarray, idx: i64 @high) -> i64 @high { + db[idx] +} + +fn conditional_access(arr: oarray, secret: bool @high) -> i64 @high { + if secret { + arr[0] + } else { + arr[1] + } +} +---- + +=== Generated OIR + +[source,json] +---- +{ + "version": "1.0.0", + "name": "example", + "imports": [], + "types": [], + "globals": [], + "functions": [ + { + "name": "secret_lookup", + "params": [ + {"name": "db", "type": {"type": {"oarray": {"prim": "i64"}}, "security": "low"}}, + {"name": "idx", "type": {"type": {"prim": "i64"}, "security": "high"}} + ], + "return_type": {"type": {"prim": "i64"}, "security": "high"}, + "locals": [ + {"name": "tmp0", "type": {"type": {"prim": "i64"}, "security": "high"}} + ], + "body": [ + { + "oram_read": { + "array": {"var": "db"}, + "index": {"var": "idx"}, + "result": "tmp0" + } + }, + {"return": {"var": "tmp0"}} + ], + "attributes": ["oblivious"] + }, + { + "name": "conditional_access", + "params": [ + {"name": "arr", "type": {"type": {"oarray": {"prim": "i64"}}, "security": "low"}}, + {"name": "secret", "type": {"type": {"prim": "bool"}, "security": "high"}} + ], + "return_type": {"type": {"prim": "i64"}, "security": "high"}, + "locals": [ + {"name": "tmp0", "type": {"type": {"prim": "i64"}, "security": "high"}}, + {"name": "tmp1", "type": {"type": {"prim": "i64"}, "security": "high"}}, + {"name": "result", "type": {"type": {"prim": "i64"}, "security": "high"}} + ], + "body": [ + { + "oram_read": { + "array": {"var": "arr"}, + "index": {"int": {"value": 0, "type": "i64"}}, + "result": "tmp0" + } + }, + { + "oram_read": { + "array": {"var": "arr"}, + "index": {"int": {"value": 1, "type": "i64"}}, + "result": "tmp1" + } + }, + { + "cmov": { + "cond": {"var": "secret"}, + "true_val": {"var": "tmp0"}, + "false_val": {"var": "tmp1"}, + "result": "result" + } + }, + {"return": {"var": "result"}} + ], + "attributes": ["oblivious", "constant_time"] + } + ], + "entry": null, + "metadata": { + "source_file": "example.obl", + "compiler_version": "0.1.0", + "timestamp": "2024-01-01T00:00:00Z", + "options": {} + } +} +---- + +== OCaml Type Definitions + +[source,ocaml] +---- +(* ir.ml - OIR types in OCaml *) + +type prim_type = + | Unit | Bool + | I8 | I16 | I32 | I64 + | U8 | U16 | U32 | U64 + | F32 | F64 + +type security_level = + | Low + | High + | Join of security_level list + +type typ = + | Prim of prim_type + | Array of typ * int option + | OArray of typ * int option (* Oblivious array *) + | Ref of typ + | ORef of typ (* Oblivious reference *) + | Tuple of typ list + | Func of typ list * typ + | Named of string + +type typed_type = { + ty: typ; + sec: security_level; +} + +type expr = + | EUnit + | EBool of bool + | EInt of int64 * prim_type + | EFloat of float * prim_type + | EVar of string + | EGlobal of string + | EBinop of binop * expr * expr + | EUnop of unop * expr + | ELoad of expr + | EIndex of expr * expr + | EOLoad of expr + | ECast of expr * typ + | ESizeof of typ + | ETuple of expr list + | EField of expr * int + | EStruct of string * (string * expr) list + | ECall of string * expr list + | EClassify of expr * security_level + | EDeclassify of expr + +and binop = + | Add | Sub | Mul | Div | Mod + | Band | Bor | Bxor | Shl | Shr + | Eq | Ne | Lt | Le | Gt | Ge + | And | Or + +and unop = Neg | Bnot | Not + +type lvalue = + | LVar of string + | LIndex of expr * expr + | LField of expr * string + +type instr = + | ILet of string * expr + | IAssign of lvalue * expr + | IIf of expr * instr list * instr list + | IOIf of expr * instr list * instr list (* Oblivious if *) + | ILoop of instr list + | IBreak + | IContinue + | IReturn of expr option + | IOramRead of { array: expr; index: expr; result: string } + | IOramWrite of { array: expr; index: expr; value: expr } + | ICmov of { cond: expr; true_val: expr; false_val: expr; result: string } + | IOSwap of { cond: expr; a: lvalue; b: lvalue } + | IAlloc of string * typ * expr option + | IOAlloc of string * typ * expr option + | IFree of string + | ICall of string * expr list * string option + | IAssert of expr * string + | IDebug of string * expr list + +type attribute = Inline | NoInline | Oblivious | ConstantTime + +type func = { + name: string; + params: (string * typed_type) list; + return_type: typed_type; + locals: (string * typed_type) list; + body: instr list; + attributes: attribute list; +} + +type import = { + module_name: string; + items: string list; +} + +type type_def = { + name: string; + definition: typ; +} + +type global = { + name: string; + typ: typed_type; + init: expr option; +} + +type source_pos = { + file: string; + start_line: int; + start_col: int; + end_line: int; + end_col: int; +} + +type source_mapping = { + ir_range: int * int; + source_range: source_pos; +} + +type metadata = { + source_file: string option; + source_map: source_mapping list; + compiler_version: string; + timestamp: string; + options: (string * string) list; +} + +type module_ = { + version: string; + name: string; + imports: import list; + types: type_def list; + globals: global list; + functions: func list; + entry: string option; + metadata: metadata; +} +---- + +== Rust Type Definitions + +[source,rust] +---- +// ir/types.rs - OIR types in Rust + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PrimType { + Unit, Bool, + I8, I16, I32, I64, + U8, U16, U32, U64, + F32, F64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SecurityLevel { + Low, + High, + Join(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Type { + Prim(PrimType), + Array { elem: Box, size: Option }, + OArray { elem: Box, size: Option }, + Ref(Box), + ORef(Box), + Tuple(Vec), + Func { params: Vec, ret: Box }, + Named(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypedType { + #[serde(rename = "type")] + pub ty: Type, + pub security: SecurityLevel, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Expr { + Unit, + Bool(bool), + Int { value: i64, #[serde(rename = "type")] ty: PrimType }, + Float { value: f64, #[serde(rename = "type")] ty: PrimType }, + Var(String), + Global(String), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Div(Box, Box), + // ... other operations + Call { func: String, args: Vec }, + Classify { value: Box, to: SecurityLevel }, + Declassify { value: Box }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LValue { + Var(String), + Index { array: Expr, index: Expr }, + Field { #[serde(rename = "struct")] strct: Expr, field: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Instr { + Let { name: String, value: Expr }, + Assign { target: LValue, value: Expr }, + If { cond: Expr, then_: Vec, else_: Vec }, + OIf { cond: Expr, then_: Vec, else_: Vec }, + Loop { body: Vec }, + Break, + Continue, + Return(Option), + OramRead { array: Expr, index: Expr, result: String }, + OramWrite { array: Expr, index: Expr, value: Expr }, + Cmov { cond: Expr, true_val: Expr, false_val: Expr, result: String }, + OSwap { cond: Expr, a: LValue, b: LValue }, + Alloc { name: String, #[serde(rename = "type")] ty: Type, size: Option }, + OAlloc { name: String, #[serde(rename = "type")] ty: Type, size: Option }, + Free { target: String }, + Call { func: String, args: Vec, result: Option }, + Assert { cond: Expr, msg: String }, + Debug { msg: String, values: Vec }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Attribute { + Inline, + Noinline, + Oblivious, + ConstantTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Param { + pub name: String, + #[serde(rename = "type")] + pub ty: TypedType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Local { + pub name: String, + #[serde(rename = "type")] + pub ty: TypedType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Function { + pub name: String, + pub params: Vec, + pub return_type: TypedType, + pub locals: Vec, + pub body: Vec, + pub attributes: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Module { + pub version: String, + pub name: String, + pub imports: Vec, + pub types: Vec, + pub globals: Vec, + pub functions: Vec, + pub entry: Option, + pub metadata: Metadata, +} +---- + +== Validation Rules + +=== Well-Formedness + +1. All referenced types must be defined +2. All referenced variables must be in scope +3. All function calls must reference defined functions +4. Oblivious operations only on oblivious types + +=== Security Typing + +1. High-security values cannot flow to low-security locations +2. Branching on high-security values requires oblivious if (`oif`) +3. `declassify` requires explicit annotation (unsafe) + +=== ORAM Constraints + +1. `oram_read` target must be `oarray` or `oref` +2. `oram_write` target must be `oarray` or `oref` +3. `cmov` must have matching types for both branches + +== Versioning + +=== Version Format + +`MAJOR.MINOR.PATCH` + +* MAJOR: Breaking changes to IR structure +* MINOR: Backward-compatible additions +* PATCH: Bug fixes, documentation + +=== Compatibility + +* Rust backend supports current major version ± 1 +* Old IR files should produce warnings, not errors + +== Extensions + +Reserved for future use: + +* `{ "parallel": [...] }` - Parallel execution block +* `{ "atomic": [...] }` - Atomic transaction +* `{ "simd": {...} }` - SIMD operations +* `{ "gpu": {...} }` - GPU offload hints