diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index df6c17c..0000000 --- a/.gitmodules +++ /dev/null @@ -1,9 +0,0 @@ -[submodule "obli-transpiler-framework"] - path = obli-transpiler-framework - url = git@github.com:hyperpolymath/obli-transpiler-framework.git -[submodule "obli-riscv-dev-kit"] - path = obli-riscv-dev-kit - url = git@github.com:hyperpolymath/obli-riscv-dev-kit.git -[submodule "obli-fs"] - path = obli-fs - url = git@github.com:hyperpolymath/obli-fs.git diff --git a/obli-transpiler-framework b/obli-transpiler-framework deleted file mode 160000 index 5e199fb..0000000 --- a/obli-transpiler-framework +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5e199fb210b40fb4817fa8a4eb3996b407e3b8c2 diff --git a/obli-transpiler-framework/.gitignore b/obli-transpiler-framework/.gitignore new file mode 100644 index 0000000..bcc0770 --- /dev/null +++ b/obli-transpiler-framework/.gitignore @@ -0,0 +1,34 @@ +# Rust +/target/ +Cargo.lock + +# OCaml +_build/ +*.install +*.merlin +.merlin + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Build artifacts +*.o +*.a +*.so +*.dylib + +# Generated files +*.oir.json +*.generated.rs + +# Test artifacts +*.log +coverage/ + +# OS +.DS_Store +Thumbs.db diff --git a/obli-transpiler-framework/Cargo.toml b/obli-transpiler-framework/Cargo.toml new file mode 100644 index 0000000..7867d8c --- /dev/null +++ b/obli-transpiler-framework/Cargo.toml @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: MIT OR Palimpsest-0.8 +# Copyright (c) 2024 Hyperpolymath + +[workspace] +resolver = "2" +members = [ + "backend", + "runtime", + "driver", +] + +[workspace.package] +version = "0.1.0" +edition = "2021" +authors = ["Hyperpolymath"] +license = "MIT OR Palimpsest-0.8" +repository = "https://github.com/hyperpolymath/oblibeny" + +[workspace.dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0" +clap = { version = "4.0", features = ["derive"] } +log = "0.4" +env_logger = "0.10" +subtle = "2.5" +zeroize = { version = "1.7", features = ["derive"] } +rand = "0.8" diff --git a/obli-transpiler-framework/README.md b/obli-transpiler-framework/README.md new file mode 100644 index 0000000..5dd2242 --- /dev/null +++ b/obli-transpiler-framework/README.md @@ -0,0 +1,115 @@ +# Oblibeny Transpiler Framework + +The compiler and runtime for the Oblibeny oblivious computing language. + +## Architecture + +``` +┌──────────────────────────────────────────────────────────────────────────┐ +│ Oblibeny Compiler │ +├──────────────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │ +│ │ Source │ │ OIR │ │ Generated Rust │ │ +│ │ (.obl) │─────▶│ (JSON) │─────▶│ + Runtime │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │ +│ │ OCaml │ │ Rust │ │ oblibeny-runtime │ │ +│ │ Frontend │ │ Backend │ │ (ORAM + Crypto) │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │ +│ │ +└──────────────────────────────────────────────────────────────────────────┘ +``` + +## Components + +### Frontend (OCaml) + +The frontend parses `.obl` source files and performs: +- Lexing and parsing +- Type checking with security labels (@low/@high) +- Obliviousness verification (no secret-dependent branches/indices) +- OIR (Oblivious Intermediate Representation) emission + +### Backend (Rust) + +The backend consumes OIR and generates: +- Rust code using the oblibeny-runtime +- Calls to constant-time primitives (cmov, cswap) +- ORAM operations (oread, owrite) + +### Runtime (Rust) + +The runtime library provides: +- **Constant-time primitives**: cmov, cswap, ct_lookup +- **Path ORAM**: O(log N) oblivious memory access +- **Oblivious collections**: OArray, OStack, OQueue, OMap +- **Cryptographic utilities**: AES-GCM, SHA-256, BLAKE3 + +### Driver + +The unified `oblibeny` CLI that orchestrates the pipeline. + +## Building + +Requires: +- OCaml 4.14+ with opam +- Rust 1.70+ +- just (command runner) + +```bash +# Install OCaml dependencies +opam install dune menhir sedlex yojson ppx_deriving ppx_deriving_yojson + +# Build everything +just build + +# Run tests +just test + +# Install to ~/.local/bin +just install +``` + +## Usage + +```bash +# Compile to Rust +oblibeny compile program.obl + +# Type-check only +oblibeny check program.obl + +# Compile and build executable +oblibeny build program.obl +``` + +## Example + +``` +// hello.obl - Oblivious array access + +@oblivious +fn secret_lookup(arr: oarray, @high idx: int) -> @high int { + return oread(arr, idx); +} + +fn main() { + let data: oarray = oarray_new(100); + + // Initialize with public indices + for i in 0..100 { + owrite(data, i, i * 10); + } + + // Look up with secret index - access pattern hidden! + let secret_idx: @high int = get_secret(); + let value: @high int = secret_lookup(data, secret_idx); +} +``` + +## License + +MIT OR Palimpsest-0.8 diff --git a/obli-transpiler-framework/backend/Cargo.toml b/obli-transpiler-framework/backend/Cargo.toml new file mode 100644 index 0000000..ec24bb2 --- /dev/null +++ b/obli-transpiler-framework/backend/Cargo.toml @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: MIT OR Palimpsest-0.8 +# Copyright (c) 2024 Hyperpolymath + +[package] +name = "oblibeny-backend" +version = "0.1.0" +edition = "2021" +authors = ["Hyperpolymath"] +description = "Oblibeny language backend - OIR to Rust code generator" +license = "MIT OR Palimpsest-0.8" +repository = "https://github.com/hyperpolymath/oblibeny" + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "1.0" +clap = { version = "4.0", features = ["derive"] } +log = "0.4" +env_logger = "0.10" + +[dev-dependencies] +pretty_assertions = "1.0" diff --git a/obli-transpiler-framework/backend/src/codegen.rs b/obli-transpiler-framework/backend/src/codegen.rs new file mode 100644 index 0000000..7ad3ab9 --- /dev/null +++ b/obli-transpiler-framework/backend/src/codegen.rs @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Code generation from OIR to Rust +//! +//! This module generates Rust code that uses the oblibeny-runtime crate +//! for ORAM operations and constant-time primitives. + +use crate::error::Error; +use crate::oir::*; +use std::fmt::Write; + +/// Code generator state +pub struct CodeGenerator { + indent: usize, + output: String, + inline_runtime: bool, +} + +impl CodeGenerator { + pub fn new() -> Self { + CodeGenerator { + indent: 0, + output: String::new(), + inline_runtime: false, + } + } + + pub fn set_inline_runtime(&mut self, inline: bool) { + self.inline_runtime = inline; + } + + /// Generate Rust code from an OIR module + pub fn generate(&mut self, module: &Module) -> Result { + self.output.clear(); + + // File header + self.emit_header(module)?; + + // Imports + self.emit_imports()?; + + // Struct definitions + for struct_def in &module.structs { + self.emit_struct(struct_def)?; + } + + // External function declarations + for ext in &module.externs { + self.emit_extern(ext)?; + } + + // Function definitions + for func in &module.functions { + self.emit_function(func)?; + } + + Ok(std::mem::take(&mut self.output)) + } + + fn emit_header(&mut self, module: &Module) -> Result<(), Error> { + writeln!(self.output, "// SPDX-License-Identifier: MIT OR Palimpsest-0.8")?; + writeln!(self.output, "// Copyright (c) 2024 Hyperpolymath")?; + writeln!(self.output)?; + writeln!(self.output, "//! Generated by oblibeny-backend")?; + if let Some(name) = &module.name { + writeln!(self.output, "//! Module: {}", name)?; + } + writeln!(self.output)?; + writeln!(self.output, "#![allow(unused_variables)]")?; + writeln!(self.output, "#![allow(dead_code)]")?; + writeln!(self.output)?; + Ok(()) + } + + fn emit_imports(&mut self) -> Result<(), Error> { + if self.inline_runtime { + // Inline the essential runtime code + writeln!(self.output, "// Inline runtime")?; + writeln!(self.output, "mod runtime {{")?; + writeln!(self.output, " pub use subtle::{{Choice, ConditionallySelectable}};")?; + writeln!(self.output)?; + writeln!(self.output, " /// Constant-time conditional move")?; + writeln!(self.output, " #[inline]")?; + writeln!(self.output, " pub fn cmov(cond: bool, a: T, b: T) -> T {{")?; + writeln!(self.output, " T::conditional_select(&b, &a, Choice::from(cond as u8))")?; + writeln!(self.output, " }}")?; + writeln!(self.output, "}}")?; + } else { + writeln!(self.output, "use oblibeny_runtime::prelude::*;")?; + } + writeln!(self.output)?; + Ok(()) + } + + fn emit_struct(&mut self, s: &StructDef) -> Result<(), Error> { + writeln!(self.output, "#[derive(Debug, Clone)]")?; + writeln!(self.output, "pub struct {} {{", s.name)?; + self.indent += 1; + for (name, at) in &s.fields { + self.emit_indent()?; + writeln!(self.output, "pub {}: {},", name, at.typ.to_rust())?; + } + self.indent -= 1; + writeln!(self.output, "}}")?; + writeln!(self.output)?; + Ok(()) + } + + fn emit_extern(&mut self, ext: &ExternFunc) -> Result<(), Error> { + writeln!(self.output, "extern \"C\" {{")?; + self.indent += 1; + self.emit_indent()?; + write!(self.output, "fn {}(", ext.name)?; + for (i, param) in ext.params.iter().enumerate() { + if i > 0 { + write!(self.output, ", ")?; + } + write!(self.output, "arg{}: {}", i, param.typ.to_rust())?; + } + writeln!(self.output, ") -> {};", ext.return_type.typ.to_rust())?; + self.indent -= 1; + writeln!(self.output, "}}")?; + writeln!(self.output)?; + Ok(()) + } + + fn emit_function(&mut self, func: &Function) -> Result<(), Error> { + // Documentation + if func.is_oblivious { + writeln!(self.output, "/// Oblivious function - access patterns hide secrets")?; + } + if func.is_constant_time { + writeln!(self.output, "/// Constant-time function - no secret-dependent branches")?; + } + + // Attributes + if func.is_constant_time { + writeln!(self.output, "#[inline(never)]")?; + } + + // Function signature + write!(self.output, "pub fn {}(", func.name)?; + for (i, (name, at)) in func.params.iter().enumerate() { + if i > 0 { + write!(self.output, ", ")?; + } + write!(self.output, "{}: {}", name, at.typ.to_rust())?; + } + writeln!(self.output, ") -> {} {{", func.return_type.typ.to_rust())?; + + // Function body + self.indent += 1; + self.emit_block(&func.body)?; + self.indent -= 1; + writeln!(self.output, "}}")?; + writeln!(self.output)?; + Ok(()) + } + + fn emit_block(&mut self, block: &Block) -> Result<(), Error> { + for instr in block { + self.emit_instr(instr)?; + } + Ok(()) + } + + fn emit_instr(&mut self, instr: &Instr) -> Result<(), Error> { + match instr { + Instr::Let(name, at, expr) => { + self.emit_indent()?; + write!(self.output, "let {}: {} = ", name, at.typ.to_rust())?; + self.emit_expr(expr)?; + writeln!(self.output, ";")?; + } + + Instr::Assign(lhs, rhs) => { + self.emit_indent()?; + self.emit_expr(lhs)?; + write!(self.output, " = ")?; + self.emit_expr(rhs)?; + writeln!(self.output, ";")?; + } + + Instr::OramWrite(arr, idx, val) => { + self.emit_indent()?; + self.emit_expr(arr)?; + write!(self.output, ".oram_write(")?; + self.emit_expr(idx)?; + write!(self.output, ", ")?; + self.emit_expr(val)?; + writeln!(self.output, ");")?; + } + + Instr::If(cond, then_block, else_block) => { + self.emit_indent()?; + write!(self.output, "if ")?; + self.emit_expr(cond)?; + writeln!(self.output, " {{")?; + self.indent += 1; + self.emit_block(then_block)?; + self.indent -= 1; + if !else_block.is_empty() { + self.emit_indent()?; + writeln!(self.output, "}} else {{")?; + self.indent += 1; + self.emit_block(else_block)?; + self.indent -= 1; + } + self.emit_indent()?; + writeln!(self.output, "}}")?; + } + + Instr::While(cond, body) => { + self.emit_indent()?; + write!(self.output, "while ")?; + self.emit_expr(cond)?; + writeln!(self.output, " {{")?; + self.indent += 1; + self.emit_block(body)?; + self.indent -= 1; + self.emit_indent()?; + writeln!(self.output, "}}")?; + } + + Instr::For(var, start, end, body) => { + self.emit_indent()?; + write!(self.output, "for {} in ", var)?; + self.emit_expr(start)?; + write!(self.output, "..")?; + self.emit_expr(end)?; + writeln!(self.output, " {{")?; + self.indent += 1; + self.emit_block(body)?; + self.indent -= 1; + self.emit_indent()?; + writeln!(self.output, "}}")?; + } + + Instr::Return(expr) => { + self.emit_indent()?; + match expr { + Some(e) => { + write!(self.output, "return ")?; + self.emit_expr(e)?; + writeln!(self.output, ";")?; + } + None => { + writeln!(self.output, "return;")?; + } + } + } + + Instr::Expr(e) => { + self.emit_indent()?; + self.emit_expr(e)?; + writeln!(self.output, ";")?; + } + } + Ok(()) + } + + fn emit_expr(&mut self, expr: &Expr) -> Result<(), Error> { + match expr { + Expr::Lit(lit) => self.emit_literal(lit)?, + + Expr::Var(name) => write!(self.output, "{}", name)?, + + Expr::Binop(op, lhs, rhs) => { + write!(self.output, "(")?; + self.emit_expr(lhs)?; + write!(self.output, " {} ", op.to_rust())?; + self.emit_expr(rhs)?; + write!(self.output, ")")?; + } + + Expr::Unop(op, operand) => { + write!(self.output, "{}", op.to_rust())?; + self.emit_expr(operand)?; + } + + Expr::Call(name, args) => { + write!(self.output, "{}(", name)?; + for (i, arg) in args.iter().enumerate() { + if i > 0 { + write!(self.output, ", ")?; + } + self.emit_expr(arg)?; + } + write!(self.output, ")")?; + } + + Expr::Index(arr, idx) => { + self.emit_expr(arr)?; + write!(self.output, "[")?; + self.emit_expr(idx)?; + write!(self.output, "]")?; + } + + Expr::Field(obj, field) => { + self.emit_expr(obj)?; + write!(self.output, ".{}", field)?; + } + + Expr::Cmov(cond, then_val, else_val) => { + if self.inline_runtime { + write!(self.output, "runtime::cmov(")?; + } else { + write!(self.output, "cmov(")?; + } + self.emit_expr(cond)?; + write!(self.output, ", ")?; + self.emit_expr(then_val)?; + write!(self.output, ", ")?; + self.emit_expr(else_val)?; + write!(self.output, ")")?; + } + + Expr::OramRead(arr, idx) => { + self.emit_expr(arr)?; + write!(self.output, ".oram_read(")?; + self.emit_expr(idx)?; + write!(self.output, ")")?; + } + + Expr::Struct(name, fields) => { + write!(self.output, "{} {{", name)?; + for (i, (fname, fval)) in fields.iter().enumerate() { + if i > 0 { + write!(self.output, ",")?; + } + write!(self.output, " {}: ", fname)?; + self.emit_expr(fval)?; + } + write!(self.output, " }}")?; + } + } + Ok(()) + } + + fn emit_literal(&mut self, lit: &Literal) -> Result<(), Error> { + match lit { + Literal::Int(n) => write!(self.output, "{}", n)?, + Literal::Bool(b) => write!(self.output, "{}", b)?, + Literal::Unit => write!(self.output, "()")?, + } + Ok(()) + } + + fn emit_indent(&mut self) -> Result<(), Error> { + for _ in 0..self.indent { + write!(self.output, " ")?; + } + Ok(()) + } +} + +impl Default for CodeGenerator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_function() { + let module = Module { + name: Some("test".to_string()), + structs: vec![], + externs: vec![], + functions: vec![Function { + name: "add".to_string(), + params: vec![ + ("a".to_string(), AnnotatedType { + typ: Type::Prim(PrimType::I64), + security: Security::Low, + }), + ("b".to_string(), AnnotatedType { + typ: Type::Prim(PrimType::I64), + security: Security::Low, + }), + ], + return_type: AnnotatedType { + typ: Type::Prim(PrimType::I64), + security: Security::Low, + }, + body: vec![ + Instr::Return(Some(Expr::Binop( + BinOp::Add, + Box::new(Expr::Var("a".to_string())), + Box::new(Expr::Var("b".to_string())), + ))), + ], + is_oblivious: false, + is_constant_time: false, + }], + }; + + let mut gen = CodeGenerator::new(); + let code = gen.generate(&module).unwrap(); + assert!(code.contains("pub fn add(a: i64, b: i64) -> i64")); + assert!(code.contains("return (a + b);")); + } +} diff --git a/obli-transpiler-framework/backend/src/error.rs b/obli-transpiler-framework/backend/src/error.rs new file mode 100644 index 0000000..63506cb --- /dev/null +++ b/obli-transpiler-framework/backend/src/error.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Error types for the Oblibeny backend + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON parse error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Code generation error: {0}")] + CodeGen(String), + + #[error("Invalid OIR: {0}")] + InvalidOir(String), + + #[error("Unsupported feature: {0}")] + Unsupported(String), +} + +impl Error { + pub fn codegen(msg: impl Into) -> Self { + Error::CodeGen(msg.into()) + } + + pub fn invalid_oir(msg: impl Into) -> Self { + Error::InvalidOir(msg.into()) + } + + pub fn unsupported(msg: impl Into) -> Self { + Error::Unsupported(msg.into()) + } +} diff --git a/obli-transpiler-framework/backend/src/main.rs b/obli-transpiler-framework/backend/src/main.rs new file mode 100644 index 0000000..bfc35af --- /dev/null +++ b/obli-transpiler-framework/backend/src/main.rs @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Oblibeny Backend +//! +//! This is the Rust backend for the Oblibeny oblivious computing language. +//! It consumes OIR (Oblivious Intermediate Representation) from the OCaml +//! frontend and generates Rust code that uses the ORAM runtime. + +mod oir; +mod codegen; +mod error; + +use clap::Parser; +use std::fs; +use std::path::PathBuf; + +#[derive(Parser, Debug)] +#[command(name = "oblibeny-backend")] +#[command(author = "Hyperpolymath")] +#[command(version = "0.1.0")] +#[command(about = "Oblibeny backend - generates Rust from OIR")] +struct Args { + /// Input OIR file (.oir.json) + #[arg(required = true)] + input: PathBuf, + + /// Output Rust file (default: .rs) + #[arg(short, long)] + output: Option, + + /// Generate inline runtime (don't require external crate) + #[arg(long)] + inline_runtime: bool, + + /// Verbose output + #[arg(short, long)] + verbose: bool, +} + +fn main() -> Result<(), error::Error> { + env_logger::init(); + let args = Args::parse(); + + if args.verbose { + eprintln!("Reading OIR from {:?}...", args.input); + } + + // Read and parse OIR + let oir_json = fs::read_to_string(&args.input)?; + let module: oir::Module = serde_json::from_str(&oir_json)?; + + if args.verbose { + eprintln!("Parsed module: {:?}", module.name); + eprintln!(" {} structs", module.structs.len()); + eprintln!(" {} externs", module.externs.len()); + eprintln!(" {} functions", module.functions.len()); + } + + // Generate Rust code + let mut generator = codegen::CodeGenerator::new(); + generator.set_inline_runtime(args.inline_runtime); + let rust_code = generator.generate(&module)?; + + // Determine output path + let output_path = args.output.unwrap_or_else(|| { + let mut path = args.input.clone(); + path.set_extension("rs"); + path + }); + + if args.verbose { + eprintln!("Writing Rust to {:?}...", output_path); + } + + fs::write(&output_path, rust_code)?; + + if args.verbose { + eprintln!("Done."); + } + + Ok(()) +} diff --git a/obli-transpiler-framework/backend/src/oir.rs b/obli-transpiler-framework/backend/src/oir.rs new file mode 100644 index 0000000..a9d3d57 --- /dev/null +++ b/obli-transpiler-framework/backend/src/oir.rs @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! OIR (Oblivious Intermediate Representation) types +//! +//! These types mirror the OCaml frontend's OIR definitions and are +//! deserialized from JSON. + +use serde::{Deserialize, Serialize}; + +/// Security label for information flow +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Security { + Low, + High, +} + +/// Primitive types +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PrimType { + I8, + I16, + I32, + I64, + U8, + U16, + U32, + U64, + Bool, + Unit, +} + +/// Type representation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum Type { + Prim(PrimType), + Array(Box, Option), + OArray(Box, Option), + Ref(Box), + Struct(String), + Fn(Vec, Box), +} + +/// Type with security annotation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnnotatedType { + pub typ: Type, + pub security: Security, +} + +/// Binary operators +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Mod, + Eq, + Ne, + Lt, + Le, + Gt, + Ge, + And, + Or, + BitAnd, + BitOr, + BitXor, + Shl, + Shr, +} + +/// Unary operators +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum UnOp { + Neg, + Not, + BitNot, +} + +/// Literal values +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Literal { + Int(i64), + Bool(bool), + Unit, +} + +/// Variable identifier +pub type VarId = String; + +/// Expressions +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Expr { + Lit(Literal), + Var(VarId), + Binop(BinOp, Box, Box), + Unop(UnOp, Box), + Call(String, Vec), + Index(Box, Box), + Field(Box, String), + Cmov(Box, Box, Box), + OramRead(Box, Box), + Struct(String, Vec<(String, Expr)>), +} + +/// Instructions +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Instr { + Let(VarId, AnnotatedType, Expr), + Assign(Expr, Expr), + OramWrite(Expr, Expr, Expr), + If(Expr, Block, Block), + While(Expr, Block), + For(VarId, Expr, Expr, Block), + Return(Option), + Expr(Expr), +} + +/// A block of instructions +pub type Block = Vec; + +/// Function definition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Function { + pub name: String, + pub params: Vec<(VarId, AnnotatedType)>, + pub return_type: AnnotatedType, + pub body: Block, + pub is_oblivious: bool, + pub is_constant_time: bool, +} + +/// Struct definition +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StructDef { + pub name: String, + pub fields: Vec<(String, AnnotatedType)>, +} + +/// External function declaration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExternFunc { + pub name: String, + pub params: Vec, + pub return_type: AnnotatedType, +} + +/// A complete module +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Module { + pub name: Option, + pub structs: Vec, + pub externs: Vec, + pub functions: Vec, +} + +impl Type { + /// Convert type to Rust type string + pub fn to_rust(&self) -> String { + match self { + Type::Prim(p) => p.to_rust().to_string(), + Type::Array(elem, size) => match size { + Some(n) => format!("[{}; {}]", elem.to_rust(), n), + None => format!("Vec<{}>", elem.to_rust()), + }, + Type::OArray(elem, size) => match size { + Some(n) => format!("OArray<{}, {}>", elem.to_rust(), n), + None => format!("OArray<{}>", elem.to_rust()), + }, + Type::Ref(inner) => format!("&mut {}", inner.to_rust()), + Type::Struct(name) => name.clone(), + Type::Fn(params, ret) => { + let params_str = params.iter().map(|p| p.to_rust()).collect::>().join(", "); + format!("fn({}) -> {}", params_str, ret.to_rust()) + } + } + } +} + +impl PrimType { + /// Convert primitive type to Rust type string + pub fn to_rust(&self) -> &'static str { + match self { + PrimType::I8 => "i8", + PrimType::I16 => "i16", + PrimType::I32 => "i32", + PrimType::I64 => "i64", + PrimType::U8 => "u8", + PrimType::U16 => "u16", + PrimType::U32 => "u32", + PrimType::U64 => "u64", + PrimType::Bool => "bool", + PrimType::Unit => "()", + } + } +} + +impl BinOp { + /// Convert binary operator to Rust operator string + pub fn to_rust(&self) -> &'static str { + match self { + BinOp::Add => "+", + BinOp::Sub => "-", + BinOp::Mul => "*", + BinOp::Div => "/", + BinOp::Mod => "%", + BinOp::Eq => "==", + BinOp::Ne => "!=", + BinOp::Lt => "<", + BinOp::Le => "<=", + BinOp::Gt => ">", + BinOp::Ge => ">=", + BinOp::And => "&&", + BinOp::Or => "||", + BinOp::BitAnd => "&", + BinOp::BitOr => "|", + BinOp::BitXor => "^", + BinOp::Shl => "<<", + BinOp::Shr => ">>", + } + } +} + +impl UnOp { + /// Convert unary operator to Rust operator string + pub fn to_rust(&self) -> &'static str { + match self { + UnOp::Neg => "-", + UnOp::Not => "!", + UnOp::BitNot => "!", + } + } +} diff --git a/obli-transpiler-framework/driver/Cargo.toml b/obli-transpiler-framework/driver/Cargo.toml new file mode 100644 index 0000000..1160bb9 --- /dev/null +++ b/obli-transpiler-framework/driver/Cargo.toml @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: MIT OR Palimpsest-0.8 +# Copyright (c) 2024 Hyperpolymath + +[package] +name = "oblibeny" +version = "0.1.0" +edition = "2021" +authors = ["Hyperpolymath"] +description = "Oblibeny language compiler - oblivious computing made safe" +license = "MIT OR Palimpsest-0.8" +repository = "https://github.com/hyperpolymath/oblibeny" +default-run = "oblibeny" + +[[bin]] +name = "oblibeny" +path = "src/main.rs" + +[dependencies] +clap = { version = "4.0", features = ["derive"] } +thiserror = "1.0" +log = "0.4" +env_logger = "0.10" +which = "6.0" +tempfile = "3.10" + +[dev-dependencies] +assert_cmd = "2.0" +predicates = "3.0" diff --git a/obli-transpiler-framework/driver/src/error.rs b/obli-transpiler-framework/driver/src/error.rs new file mode 100644 index 0000000..77b8a89 --- /dev/null +++ b/obli-transpiler-framework/driver/src/error.rs @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Error types for the driver + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Frontend not found: {0}")] + FrontendNotFound(String), + + #[error("Backend not found: {0}")] + BackendNotFound(String), + + #[error("Frontend failed: {0}")] + FrontendFailed(String), + + #[error("Backend failed: {0}")] + BackendFailed(String), + + #[error("Rust compiler failed: {0}")] + RustcFailed(String), + + #[error("Input file not found: {0}")] + InputNotFound(String), + + #[error("Invalid input: {0}")] + InvalidInput(String), +} diff --git a/obli-transpiler-framework/driver/src/main.rs b/obli-transpiler-framework/driver/src/main.rs new file mode 100644 index 0000000..48b073d --- /dev/null +++ b/obli-transpiler-framework/driver/src/main.rs @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Oblibeny Compiler Driver +//! +//! This is the main entry point for the Oblibeny compiler. It coordinates +//! the OCaml frontend and Rust backend to compile .obl source files. +//! +//! Pipeline: +//! source.obl → [OCaml Frontend] → source.oir.json → [Rust Backend] → source.rs +//! +//! The driver handles: +//! - Finding and invoking the frontend/backend executables +//! - Managing intermediate files +//! - Providing a unified CLI experience + +use clap::{Parser, Subcommand}; +use std::path::PathBuf; +use std::process::{Command, ExitCode}; +use tempfile::TempDir; + +mod error; +mod pipeline; + +use error::Error; + +#[derive(Parser, Debug)] +#[command(name = "oblibeny")] +#[command(author = "Hyperpolymath")] +#[command(version = "0.1.0")] +#[command(about = "Oblibeny - Oblivious computing language compiler")] +#[command(long_about = r#" +Oblibeny is a language for writing programs with hidden access patterns. +It compiles to Rust code that uses ORAM (Oblivious RAM) to prevent +side-channel attacks based on memory access patterns. + +Examples: + oblibeny compile source.obl Compile to Rust + oblibeny compile source.obl -o out.rs Compile with custom output + oblibeny check source.obl Type-check and verify obliviousness + oblibeny build source.obl Compile and build executable +"#)] +struct Args { + #[command(subcommand)] + command: Commands, + + /// Verbose output + #[arg(short, long, global = true)] + verbose: bool, +} + +#[derive(Subcommand, Debug)] +enum Commands { + /// Compile .obl source to Rust + Compile { + /// Input .obl file + input: PathBuf, + + /// Output .rs file (default: .rs) + #[arg(short, long)] + output: Option, + + /// Keep intermediate OIR file + #[arg(long)] + keep_oir: bool, + + /// Inline runtime (don't require oblibeny-runtime crate) + #[arg(long)] + inline_runtime: bool, + }, + + /// Type-check and verify obliviousness without generating code + Check { + /// Input .obl file + input: PathBuf, + }, + + /// Compile and build executable + Build { + /// Input .obl file + input: PathBuf, + + /// Output executable (default: without extension) + #[arg(short, long)] + output: Option, + + /// Build in release mode + #[arg(long)] + release: bool, + }, + + /// Show compiler version and paths + Info, +} + +fn main() -> ExitCode { + env_logger::init(); + let args = Args::parse(); + + match run(args) { + Ok(()) => ExitCode::SUCCESS, + Err(e) => { + eprintln!("error: {}", e); + ExitCode::FAILURE + } + } +} + +fn run(args: Args) -> Result<(), Error> { + match args.command { + Commands::Compile { + input, + output, + keep_oir, + inline_runtime, + } => { + let config = pipeline::CompileConfig { + input, + output, + keep_oir, + inline_runtime, + verbose: args.verbose, + }; + pipeline::compile(config) + } + + Commands::Check { input } => { + let config = pipeline::CheckConfig { + input, + verbose: args.verbose, + }; + pipeline::check(config) + } + + Commands::Build { + input, + output, + release, + } => { + let config = pipeline::BuildConfig { + input, + output, + release, + verbose: args.verbose, + }; + pipeline::build(config) + } + + Commands::Info => { + println!("Oblibeny Compiler v0.1.0"); + println!(); + println!("Frontend: oblibeny-frontend (OCaml)"); + println!("Backend: oblibeny-backend (Rust)"); + println!("Runtime: oblibeny-runtime (Rust)"); + println!(); + + // Try to find components + match which::which("oblibeny-frontend") { + Ok(path) => println!("Frontend path: {}", path.display()), + Err(_) => println!("Frontend path: not found in PATH"), + } + match which::which("oblibeny-backend") { + Ok(path) => println!("Backend path: {}", path.display()), + Err(_) => println!("Backend path: not found in PATH"), + } + + Ok(()) + } + } +} diff --git a/obli-transpiler-framework/driver/src/pipeline.rs b/obli-transpiler-framework/driver/src/pipeline.rs new file mode 100644 index 0000000..caf7e84 --- /dev/null +++ b/obli-transpiler-framework/driver/src/pipeline.rs @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Compilation pipeline implementation + +use crate::error::Error; +use std::path::PathBuf; +use std::process::Command; +use tempfile::TempDir; + +/// Configuration for compile command +pub struct CompileConfig { + pub input: PathBuf, + pub output: Option, + pub keep_oir: bool, + pub inline_runtime: bool, + pub verbose: bool, +} + +/// Configuration for check command +pub struct CheckConfig { + pub input: PathBuf, + pub verbose: bool, +} + +/// Configuration for build command +pub struct BuildConfig { + pub input: PathBuf, + pub output: Option, + pub release: bool, + pub verbose: bool, +} + +/// Find the frontend executable +fn find_frontend() -> Result { + // Try several locations + let candidates = [ + // In PATH + which::which("oblibeny-frontend").ok(), + // Relative to driver (for development) + std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|p| p.join("oblibeny-frontend"))), + // In frontend/_build + Some(PathBuf::from("frontend/_build/default/bin/main.exe")), + ]; + + for candidate in candidates.into_iter().flatten() { + if candidate.exists() { + return Ok(candidate); + } + } + + Err(Error::FrontendNotFound( + "oblibeny-frontend not found. Build the frontend first.".to_string(), + )) +} + +/// Find the backend executable +fn find_backend() -> Result { + let candidates = [ + which::which("oblibeny-backend").ok(), + std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|p| p.join("oblibeny-backend"))), + Some(PathBuf::from("backend/target/release/oblibeny-backend")), + Some(PathBuf::from("backend/target/debug/oblibeny-backend")), + ]; + + for candidate in candidates.into_iter().flatten() { + if candidate.exists() { + return Ok(candidate); + } + } + + Err(Error::BackendNotFound( + "oblibeny-backend not found. Build the backend first.".to_string(), + )) +} + +/// Compile .obl to .rs +pub fn compile(config: CompileConfig) -> Result<(), Error> { + if !config.input.exists() { + return Err(Error::InputNotFound(config.input.display().to_string())); + } + + let frontend = find_frontend()?; + let backend = find_backend()?; + + // Determine output paths + let oir_path = config.input.with_extension("oir.json"); + let rs_path = config.output.unwrap_or_else(|| config.input.with_extension("rs")); + + if config.verbose { + eprintln!("Using frontend: {}", frontend.display()); + eprintln!("Using backend: {}", backend.display()); + eprintln!("Input: {}", config.input.display()); + eprintln!("OIR: {}", oir_path.display()); + eprintln!("Output: {}", rs_path.display()); + } + + // Run frontend + if config.verbose { + eprintln!("\n=== Running frontend ==="); + } + + let frontend_status = Command::new(&frontend) + .arg(&config.input) + .arg("-o") + .arg(&oir_path) + .args(if config.verbose { vec!["-v"] } else { vec![] }) + .status()?; + + if !frontend_status.success() { + return Err(Error::FrontendFailed(format!( + "exit code: {:?}", + frontend_status.code() + ))); + } + + // Run backend + if config.verbose { + eprintln!("\n=== Running backend ==="); + } + + let mut backend_cmd = Command::new(&backend); + backend_cmd.arg(&oir_path).arg("-o").arg(&rs_path); + + if config.inline_runtime { + backend_cmd.arg("--inline-runtime"); + } + if config.verbose { + backend_cmd.arg("-v"); + } + + let backend_status = backend_cmd.status()?; + + if !backend_status.success() { + return Err(Error::BackendFailed(format!( + "exit code: {:?}", + backend_status.code() + ))); + } + + // Clean up OIR if not keeping + if !config.keep_oir && oir_path.exists() { + std::fs::remove_file(&oir_path)?; + } + + if config.verbose { + eprintln!("\nCompilation successful: {}", rs_path.display()); + } + + Ok(()) +} + +/// Type-check without code generation +pub fn check(config: CheckConfig) -> Result<(), Error> { + if !config.input.exists() { + return Err(Error::InputNotFound(config.input.display().to_string())); + } + + let frontend = find_frontend()?; + + if config.verbose { + eprintln!("Using frontend: {}", frontend.display()); + eprintln!("Checking: {}", config.input.display()); + } + + let status = Command::new(&frontend) + .arg(&config.input) + .arg("--check") + .args(if config.verbose { vec!["-v"] } else { vec![] }) + .status()?; + + if !status.success() { + return Err(Error::FrontendFailed(format!( + "check failed with exit code: {:?}", + status.code() + ))); + } + + println!("Check passed: {}", config.input.display()); + Ok(()) +} + +/// Compile and build executable +pub fn build(config: BuildConfig) -> Result<(), Error> { + // First compile to Rust + let rs_path = config.input.with_extension("rs"); + + compile(CompileConfig { + input: config.input.clone(), + output: Some(rs_path.clone()), + keep_oir: false, + inline_runtime: true, // Inline for standalone build + verbose: config.verbose, + })?; + + // Determine output executable name + let exe_path = config.output.unwrap_or_else(|| { + let stem = config.input.file_stem().unwrap_or_default(); + PathBuf::from(stem) + }); + + if config.verbose { + eprintln!("\n=== Building executable ==="); + } + + // Compile with rustc + let mut rustc_cmd = Command::new("rustc"); + rustc_cmd + .arg(&rs_path) + .arg("-o") + .arg(&exe_path) + .arg("--edition=2021"); + + if config.release { + rustc_cmd.arg("-O"); + } + + // Add runtime dependencies + rustc_cmd + .arg("--extern") + .arg("subtle=libsubtle.rlib") + .arg("--extern") + .arg("zeroize=libzeroize.rlib"); + + let status = rustc_cmd.status()?; + + if !status.success() { + // Try with cargo instead + if config.verbose { + eprintln!("Direct rustc failed, trying with cargo..."); + } + + // Create a temporary Cargo project + let temp_dir = TempDir::new()?; + let project_dir = temp_dir.path(); + + // Create Cargo.toml + let cargo_toml = format!( + r#"[package] +name = "oblibeny_output" +version = "0.1.0" +edition = "2021" + +[dependencies] +subtle = "2.5" +zeroize = "1.7" + +[[bin]] +name = "output" +path = "src/main.rs" +"# + ); + + std::fs::create_dir_all(project_dir.join("src"))?; + std::fs::write(project_dir.join("Cargo.toml"), cargo_toml)?; + std::fs::copy(&rs_path, project_dir.join("src/main.rs"))?; + + // Build with cargo + let cargo_status = Command::new("cargo") + .current_dir(project_dir) + .arg("build") + .args(if config.release { + vec!["--release"] + } else { + vec![] + }) + .status()?; + + if !cargo_status.success() { + return Err(Error::RustcFailed("cargo build failed".to_string())); + } + + // Copy the built executable + let build_mode = if config.release { "release" } else { "debug" }; + let built_exe = project_dir.join(format!("target/{}/output", build_mode)); + std::fs::copy(built_exe, &exe_path)?; + } + + if config.verbose { + eprintln!("\nBuild successful: {}", exe_path.display()); + } + + Ok(()) +} diff --git a/obli-transpiler-framework/examples/secret_lookup.obl b/obli-transpiler-framework/examples/secret_lookup.obl new file mode 100644 index 0000000..8f84aa9 --- /dev/null +++ b/obli-transpiler-framework/examples/secret_lookup.obl @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +// Example: Oblivious array lookup +// +// This demonstrates how to use ORAM arrays to hide access patterns +// when indexing with secret values. + +// A struct for storing records +struct Record { + id: @low int, + value: @high int, +} + +// Oblivious lookup function +// The @oblivious attribute ensures the compiler verifies +// that no secret-dependent branches or regular array accesses occur. +@oblivious +fn lookup(data: oarray, @high index: int) -> @high int { + // oread() performs an ORAM access that hides which index was accessed + return oread(data, index); +} + +// Oblivious conditional update +@oblivious +@constant_time +fn conditional_update( + data: oarray, + @high index: int, + @high should_update: bool, + @high new_value: int +) { + // Read current value + let current: @high int = oread(data, index); + + // Conditionally select new or old value using cmov + // This doesn't branch on the secret condition + let final_value: @high int = cmov(should_update, new_value, current); + + // Write back (ORAM hides which location) + owrite(data, index, final_value); +} + +// Binary search with hidden access pattern +@oblivious +fn oblivious_binary_search( + sorted_data: oarray, + size: int, + @high target: int +) -> @high int { + let low: @high int = 0; + let high: @high int = size - 1; + let result: @high int = -1; + + // Fixed number of iterations (log2 of max size) + for i in 0..32 { + let mid: @high int = (low + high) / 2; + let mid_val: @high int = oread(sorted_data, mid); + + // All comparisons use cmov, no branches on secrets + let found: @high bool = mid_val == target; + let go_left: @high bool = mid_val > target; + + result = cmov(found, mid, result); + high = cmov(go_left, mid - 1, high); + low = cmov(go_left, low, mid + 1); + } + + return result; +} + +fn main() { + // Create an oblivious array + let data: oarray = oarray_new(1000); + + // Initialize with public data + for i in 0..1000 { + owrite(data, i, i * 7); + } + + // Secret index - the access pattern will be hidden + let secret_idx: @high int = 42; + + // This lookup hides which index was accessed + let value: @high int = lookup(data, secret_idx); + + // Conditional update without leaking the condition + let should_update: @high bool = true; + conditional_update(data, secret_idx, should_update, 999); +} diff --git a/obli-transpiler-framework/frontend/bin/dune b/obli-transpiler-framework/frontend/bin/dune new file mode 100644 index 0000000..e6852bc --- /dev/null +++ b/obli-transpiler-framework/frontend/bin/dune @@ -0,0 +1,8 @@ +; SPDX-License-Identifier: MIT OR Palimpsest-0.8 +; Copyright (c) 2024 Hyperpolymath + +(executable + (name main) + (public_name oblibeny-frontend) + (package oblibeny) + (libraries oblibeny_frontend)) diff --git a/obli-transpiler-framework/frontend/bin/main.ml b/obli-transpiler-framework/frontend/bin/main.ml new file mode 100644 index 0000000..58c8f15 --- /dev/null +++ b/obli-transpiler-framework/frontend/bin/main.ml @@ -0,0 +1,174 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** Oblibeny Frontend CLI + + Parses .obl source files, performs type checking and obliviousness + analysis, then emits OIR (Oblivious Intermediate Representation) + for the Rust backend. +*) + +open Oblibeny_frontend + +let version = "0.1.0" + +(** Command line options *) +type options = { + mutable input_file: string option; + mutable output_file: string option; + mutable dump_ast: bool; + mutable dump_oir: bool; + mutable check_only: bool; + mutable verbose: bool; +} + +let default_options () = { + input_file = None; + output_file = None; + dump_ast = false; + dump_oir = false; + check_only = false; + verbose = false; +} + +let usage_msg = "oblibeny-frontend [OPTIONS] " + +let parse_args () = + let opts = default_options () in + let specs = [ + ("-o", Arg.String (fun s -> opts.output_file <- Some s), + " Output OIR file (default: .oir.json)"); + ("--dump-ast", Arg.Unit (fun () -> opts.dump_ast <- true), + " Dump parsed AST to stderr"); + ("--dump-oir", Arg.Unit (fun () -> opts.dump_oir <- true), + " Dump OIR to stderr"); + ("--check", Arg.Unit (fun () -> opts.check_only <- true), + " Only type-check, don't emit OIR"); + ("-v", Arg.Unit (fun () -> opts.verbose <- true), + " Verbose output"); + ("--verbose", Arg.Unit (fun () -> opts.verbose <- true), + " Verbose output"); + ("--version", Arg.Unit (fun () -> + Printf.printf "oblibeny-frontend %s\n" version; + exit 0), + " Print version and exit"); + ] in + Arg.parse specs (fun s -> opts.input_file <- Some s) usage_msg; + opts + +(** Parse source file *) +let parse_file filename = + let ic = open_in filename in + let lexbuf = Lexing.from_channel ic in + lexbuf.Lexing.lex_curr_p <- { lexbuf.Lexing.lex_curr_p with + Lexing.pos_fname = filename; + }; + try + let program = Parser.program Lexer.token lexbuf in + close_in ic; + Ok program + with + | Lexer.Lexer_error (msg, pos) -> + close_in ic; + Error (Printf.sprintf "%s:%d:%d: lexer error: %s" + pos.Lexing.pos_fname + pos.Lexing.pos_lnum + (pos.Lexing.pos_cnum - pos.Lexing.pos_bol) + msg) + | Parsing.Parse_error -> + let pos = lexbuf.Lexing.lex_curr_p in + close_in ic; + Error (Printf.sprintf "%s:%d:%d: syntax error" + pos.Lexing.pos_fname + pos.Lexing.pos_lnum + (pos.Lexing.pos_cnum - pos.Lexing.pos_bol)) + +(** Main compilation pipeline *) +let compile opts = + let input_file = match opts.input_file with + | Some f -> f + | None -> + prerr_endline "error: no input file"; + exit 1 + in + + let output_file = match opts.output_file with + | Some f -> f + | None -> + let base = Filename.remove_extension input_file in + base ^ ".oir.json" + in + + if opts.verbose then + Printf.eprintf "Parsing %s...\n%!" input_file; + + (* Parse *) + let program = match parse_file input_file with + | Ok p -> p + | Error msg -> + prerr_endline msg; + exit 1 + in + + if opts.dump_ast then begin + prerr_endline "=== AST ==="; + prerr_endline (Ast.show_program program) + end; + + if opts.verbose then + Printf.eprintf "Type checking...\n%!"; + + (* Type check *) + let type_diags = Typecheck.check_program program in + if Errors.has_errors type_diags then begin + Errors.print_diagnostics type_diags; + exit 1 + end; + + if opts.verbose then + Printf.eprintf "Checking obliviousness...\n%!"; + + (* Obliviousness check *) + let (obli_diags, violations) = Oblicheck.check_program program in + if violations > 0 then begin + Errors.print_diagnostics obli_diags; + let result = Oblicheck.analyze_violations obli_diags in + Printf.eprintf "\nObliviousness violations: %d\n" result.total_violations; + Printf.eprintf " Secret branches: %d\n" result.secret_branches; + Printf.eprintf " Secret indices: %d\n" result.secret_indices; + Printf.eprintf " Secret loop bounds: %d\n" result.secret_loops; + Printf.eprintf " Information leaks: %d\n" result.info_leaks; + exit 1 + end; + + (* Print warnings *) + Errors.print_diagnostics type_diags; + Errors.print_diagnostics obli_diags; + + if opts.check_only then begin + if opts.verbose then + Printf.eprintf "Check passed.\n%!"; + exit 0 + end; + + if opts.verbose then + Printf.eprintf "Emitting OIR to %s...\n%!" output_file; + + (* Emit OIR *) + let oir_module = Emit_oir.emit_module program in + + if opts.dump_oir then begin + prerr_endline "=== OIR ==="; + prerr_endline (Emit_oir.to_json oir_module) + end; + + Emit_oir.write_oir output_file oir_module; + + if opts.verbose then + Printf.eprintf "Done.\n%!"; + + exit 0 + +let () = + let opts = parse_args () in + compile opts diff --git a/obli-transpiler-framework/frontend/dune-project b/obli-transpiler-framework/frontend/dune-project new file mode 100644 index 0000000..32c0734 --- /dev/null +++ b/obli-transpiler-framework/frontend/dune-project @@ -0,0 +1,27 @@ +; SPDX-License-Identifier: MIT OR Palimpsest-0.8 +; Copyright (c) 2024 Hyperpolymath + +(lang dune 3.0) +(name oblibeny) +(version 0.1.0) + +(generate_opam_files true) + +(source (github hyperpolymath/oblibeny)) +(license "MIT OR Palimpsest-0.8") +(authors "Hyperpolymath") +(maintainers "Hyperpolymath") + +(package + (name oblibeny) + (synopsis "Oblibeny language frontend - oblivious computing compiler") + (description "OCaml frontend for the Oblibeny oblivious computing language. +Produces OIR (Oblivious Intermediate Representation) for the Rust backend.") + (depends + (ocaml (>= 4.14)) + (dune (>= 3.0)) + menhir + sedlex + yojson + ppx_deriving + ppx_deriving_yojson)) diff --git a/obli-transpiler-framework/frontend/lib/ast.ml b/obli-transpiler-framework/frontend/lib/ast.ml new file mode 100644 index 0000000..213ed89 --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/ast.ml @@ -0,0 +1,189 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** Abstract Syntax Tree for Oblibeny language *) + +open Location + +(** Security labels for information flow *) +type security_label = + | Low (** Public data *) + | High (** Secret data *) + [@@deriving show, yojson] + +(** Primitive types *) +type prim_type = + | TInt of int option (** Integer with optional bit width *) + | TUint of int option (** Unsigned integer with optional bit width *) + | TBool + | TUnit + | TByte + [@@deriving show, yojson] + +(** Type expressions *) +type typ = + | TPrim of prim_type + | TArray of typ * security_label (** Regular array *) + | TOArray of typ (** Oblivious array (ORAM-backed) *) + | TRef of typ * security_label (** Reference with security label *) + | TFun of typ list * typ (** Function type *) + | TStruct of string (** Named struct type *) + | TGeneric of string * typ list (** Generic type application *) + | TVar of string (** Type variable *) + [@@deriving show, yojson] + +(** Annotated type with security label *) +type annotated_type = { + typ: typ; + security: security_label; + loc: Location.t; +} [@@deriving show, yojson] + +(** Binary operators *) +type binop = + | Add | Sub | Mul | Div | Mod + | Eq | Neq | Lt | Le | Gt | Ge + | And | Or + | BitAnd | BitOr | BitXor + | Shl | Shr + [@@deriving show, yojson] + +(** Unary operators *) +type unop = + | Neg | Not | BitNot + [@@deriving show, yojson] + +(** Literals *) +type literal = + | LInt of int64 + | LUint of int64 + | LBool of bool + | LByte of char + | LUnit + [@@deriving show, yojson] + +(** Pattern for matching *) +type pattern = + | PWildcard + | PVar of string + | PLiteral of literal + | PTuple of pattern list + | PStruct of string * (string * pattern) list + [@@deriving show, yojson] + +(** Expressions *) +type expr = { + expr_desc: expr_desc; + expr_loc: Location.t; + mutable expr_type: annotated_type option; (** Filled during type checking *) +} [@@deriving show, yojson] + +and expr_desc = + | ELiteral of literal + | EVar of string + | EBinop of binop * expr * expr + | EUnop of unop * expr + | ECall of expr * expr list + | EIndex of expr * expr (** Array indexing *) + | EOramRead of expr * expr (** Explicit ORAM read: oread(arr, idx) *) + | EField of expr * string (** Struct field access *) + | EIf of expr * expr * expr (** Conditional expression *) + | EBlock of stmt list * expr option (** Block with optional final expression *) + | ELambda of (string * annotated_type) list * expr (** Anonymous function *) + | ETuple of expr list + | EStruct of string * (string * expr) list (** Struct construction *) + | ECmov of expr * expr * expr (** Constant-time conditional move *) + [@@deriving show, yojson] + +(** Statements *) +and stmt = { + stmt_desc: stmt_desc; + stmt_loc: Location.t; +} [@@deriving show, yojson] + +and stmt_desc = + | SLet of pattern * annotated_type option * expr (** Let binding *) + | SAssign of expr * expr (** Assignment *) + | SOramWrite of expr * expr * expr (** ORAM write: owrite(arr, idx, val) *) + | SExpr of expr (** Expression statement *) + | SIf of expr * stmt list * stmt list (** If statement *) + | SWhile of expr * stmt list (** While loop *) + | SFor of string * expr * expr * stmt list (** For loop: for i in start..end *) + | SReturn of expr option (** Return statement *) + | SBreak + | SContinue + [@@deriving show, yojson] + +(** Top-level declarations *) +type decl = { + decl_desc: decl_desc; + decl_loc: Location.t; +} [@@deriving show, yojson] + +and decl_desc = + | DFunction of { + name: string; + type_params: string list; (** Generic type parameters *) + params: (string * annotated_type) list; + return_type: annotated_type; + body: stmt list; + attributes: attribute list; + } + | DStruct of { + name: string; + type_params: string list; + fields: (string * annotated_type) list; + attributes: attribute list; + } + | DConst of { + name: string; + typ: annotated_type; + value: expr; + } + | DExtern of { + name: string; + typ: annotated_type; + attributes: attribute list; + } + | DImport of string list (** Import path *) + [@@deriving show, yojson] + +(** Attributes/annotations *) +and attribute = + | AOblivious (** Marks function as requiring oblivious execution *) + | AInline (** Hint for inlining *) + | ANoOptimize (** Disable optimizations (for crypto code) *) + | AConstantTime (** Must be constant-time *) + | APublic (** Public interface *) + | ACustom of string * string option (** Custom attribute with optional value *) + [@@deriving show, yojson] + +(** A complete compilation unit *) +type program = { + module_name: string option; + declarations: decl list; +} [@@deriving show, yojson] + +(** Helper constructors *) + +let mk_expr loc desc = { + expr_desc = desc; + expr_loc = loc; + expr_type = None; +} + +let mk_stmt loc desc = { + stmt_desc = desc; + stmt_loc = loc; +} + +let mk_decl loc desc = { + decl_desc = desc; + decl_loc = loc; +} + +let mk_atype loc security typ = { + typ; + security; + loc; +} diff --git a/obli-transpiler-framework/frontend/lib/dune b/obli-transpiler-framework/frontend/lib/dune new file mode 100644 index 0000000..6c83264 --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/dune @@ -0,0 +1,10 @@ +; SPDX-License-Identifier: MIT OR Palimpsest-0.8 +; Copyright (c) 2024 Hyperpolymath + +(library + (name oblibeny_frontend) + (public_name oblibeny.frontend) + (libraries str yojson) + (preprocess (pps ppx_deriving ppx_deriving_yojson sedlex.ppx)) + (ocamllex lexer) + (menhir (modules parser))) diff --git a/obli-transpiler-framework/frontend/lib/emit_oir.ml b/obli-transpiler-framework/frontend/lib/emit_oir.ml new file mode 100644 index 0000000..c22481b --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/emit_oir.ml @@ -0,0 +1,316 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** OIR (Oblivious Intermediate Representation) emission + + This module transforms the typed AST into OIR, which is then + serialized to JSON/MessagePack for the Rust backend. +*) + +open Ast + +(** OIR types - these mirror the Rust OIR definitions *) + +module Oir = struct + type security = Low | High [@@deriving yojson] + + type prim_type = + | I8 | I16 | I32 | I64 + | U8 | U16 | U32 | U64 + | Bool | Unit + [@@deriving yojson] + + type typ = + | Prim of prim_type + | Array of typ * int option (* element type, optional size *) + | OArray of typ * int option (* oblivious array *) + | Ref of typ + | Struct of string + | Fn of typ list * typ + [@@deriving yojson] + + type annotated_type = { + typ: typ; + security: security; + } [@@deriving yojson] + + type binop = + | Add | Sub | Mul | Div | Mod + | Eq | Ne | Lt | Le | Gt | Ge + | And | Or + | BitAnd | BitOr | BitXor | Shl | Shr + [@@deriving yojson] + + type unop = Neg | Not | BitNot [@@deriving yojson] + + type literal = + | Int of int64 + | Bool of bool + | Unit + [@@deriving yojson] + + type var_id = string [@@deriving yojson] + + type expr = + | Lit of literal + | Var of var_id + | Binop of binop * expr * expr + | Unop of unop * expr + | Call of string * expr list + | Index of expr * expr + | Field of expr * string + | Cmov of expr * expr * expr (* condition, true_val, false_val *) + | OramRead of expr * expr (* array, index *) + | Struct of string * (string * expr) list + [@@deriving yojson] + + type instr = + | Let of var_id * annotated_type * expr + | Assign of expr * expr + | OramWrite of expr * expr * expr (* array, index, value *) + | If of expr * block * block + | While of expr * block + | For of var_id * expr * expr * block (* var, start, end, body *) + | Return of expr option + | Expr of expr + [@@deriving yojson] + + and block = instr list [@@deriving yojson] + + type func = { + name: string; + params: (var_id * annotated_type) list; + return_type: annotated_type; + body: block; + is_oblivious: bool; + is_constant_time: bool; + } [@@deriving yojson] + + type struct_def = { + name: string; + fields: (string * annotated_type) list; + } [@@deriving yojson] + + type extern_func = { + name: string; + params: annotated_type list; + return_type: annotated_type; + } [@@deriving yojson] + + type module_def = { + name: string option; + structs: struct_def list; + externs: extern_func list; + functions: func list; + } [@@deriving yojson] +end + +(** Conversion utilities *) + +let convert_security = function + | Low -> Oir.Low + | High -> Oir.High + +let rec convert_prim_type = function + | TInt None -> Oir.I64 + | TInt (Some 8) -> Oir.I8 + | TInt (Some 16) -> Oir.I16 + | TInt (Some 32) -> Oir.I32 + | TInt (Some 64) -> Oir.I64 + | TInt (Some _) -> Oir.I64 (* Default to I64 for other widths *) + | TUint None -> Oir.U64 + | TUint (Some 8) -> Oir.U8 + | TUint (Some 16) -> Oir.U16 + | TUint (Some 32) -> Oir.U32 + | TUint (Some 64) -> Oir.U64 + | TUint (Some _) -> Oir.U64 + | TBool -> Oir.Bool + | TByte -> Oir.U8 + | TUnit -> Oir.Unit + +and convert_type = function + | TPrim p -> Oir.Prim (convert_prim_type p) + | TArray (elem, _) -> Oir.Array (convert_type elem, None) + | TOArray elem -> Oir.OArray (convert_type elem, None) + | TRef (elem, _) -> Oir.Ref (convert_type elem) + | TStruct name -> Oir.Struct name + | TFun (params, ret) -> Oir.Fn (List.map convert_type params, convert_type ret) + | TGeneric (name, _) -> Oir.Struct name (* Simplified: treat generics as structs *) + | TVar _ -> Oir.Prim Oir.Unit (* Type variables shouldn't reach emission *) + +let convert_annotated_type at = + { Oir.typ = convert_type at.typ; security = convert_security at.security } + +let convert_binop = function + | Add -> Oir.Add | Sub -> Oir.Sub | Mul -> Oir.Mul + | Div -> Oir.Div | Mod -> Oir.Mod + | Eq -> Oir.Eq | Neq -> Oir.Ne + | Lt -> Oir.Lt | Le -> Oir.Le | Gt -> Oir.Gt | Ge -> Oir.Ge + | And -> Oir.And | Or -> Oir.Or + | BitAnd -> Oir.BitAnd | BitOr -> Oir.BitOr | BitXor -> Oir.BitXor + | Shl -> Oir.Shl | Shr -> Oir.Shr + +let convert_unop = function + | Neg -> Oir.Neg + | Not -> Oir.Not + | BitNot -> Oir.BitNot + +let convert_literal = function + | LInt n -> Oir.Int n + | LUint n -> Oir.Int n + | LBool b -> Oir.Bool b + | LByte c -> Oir.Int (Int64.of_int (Char.code c)) + | LUnit -> Oir.Unit + +(** Name generation for temporaries *) +let temp_counter = ref 0 +let fresh_temp () = + let n = !temp_counter in + incr temp_counter; + Printf.sprintf "_t%d" n + +(** Expression emission *) +let rec emit_expr expr = + match expr.expr_desc with + | ELiteral lit -> Oir.Lit (convert_literal lit) + | EVar name -> Oir.Var name + | EBinop (op, lhs, rhs) -> + Oir.Binop (convert_binop op, emit_expr lhs, emit_expr rhs) + | EUnop (op, operand) -> + Oir.Unop (convert_unop op, emit_expr operand) + | ECall (func, args) -> + let func_name = match func.expr_desc with + | EVar name -> name + | _ -> "_anon_fn" (* Indirect calls need special handling *) + in + Oir.Call (func_name, List.map emit_expr args) + | EIndex (arr, idx) -> + Oir.Index (emit_expr arr, emit_expr idx) + | EOramRead (arr, idx) -> + Oir.OramRead (emit_expr arr, emit_expr idx) + | EField (obj, field) -> + Oir.Field (emit_expr obj, field) + | EIf (cond, then_expr, else_expr) -> + (* Convert if-expression to cmov *) + Oir.Cmov (emit_expr cond, emit_expr then_expr, emit_expr else_expr) + | EBlock (stmts, final) -> + (* Blocks in expressions need special handling - simplified here *) + (match final with + | Some e -> emit_expr e + | None -> Oir.Lit Oir.Unit) + | ELambda _ -> + (* Lambdas should be lifted to top-level *) + Oir.Lit Oir.Unit (* TODO: Lambda lifting *) + | ETuple exprs -> + (* Tuples should be converted to structs *) + Oir.Struct ("_tuple", List.mapi (fun i e -> (Printf.sprintf "_%d" i, emit_expr e)) exprs) + | EStruct (name, fields) -> + Oir.Struct (name, List.map (fun (n, e) -> (n, emit_expr e)) fields) + | ECmov (cond, then_val, else_val) -> + Oir.Cmov (emit_expr cond, emit_expr then_val, emit_expr else_val) + +(** Statement emission *) +let rec emit_stmt stmt : Oir.instr list = + match stmt.stmt_desc with + | SLet (pattern, type_annot, init) -> + let var_name = match pattern with + | PVar name -> name + | _ -> fresh_temp () (* Pattern matching needs expansion *) + in + let at = match type_annot with + | Some t -> convert_annotated_type t + | None -> match init.expr_type with + | Some t -> convert_annotated_type t + | None -> { Oir.typ = Oir.Prim Oir.Unit; security = Oir.Low } + in + [Oir.Let (var_name, at, emit_expr init)] + + | SAssign (lhs, rhs) -> + [Oir.Assign (emit_expr lhs, emit_expr rhs)] + + | SOramWrite (arr, idx, value) -> + [Oir.OramWrite (emit_expr arr, emit_expr idx, emit_expr value)] + + | SExpr e -> + [Oir.Expr (emit_expr e)] + + | SIf (cond, then_stmts, else_stmts) -> + let then_block = List.concat_map emit_stmt then_stmts in + let else_block = List.concat_map emit_stmt else_stmts in + [Oir.If (emit_expr cond, then_block, else_block)] + + | SWhile (cond, body) -> + let body_block = List.concat_map emit_stmt body in + [Oir.While (emit_expr cond, body_block)] + + | SFor (var, start_expr, end_expr, body) -> + let body_block = List.concat_map emit_stmt body in + [Oir.For (var, emit_expr start_expr, emit_expr end_expr, body_block)] + + | SReturn expr_opt -> + [Oir.Return (Option.map emit_expr expr_opt)] + + | SBreak -> + [] (* TODO: Need break instruction in OIR *) + + | SContinue -> + [] (* TODO: Need continue instruction in OIR *) + +(** Declaration emission *) +let emit_function decl = + match decl.decl_desc with + | DFunction { name; params; return_type; body; attributes; _ } -> + let is_oblivious = List.exists ((=) AOblivious) attributes in + let is_constant_time = List.exists ((=) AConstantTime) attributes in + Some { + Oir.name; + params = List.map (fun (n, t) -> (n, convert_annotated_type t)) params; + return_type = convert_annotated_type return_type; + body = List.concat_map emit_stmt body; + is_oblivious; + is_constant_time; + } + | _ -> None + +let emit_struct decl = + match decl.decl_desc with + | DStruct { name; fields; _ } -> + Some { + Oir.name; + fields = List.map (fun (n, t) -> (n, convert_annotated_type t)) fields; + } + | _ -> None + +let emit_extern decl = + match decl.decl_desc with + | DExtern { name; typ; _ } -> + (match typ.typ with + | TFun (params, ret) -> + Some { + Oir.name; + params = List.map (fun t -> { Oir.typ = convert_type t; security = Oir.Low }) params; + return_type = { Oir.typ = convert_type ret; security = convert_security typ.security }; + } + | _ -> None) + | _ -> None + +(** Emit complete module *) +let emit_module program = + temp_counter := 0; + { + Oir.name = program.module_name; + structs = List.filter_map emit_struct program.declarations; + externs = List.filter_map emit_extern program.declarations; + functions = List.filter_map emit_function program.declarations; + } + +(** Serialize to JSON *) +let to_json module_def = + Yojson.Safe.pretty_to_string (Oir.module_def_to_yojson module_def) + +(** Write OIR to file *) +let write_oir filename module_def = + let json = to_json module_def in + let oc = open_out filename in + output_string oc json; + close_out oc diff --git a/obli-transpiler-framework/frontend/lib/errors.ml b/obli-transpiler-framework/frontend/lib/errors.ml new file mode 100644 index 0000000..9c6f221 --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/errors.ml @@ -0,0 +1,171 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** Error reporting and diagnostics *) + +open Location + +type severity = + | Error + | Warning + | Note + [@@deriving show] + +type error_kind = + (* Lexer errors *) + | Unexpected_character of char + | Unterminated_comment + | Unterminated_string + | Invalid_escape of char + + (* Parser errors *) + | Syntax_error of string + | Unexpected_token of string + + (* Type errors *) + | Type_mismatch of { expected: string; found: string } + | Unknown_identifier of string + | Unknown_type of string + | Duplicate_definition of string + | Invalid_operation of { op: string; typ: string } + | Arity_mismatch of { expected: int; found: int } + | Not_a_function of string + | Field_not_found of { struct_name: string; field: string } + | Cannot_infer_type + | Recursive_type + + (* Obliviousness errors *) + | Secret_dependent_branch + | Secret_array_index of string + | Secret_loop_bound + | Non_oblivious_operation of string + | Information_leak of { from_label: string; to_label: string } + + (* Other errors *) + | Internal_error of string + [@@deriving show] + +type diagnostic = { + severity: severity; + kind: error_kind; + loc: Location.t; + message: string; + suggestion: string option; + related: (Location.t * string) list; +} [@@deriving show] + +let make_error kind loc message = { + severity = Error; + kind; + loc; + message; + suggestion = None; + related = []; +} + +let make_warning kind loc message = { + severity = Warning; + kind; + loc; + message; + suggestion = None; + related = []; +} + +let with_suggestion suggestion diag = + { diag with suggestion = Some suggestion } + +let with_related related diag = + { diag with related } + +(** Diagnostics accumulator *) +type diagnostics = { + mutable errors: diagnostic list; + mutable warnings: diagnostic list; +} + +let create_diagnostics () = { + errors = []; + warnings = []; +} + +let report diags diag = + match diag.severity with + | Error -> diags.errors <- diag :: diags.errors + | Warning | Note -> diags.warnings <- diag :: diags.warnings + +let has_errors diags = diags.errors <> [] + +let get_errors diags = List.rev diags.errors +let get_warnings diags = List.rev diags.warnings + +(** Pretty printing *) +let severity_to_string = function + | Error -> "error" + | Warning -> "warning" + | Note -> "note" + +let format_diagnostic diag = + let sev = severity_to_string diag.severity in + let loc = Location.to_string diag.loc in + let main = Printf.sprintf "%s: %s: %s" loc sev diag.message in + let suggestion = match diag.suggestion with + | Some s -> Printf.sprintf "\n suggestion: %s" s + | None -> "" + in + let related = diag.related + |> List.map (fun (loc, msg) -> + Printf.sprintf "\n %s: note: %s" (Location.to_string loc) msg) + |> String.concat "" + in + main ^ suggestion ^ related + +let print_diagnostics diags = + List.iter (fun d -> prerr_endline (format_diagnostic d)) (get_errors diags); + List.iter (fun d -> prerr_endline (format_diagnostic d)) (get_warnings diags) + +(** Convenience functions for common errors *) +let type_mismatch ~expected ~found loc = + make_error + (Type_mismatch { expected; found }) + loc + (Printf.sprintf "type mismatch: expected `%s`, found `%s`" expected found) + +let unknown_identifier name loc = + make_error + (Unknown_identifier name) + loc + (Printf.sprintf "unknown identifier `%s`" name) + +let unknown_type name loc = + make_error + (Unknown_type name) + loc + (Printf.sprintf "unknown type `%s`" name) + +let secret_branch loc = + make_error + Secret_dependent_branch + loc + "branch condition depends on secret data" + |> with_suggestion "use cmov() or oblivious selection instead" + +let secret_index array_name loc = + make_error + (Secret_array_index array_name) + loc + (Printf.sprintf "array `%s` indexed with secret value" array_name) + |> with_suggestion "use oarray with oread()/owrite() for oblivious access" + +let secret_loop_bound loc = + make_error + Secret_loop_bound + loc + "loop bound depends on secret data" + |> with_suggestion "use fixed iteration count or oblivious loop" + +let information_leak ~from_label ~to_label loc = + make_error + (Information_leak { from_label; to_label }) + loc + (Printf.sprintf "information flow from @%s to @%s" from_label to_label) diff --git a/obli-transpiler-framework/frontend/lib/lexer.mll b/obli-transpiler-framework/frontend/lib/lexer.mll new file mode 100644 index 0000000..8302d25 --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/lexer.mll @@ -0,0 +1,155 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +{ + open Parser + + exception Lexer_error of string * Lexing.position + + let keywords = Hashtbl.create 50 + let () = List.iter (fun (kw, tok) -> Hashtbl.add keywords kw tok) [ + (* Types *) + ("int", INT_T); + ("uint", UINT_T); + ("bool", BOOL_T); + ("byte", BYTE_T); + ("unit", UNIT_T); + ("array", ARRAY_T); + ("oarray", OARRAY_T); + ("ref", REF_T); + + (* Security labels *) + ("low", LOW); + ("high", HIGH); + + (* Keywords *) + ("fn", FN); + ("let", LET); + ("mut", MUT); + ("if", IF); + ("else", ELSE); + ("while", WHILE); + ("for", FOR); + ("in", IN); + ("return", RETURN); + ("break", BREAK); + ("continue", CONTINUE); + ("struct", STRUCT); + ("const", CONST); + ("extern", EXTERN); + ("import", IMPORT); + ("true", TRUE); + ("false", FALSE); + ("and", AND); + ("or", OR); + ("not", NOT); + + (* ORAM operations *) + ("oread", OREAD); + ("owrite", OWRITE); + ("cmov", CMOV); + ] + + let newline lexbuf = + let pos = lexbuf.Lexing.lex_curr_p in + lexbuf.Lexing.lex_curr_p <- { pos with + Lexing.pos_lnum = pos.Lexing.pos_lnum + 1; + Lexing.pos_bol = pos.Lexing.pos_cnum; + } +} + +let digit = ['0'-'9'] +let hex_digit = ['0'-'9' 'a'-'f' 'A'-'F'] +let alpha = ['a'-'z' 'A'-'Z'] +let ident_start = alpha | '_' +let ident_char = alpha | digit | '_' + +let integer = digit+ +let hex_integer = "0x" hex_digit+ +let identifier = ident_start ident_char* + +let whitespace = [' ' '\t']+ +let newline = '\r'? '\n' + +rule token = parse + | whitespace { token lexbuf } + | newline { newline lexbuf; token lexbuf } + + (* Comments *) + | "//" [^ '\n']* { token lexbuf } + | "/*" { block_comment lexbuf; token lexbuf } + + (* Delimiters *) + | '(' { LPAREN } + | ')' { RPAREN } + | '{' { LBRACE } + | '}' { RBRACE } + | '[' { LBRACK } + | ']' { RBRACK } + | '<' { LT } + | '>' { GT } + | ',' { COMMA } + | ';' { SEMI } + | ':' { COLON } + | '.' { DOT } + | ".." { DOTDOT } + | "->" { ARROW } + | "=>" { FAT_ARROW } + | '@' { AT } + + (* Operators *) + | '+' { PLUS } + | '-' { MINUS } + | '*' { STAR } + | '/' { SLASH } + | '%' { PERCENT } + | '=' { EQ } + | "==" { EQEQ } + | "!=" { NEQ } + | "<=" { LE } + | ">=" { GE } + | "<<" { SHL } + | ">>" { SHR } + | '&' { AMP } + | '|' { PIPE } + | '^' { CARET } + | '~' { TILDE } + | '!' { BANG } + | "&&" { AMPAMP } + | "||" { PIPEPIPE } + + (* Literals *) + | integer as n { INT_LIT (Int64.of_string n) } + | hex_integer as n { INT_LIT (Int64.of_string n) } + | "0b" (['0' '1']+ as n) { INT_LIT (Int64.of_string ("0b" ^ n)) } + + (* Byte literals *) + | '\'' ([^ '\\' '\''] as c) '\'' { BYTE_LIT c } + | "'\\" (['n' 't' 'r' '\\' '\''] as c) '\'' { + let c' = match c with + | 'n' -> '\n' + | 't' -> '\t' + | 'r' -> '\r' + | '\\' -> '\\' + | '\'' -> '\'' + | _ -> assert false + in BYTE_LIT c' + } + + (* Identifiers and keywords *) + | identifier as id { + try Hashtbl.find keywords id + with Not_found -> IDENT id + } + + | eof { EOF } + + | _ as c { + raise (Lexer_error (Printf.sprintf "Unexpected character: %c" c, lexbuf.Lexing.lex_curr_p)) + } + +and block_comment = parse + | "*/" { () } + | newline { newline lexbuf; block_comment lexbuf } + | _ { block_comment lexbuf } + | eof { raise (Lexer_error ("Unterminated block comment", lexbuf.Lexing.lex_curr_p)) } diff --git a/obli-transpiler-framework/frontend/lib/location.ml b/obli-transpiler-framework/frontend/lib/location.ml new file mode 100644 index 0000000..3e231db --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/location.ml @@ -0,0 +1,58 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** Source location tracking *) + +type position = { + line: int; + column: int; + offset: int; +} [@@deriving show, yojson] + +type t = { + start_pos: position; + end_pos: position; + filename: string; +} [@@deriving show, yojson] + +let dummy = { + start_pos = { line = 0; column = 0; offset = 0 }; + end_pos = { line = 0; column = 0; offset = 0 }; + filename = ""; +} + +let make ~filename ~start_line ~start_col ~end_line ~end_col = { + start_pos = { line = start_line; column = start_col; offset = 0 }; + end_pos = { line = end_line; column = end_col; offset = 0 }; + filename; +} + +let from_lexbuf filename lexbuf = + let open Lexing in + let start_p = lexbuf.lex_start_p in + let end_p = lexbuf.lex_curr_p in + { + start_pos = { + line = start_p.pos_lnum; + column = start_p.pos_cnum - start_p.pos_bol; + offset = start_p.pos_cnum; + }; + end_pos = { + line = end_p.pos_lnum; + column = end_p.pos_cnum - end_p.pos_bol; + offset = end_p.pos_cnum; + }; + filename; + } + +let merge loc1 loc2 = { + start_pos = loc1.start_pos; + end_pos = loc2.end_pos; + filename = loc1.filename; +} + +let to_string loc = + Printf.sprintf "%s:%d:%d-%d:%d" + loc.filename + loc.start_pos.line loc.start_pos.column + loc.end_pos.line loc.end_pos.column diff --git a/obli-transpiler-framework/frontend/lib/oblicheck.ml b/obli-transpiler-framework/frontend/lib/oblicheck.ml new file mode 100644 index 0000000..71eefd7 --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/oblicheck.ml @@ -0,0 +1,249 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** Obliviousness checking pass for Oblibeny + + This pass verifies that programs do not leak secret information + through their access patterns. It enforces: + + 1. No branching on secret values (use cmov instead) + 2. No array indexing with secret indices (use oarray with oread/owrite) + 3. No secret-dependent loop bounds (use fixed iteration) + 4. Information flow constraints (high cannot flow to low) +*) + +open Ast +open Errors + +(** Security context tracking *) +type context = { + in_secret_branch: bool; (** Inside a branch dependent on secrets *) + branch_security: security_label; (** Security of current branch condition *) + loop_depth: int; (** Current loop nesting depth *) + oblivious_function: bool; (** Inside @oblivious function *) +} + +let initial_context = { + in_secret_branch = false; + branch_security = Low; + loop_depth = 0; + oblivious_function = false; +} + +let enter_secret_branch ctx security = { + ctx with + in_secret_branch = true; + branch_security = security_join ctx.branch_security security; +} + +let enter_loop ctx = { + ctx with loop_depth = ctx.loop_depth + 1; +} + +let enter_oblivious_function ctx = { + ctx with oblivious_function = true; +} + +(** State for obliviousness checker *) +type state = { + diags: diagnostics; + mutable violations: int; +} + +let create_state () = { + diags = create_diagnostics (); + violations = 0; +} + +(** Get security label of expression (requires prior type checking) *) +let get_security expr = + match expr.expr_type with + | Some at -> at.security + | None -> Low (* Default if not type-checked yet *) + +(** Check if type is oblivious array *) +let is_oarray typ = + match typ with + | TOArray _ -> true + | _ -> false + +(** Check expression for obliviousness violations *) +let rec check_expr state ctx expr = + match expr.expr_desc with + | ELiteral _ | EVar _ -> () + + | EBinop (_, lhs, rhs) -> + check_expr state ctx lhs; + check_expr state ctx rhs + + | EUnop (_, operand) -> + check_expr state ctx operand + + | ECall (func, args) -> + check_expr state ctx func; + List.iter (check_expr state ctx) args + + | EIndex (arr, idx) -> + check_expr state ctx arr; + check_expr state ctx idx; + (* Check for secret indexing into non-oblivious array *) + let idx_security = get_security idx in + let arr_type = match arr.expr_type with + | Some at -> at.typ + | None -> TPrim TUnit + in + if idx_security = High && not (is_oarray arr_type) then begin + report state.diags (secret_index "array" expr.expr_loc); + state.violations <- state.violations + 1 + end + + | EOramRead (arr, idx) -> + check_expr state ctx arr; + check_expr state ctx idx + (* ORAM operations are safe by construction *) + + | EField (obj, _) -> + check_expr state ctx obj + + | EIf (cond, then_expr, else_expr) -> + check_expr state ctx cond; + let cond_security = get_security cond in + if cond_security = High && ctx.oblivious_function then begin + (* In oblivious function, secret branches are violations *) + report state.diags (secret_branch cond.expr_loc); + state.violations <- state.violations + 1 + end; + let new_ctx = enter_secret_branch ctx cond_security in + check_expr state new_ctx then_expr; + check_expr state new_ctx else_expr + + | EBlock (stmts, expr_opt) -> + List.iter (check_stmt state ctx) stmts; + Option.iter (check_expr state ctx) expr_opt + + | ELambda (_, body) -> + check_expr state ctx body + + | ETuple exprs -> + List.iter (check_expr state ctx) exprs + + | EStruct (_, fields) -> + List.iter (fun (_, e) -> check_expr state ctx e) fields + + | ECmov (cond, then_val, else_val) -> + (* cmov is safe for oblivious selection *) + check_expr state ctx cond; + check_expr state ctx then_val; + check_expr state ctx else_val + +(** Check statement for obliviousness violations *) +and check_stmt state ctx stmt = + match stmt.stmt_desc with + | SLet (_, _, init) -> + check_expr state ctx init + + | SAssign (lhs, rhs) -> + check_expr state ctx lhs; + check_expr state ctx rhs; + (* Check information flow: cannot assign high to low *) + let lhs_security = get_security lhs in + let rhs_security = get_security rhs in + if rhs_security = High && lhs_security = Low && ctx.oblivious_function then begin + report state.diags (information_leak ~from_label:"high" ~to_label:"low" stmt.stmt_loc); + state.violations <- state.violations + 1 + end + + | SOramWrite (arr, idx, value) -> + check_expr state ctx arr; + check_expr state ctx idx; + check_expr state ctx value + (* ORAM operations are safe *) + + | SExpr e -> + check_expr state ctx e + + | SIf (cond, then_stmts, else_stmts) -> + check_expr state ctx cond; + let cond_security = get_security cond in + if cond_security = High && ctx.oblivious_function then begin + report state.diags (secret_branch cond.expr_loc); + state.violations <- state.violations + 1 + end; + let new_ctx = enter_secret_branch ctx cond_security in + List.iter (check_stmt state new_ctx) then_stmts; + List.iter (check_stmt state new_ctx) else_stmts + + | SWhile (cond, body) -> + check_expr state ctx cond; + let cond_security = get_security cond in + if cond_security = High && ctx.oblivious_function then begin + report state.diags (secret_loop_bound cond.expr_loc); + state.violations <- state.violations + 1 + end; + let new_ctx = enter_loop (enter_secret_branch ctx cond_security) in + List.iter (check_stmt state new_ctx) body + + | SFor (_, start_expr, end_expr, body) -> + check_expr state ctx start_expr; + check_expr state ctx end_expr; + let start_security = get_security start_expr in + let end_security = get_security end_expr in + let bound_security = security_join start_security end_security in + if bound_security = High && ctx.oblivious_function then begin + report state.diags (secret_loop_bound start_expr.expr_loc); + state.violations <- state.violations + 1 + end; + let new_ctx = enter_loop ctx in + List.iter (check_stmt state new_ctx) body + + | SReturn expr_opt -> + Option.iter (check_expr state ctx) expr_opt + + | SBreak | SContinue -> () + +(** Check declaration *) +let check_decl state decl = + match decl.decl_desc with + | DFunction { body; attributes; _ } -> + let is_oblivious = List.exists (fun a -> a = AOblivious || a = AConstantTime) attributes in + let ctx = if is_oblivious then enter_oblivious_function initial_context else initial_context in + List.iter (check_stmt state ctx) body + + | DStruct _ -> () + + | DConst { value; _ } -> + check_expr state initial_context value + + | DExtern _ | DImport _ -> () + +(** Check a complete program for obliviousness *) +let check_program program = + let state = create_state () in + List.iter (check_decl state) program.declarations; + (state.diags, state.violations) + +(** Summary of obliviousness analysis *) +type analysis_result = { + total_violations: int; + secret_branches: int; + secret_indices: int; + secret_loops: int; + info_leaks: int; +} + +let analyze_violations diags = + let errs = get_errors diags in + let count kind = List.length (List.filter (fun d -> + match d.kind with k when k = kind -> true | _ -> false + ) errs) in + { + total_violations = List.length errs; + secret_branches = count Secret_dependent_branch; + secret_indices = List.length (List.filter (fun d -> + match d.kind with Secret_array_index _ -> true | _ -> false + ) errs); + secret_loops = count Secret_loop_bound; + info_leaks = List.length (List.filter (fun d -> + match d.kind with Information_leak _ -> true | _ -> false + ) errs); + } diff --git a/obli-transpiler-framework/frontend/lib/parser.mly b/obli-transpiler-framework/frontend/lib/parser.mly new file mode 100644 index 0000000..a803a76 --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/parser.mly @@ -0,0 +1,461 @@ +/* SPDX-License-Identifier: MIT OR Palimpsest-0.8 */ +/* Copyright (c) 2024 Hyperpolymath */ + +/* Oblibeny Language Parser */ + +%{ + open Ast + open Location + + let loc () = + let startpos = Parsing.symbol_start_pos () in + let endpos = Parsing.symbol_end_pos () in + { + start_pos = { line = startpos.Lexing.pos_lnum; + column = startpos.Lexing.pos_cnum - startpos.Lexing.pos_bol; + offset = startpos.Lexing.pos_cnum }; + end_pos = { line = endpos.Lexing.pos_lnum; + column = endpos.Lexing.pos_cnum - endpos.Lexing.pos_bol; + offset = endpos.Lexing.pos_cnum }; + filename = startpos.Lexing.pos_fname; + } +%} + +/* Tokens */ +%token INT_LIT +%token BYTE_LIT +%token IDENT +%token TRUE FALSE + +/* Types */ +%token INT_T UINT_T BOOL_T BYTE_T UNIT_T ARRAY_T OARRAY_T REF_T + +/* Security labels */ +%token LOW HIGH + +/* Keywords */ +%token FN LET MUT IF ELSE WHILE FOR IN RETURN BREAK CONTINUE +%token STRUCT CONST EXTERN IMPORT +%token AND OR NOT +%token OREAD OWRITE CMOV + +/* Delimiters */ +%token LPAREN RPAREN LBRACE RBRACE LBRACK RBRACK +%token LT GT COMMA SEMI COLON DOT DOTDOT ARROW FAT_ARROW AT + +/* Operators */ +%token PLUS MINUS STAR SLASH PERCENT +%token EQ EQEQ NEQ LE GE +%token SHL SHR AMP PIPE CARET TILDE BANG +%token AMPAMP PIPEPIPE + +%token EOF + +/* Precedence (lowest to highest) */ +%left PIPEPIPE OR +%left AMPAMP AND +%left PIPE +%left CARET +%left AMP +%left EQEQ NEQ +%left LT LE GT GE +%left SHL SHR +%left PLUS MINUS +%left STAR SLASH PERCENT +%right BANG NOT TILDE UMINUS +%left DOT LBRACK + +%start program +%type program + +%% + +program: + | module_header declarations EOF + { { module_name = $1; declarations = $2 } } + | declarations EOF + { { module_name = None; declarations = $1 } } +; + +module_header: + | IMPORT path SEMI { Some (String.concat "." $2) } +; + +path: + | IDENT { [$1] } + | path DOT IDENT { $1 @ [$3] } +; + +declarations: + | /* empty */ { [] } + | declaration declarations { $1 :: $2 } +; + +declaration: + | function_decl { $1 } + | struct_decl { $1 } + | const_decl { $1 } + | extern_decl { $1 } + | import_decl { $1 } +; + +attributes: + | /* empty */ { [] } + | attribute attributes { $1 :: $2 } +; + +attribute: + | AT IDENT { + match $2 with + | "oblivious" -> AOblivious + | "inline" -> AInline + | "no_optimize" -> ANoOptimize + | "constant_time" -> AConstantTime + | "public" -> APublic + | name -> ACustom (name, None) + } + | AT IDENT LPAREN IDENT RPAREN { ACustom ($2, Some $4) } +; + +function_decl: + | attributes FN IDENT type_params LPAREN params RPAREN ARROW annotated_type block + { mk_decl (loc ()) (DFunction { + name = $3; + type_params = $4; + params = $6; + return_type = $9; + body = $10; + attributes = $1; + }) + } + | attributes FN IDENT type_params LPAREN params RPAREN block + { mk_decl (loc ()) (DFunction { + name = $3; + type_params = $4; + params = $6; + return_type = mk_atype (loc ()) Low (TPrim TUnit); + body = $8; + attributes = $1; + }) + } +; + +type_params: + | /* empty */ { [] } + | LT type_param_list GT { $2 } +; + +type_param_list: + | IDENT { [$1] } + | IDENT COMMA type_param_list { $1 :: $3 } +; + +params: + | /* empty */ { [] } + | param_list { $1 } +; + +param_list: + | param { [$1] } + | param COMMA param_list { $1 :: $3 } +; + +param: + | IDENT COLON annotated_type { ($1, $3) } +; + +struct_decl: + | attributes STRUCT IDENT type_params LBRACE struct_fields RBRACE + { mk_decl (loc ()) (DStruct { + name = $3; + type_params = $4; + fields = $6; + attributes = $1; + }) + } +; + +struct_fields: + | /* empty */ { [] } + | struct_field struct_fields { $1 :: $2 } +; + +struct_field: + | IDENT COLON annotated_type COMMA { ($1, $3) } + | IDENT COLON annotated_type { ($1, $3) } +; + +const_decl: + | CONST IDENT COLON annotated_type EQ expr SEMI + { mk_decl (loc ()) (DConst { + name = $2; + typ = $4; + value = $6; + }) + } +; + +extern_decl: + | attributes EXTERN FN IDENT LPAREN params RPAREN ARROW annotated_type SEMI + { mk_decl (loc ()) (DExtern { + name = $4; + typ = mk_atype (loc ()) Low (TFun (List.map snd $6 |> List.map (fun at -> at.typ), $9.typ)); + attributes = $1; + }) + } +; + +import_decl: + | IMPORT path SEMI + { mk_decl (loc ()) (DImport $2) } +; + +annotated_type: + | security_label typ { mk_atype (loc ()) $1 $2 } + | typ { mk_atype (loc ()) Low $1 } +; + +security_label: + | AT LOW { Low } + | AT HIGH { High } +; + +typ: + | prim_type { TPrim $1 } + | ARRAY_T LT typ GT { TArray ($3, Low) } + | ARRAY_T LT typ COMMA security_label GT { TArray ($3, $5) } + | OARRAY_T LT typ GT { TOArray $3 } + | REF_T LT typ GT { TRef ($3, Low) } + | REF_T LT typ COMMA security_label GT { TRef ($3, $5) } + | LPAREN type_list RPAREN ARROW typ { TFun ($2, $5) } + | IDENT { TStruct $1 } + | IDENT LT type_args GT { TGeneric ($1, $4) } +; + +prim_type: + | INT_T { TInt None } + | INT_T LT INT_LIT GT { TInt (Some (Int64.to_int $3)) } + | UINT_T { TUint None } + | UINT_T LT INT_LIT GT { TUint (Some (Int64.to_int $3)) } + | BOOL_T { TBool } + | BYTE_T { TByte } + | UNIT_T { TUnit } +; + +type_list: + | /* empty */ { [] } + | typ { [$1] } + | typ COMMA type_list { $1 :: $3 } +; + +type_args: + | typ { [$1] } + | typ COMMA type_args { $1 :: $3 } +; + +block: + | LBRACE statements RBRACE { $2 } +; + +statements: + | /* empty */ { [] } + | statement statements { $1 :: $2 } +; + +statement: + | LET pattern type_annotation EQ expr SEMI + { mk_stmt (loc ()) (SLet ($2, $3, $5)) } + | lvalue EQ expr SEMI + { mk_stmt (loc ()) (SAssign ($1, $3)) } + | OWRITE LPAREN expr COMMA expr COMMA expr RPAREN SEMI + { mk_stmt (loc ()) (SOramWrite ($3, $5, $7)) } + | expr SEMI + { mk_stmt (loc ()) (SExpr $1) } + | IF expr block + { mk_stmt (loc ()) (SIf ($2, $3, [])) } + | IF expr block ELSE block + { mk_stmt (loc ()) (SIf ($2, $3, $5)) } + | IF expr block ELSE statement + { mk_stmt (loc ()) (SIf ($2, $3, [$5])) } + | WHILE expr block + { mk_stmt (loc ()) (SWhile ($2, $3)) } + | FOR IDENT IN expr DOTDOT expr block + { mk_stmt (loc ()) (SFor ($2, $4, $6, $7)) } + | RETURN SEMI + { mk_stmt (loc ()) (SReturn None) } + | RETURN expr SEMI + { mk_stmt (loc ()) (SReturn (Some $2)) } + | BREAK SEMI + { mk_stmt (loc ()) SBreak } + | CONTINUE SEMI + { mk_stmt (loc ()) SContinue } +; + +type_annotation: + | /* empty */ { None } + | COLON annotated_type { Some $2 } +; + +pattern: + | IDENT { PVar $1 } + | LPAREN pattern_list RPAREN { PTuple $2 } +; + +pattern_list: + | pattern { [$1] } + | pattern COMMA pattern_list { $1 :: $3 } +; + +lvalue: + | IDENT { mk_expr (loc ()) (EVar $1) } + | lvalue DOT IDENT { mk_expr (loc ()) (EField ($1, $3)) } + | lvalue LBRACK expr RBRACK { mk_expr (loc ()) (EIndex ($1, $3)) } +; + +expr: + | expr_or { $1 } +; + +expr_or: + | expr_and { $1 } + | expr_or PIPEPIPE expr_and { mk_expr (loc ()) (EBinop (Or, $1, $3)) } + | expr_or OR expr_and { mk_expr (loc ()) (EBinop (Or, $1, $3)) } +; + +expr_and: + | expr_bitor { $1 } + | expr_and AMPAMP expr_bitor { mk_expr (loc ()) (EBinop (And, $1, $3)) } + | expr_and AND expr_bitor { mk_expr (loc ()) (EBinop (And, $1, $3)) } +; + +expr_bitor: + | expr_bitxor { $1 } + | expr_bitor PIPE expr_bitxor { mk_expr (loc ()) (EBinop (BitOr, $1, $3)) } +; + +expr_bitxor: + | expr_bitand { $1 } + | expr_bitxor CARET expr_bitand { mk_expr (loc ()) (EBinop (BitXor, $1, $3)) } +; + +expr_bitand: + | expr_eq { $1 } + | expr_bitand AMP expr_eq { mk_expr (loc ()) (EBinop (BitAnd, $1, $3)) } +; + +expr_eq: + | expr_cmp { $1 } + | expr_eq EQEQ expr_cmp { mk_expr (loc ()) (EBinop (Eq, $1, $3)) } + | expr_eq NEQ expr_cmp { mk_expr (loc ()) (EBinop (Neq, $1, $3)) } +; + +expr_cmp: + | expr_shift { $1 } + | expr_cmp LT expr_shift { mk_expr (loc ()) (EBinop (Lt, $1, $3)) } + | expr_cmp LE expr_shift { mk_expr (loc ()) (EBinop (Le, $1, $3)) } + | expr_cmp GT expr_shift { mk_expr (loc ()) (EBinop (Gt, $1, $3)) } + | expr_cmp GE expr_shift { mk_expr (loc ()) (EBinop (Ge, $1, $3)) } +; + +expr_shift: + | expr_add { $1 } + | expr_shift SHL expr_add { mk_expr (loc ()) (EBinop (Shl, $1, $3)) } + | expr_shift SHR expr_add { mk_expr (loc ()) (EBinop (Shr, $1, $3)) } +; + +expr_add: + | expr_mul { $1 } + | expr_add PLUS expr_mul { mk_expr (loc ()) (EBinop (Add, $1, $3)) } + | expr_add MINUS expr_mul { mk_expr (loc ()) (EBinop (Sub, $1, $3)) } +; + +expr_mul: + | expr_unary { $1 } + | expr_mul STAR expr_unary { mk_expr (loc ()) (EBinop (Mul, $1, $3)) } + | expr_mul SLASH expr_unary { mk_expr (loc ()) (EBinop (Div, $1, $3)) } + | expr_mul PERCENT expr_unary { mk_expr (loc ()) (EBinop (Mod, $1, $3)) } +; + +expr_unary: + | expr_postfix { $1 } + | MINUS expr_unary %prec UMINUS { mk_expr (loc ()) (EUnop (Neg, $2)) } + | BANG expr_unary { mk_expr (loc ()) (EUnop (Not, $2)) } + | NOT expr_unary { mk_expr (loc ()) (EUnop (Not, $2)) } + | TILDE expr_unary { mk_expr (loc ()) (EUnop (BitNot, $2)) } +; + +expr_postfix: + | expr_primary { $1 } + | expr_postfix DOT IDENT { mk_expr (loc ()) (EField ($1, $3)) } + | expr_postfix LBRACK expr RBRACK { mk_expr (loc ()) (EIndex ($1, $3)) } + | expr_postfix LPAREN args RPAREN { mk_expr (loc ()) (ECall ($1, $3)) } +; + +expr_primary: + | literal { mk_expr (loc ()) (ELiteral $1) } + | IDENT { mk_expr (loc ()) (EVar $1) } + | LPAREN expr RPAREN { $2 } + | LPAREN expr COMMA expr_list RPAREN { mk_expr (loc ()) (ETuple ($2 :: $4)) } + | LBRACE statements expr RBRACE { mk_expr (loc ()) (EBlock ($2, Some $3)) } + | LBRACE statements RBRACE { mk_expr (loc ()) (EBlock ($2, None)) } + | IF expr LBRACE expr RBRACE ELSE LBRACE expr RBRACE + { mk_expr (loc ()) (EIf ($2, $4, $8)) } + | OREAD LPAREN expr COMMA expr RPAREN { mk_expr (loc ()) (EOramRead ($3, $5)) } + | CMOV LPAREN expr COMMA expr COMMA expr RPAREN { mk_expr (loc ()) (ECmov ($3, $5, $7)) } + | IDENT LBRACE field_inits RBRACE { mk_expr (loc ()) (EStruct ($1, $3)) } + | FN LPAREN lambda_params RPAREN FAT_ARROW expr + { mk_expr (loc ()) (ELambda ($3, $6)) } +; + +literal: + | INT_LIT { LInt $1 } + | TRUE { LBool true } + | FALSE { LBool false } + | BYTE_LIT { LByte $1 } + | LPAREN RPAREN { LUnit } +; + +args: + | /* empty */ { [] } + | arg_list { $1 } +; + +arg_list: + | expr { [$1] } + | expr COMMA arg_list { $1 :: $3 } +; + +expr_list: + | expr { [$1] } + | expr COMMA expr_list { $1 :: $3 } +; + +field_inits: + | /* empty */ { [] } + | field_init_list { $1 } +; + +field_init_list: + | field_init { [$1] } + | field_init COMMA field_init_list { $1 :: $3 } +; + +field_init: + | IDENT COLON expr { ($1, $3) } +; + +lambda_params: + | /* empty */ { [] } + | lambda_param_list { $1 } +; + +lambda_param_list: + | lambda_param { [$1] } + | lambda_param COMMA lambda_param_list { $1 :: $3 } +; + +lambda_param: + | IDENT COLON annotated_type { ($1, $3) } +; + +%% diff --git a/obli-transpiler-framework/frontend/lib/typecheck.ml b/obli-transpiler-framework/frontend/lib/typecheck.ml new file mode 100644 index 0000000..f62178a --- /dev/null +++ b/obli-transpiler-framework/frontend/lib/typecheck.ml @@ -0,0 +1,528 @@ +(* SPDX-License-Identifier: MIT OR Palimpsest-0.8 *) +(* Copyright (c) 2024 Hyperpolymath *) + +(** Type checking pass for Oblibeny *) + +open Ast +open Errors + +(** Type environment *) +module Env = struct + type binding = + | VarBinding of annotated_type + | FunBinding of { params: annotated_type list; ret: annotated_type } + | TypeBinding of typ + | StructBinding of { fields: (string * annotated_type) list } + + type t = { + bindings: (string, binding) Hashtbl.t; + parent: t option; + in_function: annotated_type option; (* Return type if inside function *) + } + + let create ?parent () = { + bindings = Hashtbl.create 16; + parent; + in_function = None; + } + + let enter_function ret_type env = { + bindings = Hashtbl.create 16; + parent = Some env; + in_function = Some ret_type; + } + + let enter_scope env = { + bindings = Hashtbl.create 16; + parent = Some env; + in_function = env.in_function; + } + + let rec lookup name env = + match Hashtbl.find_opt env.bindings name with + | Some b -> Some b + | None -> + match env.parent with + | Some p -> lookup name p + | None -> None + + let add name binding env = + Hashtbl.replace env.bindings name binding + + let add_var name typ env = + add name (VarBinding typ) env + + let add_fun name params ret env = + add name (FunBinding { params; ret }) env + + let add_struct name fields env = + add name (StructBinding { fields }) env + + let return_type env = env.in_function +end + +(** Type representation utilities *) +let rec type_to_string = function + | TPrim (TInt None) -> "int" + | TPrim (TInt (Some n)) -> Printf.sprintf "int<%d>" n + | TPrim (TUint None) -> "uint" + | TPrim (TUint (Some n)) -> Printf.sprintf "uint<%d>" n + | TPrim TBool -> "bool" + | TPrim TByte -> "byte" + | TPrim TUnit -> "unit" + | TArray (t, _) -> Printf.sprintf "array<%s>" (type_to_string t) + | TOArray t -> Printf.sprintf "oarray<%s>" (type_to_string t) + | TRef (t, _) -> Printf.sprintf "ref<%s>" (type_to_string t) + | TFun (args, ret) -> + let args_str = String.concat ", " (List.map type_to_string args) in + Printf.sprintf "(%s) -> %s" args_str (type_to_string ret) + | TStruct name -> name + | TGeneric (name, args) -> + let args_str = String.concat ", " (List.map type_to_string args) in + Printf.sprintf "%s<%s>" name args_str + | TVar name -> "'" ^ name + +let security_to_string = function + | Low -> "low" + | High -> "high" + +let annotated_type_to_string at = + Printf.sprintf "@%s %s" (security_to_string at.security) (type_to_string at.typ) + +(** Type equality *) +let rec types_equal t1 t2 = + match t1, t2 with + | TPrim p1, TPrim p2 -> p1 = p2 + | TArray (e1, _), TArray (e2, _) -> types_equal e1 e2 + | TOArray e1, TOArray e2 -> types_equal e1 e2 + | TRef (e1, _), TRef (e2, _) -> types_equal e1 e2 + | TFun (a1, r1), TFun (a2, r2) -> + List.length a1 = List.length a2 && + List.for_all2 types_equal a1 a2 && + types_equal r1 r2 + | TStruct n1, TStruct n2 -> n1 = n2 + | TGeneric (n1, a1), TGeneric (n2, a2) -> + n1 = n2 && List.length a1 = List.length a2 && + List.for_all2 types_equal a1 a2 + | TVar n1, TVar n2 -> n1 = n2 + | _ -> false + +(** Security label lattice *) +let security_join s1 s2 = + match s1, s2 with + | High, _ | _, High -> High + | Low, Low -> Low + +let security_leq s1 s2 = + match s1, s2 with + | Low, _ -> true + | High, High -> true + | High, Low -> false + +(** Type checker state *) +type state = { + diags: diagnostics; + env: Env.t; +} + +let create_state () = { + diags = create_diagnostics (); + env = Env.create (); +} + +(** Check binary operator types *) +let check_binop state op lhs_type rhs_type loc = + let numeric_types = [TPrim (TInt None); TPrim (TUint None); TPrim TByte] in + let is_numeric t = List.exists (types_equal t) numeric_types in + let is_bool t = types_equal t (TPrim TBool) in + let is_int t = match t with TPrim (TInt _ | TUint _) -> true | _ -> false in + + match op with + | Add | Sub | Mul | Div | Mod -> + if not (is_numeric lhs_type && types_equal lhs_type rhs_type) then + report state.diags (make_error + (Invalid_operation { op = show_binop op; typ = type_to_string lhs_type }) + loc "arithmetic operation requires matching numeric types"); + lhs_type + + | Eq | Neq -> + if not (types_equal lhs_type rhs_type) then + report state.diags (type_mismatch + ~expected:(type_to_string lhs_type) + ~found:(type_to_string rhs_type) + loc); + TPrim TBool + + | Lt | Le | Gt | Ge -> + if not (is_numeric lhs_type && types_equal lhs_type rhs_type) then + report state.diags (make_error + (Invalid_operation { op = show_binop op; typ = type_to_string lhs_type }) + loc "comparison requires matching numeric types"); + TPrim TBool + + | And | Or -> + if not (is_bool lhs_type && is_bool rhs_type) then + report state.diags (make_error + (Invalid_operation { op = show_binop op; typ = type_to_string lhs_type }) + loc "logical operation requires boolean operands"); + TPrim TBool + + | BitAnd | BitOr | BitXor -> + if not (is_int lhs_type && types_equal lhs_type rhs_type) then + report state.diags (make_error + (Invalid_operation { op = show_binop op; typ = type_to_string lhs_type }) + loc "bitwise operation requires matching integer types"); + lhs_type + + | Shl | Shr -> + if not (is_int lhs_type && is_int rhs_type) then + report state.diags (make_error + (Invalid_operation { op = show_binop op; typ = type_to_string lhs_type }) + loc "shift operation requires integer operands"); + lhs_type + +(** Check unary operator types *) +let check_unop state op operand_type loc = + match op with + | Neg -> + (match operand_type with + | TPrim (TInt _ | TUint _) -> operand_type + | _ -> + report state.diags (make_error + (Invalid_operation { op = "negation"; typ = type_to_string operand_type }) + loc "negation requires numeric type"); + operand_type) + + | Not -> + if not (types_equal operand_type (TPrim TBool)) then + report state.diags (make_error + (Invalid_operation { op = "not"; typ = type_to_string operand_type }) + loc "logical not requires boolean operand"); + TPrim TBool + + | BitNot -> + (match operand_type with + | TPrim (TInt _ | TUint _) -> operand_type + | _ -> + report state.diags (make_error + (Invalid_operation { op = "bitwise not"; typ = type_to_string operand_type }) + loc "bitwise not requires integer type"); + operand_type) + +(** Type check expression *) +let rec check_expr state env expr = + let (typ, security) = check_expr_desc state env expr.expr_desc expr.expr_loc in + let atype = mk_atype expr.expr_loc security typ in + expr.expr_type <- Some atype; + atype + +and check_expr_desc state env desc loc = + match desc with + | ELiteral lit -> + let typ = match lit with + | LInt _ -> TPrim (TInt None) + | LUint _ -> TPrim (TUint None) + | LBool _ -> TPrim TBool + | LByte _ -> TPrim TByte + | LUnit -> TPrim TUnit + in + (typ, Low) + + | EVar name -> + (match Env.lookup name env with + | Some (Env.VarBinding at) -> (at.typ, at.security) + | Some (Env.FunBinding { params; ret }) -> + (TFun (List.map (fun at -> at.typ) params, ret.typ), Low) + | _ -> + report state.diags (unknown_identifier name loc); + (TPrim TUnit, Low)) + + | EBinop (op, lhs, rhs) -> + let lhs_at = check_expr state env lhs in + let rhs_at = check_expr state env rhs in + let result_type = check_binop state op lhs_at.typ rhs_at.typ loc in + let result_security = security_join lhs_at.security rhs_at.security in + (result_type, result_security) + + | EUnop (op, operand) -> + let operand_at = check_expr state env operand in + let result_type = check_unop state op operand_at.typ loc in + (result_type, operand_at.security) + + | ECall (func, args) -> + let func_at = check_expr state env func in + (match func_at.typ with + | TFun (param_types, ret_type) -> + if List.length args <> List.length param_types then + report state.diags (make_error + (Arity_mismatch { expected = List.length param_types; found = List.length args }) + loc "wrong number of arguments"); + let arg_security = List.fold_left (fun acc arg -> + let at = check_expr state env arg in + security_join acc at.security + ) Low args in + (ret_type, arg_security) + | _ -> + report state.diags (make_error + (Not_a_function (type_to_string func_at.typ)) + loc "called expression is not a function"); + (TPrim TUnit, Low)) + + | EIndex (arr, idx) -> + let arr_at = check_expr state env arr in + let idx_at = check_expr state env idx in + let elem_type = match arr_at.typ with + | TArray (elem, _) -> elem + | TOArray elem -> elem + | _ -> + report state.diags (make_error + (Invalid_operation { op = "index"; typ = type_to_string arr_at.typ }) + loc "indexing requires array type"); + TPrim TUnit + in + (elem_type, security_join arr_at.security idx_at.security) + + | EOramRead (arr, idx) -> + let arr_at = check_expr state env arr in + let idx_at = check_expr state env idx in + let elem_type = match arr_at.typ with + | TOArray elem -> elem + | _ -> + report state.diags (make_error + (Invalid_operation { op = "oread"; typ = type_to_string arr_at.typ }) + loc "oread requires oarray type"); + TPrim TUnit + in + (* ORAM read result is high security because it's designed for secret indices *) + (elem_type, security_join High idx_at.security) + + | EField (obj, field) -> + let obj_at = check_expr state env obj in + (match obj_at.typ with + | TStruct name -> + (match Env.lookup name env with + | Some (Env.StructBinding { fields }) -> + (match List.assoc_opt field fields with + | Some ft -> (ft.typ, security_join obj_at.security ft.security) + | None -> + report state.diags (make_error + (Field_not_found { struct_name = name; field }) + loc (Printf.sprintf "field `%s` not found in struct `%s`" field name)); + (TPrim TUnit, Low)) + | _ -> + report state.diags (unknown_type name loc); + (TPrim TUnit, Low)) + | _ -> + report state.diags (make_error + (Invalid_operation { op = "field access"; typ = type_to_string obj_at.typ }) + loc "field access requires struct type"); + (TPrim TUnit, Low)) + + | EIf (cond, then_expr, else_expr) -> + let cond_at = check_expr state env cond in + if not (types_equal cond_at.typ (TPrim TBool)) then + report state.diags (type_mismatch ~expected:"bool" ~found:(type_to_string cond_at.typ) loc); + let then_at = check_expr state env then_expr in + let else_at = check_expr state env else_expr in + if not (types_equal then_at.typ else_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string then_at.typ) + ~found:(type_to_string else_at.typ) + else_expr.expr_loc); + let result_security = security_join cond_at.security (security_join then_at.security else_at.security) in + (then_at.typ, result_security) + + | EBlock (stmts, final_expr) -> + let block_env = Env.enter_scope env in + List.iter (check_stmt state block_env) stmts; + (match final_expr with + | Some e -> check_expr state block_env e |> (fun at -> (at.typ, at.security)) + | None -> (TPrim TUnit, Low)) + + | ELambda (params, body) -> + let lambda_env = Env.enter_scope env in + List.iter (fun (name, at) -> Env.add_var name at lambda_env) params; + let body_at = check_expr state lambda_env body in + (TFun (List.map (fun (_, at) -> at.typ) params, body_at.typ), Low) + + | ETuple exprs -> + (* For simplicity, we don't have tuple types yet - treat as struct *) + let _ = List.map (check_expr state env) exprs in + (TPrim TUnit, Low) (* TODO: Add proper tuple types *) + + | EStruct (name, fields) -> + (match Env.lookup name env with + | Some (Env.StructBinding { fields = expected_fields }) -> + let field_security = List.fold_left (fun acc (fname, fexpr) -> + let fat = check_expr state env fexpr in + match List.assoc_opt fname expected_fields with + | Some expected_at -> + if not (types_equal fat.typ expected_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string expected_at.typ) + ~found:(type_to_string fat.typ) + fexpr.expr_loc); + security_join acc fat.security + | None -> + report state.diags (make_error + (Field_not_found { struct_name = name; field = fname }) + fexpr.expr_loc (Printf.sprintf "unknown field `%s`" fname)); + acc + ) Low fields in + (TStruct name, field_security) + | _ -> + report state.diags (unknown_type name loc); + (TPrim TUnit, Low)) + + | ECmov (cond, then_val, else_val) -> + let cond_at = check_expr state env cond in + if not (types_equal cond_at.typ (TPrim TBool)) then + report state.diags (type_mismatch ~expected:"bool" ~found:(type_to_string cond_at.typ) loc); + let then_at = check_expr state env then_val in + let else_at = check_expr state env else_val in + if not (types_equal then_at.typ else_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string then_at.typ) + ~found:(type_to_string else_at.typ) + else_val.expr_loc); + let result_security = security_join cond_at.security (security_join then_at.security else_at.security) in + (then_at.typ, result_security) + +(** Type check statement *) +and check_stmt state env stmt = + match stmt.stmt_desc with + | SLet (pattern, type_annot, init) -> + let init_at = check_expr state env init in + let bound_type = match type_annot with + | Some annot -> + if not (types_equal annot.typ init_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string annot.typ) + ~found:(type_to_string init_at.typ) + init.expr_loc); + annot + | None -> init_at + in + (match pattern with + | PVar name -> Env.add_var name bound_type env + | _ -> () (* TODO: Handle other patterns *)) + + | SAssign (lhs, rhs) -> + let lhs_at = check_expr state env lhs in + let rhs_at = check_expr state env rhs in + if not (types_equal lhs_at.typ rhs_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string lhs_at.typ) + ~found:(type_to_string rhs_at.typ) + rhs.expr_loc) + + | SOramWrite (arr, idx, value) -> + let arr_at = check_expr state env arr in + let _idx_at = check_expr state env idx in + let value_at = check_expr state env value in + (match arr_at.typ with + | TOArray elem_type -> + if not (types_equal elem_type value_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string elem_type) + ~found:(type_to_string value_at.typ) + value.expr_loc) + | _ -> + report state.diags (make_error + (Invalid_operation { op = "owrite"; typ = type_to_string arr_at.typ }) + stmt.stmt_loc "owrite requires oarray type")) + + | SExpr e -> + let _ = check_expr state env e in () + + | SIf (cond, then_stmts, else_stmts) -> + let cond_at = check_expr state env cond in + if not (types_equal cond_at.typ (TPrim TBool)) then + report state.diags (type_mismatch ~expected:"bool" ~found:(type_to_string cond_at.typ) cond.expr_loc); + let then_env = Env.enter_scope env in + List.iter (check_stmt state then_env) then_stmts; + let else_env = Env.enter_scope env in + List.iter (check_stmt state else_env) else_stmts + + | SWhile (cond, body) -> + let cond_at = check_expr state env cond in + if not (types_equal cond_at.typ (TPrim TBool)) then + report state.diags (type_mismatch ~expected:"bool" ~found:(type_to_string cond_at.typ) cond.expr_loc); + let body_env = Env.enter_scope env in + List.iter (check_stmt state body_env) body + + | SFor (var, start_expr, end_expr, body) -> + let start_at = check_expr state env start_expr in + let end_at = check_expr state env end_expr in + let iter_type = match start_at.typ with + | TPrim (TInt _ | TUint _) -> start_at.typ + | _ -> + report state.diags (make_error + (Invalid_operation { op = "for loop"; typ = type_to_string start_at.typ }) + start_expr.expr_loc "for loop range requires integer type"); + TPrim (TInt None) + in + if not (types_equal start_at.typ end_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string start_at.typ) + ~found:(type_to_string end_at.typ) + end_expr.expr_loc); + let body_env = Env.enter_scope env in + Env.add_var var (mk_atype stmt.stmt_loc (security_join start_at.security end_at.security) iter_type) body_env; + List.iter (check_stmt state body_env) body + + | SReturn expr_opt -> + (match Env.return_type env, expr_opt with + | Some ret_type, Some expr -> + let expr_at = check_expr state env expr in + if not (types_equal ret_type.typ expr_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string ret_type.typ) + ~found:(type_to_string expr_at.typ) + expr.expr_loc) + | Some ret_type, None -> + if not (types_equal ret_type.typ (TPrim TUnit)) then + report state.diags (type_mismatch + ~expected:(type_to_string ret_type.typ) + ~found:"unit" + stmt.stmt_loc) + | None, _ -> + report state.diags (make_error + (Internal_error "return outside function") + stmt.stmt_loc "return statement outside of function")) + + | SBreak | SContinue -> () + +(** Type check declaration *) +let check_decl state env decl = + match decl.decl_desc with + | DFunction { name; params; return_type; body; _ } -> + let param_types = List.map snd params in + Env.add_fun name param_types return_type env; + let fn_env = Env.enter_function return_type env in + List.iter (fun (pname, ptype) -> Env.add_var pname ptype fn_env) params; + List.iter (check_stmt state fn_env) body + + | DStruct { name; fields; _ } -> + Env.add_struct name fields env + + | DConst { name; typ; value } -> + let value_at = check_expr state env value in + if not (types_equal typ.typ value_at.typ) then + report state.diags (type_mismatch + ~expected:(type_to_string typ.typ) + ~found:(type_to_string value_at.typ) + value.expr_loc); + Env.add_var name typ env + + | DExtern { name; typ; _ } -> + Env.add_var name typ env + + | DImport _ -> () + +(** Type check a complete program *) +let check_program program = + let state = create_state () in + (* Add built-in types and functions *) + List.iter (check_decl state state.env) program.declarations; + state.diags diff --git a/obli-transpiler-framework/justfile b/obli-transpiler-framework/justfile new file mode 100644 index 0000000..47aad61 --- /dev/null +++ b/obli-transpiler-framework/justfile @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: MIT OR Palimpsest-0.8 +# Copyright (c) 2024 Hyperpolymath + +# Oblibeny Transpiler Framework Build System + +default: build + +# Build everything +build: build-frontend build-backend build-runtime build-driver + +# Build the OCaml frontend +build-frontend: + cd frontend && dune build + +# Build the Rust backend +build-backend: + cargo build -p oblibeny-backend --release + +# Build the ORAM runtime +build-runtime: + cargo build -p oblibeny-runtime --release + +# Build the driver CLI +build-driver: + cargo build -p oblibeny --release + +# Run all tests +test: test-frontend test-backend test-runtime + +# Test the OCaml frontend +test-frontend: + cd frontend && dune runtest + +# Test the Rust backend +test-backend: + cargo test -p oblibeny-backend + +# Test the runtime +test-runtime: + cargo test -p oblibeny-runtime + +# Clean all build artifacts +clean: + cd frontend && dune clean + cargo clean + +# Format all code +fmt: + cd frontend && dune fmt + cargo fmt + +# Lint all code +lint: + cd frontend && dune build @check + cargo clippy -- -D warnings + +# Install the compiler to ~/.local/bin +install: build + install -m755 target/release/oblibeny ~/.local/bin/ + install -m755 target/release/oblibeny-backend ~/.local/bin/ + install -m755 frontend/_build/default/bin/main.exe ~/.local/bin/oblibeny-frontend + +# Run benchmarks +bench: + cargo bench -p oblibeny-runtime + +# Generate documentation +doc: + cd frontend && dune build @doc + cargo doc --workspace --no-deps + +# Compile an example file +example FILE: + ./target/release/oblibeny compile {{FILE}} + +# Check an example file +check FILE: + ./target/release/oblibeny check {{FILE}} diff --git a/obli-transpiler-framework/runtime/Cargo.toml b/obli-transpiler-framework/runtime/Cargo.toml new file mode 100644 index 0000000..b5c19bc --- /dev/null +++ b/obli-transpiler-framework/runtime/Cargo.toml @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: MIT OR Palimpsest-0.8 +# Copyright (c) 2024 Hyperpolymath + +[package] +name = "oblibeny-runtime" +version = "0.1.0" +edition = "2021" +authors = ["Hyperpolymath"] +description = "Oblibeny ORAM runtime library - oblivious data structures and constant-time primitives" +license = "MIT OR Palimpsest-0.8" +repository = "https://github.com/hyperpolymath/oblibeny" + +[features] +default = ["std"] +std = [] +hardware-aes = ["aes/force-soft"] + +[dependencies] +subtle = "2.5" +zeroize = { version = "1.7", features = ["derive"] } +rand = "0.8" +rand_chacha = "0.3" +aes-gcm = "0.10" +aes = "0.8" +sha2 = "0.10" +blake3 = "1.5" + +[dev-dependencies] +criterion = "0.5" +proptest = "1.4" + +[[bench]] +name = "oram_bench" +harness = false diff --git a/obli-transpiler-framework/runtime/benches/oram_bench.rs b/obli-transpiler-framework/runtime/benches/oram_bench.rs new file mode 100644 index 0000000..8d1ef7b --- /dev/null +++ b/obli-transpiler-framework/runtime/benches/oram_bench.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! ORAM benchmarks + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use oblibeny_runtime::prelude::*; + +fn bench_oarray_read(c: &mut Criterion) { + let mut group = c.benchmark_group("OArray Read"); + + for size in [100, 1000, 10000] { + let mut arr: OArray = OArray::new(size); + + // Initialize + for i in 0..size { + arr.write(i, i * 10); + } + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let idx = size / 2; + b.iter(|| black_box(arr.read(black_box(idx)))); + }); + } + + group.finish(); +} + +fn bench_oarray_write(c: &mut Criterion) { + let mut group = c.benchmark_group("OArray Write"); + + for size in [100, 1000, 10000] { + let mut arr: OArray = OArray::new(size); + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| { + let idx = size / 2; + b.iter(|| arr.write(black_box(idx), black_box(12345))); + }); + } + + group.finish(); +} + +fn bench_constant_time_ops(c: &mut Criterion) { + let mut group = c.benchmark_group("Constant Time"); + + group.bench_function("cmov u64", |b| { + b.iter(|| cmov(black_box(true), black_box(42u64), black_box(0u64))); + }); + + group.bench_function("cswap u64", |b| { + let mut a = 1u64; + let mut x = 2u64; + b.iter(|| { + cswap(black_box(true), &mut a, &mut x); + }); + }); + + let array = [1u64, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + group.bench_function("ct_lookup [10]", |b| { + b.iter(|| ct_lookup(&array, black_box(5))); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_oarray_read, + bench_oarray_write, + bench_constant_time_ops +); +criterion_main!(benches); diff --git a/obli-transpiler-framework/runtime/src/collections.rs b/obli-transpiler-framework/runtime/src/collections.rs new file mode 100644 index 0000000..d2e9056 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/collections.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Oblivious collections +//! +//! These collections hide access patterns using ORAM. + +use crate::crypto::SecretKey; +use crate::oram::{OArray, OramBlock}; +use subtle::ConditionallySelectable; + +/// Oblivious stack +/// +/// A stack where push/pop operations hide which element is accessed. +pub struct OStack { + data: OArray, + size: u64, + capacity: u64, +} + +impl OStack { + /// Create a new oblivious stack with given capacity + pub fn new(capacity: u64) -> Self { + OStack { + data: OArray::new(capacity), + size: 0, + capacity, + } + } + + /// Push a value onto the stack + pub fn push(&mut self, value: T) -> bool { + if self.size >= self.capacity { + return false; + } + self.data.write(self.size, value); + self.size += 1; + true + } + + /// Pop a value from the stack + pub fn pop(&mut self) -> Option { + if self.size == 0 { + // Perform dummy access to maintain constant access pattern + let _ = self.data.read(0); + return None; + } + self.size -= 1; + Some(self.data.read(self.size)) + } + + /// Peek at the top value + pub fn peek(&mut self) -> Option { + if self.size == 0 { + let _ = self.data.read(0); + return None; + } + Some(self.data.read(self.size - 1)) + } + + /// Get the current size + pub fn len(&self) -> u64 { + self.size + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.size == 0 + } +} + +/// Oblivious queue +/// +/// A queue where enqueue/dequeue hide which element is accessed. +pub struct OQueue { + data: OArray, + head: u64, + tail: u64, + size: u64, + capacity: u64, +} + +impl OQueue { + /// Create a new oblivious queue with given capacity + pub fn new(capacity: u64) -> Self { + OQueue { + data: OArray::new(capacity), + head: 0, + tail: 0, + size: 0, + capacity, + } + } + + /// Enqueue a value + pub fn enqueue(&mut self, value: T) -> bool { + if self.size >= self.capacity { + return false; + } + self.data.write(self.tail, value); + self.tail = (self.tail + 1) % self.capacity; + self.size += 1; + true + } + + /// Dequeue a value + pub fn dequeue(&mut self) -> Option { + if self.size == 0 { + let _ = self.data.read(0); + return None; + } + let value = self.data.read(self.head); + self.head = (self.head + 1) % self.capacity; + self.size -= 1; + Some(value) + } + + /// Peek at the front value + pub fn peek(&mut self) -> Option { + if self.size == 0 { + let _ = self.data.read(0); + return None; + } + Some(self.data.read(self.head)) + } + + /// Get the current size + pub fn len(&self) -> u64 { + self.size + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.size == 0 + } +} + +/// Oblivious map (simple linear scan implementation) +/// +/// For small maps, uses linear scan with ORAM backing. +/// For large maps, a tree-based structure would be more efficient. +pub struct OMap { + keys: OArray, + values: OArray, + size: u64, + capacity: u64, +} + +impl OMap { + /// Create a new oblivious map with given capacity + pub fn new(capacity: u64) -> Self { + OMap { + keys: OArray::new(capacity), + values: OArray::new(capacity), + size: 0, + capacity, + } + } + + /// Insert or update a key-value pair + pub fn insert(&mut self, key: K, value: V) -> bool { + // First, try to find existing key + for i in 0..self.size { + let k = self.keys.read(i); + if k == key { + self.values.write(i, value); + return true; + } + } + + // Key not found, insert new + if self.size >= self.capacity { + return false; + } + self.keys.write(self.size, key); + self.values.write(self.size, value); + self.size += 1; + true + } + + /// Get a value by key + pub fn get(&mut self, key: &K) -> Option + where + K: Clone, + { + for i in 0..self.size { + let k = self.keys.read(i); + if &k == key { + return Some(self.values.read(i)); + } + } + // Dummy access for constant pattern + if self.size < self.capacity { + let _ = self.values.read(0); + } + None + } + + /// Check if key exists + pub fn contains(&mut self, key: &K) -> bool + where + K: Clone, + { + for i in 0..self.size { + let k = self.keys.read(i); + if &k == key { + return true; + } + } + false + } + + /// Get the current size + pub fn len(&self) -> u64 { + self.size + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.size == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ostack() { + let mut stack: OStack = OStack::new(10); + assert!(stack.push(1)); + assert!(stack.push(2)); + assert!(stack.push(3)); + assert_eq!(stack.pop(), Some(3)); + assert_eq!(stack.pop(), Some(2)); + assert_eq!(stack.pop(), Some(1)); + assert_eq!(stack.pop(), None); + } + + #[test] + fn test_oqueue() { + let mut queue: OQueue = OQueue::new(10); + assert!(queue.enqueue(1)); + assert!(queue.enqueue(2)); + assert!(queue.enqueue(3)); + assert_eq!(queue.dequeue(), Some(1)); + assert_eq!(queue.dequeue(), Some(2)); + assert_eq!(queue.dequeue(), Some(3)); + assert_eq!(queue.dequeue(), None); + } + + #[test] + fn test_omap() { + let mut map: OMap = OMap::new(10); + assert!(map.insert(1, 100)); + assert!(map.insert(2, 200)); + assert_eq!(map.get(&1), Some(100)); + assert_eq!(map.get(&2), Some(200)); + assert_eq!(map.get(&3), None); + + // Update existing + assert!(map.insert(1, 150)); + assert_eq!(map.get(&1), Some(150)); + } +} diff --git a/obli-transpiler-framework/runtime/src/constant_time.rs b/obli-transpiler-framework/runtime/src/constant_time.rs new file mode 100644 index 0000000..dc13500 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/constant_time.rs @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Constant-time primitives for side-channel resistance +//! +//! These primitives ensure that execution time does not depend on +//! secret values, preventing timing attacks. + +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess}; + +/// Constant-time conditional move +/// +/// Returns `a` if `cond` is true, `b` otherwise. +/// The selection is done in constant time. +#[inline] +pub fn cmov(cond: bool, a: T, b: T) -> T { + T::conditional_select(&b, &a, Choice::from(cond as u8)) +} + +/// Constant-time conditional swap +/// +/// Swaps `a` and `b` if `cond` is true, otherwise leaves them unchanged. +/// The swap is done in constant time. +#[inline] +pub fn cswap(cond: bool, a: &mut T, b: &mut T) { + T::conditional_swap(a, b, Choice::from(cond as u8)); +} + +/// Constant-time equality comparison +/// +/// Returns true if `a == b` in constant time. +#[inline] +pub fn ct_eq(a: &T, b: &T) -> bool { + a.ct_eq(b).into() +} + +/// Constant-time less-than comparison +/// +/// Returns true if `a < b` in constant time. +#[inline] +pub fn ct_lt(a: &T, b: &T) -> bool { + a.ct_lt(b).into() +} + +/// Constant-time array lookup +/// +/// Returns the element at `index` from `array` in constant time. +/// All elements are accessed regardless of the index value. +#[inline] +pub fn ct_lookup(array: &[T], index: usize) -> T { + let mut result = T::default(); + for (i, elem) in array.iter().enumerate() { + let select = Choice::from((i == index) as u8); + result.conditional_assign(elem, select); + } + result +} + +/// Constant-time array store +/// +/// Stores `value` at `index` in `array` in constant time. +/// All elements are potentially modified regardless of the index value. +#[inline] +pub fn ct_store(array: &mut [T], index: usize, value: &T) { + for (i, elem) in array.iter_mut().enumerate() { + let select = Choice::from((i == index) as u8); + elem.conditional_assign(value, select); + } +} + +/// Constant-time minimum +#[inline] +pub fn ct_min(a: T, b: T) -> T { + cmov(a.ct_lt(&b).into(), a, b) +} + +/// Constant-time maximum +#[inline] +pub fn ct_max(a: T, b: T) -> T { + cmov(b.ct_lt(&a).into(), a, b) +} + +/// Constant-time absolute value for signed integers +#[inline] +pub fn ct_abs_i64(x: i64) -> i64 { + let mask = x >> 63; + (x ^ mask) - mask +} + +/// Constant-time sign extraction +#[inline] +pub fn ct_sign_i64(x: i64) -> i64 { + let positive = !(x >> 63) & 1; + let negative = (x >> 63) & 1; + positive - negative +} + +/// Convert a bool to Choice in constant time +#[inline] +pub fn bool_to_choice(b: bool) -> Choice { + Choice::from(b as u8) +} + +/// Convert a Choice to bool +#[inline] +pub fn choice_to_bool(c: Choice) -> bool { + c.into() +} + +/// Constant-time byte array equality +pub fn ct_bytes_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut result = 0u8; + for (x, y) in a.iter().zip(b.iter()) { + result |= x ^ y; + } + result == 0 +} + +/// Constant-time byte array copy based on condition +pub fn ct_copy_if(cond: bool, dst: &mut [u8], src: &[u8]) { + assert_eq!(dst.len(), src.len()); + let mask = if cond { 0xFF } else { 0x00 }; + for (d, s) in dst.iter_mut().zip(src.iter()) { + *d = (*d & !mask) | (*s & mask); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cmov() { + assert_eq!(cmov(true, 42u64, 0u64), 42u64); + assert_eq!(cmov(false, 42u64, 0u64), 0u64); + } + + #[test] + fn test_cswap() { + let mut a = 1u64; + let mut b = 2u64; + cswap(true, &mut a, &mut b); + assert_eq!(a, 2); + assert_eq!(b, 1); + + cswap(false, &mut a, &mut b); + assert_eq!(a, 2); + assert_eq!(b, 1); + } + + #[test] + fn test_ct_lookup() { + let array = [10u64, 20, 30, 40, 50]; + assert_eq!(ct_lookup(&array, 0), 10); + assert_eq!(ct_lookup(&array, 2), 30); + assert_eq!(ct_lookup(&array, 4), 50); + } + + #[test] + fn test_ct_store() { + let mut array = [10u64, 20, 30, 40, 50]; + ct_store(&mut array, 2, &99); + assert_eq!(array, [10, 20, 99, 40, 50]); + } + + #[test] + fn test_ct_bytes_eq() { + assert!(ct_bytes_eq(b"hello", b"hello")); + assert!(!ct_bytes_eq(b"hello", b"world")); + assert!(!ct_bytes_eq(b"hello", b"hell")); + } +} diff --git a/obli-transpiler-framework/runtime/src/crypto.rs b/obli-transpiler-framework/runtime/src/crypto.rs new file mode 100644 index 0000000..cbbf4d9 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/crypto.rs @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Cryptographic utilities for ORAM +//! +//! Provides encryption, hashing, and key derivation used by ORAM implementations. + +use aes_gcm::{ + aead::{Aead, KeyInit, OsRng}, + Aes256Gcm, Nonce, +}; +use rand::RngCore; +use sha2::{Digest, Sha256}; +use zeroize::{Zeroize, ZeroizeOnDrop}; + +/// Encryption key size (256 bits) +pub const KEY_SIZE: usize = 32; + +/// Nonce size for AES-GCM +pub const NONCE_SIZE: usize = 12; + +/// Authentication tag size +pub const TAG_SIZE: usize = 16; + +/// A secret key that zeroizes on drop +#[derive(Clone, Zeroize, ZeroizeOnDrop)] +pub struct SecretKey([u8; KEY_SIZE]); + +impl SecretKey { + /// Generate a new random key + pub fn generate() -> Self { + let mut key = [0u8; KEY_SIZE]; + OsRng.fill_bytes(&mut key); + SecretKey(key) + } + + /// Create from bytes (takes ownership) + pub fn from_bytes(bytes: [u8; KEY_SIZE]) -> Self { + SecretKey(bytes) + } + + /// Get key bytes (be careful with this!) + pub fn as_bytes(&self) -> &[u8; KEY_SIZE] { + &self.0 + } +} + +/// Encrypt a block of data using AES-256-GCM +/// +/// Returns ciphertext with nonce prepended. +pub fn encrypt(key: &SecretKey, plaintext: &[u8]) -> Vec { + let cipher = Aes256Gcm::new(key.0.as_ref().into()); + + let mut nonce_bytes = [0u8; NONCE_SIZE]; + OsRng.fill_bytes(&mut nonce_bytes); + let nonce = Nonce::from_slice(&nonce_bytes); + + let ciphertext = cipher + .encrypt(nonce, plaintext) + .expect("encryption should not fail"); + + let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len()); + result.extend_from_slice(&nonce_bytes); + result.extend_from_slice(&ciphertext); + result +} + +/// Decrypt a block of data using AES-256-GCM +/// +/// Expects nonce prepended to ciphertext. +pub fn decrypt(key: &SecretKey, ciphertext: &[u8]) -> Result, CryptoError> { + if ciphertext.len() < NONCE_SIZE + TAG_SIZE { + return Err(CryptoError::InvalidCiphertext); + } + + let cipher = Aes256Gcm::new(key.0.as_ref().into()); + let nonce = Nonce::from_slice(&ciphertext[..NONCE_SIZE]); + let ct = &ciphertext[NONCE_SIZE..]; + + cipher + .decrypt(nonce, ct) + .map_err(|_| CryptoError::DecryptionFailed) +} + +/// Compute SHA-256 hash +pub fn sha256(data: &[u8]) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(data); + hasher.finalize().into() +} + +/// Compute BLAKE3 hash +pub fn blake3(data: &[u8]) -> [u8; 32] { + blake3::hash(data).into() +} + +/// Derive a key from a master key and path +pub fn derive_key(master: &SecretKey, path: &[u8]) -> SecretKey { + let mut hasher = Sha256::new(); + hasher.update(master.as_bytes()); + hasher.update(path); + SecretKey::from_bytes(hasher.finalize().into()) +} + +/// Pseudo-random function (PRF) for position map +pub fn prf(key: &SecretKey, input: u64) -> u64 { + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + hasher.update(input.to_le_bytes()); + let hash: [u8; 32] = hasher.finalize().into(); + u64::from_le_bytes(hash[..8].try_into().unwrap()) +} + +/// Cryptographic errors +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CryptoError { + InvalidCiphertext, + DecryptionFailed, + InvalidKeyLength, +} + +impl std::fmt::Display for CryptoError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CryptoError::InvalidCiphertext => write!(f, "invalid ciphertext"), + CryptoError::DecryptionFailed => write!(f, "decryption failed"), + CryptoError::InvalidKeyLength => write!(f, "invalid key length"), + } + } +} + +impl std::error::Error for CryptoError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encrypt_decrypt() { + let key = SecretKey::generate(); + let plaintext = b"hello, ORAM world!"; + let ciphertext = encrypt(&key, plaintext); + let decrypted = decrypt(&key, &ciphertext).unwrap(); + assert_eq!(decrypted, plaintext); + } + + #[test] + fn test_decrypt_wrong_key() { + let key1 = SecretKey::generate(); + let key2 = SecretKey::generate(); + let ciphertext = encrypt(&key1, b"secret data"); + assert!(decrypt(&key2, &ciphertext).is_err()); + } + + #[test] + fn test_prf_deterministic() { + let key = SecretKey::generate(); + assert_eq!(prf(&key, 42), prf(&key, 42)); + assert_ne!(prf(&key, 42), prf(&key, 43)); + } +} diff --git a/obli-transpiler-framework/runtime/src/lib.rs b/obli-transpiler-framework/runtime/src/lib.rs new file mode 100644 index 0000000..dae3338 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/lib.rs @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Oblibeny Runtime Library +//! +//! This crate provides the runtime support for Oblibeny compiled programs, +//! including: +//! +//! - **Constant-time primitives**: `cmov`, `cswap`, and other operations +//! that don't leak information through timing +//! - **ORAM implementations**: Path ORAM for oblivious memory access +//! - **Oblivious collections**: Maps, vectors, stacks with hidden access patterns +//! - **Cryptographic utilities**: Encryption, hashing, key derivation + +#![cfg_attr(not(feature = "std"), no_std)] + +#[cfg(not(feature = "std"))] +extern crate alloc; + +pub mod constant_time; +pub mod oram; +pub mod crypto; +pub mod collections; + +/// Prelude module for common imports +pub mod prelude { + pub use crate::constant_time::*; + pub use crate::oram::{OArray, PathOram, OramAccess}; + pub use crate::collections::*; +} diff --git a/obli-transpiler-framework/runtime/src/oram.rs b/obli-transpiler-framework/runtime/src/oram.rs new file mode 100644 index 0000000..70b323a --- /dev/null +++ b/obli-transpiler-framework/runtime/src/oram.rs @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! ORAM (Oblivious RAM) implementations +//! +//! This module provides oblivious memory access through Path ORAM, +//! hiding access patterns from observers. + +mod path; +mod position; +mod stash; +mod bucket; + +pub use path::PathOram; +pub use position::PositionMap; +pub use stash::Stash; +pub use bucket::Bucket; + +use crate::crypto::SecretKey; + +/// Trait for types that can be stored in ORAM +pub trait OramBlock: Clone + Default + Sized { + /// Size of the block in bytes + const SIZE: usize; + + /// Serialize to bytes + fn to_bytes(&self) -> Vec; + + /// Deserialize from bytes + fn from_bytes(bytes: &[u8]) -> Self; +} + +/// Implement OramBlock for primitive types +macro_rules! impl_oram_block_primitive { + ($($t:ty),*) => { + $( + impl OramBlock for $t { + const SIZE: usize = std::mem::size_of::<$t>(); + + fn to_bytes(&self) -> Vec { + self.to_le_bytes().to_vec() + } + + fn from_bytes(bytes: &[u8]) -> Self { + let arr: [u8; std::mem::size_of::<$t>()] = + bytes.try_into().unwrap_or([0; std::mem::size_of::<$t>()]); + <$t>::from_le_bytes(arr) + } + } + )* + }; +} + +impl_oram_block_primitive!(u8, u16, u32, u64, u128, i8, i16, i32, i64, i128); + +/// ORAM access trait +pub trait OramAccess { + /// Read a value at the given logical address + fn oram_read(&mut self, addr: u64) -> T; + + /// Write a value at the given logical address + fn oram_write(&mut self, addr: u64, value: T); + + /// Get the capacity (number of blocks) + fn capacity(&self) -> u64; +} + +/// Oblivious array type - the main interface for ORAM access +pub struct OArray { + oram: PathOram, +} + +impl OArray { + /// Create a new oblivious array with the given capacity + pub fn new(capacity: u64) -> Self { + OArray { + oram: PathOram::new(capacity, SecretKey::generate()), + } + } + + /// Create with a specific key (for testing/deterministic behavior) + pub fn with_key(capacity: u64, key: SecretKey) -> Self { + OArray { + oram: PathOram::new(capacity, key), + } + } + + /// Read a value at the given index + #[inline] + pub fn read(&mut self, index: u64) -> T { + self.oram.oram_read(index) + } + + /// Write a value at the given index + #[inline] + pub fn write(&mut self, index: u64, value: T) { + self.oram.oram_write(index, value); + } + + /// Get the capacity + pub fn len(&self) -> u64 { + self.oram.capacity() + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +// Convenience methods matching the OIR codegen expectations +impl OArray { + /// ORAM read (matches codegen output) + #[inline] + pub fn oram_read(&mut self, index: u64) -> T { + self.read(index) + } + + /// ORAM write (matches codegen output) + #[inline] + pub fn oram_write(&mut self, index: u64, value: T) { + self.write(index, value); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_oarray_basic() { + let mut arr: OArray = OArray::new(100); + arr.write(42, 12345); + assert_eq!(arr.read(42), 12345); + } + + #[test] + fn test_oarray_multiple_writes() { + let mut arr: OArray = OArray::new(100); + for i in 0..10 { + arr.write(i, i * 100); + } + for i in 0..10 { + assert_eq!(arr.read(i), i * 100); + } + } +} diff --git a/obli-transpiler-framework/runtime/src/oram/bucket.rs b/obli-transpiler-framework/runtime/src/oram/bucket.rs new file mode 100644 index 0000000..dcd6794 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/oram/bucket.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! ORAM bucket implementation +//! +//! A bucket is a fixed-size container of blocks in the ORAM tree. + +use super::OramBlock; +use crate::constant_time::{ct_lookup, ct_store}; +use subtle::{Choice, ConditionallySelectable}; + +/// Number of blocks per bucket (Z parameter in Path ORAM) +pub const BUCKET_SIZE: usize = 4; + +/// A single entry in a bucket +#[derive(Clone)] +pub struct BucketEntry { + /// The logical address (u64::MAX means empty/dummy) + pub addr: u64, + /// The data block + pub data: T, +} + +impl Default for BucketEntry { + fn default() -> Self { + BucketEntry { + addr: u64::MAX, // Empty marker + data: T::default(), + } + } +} + +impl ConditionallySelectable for BucketEntry +where + T: ConditionallySelectable, +{ + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + BucketEntry { + addr: u64::conditional_select(&a.addr, &b.addr, choice), + data: T::conditional_select(&a.data, &b.data, choice), + } + } +} + +/// A bucket containing multiple entries +#[derive(Clone)] +pub struct Bucket { + entries: [BucketEntry; BUCKET_SIZE], +} + +impl Default for Bucket { + fn default() -> Self { + Bucket { + entries: std::array::from_fn(|_| BucketEntry::default()), + } + } +} + +impl Bucket { + /// Create a new empty bucket + pub fn new() -> Self { + Self::default() + } + + /// Check if the bucket is full + pub fn is_full(&self) -> bool { + self.entries.iter().all(|e| e.addr != u64::MAX) + } + + /// Count non-empty entries + pub fn count(&self) -> usize { + self.entries.iter().filter(|e| e.addr != u64::MAX).count() + } + + /// Try to add an entry (returns false if full) + pub fn try_add(&mut self, addr: u64, data: T) -> bool { + for entry in &mut self.entries { + if entry.addr == u64::MAX { + entry.addr = addr; + entry.data = data; + return true; + } + } + false + } + + /// Read and remove entry with given address (constant-time) + /// + /// Returns the data if found, None otherwise. + /// The entry is marked as empty. + pub fn read_and_remove(&mut self, addr: u64) -> Option + where + T: ConditionallySelectable + Clone, + { + let mut found = false; + let mut result = T::default(); + + for entry in &mut self.entries { + let matches = entry.addr == addr; + if matches { + found = true; + result = entry.data.clone(); + entry.addr = u64::MAX; + entry.data = T::default(); + } + } + + if found { + Some(result) + } else { + None + } + } + + /// Read entry with given address without removing (constant-time) + pub fn read(&self, addr: u64) -> Option + where + T: ConditionallySelectable + Clone, + { + for entry in &self.entries { + if entry.addr == addr { + return Some(entry.data.clone()); + } + } + None + } + + /// Get entries as slice + pub fn entries(&self) -> &[BucketEntry; BUCKET_SIZE] { + &self.entries + } + + /// Get mutable entries + pub fn entries_mut(&mut self) -> &mut [BucketEntry; BUCKET_SIZE] { + &mut self.entries + } + + /// Drain all real (non-dummy) entries + pub fn drain_real(&mut self) -> Vec<(u64, T)> { + let mut result = Vec::new(); + for entry in &mut self.entries { + if entry.addr != u64::MAX { + result.push((entry.addr, entry.data.clone())); + entry.addr = u64::MAX; + entry.data = T::default(); + } + } + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bucket_add_read() { + let mut bucket: Bucket = Bucket::new(); + assert!(bucket.try_add(10, 100)); + assert!(bucket.try_add(20, 200)); + assert_eq!(bucket.read(10), Some(100)); + assert_eq!(bucket.read(20), Some(200)); + assert_eq!(bucket.read(30), None); + } + + #[test] + fn test_bucket_full() { + let mut bucket: Bucket = Bucket::new(); + for i in 0..BUCKET_SIZE { + assert!(bucket.try_add(i as u64, i as u64 * 10)); + } + assert!(!bucket.try_add(100, 1000)); + } + + #[test] + fn test_bucket_read_and_remove() { + let mut bucket: Bucket = Bucket::new(); + bucket.try_add(10, 100); + assert_eq!(bucket.read_and_remove(10), Some(100)); + assert_eq!(bucket.read_and_remove(10), None); + } +} diff --git a/obli-transpiler-framework/runtime/src/oram/path.rs b/obli-transpiler-framework/runtime/src/oram/path.rs new file mode 100644 index 0000000..22e45c8 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/oram/path.rs @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Path ORAM implementation +//! +//! Path ORAM provides O(log N) bandwidth overhead per access with +//! O(log N) client storage. This implementation follows the original +//! Path ORAM paper by Stefanov et al. + +use super::bucket::{Bucket, BUCKET_SIZE}; +use super::position::{PositionMap, SimplePositionMap}; +use super::stash::{path_overlap_level, Stash, StashEntry}; +use super::{OramAccess, OramBlock}; +use crate::crypto::SecretKey; +use subtle::ConditionallySelectable; + +/// Path ORAM implementation +pub struct PathOram { + /// The binary tree of buckets (stored as array) + tree: Vec>, + /// Position map: addr -> leaf + position_map: SimplePositionMap, + /// Stash for overflow blocks + stash: Stash, + /// Tree depth (log2 of capacity) + depth: usize, + /// Number of leaves + num_leaves: u64, + /// Logical capacity + capacity: u64, +} + +impl PathOram { + /// Create a new Path ORAM with given capacity + pub fn new(capacity: u64, key: SecretKey) -> Self { + // Calculate tree depth (ceil(log2(capacity))) + let depth = (64 - capacity.leading_zeros()) as usize; + let num_leaves = 1u64 << depth; + + // Total nodes in complete binary tree: 2^(depth+1) - 1 + let num_nodes = (1usize << (depth + 1)) - 1; + + // Initialize empty tree + let tree: Vec> = (0..num_nodes).map(|_| Bucket::new()).collect(); + + // Initialize position map + let position_map = SimplePositionMap::new(capacity, num_leaves, &key); + + PathOram { + tree, + position_map, + stash: Stash::new(), + depth, + num_leaves, + capacity, + } + } + + /// Access (read or write) a block + fn access(&mut self, addr: u64, op: AccessOp) -> T + where + T: Clone, + { + // 1. Look up position and remap + let (old_leaf, new_leaf) = self.position_map.get_and_remap(addr); + + // 2. Read path from root to old leaf into stash + self.read_path(old_leaf); + + // 3. Find block in stash and update + let result = if let Some((_, data)) = self.stash.remove(addr) { + data + } else { + T::default() + }; + + // 4. Prepare new data based on operation + let new_data = match op { + AccessOp::Read => result.clone(), + AccessOp::Write(data) => data, + }; + + // 5. Add block back to stash with new leaf + self.stash.add(addr, new_leaf, new_data); + + // 6. Evict: write path back + self.write_path(old_leaf); + + result + } + + /// Read a path from root to leaf into the stash + fn read_path(&mut self, leaf: u64) + where + T: Clone, + { + for level in 0..=self.depth { + let node_idx = self.path_node(leaf, level); + let bucket = &mut self.tree[node_idx]; + + // Move all real blocks from bucket to stash + for entry in bucket.entries_mut() { + if entry.addr != u64::MAX { + // Get the leaf for this block from position map + let block_leaf = self.position_map.get(entry.addr); + self.stash.add(entry.addr, block_leaf, entry.data.clone()); + entry.addr = u64::MAX; + entry.data = T::default(); + } + } + } + } + + /// Write path back from stash + fn write_path(&mut self, leaf: u64) + where + T: Clone, + { + // For each level from leaf to root + for level in (0..=self.depth).rev() { + let node_idx = self.path_node(leaf, level); + let bucket = &mut self.tree[node_idx]; + + // Find blocks in stash that can be placed at this level + let mut placed = 0; + let mut to_remove = Vec::new(); + + for (i, entry) in self.stash.entries().iter().enumerate() { + if placed >= BUCKET_SIZE { + break; + } + + // Check if this block's path passes through this node + let overlap = path_overlap_level(entry.leaf, leaf, self.depth + 1); + if overlap >= level { + to_remove.push(i); + placed += 1; + } + } + + // Remove from stash and add to bucket + let removed: Vec> = self.stash.remove_indices(to_remove); + for entry in removed { + bucket.try_add(entry.addr, entry.data); + } + } + } + + /// Get the node index for a given level on the path to leaf + fn path_node(&self, leaf: u64, level: usize) -> usize { + // Level 0 is root, level depth is leaf + // Node index in level-order traversal + let leaf_offset = leaf as usize; + let level_start = (1 << level) - 1; + let node_in_level = leaf_offset >> (self.depth - level); + level_start + node_in_level + } +} + +/// Access operation type +enum AccessOp { + Read, + Write(T), +} + +impl OramAccess for PathOram { + fn oram_read(&mut self, addr: u64) -> T { + self.access(addr, AccessOp::Read) + } + + fn oram_write(&mut self, addr: u64, value: T) { + self.access(addr, AccessOp::Write(value)); + } + + fn capacity(&self) -> u64 { + self.capacity + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_path_oram_basic() { + let key = SecretKey::generate(); + let mut oram: PathOram = PathOram::new(100, key); + + // Write and read back + oram.oram_write(42, 12345); + assert_eq!(oram.oram_read(42), 12345); + } + + #[test] + fn test_path_oram_multiple() { + let key = SecretKey::generate(); + let mut oram: PathOram = PathOram::new(100, key); + + for i in 0..20 { + oram.oram_write(i, i * 100); + } + + for i in 0..20 { + assert_eq!(oram.oram_read(i), i * 100); + } + } + + #[test] + fn test_path_oram_overwrite() { + let key = SecretKey::generate(); + let mut oram: PathOram = PathOram::new(100, key); + + oram.oram_write(10, 100); + oram.oram_write(10, 200); + assert_eq!(oram.oram_read(10), 200); + } + + #[test] + fn test_path_node_calculation() { + let key = SecretKey::generate(); + let oram: PathOram = PathOram::new(8, key); + + // For depth 3 (8 leaves), tree has 15 nodes + // Root is node 0 + assert_eq!(oram.path_node(0, 0), 0); // Root for any leaf + assert_eq!(oram.path_node(7, 0), 0); // Root for any leaf + + // Level 1 has nodes 1, 2 + assert_eq!(oram.path_node(0, 1), 1); // Left child + assert_eq!(oram.path_node(4, 1), 2); // Right child + } +} diff --git a/obli-transpiler-framework/runtime/src/oram/position.rs b/obli-transpiler-framework/runtime/src/oram/position.rs new file mode 100644 index 0000000..6128ba8 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/oram/position.rs @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! Position map for ORAM +//! +//! Maps logical addresses to random leaf positions in the ORAM tree. +//! For small ORAMs, uses a simple array. For large ORAMs, this would +//! itself be stored in a recursive ORAM. + +use crate::crypto::{prf, SecretKey}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha20Rng; + +/// Position map interface +pub trait PositionMap { + /// Get the current position for an address + fn get(&self, addr: u64) -> u64; + + /// Update position and return the old position + fn update(&mut self, addr: u64, new_pos: u64) -> u64; + + /// Get and update to a new random position + fn get_and_remap(&mut self, addr: u64) -> (u64, u64); + + /// Number of leaves in the tree + fn num_leaves(&self) -> u64; +} + +/// Simple in-memory position map (for small ORAMs) +pub struct SimplePositionMap { + positions: Vec, + num_leaves: u64, + rng: ChaCha20Rng, +} + +impl SimplePositionMap { + /// Create a new position map + pub fn new(capacity: u64, num_leaves: u64, key: &SecretKey) -> Self { + // Derive RNG seed from key + let seed = crate::crypto::sha256(key.as_bytes()); + + let mut rng = ChaCha20Rng::from_seed(seed); + + // Initialize all positions randomly + let positions: Vec = (0..capacity) + .map(|_| rng.gen_range(0..num_leaves)) + .collect(); + + SimplePositionMap { + positions, + num_leaves, + rng, + } + } + + /// Get a new random leaf position + fn random_leaf(&mut self) -> u64 { + self.rng.gen_range(0..self.num_leaves) + } +} + +impl PositionMap for SimplePositionMap { + fn get(&self, addr: u64) -> u64 { + self.positions.get(addr as usize).copied().unwrap_or(0) + } + + fn update(&mut self, addr: u64, new_pos: u64) -> u64 { + let idx = addr as usize; + if idx < self.positions.len() { + let old = self.positions[idx]; + self.positions[idx] = new_pos; + old + } else { + 0 + } + } + + fn get_and_remap(&mut self, addr: u64) -> (u64, u64) { + let old_pos = self.get(addr); + let new_pos = self.random_leaf(); + self.update(addr, new_pos); + (old_pos, new_pos) + } + + fn num_leaves(&self) -> u64 { + self.num_leaves + } +} + +/// PRF-based position map (for use with recursive ORAM) +/// +/// Uses a PRF to deterministically compute positions, avoiding +/// the need to store positions explicitly (at the cost of no +/// position updates - used for read-only scenarios or as base case). +pub struct PrfPositionMap { + key: SecretKey, + num_leaves: u64, +} + +impl PrfPositionMap { + pub fn new(key: SecretKey, num_leaves: u64) -> Self { + PrfPositionMap { key, num_leaves } + } +} + +impl PositionMap for PrfPositionMap { + fn get(&self, addr: u64) -> u64 { + prf(&self.key, addr) % self.num_leaves + } + + fn update(&mut self, _addr: u64, _new_pos: u64) -> u64 { + // PRF-based map doesn't support updates + panic!("PrfPositionMap does not support updates") + } + + fn get_and_remap(&mut self, _addr: u64) -> (u64, u64) { + panic!("PrfPositionMap does not support remapping") + } + + fn num_leaves(&self) -> u64 { + self.num_leaves + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple_position_map() { + let key = SecretKey::generate(); + let mut pm = SimplePositionMap::new(100, 16, &key); + + // Get initial position + let pos1 = pm.get(42); + assert!(pos1 < 16); + + // Update and verify + let old = pm.update(42, 7); + assert_eq!(old, pos1); + assert_eq!(pm.get(42), 7); + + // Remap + let (old, new) = pm.get_and_remap(42); + assert_eq!(old, 7); + assert!(new < 16); + } + + #[test] + fn test_prf_position_map() { + let key = SecretKey::generate(); + let pm = PrfPositionMap::new(key.clone(), 16); + + // PRF should be deterministic + let pos1 = pm.get(42); + let pos2 = pm.get(42); + assert_eq!(pos1, pos2); + assert!(pos1 < 16); + + // Different addresses should (usually) have different positions + let pos3 = pm.get(43); + // Not guaranteed but very likely + assert!(pos1 < 16 && pos3 < 16); + } +} diff --git a/obli-transpiler-framework/runtime/src/oram/stash.rs b/obli-transpiler-framework/runtime/src/oram/stash.rs new file mode 100644 index 0000000..516c738 --- /dev/null +++ b/obli-transpiler-framework/runtime/src/oram/stash.rs @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT OR Palimpsest-0.8 +// Copyright (c) 2024 Hyperpolymath + +//! ORAM stash implementation +//! +//! The stash is a temporary storage for blocks that cannot fit +//! in the ORAM tree during eviction. + +use super::OramBlock; +use subtle::ConditionallySelectable; + +/// Maximum stash size (should be O(log N) for security) +pub const MAX_STASH_SIZE: usize = 128; + +/// Entry in the stash +#[derive(Clone)] +pub struct StashEntry { + pub addr: u64, + pub leaf: u64, // Target leaf in the tree + pub data: T, +} + +/// The stash for temporary block storage +pub struct Stash { + entries: Vec>, +} + +impl Default for Stash { + fn default() -> Self { + Stash::new() + } +} + +impl Stash { + /// Create a new empty stash + pub fn new() -> Self { + Stash { + entries: Vec::with_capacity(MAX_STASH_SIZE), + } + } + + /// Add a block to the stash + pub fn add(&mut self, addr: u64, leaf: u64, data: T) { + self.entries.push(StashEntry { addr, leaf, data }); + if self.entries.len() > MAX_STASH_SIZE { + // In production, this would be a security failure + // For now, just warn (the stash overflow bound proof ensures this is negligible) + log::warn!("Stash overflow: {} entries", self.entries.len()); + } + } + + /// Find and remove a block by address + pub fn remove(&mut self, addr: u64) -> Option<(u64, T)> + where + T: Clone, + { + if let Some(idx) = self.entries.iter().position(|e| e.addr == addr) { + let entry = self.entries.remove(idx); + Some((entry.leaf, entry.data)) + } else { + None + } + } + + /// Check if address is in stash + pub fn contains(&self, addr: u64) -> bool { + self.entries.iter().any(|e| e.addr == addr) + } + + /// Get block by address (without removing) + pub fn get(&self, addr: u64) -> Option<&T> { + self.entries + .iter() + .find(|e| e.addr == addr) + .map(|e| &e.data) + } + + /// Update a block in the stash + pub fn update(&mut self, addr: u64, data: T) -> bool + where + T: Clone, + { + if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) { + entry.data = data; + true + } else { + false + } + } + + /// Update the target leaf for an address + pub fn update_leaf(&mut self, addr: u64, new_leaf: u64) -> bool { + if let Some(entry) = self.entries.iter_mut().find(|e| e.addr == addr) { + entry.leaf = new_leaf; + true + } else { + false + } + } + + /// Get all entries that can be placed on the path to a given leaf + pub fn entries_for_path(&self, leaf: u64, depth: usize) -> Vec { + let mut result = Vec::new(); + for (i, entry) in self.entries.iter().enumerate() { + // Check if this entry's leaf shares a prefix with the target leaf + // at some level up to depth + if path_overlaps(entry.leaf, leaf, depth) { + result.push(i); + } + } + result + } + + /// Remove entries at given indices (indices must be sorted descending) + pub fn remove_indices(&mut self, mut indices: Vec) -> Vec> { + indices.sort_by(|a, b| b.cmp(a)); // Sort descending + indices.iter().map(|&i| self.entries.remove(i)).collect() + } + + /// Current stash size + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Check if stash is empty + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + /// Get all entries (for debugging/testing) + pub fn entries(&self) -> &[StashEntry] { + &self.entries + } +} + +/// Check if two leaves overlap at any level up to depth +fn path_overlaps(leaf1: u64, leaf2: u64, depth: usize) -> bool { + for level in 0..depth { + let shift = depth - level - 1; + if (leaf1 >> shift) == (leaf2 >> shift) { + return true; + } + } + false +} + +/// Calculate the deepest level where two paths overlap +pub fn path_overlap_level(leaf1: u64, leaf2: u64, depth: usize) -> usize { + for level in (0..depth).rev() { + let shift = depth - level - 1; + if (leaf1 >> shift) == (leaf2 >> shift) { + return level; + } + } + 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_stash_basic() { + let mut stash: Stash = Stash::new(); + stash.add(10, 5, 100); + stash.add(20, 3, 200); + + assert!(stash.contains(10)); + assert!(stash.contains(20)); + assert!(!stash.contains(30)); + + assert_eq!(stash.get(10), Some(&100)); + assert_eq!(stash.remove(10), Some((5, 100))); + assert!(!stash.contains(10)); + } + + #[test] + fn test_path_overlaps() { + // With depth 4, leaves 0-15 + // Leaf 5 (0101) and leaf 7 (0111) share prefix at level 1 (both start with 0) + assert!(path_overlaps(5, 7, 4)); + + // Leaf 0 (0000) and leaf 8 (1000) only share root + assert!(path_overlaps(0, 8, 4)); + } + + #[test] + fn test_update() { + let mut stash: Stash = Stash::new(); + stash.add(10, 5, 100); + assert!(stash.update(10, 999)); + assert_eq!(stash.get(10), Some(&999)); + } +}