diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0d70992..2dffecd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,3 +35,7 @@ jobs: TEST=$(extism call test/host_function.wasm count_vowels --link extism:host/user=test/host_function_host.wasm --input "this is a test" --set-config='{"thing": "1", "a": "b"}') echo $TEST | grep '"count":40' + + # Test unsafe extern with safe fn (Rust 1.82+ syntax) + TEST=$(extism call test/host_function_safe.wasm count_vowels --link extism:host/user=test/host_function_host.wasm --input "this is a test" --set-config='{"thing": "1", "a": "b"}') + echo $TEST | grep '"count":40' diff --git a/Makefile b/Makefile index 23e2d2f..cba0e33 100644 --- a/Makefile +++ b/Makefile @@ -5,8 +5,10 @@ plugins: cargo build --release --example http_headers cargo build --release --example host_function cargo build --release --example host_function_host + cargo build --release --example host_function_safe cp target/wasm32-unknown-unknown/release/examples/count_vowels.wasm test/code.wasm cp target/wasm32-unknown-unknown/release/examples/http.wasm test/http.wasm cp target/wasm32-unknown-unknown/release/examples/http_headers.wasm test/http_headers.wasm cp target/wasm32-unknown-unknown/release/examples/host_function.wasm test/host_function.wasm cp target/wasm32-unknown-unknown/release/examples/host_function_host.wasm test/host_function_host.wasm + cp target/wasm32-unknown-unknown/release/examples/host_function_safe.wasm test/host_function_safe.wasm diff --git a/derive/src/lib.rs b/derive/src/lib.rs index e4548cd..f8014df 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -252,6 +252,23 @@ pub fn shared_fn( } /// `host_fn` is used to import a host function from an `extern` block +/// +/// ## Rust 1.82+ / Edition 2024 +/// +/// Starting with Rust 1.82, extern blocks can be marked `unsafe` and functions +/// within can be marked `safe`. When using `unsafe extern "ExtismHost"` blocks, +/// you can use `safe fn` to indicate that a host function is safe to call: +/// +/// ```rust,ignore +/// #[host_fn] +/// unsafe extern "ExtismHost" { +/// // Safe to call - generates a safe wrapper function +/// safe fn get_config(key: String) -> String; +/// +/// // Unsafe to call (implicit) - generates unsafe wrapper +/// fn dangerous_operation(data: Vec) -> Vec; +/// } +/// ``` #[proc_macro_attribute] pub fn host_fn( attr: proc_macro::TokenStream, @@ -264,91 +281,212 @@ pub fn host_fn( }; let item = parse_macro_input!(item as ItemForeignMod); - if item.abi.name.is_none() || item.abi.name.unwrap().value() != "ExtismHost" { - panic!("Expected `extern \"ExtismHost\"` block"); + if item.abi.name.is_none() || item.abi.name.as_ref().unwrap().value() != "ExtismHost" { + panic!("Expected `extern \"ExtismHost\"` or `unsafe extern \"ExtismHost\"` block"); } + + // Track if this is an `unsafe extern` block (Rust 1.82+) + let is_unsafe_extern = item.unsafety.is_some(); let functions = item.items; let mut gen = quote!(); for function in functions { - if let syn::ForeignItem::Fn(function) = function { - let name = &function.sig.ident; - let original_inputs = function.sig.inputs.clone(); - let output = &function.sig.output; - - let vis = &function.vis; - let generics = &function.sig.generics; - let mut into_inputs = vec![]; - let mut converted_inputs = vec![]; - - let (output_is_ptr, converted_output) = match output { - syn::ReturnType::Default => (false, quote!(())), - syn::ReturnType::Type(_, _) => (true, quote!(u64)), + // Handle regular ForeignItem::Fn (normal fn or unsafe fn) + if let syn::ForeignItem::Fn(ref function) = function { + // In non-unsafe extern blocks, all functions are unsafe + // In unsafe extern blocks, unmarked fn is implicitly unsafe + let wrapper = generate_host_fn_wrapper(&namespace, function, false); + gen = quote! { + #gen + #wrapper }; + continue; + } - for input in &original_inputs { - match input { - syn::FnArg::Typed(t) => { - let mut input = t.clone(); - input.ty = Box::new(syn::Type::Verbatim(quote!(u64))); - converted_inputs.push(syn::FnArg::Typed(input)); - match &*t.pat { - syn::Pat::Ident(i) => { - into_inputs - .push(quote!( - extism_pdk::ManagedMemory::from(extism_pdk::ToMemory::to_memory(&&#i)?).offset() - )); - } - _ => panic!("invalid host function argument"), - } - } - _ => panic!("self arguments are not permitted in host functions"), + // Handle ForeignItem::Verbatim which syn uses for `safe fn` in unsafe extern blocks + if let syn::ForeignItem::Verbatim(ref tokens) = function { + if is_unsafe_extern { + if let Some(wrapper) = parse_safe_fn_verbatim(&namespace, tokens) { + gen = quote! { + #gen + #wrapper + }; + continue; } } + // If we can't parse it, ignore or panic + panic!("Unsupported item in extern block"); + } + } + + gen.into() +} + +/// Generates a wrapper function for a host function +fn generate_host_fn_wrapper( + namespace: &str, + function: &syn::ForeignItemFn, + is_safe_fn: bool, +) -> proc_macro2::TokenStream { + let name = &function.sig.ident; + let original_inputs = function.sig.inputs.clone(); + let output = &function.sig.output; - let impl_name = syn::Ident::new(&format!("{name}_impl"), name.span()); - let link_name = name.to_string(); - let link_name = link_name.as_str(); + let vis = &function.vis; + let generics = &function.sig.generics; + let mut into_inputs = vec![]; + let mut converted_inputs = vec![]; + + let (output_is_ptr, converted_output) = match output { + syn::ReturnType::Default => (false, quote!(())), + syn::ReturnType::Type(_, _) => (true, quote!(u64)), + }; - let impl_block = quote! { - #[link(wasm_import_module = #namespace)] - extern "C" { - #[link_name = #link_name] - fn #impl_name(#(#converted_inputs),*) -> #converted_output; + for input in &original_inputs { + match input { + syn::FnArg::Typed(t) => { + let mut input = t.clone(); + input.ty = Box::new(syn::Type::Verbatim(quote!(u64))); + converted_inputs.push(syn::FnArg::Typed(input)); + match &*t.pat { + syn::Pat::Ident(i) => { + into_inputs + .push(quote!( + extism_pdk::ManagedMemory::from(extism_pdk::ToMemory::to_memory(&&#i)?).offset() + )); + } + _ => panic!("invalid host function argument"), } - }; + } + _ => panic!("self arguments are not permitted in host functions"), + } + } - let output = match output { - syn::ReturnType::Default => quote!(()), - syn::ReturnType::Type(_, ty) => quote!(#ty), - }; + let impl_name = syn::Ident::new(&format!("{name}_impl"), name.span()); + let link_name = name.to_string(); + let link_name = link_name.as_str(); - if output_is_ptr { - gen = quote! { - #gen + let impl_block = quote! { + #[link(wasm_import_module = #namespace)] + extern "C" { + #[link_name = #link_name] + fn #impl_name(#(#converted_inputs),*) -> #converted_output; + } + }; - #impl_block + let output = match output { + syn::ReturnType::Default => quote!(()), + syn::ReturnType::Type(_, ty) => quote!(#ty), + }; - #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> { - let res = extism_pdk::Memory::from(#impl_name(#(#into_inputs),*)); - <#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec()) - } - }; - } else { - gen = quote! { - #gen + // For safe functions, we generate a safe wrapper that uses unsafe internally + // For unsafe functions, we generate an unsafe wrapper + if is_safe_fn { + if output_is_ptr { + quote! { + #impl_block + + #vis fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> { + // SAFETY: The caller of the macro has asserted this host function is safe + // by marking it with `safe fn` in an `unsafe extern` block. + let res = unsafe { extism_pdk::Memory::from(#impl_name(#(#into_inputs),*)) }; + <#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec()) + } + } + } else { + quote! { + #impl_block + + #vis fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> { + // SAFETY: The caller of the macro has asserted this host function is safe + // by marking it with `safe fn` in an `unsafe extern` block. + let res = unsafe { #impl_name(#(#into_inputs),*) }; + core::result::Result::Ok(res) + } + } + } + } else if output_is_ptr { + quote! { + #impl_block - #impl_block + #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> { + let res = extism_pdk::Memory::from(#impl_name(#(#into_inputs),*)); + <#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec()) + } + } + } else { + quote! { + #impl_block - #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> { - let res = #impl_name(#(#into_inputs),*); - core::result::Result::Ok(res) - } - }; + #vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> { + let res = #impl_name(#(#into_inputs),*); + core::result::Result::Ok(res) } } } +} - gen.into() +/// Attempts to parse a `safe fn` from verbatim tokens +/// Returns Some(wrapper) if successful, None if the tokens don't represent a safe fn +fn parse_safe_fn_verbatim(namespace: &str, tokens: &proc_macro2::TokenStream) -> Option { + use syn::parse::{Parse, Parser}; + + // Try to parse: [visibility] safe fn name(args) [-> ReturnType]; + let parser = |input: syn::parse::ParseStream| -> syn::Result { + let attrs = input.call(syn::Attribute::parse_outer)?; + let vis: syn::Visibility = input.parse()?; + + // Check for `safe` keyword + let safe_ident: syn::Ident = input.parse()?; + if safe_ident != "safe" { + return Err(syn::Error::new(safe_ident.span(), "expected `safe`")); + } + + // Parse `fn` + let fn_token: syn::token::Fn = input.parse()?; + + // Parse the rest of the signature + let ident: syn::Ident = input.parse()?; + let generics: syn::Generics = input.parse()?; + + let content; + let paren_token = syn::parenthesized!(content in input); + let inputs = content.parse_terminated(syn::FnArg::parse, syn::Token![,])?; + + let output: syn::ReturnType = input.parse()?; + + let where_clause: Option = input.parse()?; + let mut generics = generics; + generics.where_clause = where_clause; + + let semi_token: syn::Token![;] = input.parse()?; + + Ok(syn::ForeignItemFn { + attrs, + vis, + sig: syn::Signature { + constness: None, + asyncness: None, + unsafety: None, + abi: None, + fn_token, + ident, + generics, + paren_token, + inputs, + variadic: None, + output, + }, + semi_token, + }) + }; + + match parser.parse2(tokens.clone()) { + Ok(function) => { + // It's a safe fn, generate a safe wrapper + Some(generate_host_fn_wrapper(namespace, &function, true)) + } + Err(_) => None, + } } diff --git a/examples/host_function_safe.rs b/examples/host_function_safe.rs new file mode 100644 index 0000000..15f0ea0 --- /dev/null +++ b/examples/host_function_safe.rs @@ -0,0 +1,40 @@ +#![no_main] + +//! This example demonstrates the Rust 1.82+ `unsafe extern` syntax with `safe fn`. +//! +//! In `unsafe extern` blocks, you can mark individual functions as `safe` to indicate +//! they are safe to call without an unsafe block. Functions without the `safe` qualifier +//! are implicitly unsafe. + +use extism_pdk::*; +use serde::{Deserialize, Serialize}; + +const VOWELS: &[char] = &['a', 'A', 'e', 'E', 'i', 'I', 'o', 'O', 'u', 'U']; + +#[derive(Serialize, Deserialize, ToBytes, FromBytes)] +#[encoding(Json)] +struct Output { + pub count: i32, +} + +// Using Rust 1.82+ unsafe extern syntax with safe fn +#[host_fn("extism:host/user")] +unsafe extern "ExtismHost" { + // This function is marked safe - the generated wrapper is safe to call + safe fn hello_world(count: Output) -> Output; +} + +#[plugin_fn] +pub fn count_vowels<'a>(input: String) -> FnResult { + let mut count = 0; + for ch in input.chars() { + if VOWELS.contains(&ch) { + count += 1; + } + } + + let output = Output { count }; + // No unsafe block needed because hello_world is marked as `safe fn` + let output = hello_world(output)?; + Ok(output) +} diff --git a/test/code.wasm b/test/code.wasm index b9396b2..3111b80 100755 Binary files a/test/code.wasm and b/test/code.wasm differ diff --git a/test/host_function.wasm b/test/host_function.wasm index c45ed5e..d2b3c03 100755 Binary files a/test/host_function.wasm and b/test/host_function.wasm differ diff --git a/test/host_function_host.wasm b/test/host_function_host.wasm new file mode 100755 index 0000000..60c00e1 Binary files /dev/null and b/test/host_function_host.wasm differ diff --git a/test/host_function_safe.wasm b/test/host_function_safe.wasm new file mode 100755 index 0000000..fcfff4b Binary files /dev/null and b/test/host_function_safe.wasm differ diff --git a/test/http.wasm b/test/http.wasm index 723263a..6e3f1f1 100755 Binary files a/test/http.wasm and b/test/http.wasm differ diff --git a/test/http_headers.wasm b/test/http_headers.wasm index b3f8539..8ac8e58 100755 Binary files a/test/http_headers.wasm and b/test/http_headers.wasm differ