Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@ jobs:

TEST=$(extism call test/host_function.wasm count_vowels --link extism:host/user=test/host_function_host.wasm --input "this is a test" --set-config='{"thing": "1", "a": "b"}')
echo $TEST | grep '"count":40'

# Test unsafe extern with safe fn (Rust 1.82+ syntax)
TEST=$(extism call test/host_function_safe.wasm count_vowels --link extism:host/user=test/host_function_host.wasm --input "this is a test" --set-config='{"thing": "1", "a": "b"}')
echo $TEST | grep '"count":40'
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ plugins:
cargo build --release --example http_headers
cargo build --release --example host_function
cargo build --release --example host_function_host
cargo build --release --example host_function_safe
cp target/wasm32-unknown-unknown/release/examples/count_vowels.wasm test/code.wasm
cp target/wasm32-unknown-unknown/release/examples/http.wasm test/http.wasm
cp target/wasm32-unknown-unknown/release/examples/http_headers.wasm test/http_headers.wasm
cp target/wasm32-unknown-unknown/release/examples/host_function.wasm test/host_function.wasm
cp target/wasm32-unknown-unknown/release/examples/host_function_host.wasm test/host_function_host.wasm
cp target/wasm32-unknown-unknown/release/examples/host_function_safe.wasm test/host_function_safe.wasm
266 changes: 202 additions & 64 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,23 @@ pub fn shared_fn(
}

/// `host_fn` is used to import a host function from an `extern` block
///
/// ## Rust 1.82+ / Edition 2024
///
/// Starting with Rust 1.82, extern blocks can be marked `unsafe` and functions
/// within can be marked `safe`. When using `unsafe extern "ExtismHost"` blocks,
/// you can use `safe fn` to indicate that a host function is safe to call:
///
/// ```rust,ignore
/// #[host_fn]
/// unsafe extern "ExtismHost" {
/// // Safe to call - generates a safe wrapper function
/// safe fn get_config(key: String) -> String;
///
/// // Unsafe to call (implicit) - generates unsafe wrapper
/// fn dangerous_operation(data: Vec<u8>) -> Vec<u8>;
/// }
/// ```
#[proc_macro_attribute]
pub fn host_fn(
attr: proc_macro::TokenStream,
Expand All @@ -264,91 +281,212 @@ pub fn host_fn(
};

let item = parse_macro_input!(item as ItemForeignMod);
if item.abi.name.is_none() || item.abi.name.unwrap().value() != "ExtismHost" {
panic!("Expected `extern \"ExtismHost\"` block");
if item.abi.name.is_none() || item.abi.name.as_ref().unwrap().value() != "ExtismHost" {
panic!("Expected `extern \"ExtismHost\"` or `unsafe extern \"ExtismHost\"` block");
}

// Track if this is an `unsafe extern` block (Rust 1.82+)
let is_unsafe_extern = item.unsafety.is_some();
let functions = item.items;

let mut gen = quote!();

for function in functions {
if let syn::ForeignItem::Fn(function) = function {
let name = &function.sig.ident;
let original_inputs = function.sig.inputs.clone();
let output = &function.sig.output;

let vis = &function.vis;
let generics = &function.sig.generics;
let mut into_inputs = vec![];
let mut converted_inputs = vec![];

let (output_is_ptr, converted_output) = match output {
syn::ReturnType::Default => (false, quote!(())),
syn::ReturnType::Type(_, _) => (true, quote!(u64)),
// Handle regular ForeignItem::Fn (normal fn or unsafe fn)
if let syn::ForeignItem::Fn(ref function) = function {
// In non-unsafe extern blocks, all functions are unsafe
// In unsafe extern blocks, unmarked fn is implicitly unsafe
let wrapper = generate_host_fn_wrapper(&namespace, function, false);
gen = quote! {
#gen
#wrapper
};
continue;
}

for input in &original_inputs {
match input {
syn::FnArg::Typed(t) => {
let mut input = t.clone();
input.ty = Box::new(syn::Type::Verbatim(quote!(u64)));
converted_inputs.push(syn::FnArg::Typed(input));
match &*t.pat {
syn::Pat::Ident(i) => {
into_inputs
.push(quote!(
extism_pdk::ManagedMemory::from(extism_pdk::ToMemory::to_memory(&&#i)?).offset()
));
}
_ => panic!("invalid host function argument"),
}
}
_ => panic!("self arguments are not permitted in host functions"),
// Handle ForeignItem::Verbatim which syn uses for `safe fn` in unsafe extern blocks
if let syn::ForeignItem::Verbatim(ref tokens) = function {
if is_unsafe_extern {
if let Some(wrapper) = parse_safe_fn_verbatim(&namespace, tokens) {
gen = quote! {
#gen
#wrapper
};
continue;
}
}
// If we can't parse it, ignore or panic
panic!("Unsupported item in extern block");
}
}

gen.into()
}

/// Generates a wrapper function for a host function
fn generate_host_fn_wrapper(
namespace: &str,
function: &syn::ForeignItemFn,
is_safe_fn: bool,
) -> proc_macro2::TokenStream {
let name = &function.sig.ident;
let original_inputs = function.sig.inputs.clone();
let output = &function.sig.output;

let impl_name = syn::Ident::new(&format!("{name}_impl"), name.span());
let link_name = name.to_string();
let link_name = link_name.as_str();
let vis = &function.vis;
let generics = &function.sig.generics;
let mut into_inputs = vec![];
let mut converted_inputs = vec![];

let (output_is_ptr, converted_output) = match output {
syn::ReturnType::Default => (false, quote!(())),
syn::ReturnType::Type(_, _) => (true, quote!(u64)),
};

let impl_block = quote! {
#[link(wasm_import_module = #namespace)]
extern "C" {
#[link_name = #link_name]
fn #impl_name(#(#converted_inputs),*) -> #converted_output;
for input in &original_inputs {
match input {
syn::FnArg::Typed(t) => {
let mut input = t.clone();
input.ty = Box::new(syn::Type::Verbatim(quote!(u64)));
converted_inputs.push(syn::FnArg::Typed(input));
match &*t.pat {
syn::Pat::Ident(i) => {
into_inputs
.push(quote!(
extism_pdk::ManagedMemory::from(extism_pdk::ToMemory::to_memory(&&#i)?).offset()
));
}
_ => panic!("invalid host function argument"),
}
};
}
_ => panic!("self arguments are not permitted in host functions"),
}
}

let output = match output {
syn::ReturnType::Default => quote!(()),
syn::ReturnType::Type(_, ty) => quote!(#ty),
};
let impl_name = syn::Ident::new(&format!("{name}_impl"), name.span());
let link_name = name.to_string();
let link_name = link_name.as_str();

if output_is_ptr {
gen = quote! {
#gen
let impl_block = quote! {
#[link(wasm_import_module = #namespace)]
extern "C" {
#[link_name = #link_name]
fn #impl_name(#(#converted_inputs),*) -> #converted_output;
}
};

#impl_block
let output = match output {
syn::ReturnType::Default => quote!(()),
syn::ReturnType::Type(_, ty) => quote!(#ty),
};

#vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
let res = extism_pdk::Memory::from(#impl_name(#(#into_inputs),*));
<#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec())
}
};
} else {
gen = quote! {
#gen
// For safe functions, we generate a safe wrapper that uses unsafe internally
// For unsafe functions, we generate an unsafe wrapper
if is_safe_fn {
if output_is_ptr {
quote! {
#impl_block

#vis fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
// SAFETY: The caller of the macro has asserted this host function is safe
// by marking it with `safe fn` in an `unsafe extern` block.
let res = unsafe { extism_pdk::Memory::from(#impl_name(#(#into_inputs),*)) };
<#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec())
}
}
} else {
quote! {
#impl_block

#vis fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
// SAFETY: The caller of the macro has asserted this host function is safe
// by marking it with `safe fn` in an `unsafe extern` block.
let res = unsafe { #impl_name(#(#into_inputs),*) };
core::result::Result::Ok(res)
}
}
}
} else if output_is_ptr {
quote! {
#impl_block

#impl_block
#vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
let res = extism_pdk::Memory::from(#impl_name(#(#into_inputs),*));
<#output as extism_pdk::FromBytes>::from_bytes(&res.to_vec())
}
}
} else {
quote! {
#impl_block

#vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
let res = #impl_name(#(#into_inputs),*);
core::result::Result::Ok(res)
}
};
#vis unsafe fn #name #generics (#original_inputs) -> core::result::Result<#output, extism_pdk::Error> {
let res = #impl_name(#(#into_inputs),*);
core::result::Result::Ok(res)
}
}
}
}

gen.into()
/// Attempts to parse a `safe fn` from verbatim tokens
/// Returns Some(wrapper) if successful, None if the tokens don't represent a safe fn
fn parse_safe_fn_verbatim(namespace: &str, tokens: &proc_macro2::TokenStream) -> Option<proc_macro2::TokenStream> {
use syn::parse::{Parse, Parser};

// Try to parse: [visibility] safe fn name(args) [-> ReturnType];
let parser = |input: syn::parse::ParseStream| -> syn::Result<syn::ForeignItemFn> {
let attrs = input.call(syn::Attribute::parse_outer)?;
let vis: syn::Visibility = input.parse()?;

// Check for `safe` keyword
let safe_ident: syn::Ident = input.parse()?;
if safe_ident != "safe" {
return Err(syn::Error::new(safe_ident.span(), "expected `safe`"));
}

// Parse `fn`
let fn_token: syn::token::Fn = input.parse()?;

// Parse the rest of the signature
let ident: syn::Ident = input.parse()?;
let generics: syn::Generics = input.parse()?;

let content;
let paren_token = syn::parenthesized!(content in input);
let inputs = content.parse_terminated(syn::FnArg::parse, syn::Token![,])?;

let output: syn::ReturnType = input.parse()?;

let where_clause: Option<syn::WhereClause> = input.parse()?;
let mut generics = generics;
generics.where_clause = where_clause;

let semi_token: syn::Token![;] = input.parse()?;

Ok(syn::ForeignItemFn {
attrs,
vis,
sig: syn::Signature {
constness: None,
asyncness: None,
unsafety: None,
abi: None,
fn_token,
ident,
generics,
paren_token,
inputs,
variadic: None,
output,
},
semi_token,
})
};

match parser.parse2(tokens.clone()) {
Ok(function) => {
// It's a safe fn, generate a safe wrapper
Some(generate_host_fn_wrapper(namespace, &function, true))
}
Err(_) => None,
}
}
40 changes: 40 additions & 0 deletions examples/host_function_safe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#![no_main]

//! This example demonstrates the Rust 1.82+ `unsafe extern` syntax with `safe fn`.
//!
//! In `unsafe extern` blocks, you can mark individual functions as `safe` to indicate
//! they are safe to call without an unsafe block. Functions without the `safe` qualifier
//! are implicitly unsafe.

use extism_pdk::*;
use serde::{Deserialize, Serialize};

const VOWELS: &[char] = &['a', 'A', 'e', 'E', 'i', 'I', 'o', 'O', 'u', 'U'];

#[derive(Serialize, Deserialize, ToBytes, FromBytes)]
#[encoding(Json)]
struct Output {
pub count: i32,
}

// Using Rust 1.82+ unsafe extern syntax with safe fn
#[host_fn("extism:host/user")]
unsafe extern "ExtismHost" {
// This function is marked safe - the generated wrapper is safe to call
safe fn hello_world(count: Output) -> Output;
}

#[plugin_fn]
pub fn count_vowels<'a>(input: String) -> FnResult<Output> {
let mut count = 0;
for ch in input.chars() {
if VOWELS.contains(&ch) {
count += 1;
}
}

let output = Output { count };
// No unsafe block needed because hello_world is marked as `safe fn`
let output = hello_world(output)?;
Ok(output)
}
Binary file modified test/code.wasm
Binary file not shown.
Binary file modified test/host_function.wasm
Binary file not shown.
Binary file added test/host_function_host.wasm
Binary file not shown.
Binary file added test/host_function_safe.wasm
Binary file not shown.
Binary file modified test/http.wasm
Binary file not shown.
Binary file modified test/http_headers.wasm
Binary file not shown.
Loading