From 907c1d8e387b386db81274529a62f29c82a414a9 Mon Sep 17 00:00:00 2001 From: BDaws04 Date: Fri, 23 Jan 2026 00:38:42 +0000 Subject: [PATCH 1/6] add init refactor (more aggressive inlining, removed ok_or_else and replaced w/ ok_or and hoist scanner.skip constants) --- sje_derive/src/lib.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/sje_derive/src/lib.rs b/sje_derive/src/lib.rs index cad5de0..f2c6b11 100644 --- a/sje_derive/src/lib.rs +++ b/sje_derive/src/lib.rs @@ -180,6 +180,7 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA let mut key_len = field_name.to_string().len(); let mut val_len = None; let mut ty_override = None; + let skip_const = format_ident!("SKIP_{}", field_name.to_string().to_uppercase()); if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { let sje_field = sje_attr.parse_args::().expect("unable to parse"); if let Some(name) = sje_field.name { @@ -202,8 +203,9 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA let next = Ident::new(&format!("next_{}_with_known_len", type_str), field_name.span()); let field_name_string = field_name.to_string(); quote! { - scanner.skip(#key_len); - let (offset, len) = scanner.#next(#known_len).ok_or_else(|| sje::error::Error::MissingField(#field_name_string))?; + const #skip_const: usize = #key_len; + scanner.skip(#skip_const); + let (offset, len) = scanner.#next(#known_len).ok_or(sje::error::Error::MissingField(#field_name_string))?; let #field_name = sje::LazyField::from_bytes(unsafe { bytes.get_unchecked(offset..offset + len) }); } } @@ -212,14 +214,16 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA let field_name_string = field_name.to_string(); if type_str == "array" { quote! { - scanner.skip(#key_len); - let (offset, len, count) = scanner.#next().ok_or_else(|| sje::error::Error::MissingField(#field_name_string))?; + const #skip_const: usize = #key_len; + scanner.skip(#skip_const); + let (offset, len, count) = scanner.#next().ok_or(sje::error::Error::MissingField(#field_name_string))?; let #field_name = (unsafe { bytes.get_unchecked(offset..offset + len) }, count); } } else { quote! { - scanner.skip(#key_len); - let (offset, len) = scanner.#next().ok_or_else(|| sje::error::Error::MissingField(#field_name_string))?; + const #skip_const: usize = #key_len; + scanner.skip(#skip_const); + let (offset, len) = scanner.#next().ok_or(sje::error::Error::MissingField(#field_name_string))?; let #field_name = sje::LazyField::from_bytes(unsafe { bytes.get_unchecked(offset..offset + len) }); } } @@ -279,15 +283,15 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { let array_count = Ident::new(&format!("{}_count", field_name.as_ref().unwrap()), field_name.span()); generated.extend(quote! { - #[inline] + #[inline(always)] pub const fn #as_slice(&self) -> &[u8] { self.#field_name.0 } - #[inline] + #[inline(always)] pub const fn #as_str(&self) -> &str { unsafe { std::str::from_utf8_unchecked(self.#as_slice()) } } - #[inline] + #[inline(always)] pub const fn #array_count(&self) -> usize { self.#field_name.1 } @@ -296,15 +300,15 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA let as_lazy_field = Ident::new(&format!("{}_as_lazy_field", field_name.as_ref().unwrap()), field_name.span()); generated.extend(quote! { - #[inline] + #[inline(always)] pub const fn #as_slice(&self) -> &[u8] { self.#field_name.as_slice() } - #[inline] + #[inline(always)] pub const fn #as_str(&self) -> &str { self.#field_name.as_str() } - #[inline] + #[inline(always)] pub const fn #as_lazy_field(&self) -> &sje::LazyField<'a, #field_type> { &self.#field_name } @@ -393,7 +397,7 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA } impl ExactSizeIterator for #iterator_name<'_> { - #[inline] + #[inline(always)] fn len(&self) -> usize { self.remaining } From 7dba3761cf429e0e7d1a60eb8133def29f243e8d Mon Sep 17 00:00:00 2001 From: BDaws04 Date: Fri, 23 Jan 2026 01:04:09 +0000 Subject: [PATCH 2/6] change resolve_type() to return an enum --- sje_derive/src/lib.rs | 45 +++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/sje_derive/src/lib.rs b/sje_derive/src/lib.rs index f2c6b11..9a28887 100644 --- a/sje_derive/src/lib.rs +++ b/sje_derive/src/lib.rs @@ -200,7 +200,7 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA key_len += 4; match val_len { Some(known_len) => { - let next = Ident::new(&format!("next_{}_with_known_len", type_str), field_name.span()); + let next = Ident::new(&format!("next_{}_with_known_len", type_str.as_str()), field_name.span()); let field_name_string = field_name.to_string(); quote! { const #skip_const: usize = #key_len; @@ -210,9 +210,9 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA } } None => { - let next = Ident::new(&format!("next_{}", type_str), field_name.span()); + let next = Ident::new(&format!("next_{}", type_str.as_str()), field_name.span()); let field_name_string = field_name.to_string(); - if type_str == "array" { + if type_str == JsonKind::Array { quote! { const #skip_const: usize = #key_len; scanner.skip(#skip_const); @@ -512,9 +512,34 @@ fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeA generated.into() } -fn resolve_type(ty: &Type, ty_override: Option) -> syn::Result<&'static str> { +#[derive(PartialEq, Eq, Clone, Debug)] +enum JsonKind { + Number, + String, + Boolean, + Array, +} +impl JsonKind { + #[inline(always)] + const fn as_str(&self) -> &'static str { + match self { + JsonKind::Number => "number", + JsonKind::String => "string", + JsonKind::Boolean => "boolean", + JsonKind::Array => "array", + } + } +} + +fn resolve_type(ty: &Type, ty_override: Option) -> syn::Result { if let Some(ty_override) = ty_override { - return Ok(ty_override.leak()); + return match ty_override.as_str() { + "number" => Ok(JsonKind::Number), + "string" => Ok(JsonKind::String), + "boolean" => Ok(JsonKind::Boolean), + "array" => Ok(JsonKind::Array), + _ => Err(Error::new(Span::call_site(), format!("unknown ty override `{}`", ty_override))), + }; } match ty { @@ -524,11 +549,11 @@ fn resolve_type(ty: &Type, ty_override: Option) -> syn::Result<&'static match ident.as_str() { // Primitive number types "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" | "f32" | "f64" => { - Ok("number") + Ok(JsonKind::Number) } - "String" => Ok("string"), - "bool" => Ok("boolean"), - "Vec" => Ok("array"), + "String" => Ok(JsonKind::String), + "bool" => Ok(JsonKind::Boolean), + "Vec" => Ok(JsonKind::Array), _ => Err(Error::new(Span::call_site(), "Only primitives, String, and Vec are allowed")), } } @@ -668,7 +693,7 @@ mod tests { let result = resolve_type(&parsed_ty, ty_override.map(String::from)); match (result.clone(), expected) { - (Ok(actual), Ok(expected_str)) => assert_eq!(actual, expected_str), + (Ok(actual), Ok(expected_str)) => assert_eq!(actual.as_str(), expected_str), (Err(err), Err(expected_err)) => assert!(err.to_string().contains(expected_err)), _ => panic!("Unexpected result: {:?}", result), } From 40eee479aff2c5218dd1e3d8fccb09e0b512ea60 Mon Sep 17 00:00:00 2001 From: BDaws04 Date: Tue, 27 Jan 2026 23:16:38 +0000 Subject: [PATCH 3/6] add changes --- sje_derive/src/attribute.rs | 118 +++++++ sje_derive/src/enums.rs | 18 + sje_derive/src/lib.rs | 658 +----------------------------------- sje_derive/src/sje_types.rs | 78 +++++ sje_derive/src/structs.rs | 451 ++++++++++++++++++++++++ 5 files changed, 677 insertions(+), 646 deletions(-) create mode 100644 sje_derive/src/attribute.rs create mode 100644 sje_derive/src/enums.rs create mode 100644 sje_derive/src/sje_types.rs create mode 100644 sje_derive/src/structs.rs diff --git a/sje_derive/src/attribute.rs b/sje_derive/src/attribute.rs new file mode 100644 index 0000000..763689b --- /dev/null +++ b/sje_derive/src/attribute.rs @@ -0,0 +1,118 @@ +use syn::{parse::{Parse, ParseStream}, Ident, LitInt, LitStr, LitBool, Token}; +use std::str::FromStr; +use proc_macro2::Span; + +#[derive(Debug, Copy, Clone)] +#[allow(dead_code)] +pub enum SjeType { + Object, + Array, + Tuple, + Union, +} + +impl FromStr for SjeType { + type Err = syn::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "object" => Ok(SjeType::Object), + "array" => Ok(SjeType::Array), + "tuple" => Ok(SjeType::Tuple), + "union" => Ok(SjeType::Union), + _ => Err(syn::Error::new(Span::call_site(), "expected 'object', 'array', 'tuple' or 'union'")), + } + } +} + +#[derive(Copy, Clone)] +#[allow(dead_code)] +pub struct SjeAttribute { + pub sje_type: SjeType, +} + +impl Parse for SjeAttribute { + fn parse(input: ParseStream) -> syn::Result { + let ident: Ident = input.parse()?; + let sje_type = ident.to_string().parse()?; + Ok(SjeAttribute { sje_type }) + } +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct SjeFieldAttribute { + #[allow(dead_code)] + /// value length + pub len: Option, + /// field name override + pub name: Option, + /// json type name override + pub ty: Option, + /// additional conversion method + pub also_as: Option, + /// offset at which value begins + pub offset: usize, + pub decoder: bool, +} + +impl Parse for SjeFieldAttribute { + fn parse(input: ParseStream) -> syn::Result { + let mut len = None; + let mut name = None; + let mut ty = None; + let mut also_as = None; + let mut offset = 0; + let mut decoder = false; + + while !input.is_empty() { + let lookahead = input.lookahead1(); + if lookahead.peek(Ident) { + let ident: Ident = input.parse()?; + if ident == "len" { + input.parse::()?; + let len_lit: LitInt = input.parse()?; + len = Some(len_lit.base10_parse()?); + } else if ident == "rename" { + input.parse::()?; + let ref_lit: LitStr = input.parse()?; + name = Some(ref_lit.value()); + } else if ident == "ty" { + input.parse::()?; + let ty_lit: LitStr = input.parse()?; + ty = Some(ty_lit.value()); + } else if ident == "also_as" { + input.parse::()?; + let as_lit: LitStr = input.parse()?; + also_as = Some(as_lit.value()); + } else if ident == "offset" { + input.parse::()?; + let offset_lit: LitInt = input.parse()?; + offset = offset_lit.base10_parse()?; + } else if ident == "decoder" { + input.parse::()?; + let decoder_lit: LitBool = input.parse()?; + decoder = decoder_lit.value(); + } else { + return Err(syn::Error::new_spanned(ident, "expected ['len' | 'rename' | 'ty']")); + } + } else { + return Err(lookahead.error()); + } + + // Optional comma + if input.peek(Token![,]) { + input.parse::()?; + } + } + + Ok(SjeFieldAttribute { + len, + name, + ty, + also_as, + offset, + decoder, + }) + } +} \ No newline at end of file diff --git a/sje_derive/src/enums.rs b/sje_derive/src/enums.rs new file mode 100644 index 0000000..3f2742a --- /dev/null +++ b/sje_derive/src/enums.rs @@ -0,0 +1,18 @@ +use proc_macro2::TokenStream; +use syn::DataEnum; +use quote::quote; + +pub fn handle_enum(name: &syn::Ident, data_enum: DataEnum) -> TokenStream { + let variants = data_enum.variants.iter().map(|v| &v.ident); + let generated = quote! { + impl From<&[u8]> for #name { + fn from(bytes: &[u8]) -> Self { + match std::str::from_utf8(bytes).unwrap() { + #( stringify!(#variants) => #name::#variants, )* + _ => panic!("unrecognized side"), + } + } + } + }; + generated.into() +} \ No newline at end of file diff --git a/sje_derive/src/lib.rs b/sje_derive/src/lib.rs index 9a28887..ecbead0 100644 --- a/sje_derive/src/lib.rs +++ b/sje_derive/src/lib.rs @@ -1,126 +1,18 @@ -use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::{format_ident, quote}; -use std::str::FromStr; -use syn::parse::{Parse, ParseStream}; -use syn::spanned::Spanned; use syn::{ - Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Ident, LitBool, LitInt, LitStr, PathArguments, PathSegment, - Token, Type, TypePath, parse_macro_input, + Data,DeriveInput, Error,parse_macro_input }; -#[derive(Debug, Copy, Clone)] -enum SjeType { - Object, - Array, - Tuple, - Union, -} - -impl FromStr for SjeType { - type Err = syn::Error; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "object" => Ok(SjeType::Object), - "array" => Ok(SjeType::Array), - "tuple" => Ok(SjeType::Tuple), - "union" => Ok(SjeType::Union), - _ => Err(syn::Error::new(Span::call_site(), "expected 'object', 'array', 'tuple' or 'union'")), - } - } -} - -#[derive(Copy, Clone)] -struct SjeAttribute { - sje_type: SjeType, -} +mod attribute; +mod sje_types; +mod structs; +mod enums; -impl Parse for SjeAttribute { - fn parse(input: ParseStream) -> syn::Result { - let ident: Ident = input.parse()?; - let sje_type = ident.to_string().parse()?; - Ok(SjeAttribute { sje_type }) - } -} +use crate::attribute::*; +use crate::sje_types::*; +use crate::structs::*; +use crate::enums::*; -#[derive(Debug, Clone)] -struct SjeFieldAttribute { - #[allow(dead_code)] - /// value length - len: Option, - /// field name override - name: Option, - /// json type name override - ty: Option, - /// additional conversion method - also_as: Option, - /// offset at which value begins - offset: usize, - decoder: bool, -} - -impl Parse for SjeFieldAttribute { - fn parse(input: ParseStream) -> syn::Result { - let mut len = None; - let mut name = None; - let mut ty = None; - let mut also_as = None; - let mut offset = 0; - let mut decoder = false; - - while !input.is_empty() { - let lookahead = input.lookahead1(); - if lookahead.peek(Ident) { - let ident: Ident = input.parse()?; - if ident == "len" { - input.parse::()?; - let len_lit: LitInt = input.parse()?; - len = Some(len_lit.base10_parse()?); - } else if ident == "rename" { - input.parse::()?; - let ref_lit: LitStr = input.parse()?; - name = Some(ref_lit.value()); - } else if ident == "ty" { - input.parse::()?; - let ty_lit: LitStr = input.parse()?; - ty = Some(ty_lit.value()); - } else if ident == "also_as" { - input.parse::()?; - let as_lit: LitStr = input.parse()?; - also_as = Some(as_lit.value()); - } else if ident == "offset" { - input.parse::()?; - let offset_lit: LitInt = input.parse()?; - offset = offset_lit.base10_parse()?; - } else if ident == "decoder" { - input.parse::()?; - let decoder_lit: LitBool = input.parse()?; - decoder = decoder_lit.value(); - } else { - return Err(syn::Error::new_spanned(ident, "expected ['len' | 'rename' | 'ty']")); - } - } else { - return Err(lookahead.error()); - } - - // Optional comma - if input.peek(Token![,]) { - input.parse::()?; - } - } - - Ok(SjeFieldAttribute { - len, - name, - ty, - also_as, - offset, - decoder, - }) - } -} #[proc_macro_derive(Decoder, attributes(sje))] pub fn decoder_derive(input: TokenStream) -> TokenStream { @@ -135,7 +27,7 @@ pub fn decoder_derive(input: TokenStream) -> TokenStream { .expect("Failed to parse 'sje' attribute"); match ast.data { - Data::Enum(data_enum) => handle_enum(&ast.ident, data_enum), + Data::Enum(data_enum) => handle_enum(&ast.ident, data_enum).into(), Data::Struct(data_struct) => { handle_struct(&ast.ident, data_struct, sje_attr.expect("sje attribute must be present")) } @@ -143,538 +35,12 @@ pub fn decoder_derive(input: TokenStream) -> TokenStream { } } -fn handle_enum(name: &syn::Ident, data_enum: DataEnum) -> TokenStream { - let variants = data_enum.variants.iter().map(|v| &v.ident); - let generated = quote! { - impl From<&[u8]> for #name { - fn from(bytes: &[u8]) -> Self { - match std::str::from_utf8(bytes).unwrap() { - #( stringify!(#variants) => #name::#variants, )* - _ => panic!("unrecognized side"), - } - } - } - }; - generated.into() -} - -fn handle_struct(name: &syn::Ident, data_struct: DataStruct, sje_attr: SjeAttribute) -> TokenStream { - match sje_attr.sje_type { - SjeType::Object => handle_sje_object(name, data_struct, sje_attr), - SjeType::Array => unimplemented!("array not supported"), - SjeType::Tuple => unimplemented!("tuple not supported"), - SjeType::Union => unimplemented!("union not supported"), - } -} - -fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeAttribute) -> TokenStream { - let struct_name = Ident::new(&format!("{}Decoder", name), name.span()); - - let fields = match data_struct.fields { - Fields::Named(fields) => fields.named, - _ => return quote! { compile_error!("Decoder can only be derived for structs with named fields."); }.into(), - }; - - let field_initializations = fields.iter().map(|field| { - let field_name = field.ident.as_ref().unwrap(); - let mut key_len = field_name.to_string().len(); - let mut val_len = None; - let mut ty_override = None; - let skip_const = format_ident!("SKIP_{}", field_name.to_string().to_uppercase()); - if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { - let sje_field = sje_attr.parse_args::().expect("unable to parse"); - if let Some(name) = sje_field.name { - key_len = name.len(); - } - if let Some(len) = sje_field.len { - val_len = Some(len); - } - if let Some(ty) = sje_field.ty { - ty_override = Some(ty); - } - key_len += sje_field.offset; - } - - match resolve_type(&field.ty, ty_override) { - Ok(type_str) => { - key_len += 4; - match val_len { - Some(known_len) => { - let next = Ident::new(&format!("next_{}_with_known_len", type_str.as_str()), field_name.span()); - let field_name_string = field_name.to_string(); - quote! { - const #skip_const: usize = #key_len; - scanner.skip(#skip_const); - let (offset, len) = scanner.#next(#known_len).ok_or(sje::error::Error::MissingField(#field_name_string))?; - let #field_name = sje::LazyField::from_bytes(unsafe { bytes.get_unchecked(offset..offset + len) }); - } - } - None => { - let next = Ident::new(&format!("next_{}", type_str.as_str()), field_name.span()); - let field_name_string = field_name.to_string(); - if type_str == JsonKind::Array { - quote! { - const #skip_const: usize = #key_len; - scanner.skip(#skip_const); - let (offset, len, count) = scanner.#next().ok_or(sje::error::Error::MissingField(#field_name_string))?; - let #field_name = (unsafe { bytes.get_unchecked(offset..offset + len) }, count); - } - } else { - quote! { - const #skip_const: usize = #key_len; - scanner.skip(#skip_const); - let (offset, len) = scanner.#next().ok_or(sje::error::Error::MissingField(#field_name_string))?; - let #field_name = sje::LazyField::from_bytes(unsafe { bytes.get_unchecked(offset..offset + len) }); - } - } - } - } - } - Err(e) => e.to_compile_error(), - } - }); - - let field_assignments = fields.iter().map(|field| { - let field_name = &field.ident; - quote! { - #field_name, - } - }); - - let from_field_assignments = fields.iter().map(|field| { - let field_name = &field.ident; - quote! { - #field_name: decoder.#field_name().into(), - } - }); - - let from_impl = quote! { - impl From<#struct_name<'_>> for #name { - fn from(decoder: #struct_name<'_>) -> Self { - Self { - #(#from_field_assignments)* - } - } - } - }; - - let decode_impl = quote! { - impl <'a> #struct_name<'a> { - #[inline] - pub fn decode(bytes: &'a [u8]) -> Result { - let mut scanner = sje::scanner::JsonScanner::wrap(bytes); - #(#field_initializations)* - Ok(Self { - #(#field_assignments)* - }) - } - } - }; - - let accessor_methods = fields.iter().map(|field| { - let field_name = &field.ident; - let as_slice = Ident::new(&format!("{}_as_slice", field_name.as_ref().unwrap()), field_name.span()); - let as_str = Ident::new(&format!("{}_as_str", field_name.as_ref().unwrap()), field_name.span()); - - let mut generated = quote! {}; - - let field_type = &field.ty; - if let syn::Type::Path(path) = field_type { - if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { - let array_count = Ident::new(&format!("{}_count", field_name.as_ref().unwrap()), field_name.span()); - generated.extend(quote! { - #[inline(always)] - pub const fn #as_slice(&self) -> &[u8] { - self.#field_name.0 - } - #[inline(always)] - pub const fn #as_str(&self) -> &str { - unsafe { std::str::from_utf8_unchecked(self.#as_slice()) } - } - #[inline(always)] - pub const fn #array_count(&self) -> usize { - self.#field_name.1 - } - }) - } else { - let as_lazy_field = - Ident::new(&format!("{}_as_lazy_field", field_name.as_ref().unwrap()), field_name.span()); - generated.extend(quote! { - #[inline(always)] - pub const fn #as_slice(&self) -> &[u8] { - self.#field_name.as_slice() - } - #[inline(always)] - pub const fn #as_str(&self) -> &str { - self.#field_name.as_str() - } - #[inline(always)] - pub const fn #as_lazy_field(&self) -> &sje::LazyField<'a, #field_type> { - &self.#field_name - } - }) - } - } - - if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { - let sje_field = sje_attr.parse_args::().expect("unable to parse"); - if let Some(also_as) = sje_field.also_as { - let type_name = also_as.split("::").last().map(|s| s.to_string()).unwrap(); - let type_name_ident: syn::Path = syn::parse_str(&also_as).unwrap(); - let also_as = Ident::new( - &format!("{}_as_{}", field_name.as_ref().unwrap(), type_name.to_snake_case()), - field_name.span(), - ); - generated.extend(quote! { - - #[inline] - pub fn #also_as(&self) -> #type_name_ident { - self.#as_str().parse().unwrap() - } - }); - } - } - - generated - }); - - let new_fields = fields.iter().map(|field| { - let field_name = &field.ident; - let field_type = &field.ty; - if let syn::Type::Path(path) = field_type { - if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { - quote! { - #field_name: (&'a [u8], usize), - } - } else { - quote! { - #field_name: sje::LazyField<'a, #field_type>, - } - } - } else { - quote! {} - } - }); - - let iterators = fields.iter().map(|field| { - let mut decoder = false; - if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { - let sje_field = sje_attr.parse_args::().expect("unable to parse"); - decoder = sje_field.decoder - } - - let field_name = &field.ident; - let field_type = &field.ty; - - if let syn::Type::Path(path) = field_type { - if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { - if let Some(segment) = path.path.segments.last() { - if let PathArguments::AngleBracketed(args) = &segment.arguments { - if let Some(syn::GenericArgument::Type(arg_type)) = args.args.first() { - let array_struct_name = - format_ident!("{}", field_name.as_ref().unwrap().to_string().to_upper_camel_case()); - let array_fn_name = format_ident!("{}", field_name.as_ref().unwrap().to_string()); - let iterator_name = - format_ident!("{}Iter", field_name.as_ref().unwrap().to_string().to_upper_camel_case()); - let next_impl = iterator_next_impl(arg_type, decoder); - - let mut code = quote! { - #[derive(Debug)] - pub struct #array_struct_name<'a> { - bytes: &'a [u8], - remaining: usize, - } - - impl #struct_name<'_> { - #[inline] - pub const fn #array_fn_name(&self) -> #array_struct_name { - #array_struct_name { bytes: self.#array_fn_name.0, remaining: self.#array_fn_name.1 } - } - } - pub struct #iterator_name<'a> { - scanner: sje::scanner::JsonScanner<'a>, - remaining: usize, - } - impl ExactSizeIterator for #iterator_name<'_> { - - #[inline(always)] - fn len(&self) -> usize { - self.remaining - } - } - }; - - if decoder { - let arg_type_decoder = format_ident!("{}Decoder", type_to_ident(arg_type).unwrap()); - code.extend(quote! { - impl <'a> From<#array_struct_name<'a>> for Vec<#arg_type_decoder<'a>> { - fn from(value: #array_struct_name<'a>) -> Self { - value.into_iter().collect() - } - } - - impl<'a> IntoIterator for #array_struct_name<'a> { - type Item = #arg_type_decoder<'a>; - type IntoIter = #iterator_name<'a>; - fn into_iter(self) -> Self::IntoIter { - #iterator_name { - scanner: sje::scanner::JsonScanner::wrap(self.bytes), - remaining: self.remaining - } - } - } - impl <'a> Iterator for #iterator_name<'a> { - type Item = #arg_type_decoder<'a>; - #[inline] - fn next(&mut self) -> Option { - #next_impl - } - #[inline] - fn size_hint(&self) -> (usize, Option) { - (self.remaining, Some(self.remaining)) - } - } - impl From<#array_struct_name<'_>> for Vec<#arg_type> { - fn from(value: #array_struct_name<'_>) -> Self { - value.into_iter().map(|decoder| decoder.into()).collect() - } - } - }); - } else { - code.extend(quote! { - impl From<#array_struct_name<'_>> for Vec<#arg_type> { - fn from(value: #array_struct_name) -> Self { - value.into_iter().collect() - } - } - - impl<'a> IntoIterator for #array_struct_name<'a> { - type Item = #arg_type; - type IntoIter = #iterator_name<'a>; - - fn into_iter(self) -> Self::IntoIter { - #iterator_name { - scanner: sje::scanner::JsonScanner::wrap(self.bytes), - remaining: self.remaining - } - } - } - - impl Iterator for #iterator_name<'_> { - type Item = #arg_type; - - #[inline] - fn next(&mut self) -> Option { - #next_impl - } - #[inline] - fn size_hint(&self) -> (usize, Option) { - (self.remaining, Some(self.remaining)) - } - } - }); - } - return code; - } - } - } - } else { - return quote! { - impl #struct_name<'_> { - #[inline] - pub fn #field_name(&self) -> #field_type { - self.#field_name.get().unwrap() - } - } - }; - } - } - quote! {} - }); - - let generated = quote! { - #[derive(Debug)] - pub struct #struct_name<'a> { - #(#new_fields)* - } - - #from_impl - - #decode_impl - - impl <'a> #struct_name<'a> { - #(#accessor_methods)* - } - - #(#iterators)* - }; - - generated.into() -} - -#[derive(PartialEq, Eq, Clone, Debug)] -enum JsonKind { - Number, - String, - Boolean, - Array, -} -impl JsonKind { - #[inline(always)] - const fn as_str(&self) -> &'static str { - match self { - JsonKind::Number => "number", - JsonKind::String => "string", - JsonKind::Boolean => "boolean", - JsonKind::Array => "array", - } - } -} - -fn resolve_type(ty: &Type, ty_override: Option) -> syn::Result { - if let Some(ty_override) = ty_override { - return match ty_override.as_str() { - "number" => Ok(JsonKind::Number), - "string" => Ok(JsonKind::String), - "boolean" => Ok(JsonKind::Boolean), - "array" => Ok(JsonKind::Array), - _ => Err(Error::new(Span::call_site(), format!("unknown ty override `{}`", ty_override))), - }; - } - - match ty { - Type::Path(type_path) => { - let ident = type_path.path.segments.last().unwrap().ident.to_string(); - - match ident.as_str() { - // Primitive number types - "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" | "f32" | "f64" => { - Ok(JsonKind::Number) - } - "String" => Ok(JsonKind::String), - "bool" => Ok(JsonKind::Boolean), - "Vec" => Ok(JsonKind::Array), - _ => Err(Error::new(Span::call_site(), "Only primitives, String, and Vec are allowed")), - } - } - _ => Err(Error::new(Span::call_site(), "Unsupported type: Only primitives, String, and Vec are allowed")), - } -} - -fn iterator_next_impl(ty: &Type, decoder: bool) -> proc_macro2::TokenStream { - match ty { - Type::Path(path) => { - let mut code = quote! {}; - let mut p = path.path.clone(); - if let Some(last) = p.segments.last_mut() { - // last.ident = format_ident!("{}Decoder", last.ident); - let ident = match decoder { - true => format_ident!("{}Decoder", last.ident.clone()), - false => format_ident!("{}", last.ident.clone()), - }; - - let last = match decoder { - true => quote! { - Some(#ident::decode(bytes).unwrap()) - }, - false => quote! { - let s = unsafe { std::str::from_utf8_unchecked(bytes) }; - Some(#ident::from_str(s).unwrap()) - }, - }; - - code.extend(quote! { - if self.scanner.position() + 1 == self.scanner.bytes().len() { - return None; - } - self.scanner.skip(1); - let (offset, len) = self.scanner.next_object()?; - self.remaining -= 1; - - let bytes = &self.scanner.bytes()[offset..offset + len]; - let bytes = unsafe { std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) }; - #last - // Some(#ident::decode(bytes).unwrap()) - }); - } - code - } - Type::Tuple(tuple) => { - // Generate code for processing each element - let mut code = quote! {}; - let mut tuple_values = Vec::new(); - - code.extend(quote! { - if self.scanner.position() + 1 == self.scanner.bytes().len() { - return None; - } - self.scanner.skip(1); - let (offset, len) = self.scanner.next_tuple()?; - let mut tuple_scanner = unsafe { sje::scanner::JsonScanner::wrap(self.scanner.bytes().get_unchecked(offset..offset + len)) }; - }); - - // Iterate over the tuple elements and generate code for each element - for (i, _) in tuple.elems.iter().enumerate() { - // Dynamically generate a variable name based on the index - let var_name = format_ident!("val_{i}"); - - // Generate the code for processing this element - code.extend(quote! { - tuple_scanner.skip(1); - let (offset, len) = tuple_scanner.next_string()?; - let str = unsafe { std::str::from_utf8_unchecked(tuple_scanner.bytes().get_unchecked(offset..offset + len)) }; - let #var_name = str.parse().unwrap(); - }); - - // Add the variable to the tuple values vector for dynamic construction - tuple_values.push(quote! { #var_name }); - } - - // Combine the generated code and the `Some(...)` expression - code.extend(quote! { - self.remaining -= 1; - Some((#(#tuple_values),*)) - }); - - code - } - _ => { - // If it's not a tuple, return an empty TokenStream - quote! {} - } - } -} - -#[allow(dead_code)] -fn is_integer_type(ty: &Type) -> bool { - if let Type::Path(type_path) = ty { - if let Some(PathSegment { ident, .. }) = type_path.path.segments.last() { - return matches!( - ident.to_string().as_str(), - "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64" | "isize" - ); - } - } - false -} - -/// Try to extract the bare `Ident` from a `&Type::Path`. -fn type_to_ident(ty: &Type) -> Option { - if let Type::Path(TypePath { qself: None, path }) = ty { - // if it's something like `Foo` or `my::crate::Bar`, - // `.segments.last()` is the `Bar` segment - path.segments.last().map(|seg| seg.ident.clone()) - } else { - None - } -} - #[cfg(test)] mod tests { - use syn::{Attribute, parse_quote, parse_str}; + use syn::{Attribute, parse_quote, parse_str, Type}; + struct Price; use super::*; - #[test] fn should_parse_sje_field_attribute() { let attr: Attribute = parse_quote! { diff --git a/sje_derive/src/sje_types.rs b/sje_derive/src/sje_types.rs new file mode 100644 index 0000000..1c7fde2 --- /dev/null +++ b/sje_derive/src/sje_types.rs @@ -0,0 +1,78 @@ +use proc_macro2::Span; +use syn::{Type, Ident, TypePath, PathSegment}; +use crate::Error; + +// Json type mapping + Type resolution + +#[derive(PartialEq, Eq, Clone, Debug)] +pub enum JsonKind { + Number, + String, + Boolean, + Array, +} +impl JsonKind { + #[inline(always)] + pub const fn as_str(&self) -> &'static str { + match self { + JsonKind::Number => "number", + JsonKind::String => "string", + JsonKind::Boolean => "boolean", + JsonKind::Array => "array", + } + } +} + +pub fn resolve_type(ty: &Type, ty_override: Option) -> syn::Result { + if let Some(ty_override) = ty_override { + return match ty_override.as_str() { + "number" => Ok(JsonKind::Number), + "string" => Ok(JsonKind::String), + "boolean" => Ok(JsonKind::Boolean), + "array" => Ok(JsonKind::Array), + _ => Err(Error::new(Span::call_site(), format!("unknown ty override `{}`", ty_override))), + }; + } + + match ty { + Type::Path(type_path) => { + let ident = type_path.path.segments.last().unwrap().ident.to_string(); + + match ident.as_str() { + // Primitive number types + "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128" | "f32" | "f64" => { + Ok(JsonKind::Number) + } + "String" => Ok(JsonKind::String), + "bool" => Ok(JsonKind::Boolean), + "Vec" => Ok(JsonKind::Array), + _ => Err(Error::new(Span::call_site(), "Only primitives, String, and Vec are allowed")), + } + } + _ => Err(Error::new(Span::call_site(), "Unsupported type: Only primitives, String, and Vec are allowed")), + } +} + +#[allow(dead_code)] +pub fn is_integer_type(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(PathSegment { ident, .. }) = type_path.path.segments.last() { + return matches!( + ident.to_string().as_str(), + "u8" | "u16" | "u32" | "u64" | "usize" | "i8" | "i16" | "i32" | "i64" | "isize" + ); + } + } + false +} + +/// Try to extract the bare `Ident` from a `&Type::Path`. +pub fn type_to_ident(ty: &Type) -> Option { + if let Type::Path(TypePath { qself: None, path }) = ty { + // if it's something like `Foo` or `my::crate::Bar`, + // `.segments.last()` is the `Bar` segment + path.segments.last().map(|seg| seg.ident.clone()) + } else { + None + } +} \ No newline at end of file diff --git a/sje_derive/src/structs.rs b/sje_derive/src/structs.rs new file mode 100644 index 0000000..cf25f50 --- /dev/null +++ b/sje_derive/src/structs.rs @@ -0,0 +1,451 @@ +use heck::{ToSnakeCase, ToUpperCamelCase}; +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::spanned::Spanned; +use syn::{ + DataStruct,Fields, Ident, PathArguments, Type, +}; + +use crate::{ + SjeType, SjeAttribute, SjeFieldAttribute, resolve_type, JsonKind, type_to_ident +}; + +/// Entry point for handling structs based on SJE type +pub fn handle_struct(name: &syn::Ident, data_struct: DataStruct, sje_attr: SjeAttribute) -> TokenStream { + match sje_attr.sje_type { + SjeType::Object => handle_sje_object(name, data_struct, sje_attr), + SjeType::Array => unimplemented!("array not supported"), + SjeType::Tuple => unimplemented!("tuple not supported"), + SjeType::Union => unimplemented!("union not supported"), + } +} + +/// Handle struct with SjeType::Object +fn handle_sje_object(name: &syn::Ident, data_struct: DataStruct, _sje_attr: SjeAttribute) -> TokenStream { + let struct_name = Ident::new(&format!("{}Decoder", name), name.span()); + + let fields = match data_struct.fields { + Fields::Named(fields) => fields.named, + _ => return quote! { compile_error!("Decoder can only be derived for structs with named fields."); }.into(), + }; + + let field_initializations = fields.iter().map(|field| { + let field_name = field.ident.as_ref().unwrap(); + let mut key_len = field_name.to_string().len(); + let mut val_len = None; + let mut ty_override = None; + let skip_const = format_ident!("SKIP_{}", field_name.to_string().to_uppercase()); + if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { + let sje_field = sje_attr.parse_args::().expect("unable to parse"); + if let Some(name) = sje_field.name { + key_len = name.len(); + } + if let Some(len) = sje_field.len { + val_len = Some(len); + } + if let Some(ty) = sje_field.ty { + ty_override = Some(ty); + } + key_len += sje_field.offset; + } + + match resolve_type(&field.ty, ty_override) { + Ok(type_str) => { + key_len += 4; + match val_len { + Some(known_len) => { + let next = Ident::new(&format!("next_{}_with_known_len", type_str.as_str()), field_name.span()); + let field_name_string = field_name.to_string(); + quote! { + const #skip_const: usize = #key_len; + scanner.skip(#skip_const); + let (offset, len) = scanner.#next(#known_len).ok_or(sje::error::Error::MissingField(#field_name_string))?; + let #field_name = sje::LazyField::from_bytes(unsafe { bytes.get_unchecked(offset..offset + len) }); + } + } + None => { + let next = Ident::new(&format!("next_{}", type_str.as_str()), field_name.span()); + let field_name_string = field_name.to_string(); + if type_str == JsonKind::Array { + quote! { + const #skip_const: usize = #key_len; + scanner.skip(#skip_const); + let (offset, len, count) = scanner.#next().ok_or(sje::error::Error::MissingField(#field_name_string))?; + let #field_name = (unsafe { bytes.get_unchecked(offset..offset + len) }, count); + } + } else { + quote! { + const #skip_const: usize = #key_len; + scanner.skip(#skip_const); + let (offset, len) = scanner.#next().ok_or(sje::error::Error::MissingField(#field_name_string))?; + let #field_name = sje::LazyField::from_bytes(unsafe { bytes.get_unchecked(offset..offset + len) }); + } + } + } + } + } + Err(e) => e.to_compile_error(), + } + }); + + let field_assignments = fields.iter().map(|field| { + let field_name = &field.ident; + quote! { + #field_name, + } + }); + + let from_field_assignments = fields.iter().map(|field| { + let field_name = &field.ident; + quote! { + #field_name: decoder.#field_name().into(), + } + }); + + let from_impl = quote! { + impl From<#struct_name<'_>> for #name { + fn from(decoder: #struct_name<'_>) -> Self { + Self { + #(#from_field_assignments)* + } + } + } + }; + + let decode_impl = quote! { + impl <'a> #struct_name<'a> { + #[inline] + pub fn decode(bytes: &'a [u8]) -> Result { + let mut scanner = sje::scanner::JsonScanner::wrap(bytes); + #(#field_initializations)* + Ok(Self { + #(#field_assignments)* + }) + } + } + }; + + let accessor_methods = fields.iter().map(|field| { + let field_name = &field.ident; + let as_slice = Ident::new(&format!("{}_as_slice", field_name.as_ref().unwrap()), field_name.span()); + let as_str = Ident::new(&format!("{}_as_str", field_name.as_ref().unwrap()), field_name.span()); + + let mut generated = quote! {}; + + let field_type = &field.ty; + if let syn::Type::Path(path) = field_type { + if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { + let array_count = Ident::new(&format!("{}_count", field_name.as_ref().unwrap()), field_name.span()); + generated.extend(quote! { + #[inline(always)] + pub const fn #as_slice(&self) -> &[u8] { + self.#field_name.0 + } + #[inline(always)] + pub const fn #as_str(&self) -> &str { + unsafe { std::str::from_utf8_unchecked(self.#as_slice()) } + } + #[inline(always)] + pub const fn #array_count(&self) -> usize { + self.#field_name.1 + } + }) + } else { + let as_lazy_field = + Ident::new(&format!("{}_as_lazy_field", field_name.as_ref().unwrap()), field_name.span()); + generated.extend(quote! { + #[inline(always)] + pub const fn #as_slice(&self) -> &[u8] { + self.#field_name.as_slice() + } + #[inline(always)] + pub const fn #as_str(&self) -> &str { + self.#field_name.as_str() + } + #[inline(always)] + pub const fn #as_lazy_field(&self) -> &sje::LazyField<'a, #field_type> { + &self.#field_name + } + }) + } + } + + if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { + let sje_field = sje_attr.parse_args::().expect("unable to parse"); + if let Some(also_as) = sje_field.also_as { + let type_name = also_as.split("::").last().map(|s| s.to_string()).unwrap(); + let type_name_ident: syn::Path = syn::parse_str(&also_as).unwrap(); + let also_as = Ident::new( + &format!("{}_as_{}", field_name.as_ref().unwrap(), type_name.to_snake_case()), + field_name.span(), + ); + generated.extend(quote! { + + #[inline] + pub fn #also_as(&self) -> #type_name_ident { + self.#as_str().parse().unwrap() + } + }); + } + } + + generated + }); + + let new_fields = fields.iter().map(|field| { + let field_name = &field.ident; + let field_type = &field.ty; + if let syn::Type::Path(path) = field_type { + if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { + quote! { + #field_name: (&'a [u8], usize), + } + } else { + quote! { + #field_name: sje::LazyField<'a, #field_type>, + } + } + } else { + quote! {} + } + }); + + let iterators = fields.iter().map(|field| { + let mut decoder = false; + if let Some(sje_attr) = field.attrs.iter().find(|attr| attr.path().is_ident("sje")) { + let sje_field = sje_attr.parse_args::().expect("unable to parse"); + decoder = sje_field.decoder + } + + let field_name = &field.ident; + let field_type = &field.ty; + + if let syn::Type::Path(path) = field_type { + if path.path.segments.last().map(|seg| seg.ident == "Vec").unwrap_or(false) { + if let Some(segment) = path.path.segments.last() { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(arg_type)) = args.args.first() { + let array_struct_name = + format_ident!("{}", field_name.as_ref().unwrap().to_string().to_upper_camel_case()); + let array_fn_name = format_ident!("{}", field_name.as_ref().unwrap().to_string()); + let iterator_name = + format_ident!("{}Iter", field_name.as_ref().unwrap().to_string().to_upper_camel_case()); + let next_impl = iterator_next_impl(arg_type, decoder); + + let mut code = quote! { + #[derive(Debug)] + pub struct #array_struct_name<'a> { + bytes: &'a [u8], + remaining: usize, + } + + impl #struct_name<'_> { + #[inline] + pub const fn #array_fn_name(&self) -> #array_struct_name { + #array_struct_name { bytes: self.#array_fn_name.0, remaining: self.#array_fn_name.1 } + } + } + pub struct #iterator_name<'a> { + scanner: sje::scanner::JsonScanner<'a>, + remaining: usize, + } + impl ExactSizeIterator for #iterator_name<'_> { + + #[inline(always)] + fn len(&self) -> usize { + self.remaining + } + } + }; + + if decoder { + let arg_type_decoder = format_ident!("{}Decoder", type_to_ident(arg_type).unwrap()); + code.extend(quote! { + impl <'a> From<#array_struct_name<'a>> for Vec<#arg_type_decoder<'a>> { + fn from(value: #array_struct_name<'a>) -> Self { + value.into_iter().collect() + } + } + + impl<'a> IntoIterator for #array_struct_name<'a> { + type Item = #arg_type_decoder<'a>; + type IntoIter = #iterator_name<'a>; + fn into_iter(self) -> Self::IntoIter { + #iterator_name { + scanner: sje::scanner::JsonScanner::wrap(self.bytes), + remaining: self.remaining + } + } + } + impl <'a> Iterator for #iterator_name<'a> { + type Item = #arg_type_decoder<'a>; + #[inline] + fn next(&mut self) -> Option { + #next_impl + } + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } + } + impl From<#array_struct_name<'_>> for Vec<#arg_type> { + fn from(value: #array_struct_name<'_>) -> Self { + value.into_iter().map(|decoder| decoder.into()).collect() + } + } + }); + } else { + code.extend(quote! { + impl From<#array_struct_name<'_>> for Vec<#arg_type> { + fn from(value: #array_struct_name) -> Self { + value.into_iter().collect() + } + } + + impl<'a> IntoIterator for #array_struct_name<'a> { + type Item = #arg_type; + type IntoIter = #iterator_name<'a>; + + fn into_iter(self) -> Self::IntoIter { + #iterator_name { + scanner: sje::scanner::JsonScanner::wrap(self.bytes), + remaining: self.remaining + } + } + } + + impl Iterator for #iterator_name<'_> { + type Item = #arg_type; + + #[inline] + fn next(&mut self) -> Option { + #next_impl + } + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } + } + }); + } + return code; + } + } + } + } else { + return quote! { + impl #struct_name<'_> { + #[inline] + pub fn #field_name(&self) -> #field_type { + self.#field_name.get().unwrap() + } + } + }; + } + } + quote! {} + }); + + let generated = quote! { + #[derive(Debug)] + pub struct #struct_name<'a> { + #(#new_fields)* + } + + #from_impl + + #decode_impl + + impl <'a> #struct_name<'a> { + #(#accessor_methods)* + } + + #(#iterators)* + }; + + generated.into() +} + +fn iterator_next_impl(ty: &Type, decoder: bool) -> proc_macro2::TokenStream { + match ty { + Type::Path(path) => { + let mut code = quote! {}; + let mut p = path.path.clone(); + if let Some(last) = p.segments.last_mut() { + // last.ident = format_ident!("{}Decoder", last.ident); + let ident = match decoder { + true => format_ident!("{}Decoder", last.ident.clone()), + false => format_ident!("{}", last.ident.clone()), + }; + + let last = match decoder { + true => quote! { + Some(#ident::decode(bytes).unwrap()) + }, + false => quote! { + let s = unsafe { std::str::from_utf8_unchecked(bytes) }; + Some(#ident::from_str(s).unwrap()) + }, + }; + + code.extend(quote! { + if self.scanner.position() + 1 == self.scanner.bytes().len() { + return None; + } + self.scanner.skip(1); + let (offset, len) = self.scanner.next_object()?; + self.remaining -= 1; + + let bytes = &self.scanner.bytes()[offset..offset + len]; + let bytes = unsafe { std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) }; + #last + // Some(#ident::decode(bytes).unwrap()) + }); + } + code + } + Type::Tuple(tuple) => { + // Generate code for processing each element + let mut code = quote! {}; + let mut tuple_values = Vec::new(); + + code.extend(quote! { + if self.scanner.position() + 1 == self.scanner.bytes().len() { + return None; + } + self.scanner.skip(1); + let (offset, len) = self.scanner.next_tuple()?; + let mut tuple_scanner = unsafe { sje::scanner::JsonScanner::wrap(self.scanner.bytes().get_unchecked(offset..offset + len)) }; + }); + + // Iterate over the tuple elements and generate code for each element + for (i, _) in tuple.elems.iter().enumerate() { + // Dynamically generate a variable name based on the index + let var_name = format_ident!("val_{i}"); + + // Generate the code for processing this element + code.extend(quote! { + tuple_scanner.skip(1); + let (offset, len) = tuple_scanner.next_string()?; + let str = unsafe { std::str::from_utf8_unchecked(tuple_scanner.bytes().get_unchecked(offset..offset + len)) }; + let #var_name = str.parse().unwrap(); + }); + + // Add the variable to the tuple values vector for dynamic construction + tuple_values.push(quote! { #var_name }); + } + + // Combine the generated code and the `Some(...)` expression + code.extend(quote! { + self.remaining -= 1; + Some((#(#tuple_values),*)) + }); + + code + } + _ => { + // If it's not a tuple, return an empty TokenStream + quote! {} + } + } +} \ No newline at end of file From 8db7d8c843cbfcddf99a00bacbb260ca47ab28d0 Mon Sep 17 00:00:00 2001 From: BDaws04 Date: Tue, 27 Jan 2026 23:29:55 +0000 Subject: [PATCH 4/6] add refactor + more tests for macro --- sje_derive/src/attribute.rs | 9 +++++--- sje_derive/src/enums.rs | 6 +++--- sje_derive/src/lib.rs | 42 ++++++++++++++++++++++++++++++------- sje_derive/src/sje_types.rs | 9 +++++--- sje_derive/src/structs.rs | 10 +++------ 5 files changed, 52 insertions(+), 24 deletions(-) diff --git a/sje_derive/src/attribute.rs b/sje_derive/src/attribute.rs index 763689b..aed00d1 100644 --- a/sje_derive/src/attribute.rs +++ b/sje_derive/src/attribute.rs @@ -1,6 +1,9 @@ -use syn::{parse::{Parse, ParseStream}, Ident, LitInt, LitStr, LitBool, Token}; -use std::str::FromStr; use proc_macro2::Span; +use std::str::FromStr; +use syn::{ + Ident, LitBool, LitInt, LitStr, Token, + parse::{Parse, ParseStream}, +}; #[derive(Debug, Copy, Clone)] #[allow(dead_code)] @@ -115,4 +118,4 @@ impl Parse for SjeFieldAttribute { decoder, }) } -} \ No newline at end of file +} diff --git a/sje_derive/src/enums.rs b/sje_derive/src/enums.rs index 3f2742a..f300a94 100644 --- a/sje_derive/src/enums.rs +++ b/sje_derive/src/enums.rs @@ -1,6 +1,6 @@ use proc_macro2::TokenStream; -use syn::DataEnum; use quote::quote; +use syn::DataEnum; pub fn handle_enum(name: &syn::Ident, data_enum: DataEnum) -> TokenStream { let variants = data_enum.variants.iter().map(|v| &v.ident); @@ -14,5 +14,5 @@ pub fn handle_enum(name: &syn::Ident, data_enum: DataEnum) -> TokenStream { } } }; - generated.into() -} \ No newline at end of file + generated +} diff --git a/sje_derive/src/lib.rs b/sje_derive/src/lib.rs index ecbead0..b5a3836 100644 --- a/sje_derive/src/lib.rs +++ b/sje_derive/src/lib.rs @@ -1,18 +1,15 @@ use proc_macro::TokenStream; -use syn::{ - Data,DeriveInput, Error,parse_macro_input -}; +use syn::{Data, DeriveInput, Error, parse_macro_input}; mod attribute; +mod enums; mod sje_types; mod structs; -mod enums; use crate::attribute::*; +use crate::enums::*; use crate::sje_types::*; use crate::structs::*; -use crate::enums::*; - #[proc_macro_derive(Decoder, attributes(sje))] pub fn decoder_derive(input: TokenStream) -> TokenStream { @@ -37,8 +34,7 @@ pub fn decoder_derive(input: TokenStream) -> TokenStream { #[cfg(test)] mod tests { - use syn::{Attribute, parse_quote, parse_str, Type}; - struct Price; + use syn::{Attribute, Type, parse_quote, parse_str}; use super::*; #[test] @@ -80,4 +76,34 @@ mod tests { check_type("Option", None, Err("Only primitives, String, and Vec are allowed")); check_type("Result", None, Err("Only primitives, String, and Vec are allowed")); } + + #[test] + fn should_parse_sje_field_attribute_with_ty_and_offset() { + let attr: Attribute = parse_quote! { + #[sje(rename = "bar", len = 8, ty = "u64", offset = 3)] + }; + let field: SjeFieldAttribute = attr.parse_args().unwrap(); + assert_eq!(Some("bar".to_string()), field.name); + assert_eq!(Some(8), field.len); + assert_eq!(Some("u64".to_string()), field.ty); + assert_eq!(3, field.offset); + } + + #[test] + fn should_handle_nested_vec_types() { + check_type("Vec>", None, Ok("array")); + check_type("Vec>", None, Ok("array")); + } + + #[test] + fn should_handle_tuple_inside_vec() { + check_type("Vec<(u64, String)>", None, Ok("array")); + check_type("Vec<(Price, Quantity)>", None, Ok("array")); + } + + #[test] + fn should_fail_for_non_vec_complex_types() { + check_type("HashMap", None, Err("Only primitives, String, and Vec are allowed")); + check_type("BTreeSet", None, Err("Only primitives, String, and Vec are allowed")); + } } diff --git a/sje_derive/src/sje_types.rs b/sje_derive/src/sje_types.rs index 1c7fde2..ec07d9b 100644 --- a/sje_derive/src/sje_types.rs +++ b/sje_derive/src/sje_types.rs @@ -1,6 +1,6 @@ -use proc_macro2::Span; -use syn::{Type, Ident, TypePath, PathSegment}; use crate::Error; +use proc_macro2::Span; +use syn::{Ident, PathSegment, Type, TypePath}; // Json type mapping + Type resolution @@ -10,6 +10,7 @@ pub enum JsonKind { String, Boolean, Array, + Object, } impl JsonKind { #[inline(always)] @@ -19,6 +20,7 @@ impl JsonKind { JsonKind::String => "string", JsonKind::Boolean => "boolean", JsonKind::Array => "array", + JsonKind::Object => "object", } } } @@ -30,6 +32,7 @@ pub fn resolve_type(ty: &Type, ty_override: Option) -> syn::Result Ok(JsonKind::String), "boolean" => Ok(JsonKind::Boolean), "array" => Ok(JsonKind::Array), + "object" => Ok(JsonKind::Object), _ => Err(Error::new(Span::call_site(), format!("unknown ty override `{}`", ty_override))), }; } @@ -75,4 +78,4 @@ pub fn type_to_ident(ty: &Type) -> Option { } else { None } -} \ No newline at end of file +} diff --git a/sje_derive/src/structs.rs b/sje_derive/src/structs.rs index cf25f50..32c277e 100644 --- a/sje_derive/src/structs.rs +++ b/sje_derive/src/structs.rs @@ -2,13 +2,9 @@ use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::spanned::Spanned; -use syn::{ - DataStruct,Fields, Ident, PathArguments, Type, -}; +use syn::{DataStruct, Fields, Ident, PathArguments, Type}; -use crate::{ - SjeType, SjeAttribute, SjeFieldAttribute, resolve_type, JsonKind, type_to_ident -}; +use crate::{JsonKind, SjeAttribute, SjeFieldAttribute, SjeType, resolve_type, type_to_ident}; /// Entry point for handling structs based on SJE type pub fn handle_struct(name: &syn::Ident, data_struct: DataStruct, sje_attr: SjeAttribute) -> TokenStream { @@ -448,4 +444,4 @@ fn iterator_next_impl(ty: &Type, decoder: bool) -> proc_macro2::TokenStream { quote! {} } } -} \ No newline at end of file +} From a8f2cf8df3ea41301c872f9cb6a26f81710dcd54 Mon Sep 17 00:00:00 2001 From: BDaws04 Date: Wed, 28 Jan 2026 00:56:29 +0000 Subject: [PATCH 5/6] add tests --- sje/tests/array_of_objects.rs | 48 +++++++++++++++++++++ sje/tests/custom.rs | 54 ++++++++++++++++++++---- sje/tests/decoder.rs | 78 ++++++++++++++++++++++++++++++++++- sje/tests/iter.rs | 43 +++++++++++++++++++ 4 files changed, 215 insertions(+), 8 deletions(-) diff --git a/sje/tests/array_of_objects.rs b/sje/tests/array_of_objects.rs index 5b15bf1..99bb98b 100644 --- a/sje/tests/array_of_objects.rs +++ b/sje/tests/array_of_objects.rs @@ -39,3 +39,51 @@ fn should_decode_array_of_objects() { assert!(positions.next().is_none()); } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_handle_empty_updates() { + let json = r#"{"t":1746699621,"u":[]}"#; + let update = PositionUpdateDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(0, update.updates_count()); + let mut positions = update.updates().into_iter(); + assert!(positions.next().is_none()); + } + + #[test] + fn should_handle_single_update() { + let json = r#"{"t":1746699621,"u":[{"s":"bnbusdt","a":50}]}"#; + let update = PositionUpdateDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(1, update.updates_count()); + + let mut positions = update.updates().into_iter(); + let position = positions.next().unwrap(); + assert_eq!("bnbusdt", position.symbol_as_str()); + assert_eq!(50, position.amount()); + assert!(positions.next().is_none()); + } + + #[test] + fn should_decode_array_of_objects() { + let json = r#"{"t":1746699621,"u":[{"s":"btcusdt","a":100},{"s":"ethusdt","a":200}]}"#; + + let update = PositionUpdateDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(2, update.updates_count()); + + let mut positions = update.updates().into_iter(); + + let position = positions.next().unwrap(); + assert_eq!("btcusdt", position.symbol_as_str()); + assert_eq!(100, position.amount()); + + let position = positions.next().unwrap(); + assert_eq!("ethusdt", position.symbol_as_str()); + assert_eq!(200, position.amount()); + + assert!(positions.next().is_none()); + } +} + diff --git a/sje/tests/custom.rs b/sje/tests/custom.rs index bb8a8f2..4b1de36 100644 --- a/sje/tests/custom.rs +++ b/sje/tests/custom.rs @@ -20,11 +20,51 @@ pub struct Trade { price: Price, } -#[test] -fn should_parse_custom_field() { - let json = r#"{"p":"12345"}"#; - let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); - assert_eq!(&Price(12345), trade.price_as_lazy_field().get_ref().unwrap()); - assert_eq!(Price(12345), trade.price_as_lazy_field().get().unwrap()); - assert_eq!(Price(12345), trade.price()); +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_parse_custom_field() { + let json = r#"{"p":"12345"}"#; + let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(&Price(12345), trade.price_as_lazy_field().get_ref().unwrap()); + assert_eq!(Price(12345), trade.price_as_lazy_field().get().unwrap()); + assert_eq!(Price(12345), trade.price()); + } + + #[test] + fn should_parse_zero_price() { + let json = r#"{"p":"0"}"#; + let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(Price(0), trade.price()); + } + + #[test] + fn should_parse_large_price() { + let json = r#"{"p":"9876543210"}"#; + let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(Price(9876543210), trade.price()); + } + + #[test] + fn should_handle_multiple_trades() { + let trades_json = [ + r#"{"p":"100"}"#, + r#"{"p":"200"}"#, + r#"{"p":"300"}"#, + ]; + + for (i, json) in trades_json.iter().enumerate() { + let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(Price(((i + 1) * 100) as u64), trade.price()); + } + } + + #[test] + fn should_parse_price_with_leading_zeros() { + let json = r#"{"p":"000012345"}"#; + let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); + assert_eq!(Price(12345), trade.price()); + } } diff --git a/sje/tests/decoder.rs b/sje/tests/decoder.rs index 3d5f7b2..d9f436f 100644 --- a/sje/tests/decoder.rs +++ b/sje/tests/decoder.rs @@ -26,6 +26,32 @@ pub struct Trade { is_buyer_maker: bool, } +#[derive(Decoder)] +#[sje(object)] +#[allow(dead_code)] +pub struct AggTrade { + #[sje(rename = "e", len = 8)] + event_type: String, + #[sje(rename = "E", len = 13)] + event_time: u64, + #[sje(rename = "s")] + symbol: String, + #[sje(rename = "t", len = 10)] + trade_id: u64, + #[sje(rename = "p")] + price: String, + #[sje(rename = "q")] + quantity: String, + #[sje(rename = "b", len = 11)] + buyer_order_id: u64, + #[sje(rename = "a", len = 11)] + seller_order_id: u64, + #[sje(rename = "T", len = 13)] + transaction_time: u64, + #[sje(rename = "m")] + is_buyer_maker: bool, +} + #[derive(Decoder, Debug)] #[sje(object)] #[allow(dead_code)] @@ -40,8 +66,10 @@ struct ListenKeyExpired { #[cfg(test)] mod tests { - use crate::{ListenKeyExpiredDecoder, Trade, TradeDecoder}; + use crate::{AggTradeDecoder, ListenKeyExpiredDecoder, Trade, TradeDecoder}; use std::str::from_utf8_unchecked; + use sje_derive::Decoder; + #[test] fn should_decode_trade() { @@ -71,4 +99,52 @@ mod tests { assert_eq!(1743606297156, listen_key_expired.event_time()); assert_eq!("FdffIUjdfd343DtLMw2tKS87iL2HpYRniDWpkoxWCb4fwP2yzJXalBlBNnz471cE", listen_key_expired.listen_key()); } + + #[test] + fn should_decode_agg_trade() { + let agg_trade = AggTradeDecoder::decode(br#"{"e":"aggTrade","E":1705085312570,"s":"ETHUSDT","a":12345678,"p":"1850.00000000","q":"0.00500000","f":1000001,"l":1000001,"T":1705085312570,"m":false,"M":true}"#).unwrap(); + assert_eq!("aggTrade", agg_trade.event_type()); + assert_eq!("ETHUSDT", agg_trade.symbol()); + } + + #[test] + fn should_decode_kline_event() { + #[derive(Decoder)] + #[sje(object)] + #[allow(dead_code)] + struct Kline { + #[sje(rename = "e")] event_type: String, + #[sje(rename = "E")] event_time: u64, + #[sje(rename = "s")] symbol: String, + #[sje(rename = "k")] kline: String, + } + + let kline_msg = KlineDecoder::decode(br#"{"e":"kline","E":1705085313000,"s":"BNBUSDT","k":"simplified"}"#).unwrap(); + assert_eq!("kline", kline_msg.event_type()); + assert_eq!("BNBUSDT", kline_msg.symbol()); + assert_eq!("simplified", kline_msg.kline()); + } + + #[test] + fn should_decode_24hr_ticker() { + #[derive(Decoder)] + #[sje(object)] + #[allow(dead_code)] + struct Ticker24hr { + #[sje(rename = "e")] event_type: String, + #[sje(rename = "E")] event_time: u64, + #[sje(rename = "s")] symbol: String, + #[sje(rename = "c")] close_price: String, + #[sje(rename = "h")] high_price: String, + #[sje(rename = "l")] low_price: String, + } + + let ticker_msg = Ticker24hrDecoder::decode(br#"{"e":"24hrTicker","E":1705085314000,"s":"BNBUSDT","c":"350.50","h":"355.00","l":"345.00"}"#).unwrap(); + assert_eq!("24hrTicker", ticker_msg.event_type()); + assert_eq!("BNBUSDT", ticker_msg.symbol()); + assert_eq!("350.50", ticker_msg.close_price()); + assert_eq!("355.00", ticker_msg.high_price()); + assert_eq!("345.00", ticker_msg.low_price()); + } } + diff --git a/sje/tests/iter.rs b/sje/tests/iter.rs index 603a9ee..618a3ff 100644 --- a/sje/tests/iter.rs +++ b/sje/tests/iter.rs @@ -114,4 +114,47 @@ mod tests { assert_eq!(Some((Price(2.6468), Quantity(22540.8))), asks.next()); assert_eq!(None, asks.next()); } + + + #[test] + fn should_handle_empty_bids_or_asks() { + // Empty bids + let json = br#"{"e":"depthUpdate","b":[],"a":[["2.5","100"]]}"#; + let update = L2UpdateDecoder::decode(json).unwrap(); + assert_eq!(0, update.bids_count()); + assert_eq!(1, update.asks_count()); + + let mut bids = update.bids().into_iter(); + assert_eq!(None, bids.next()); + + let mut asks = update.asks().into_iter(); + assert_eq!(Some((Price(2.5), Quantity(100.0))), asks.next()); + assert_eq!(None, asks.next()); + + // Empty asks + let json = br#"{"e":"depthUpdate","b":[["1.5","50"]],"a":[]}"#; + let update = L2UpdateDecoder::decode(json).unwrap(); + assert_eq!(1, update.bids_count()); + assert_eq!(0, update.asks_count()); + + let mut bids = update.bids().into_iter(); + assert_eq!(Some((Price(1.5), Quantity(50.0))), bids.next()); + assert_eq!(None, bids.next()); + + let mut asks = update.asks().into_iter(); + assert_eq!(None, asks.next()); + } + + #[test] + fn should_convert_to_owned_with_empty_arrays() { + let json = br#"{"e":"depthUpdate","b":[],"a":[]}"#; + let update: L2Update = L2UpdateDecoder::decode(json).unwrap().into(); + + assert_eq!("depthUpdate", update.event_type); + assert!(update.bids.is_empty()); + assert!(update.asks.is_empty()); + } + + + } From d7f5c4e3b0f2094f2eed9e4c41d72016df80b80a Mon Sep 17 00:00:00 2001 From: BDaws04 Date: Wed, 28 Jan 2026 01:00:58 +0000 Subject: [PATCH 6/6] cargo fmt --- sje/tests/array_of_objects.rs | 1 - sje/tests/custom.rs | 6 +---- sje/tests/decoder.rs | 42 ++++++++++++++++++++++------------- sje/tests/iter.rs | 4 ---- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/sje/tests/array_of_objects.rs b/sje/tests/array_of_objects.rs index 99bb98b..9fdf1a5 100644 --- a/sje/tests/array_of_objects.rs +++ b/sje/tests/array_of_objects.rs @@ -86,4 +86,3 @@ mod tests { assert!(positions.next().is_none()); } } - diff --git a/sje/tests/custom.rs b/sje/tests/custom.rs index 4b1de36..baf1f0c 100644 --- a/sje/tests/custom.rs +++ b/sje/tests/custom.rs @@ -49,11 +49,7 @@ mod tests { #[test] fn should_handle_multiple_trades() { - let trades_json = [ - r#"{"p":"100"}"#, - r#"{"p":"200"}"#, - r#"{"p":"300"}"#, - ]; + let trades_json = [r#"{"p":"100"}"#, r#"{"p":"200"}"#, r#"{"p":"300"}"#]; for (i, json) in trades_json.iter().enumerate() { let trade = TradeDecoder::decode(json.as_bytes()).unwrap(); diff --git a/sje/tests/decoder.rs b/sje/tests/decoder.rs index d9f436f..e5a3c5d 100644 --- a/sje/tests/decoder.rs +++ b/sje/tests/decoder.rs @@ -67,9 +67,8 @@ struct ListenKeyExpired { #[cfg(test)] mod tests { use crate::{AggTradeDecoder, ListenKeyExpiredDecoder, Trade, TradeDecoder}; - use std::str::from_utf8_unchecked; use sje_derive::Decoder; - + use std::str::from_utf8_unchecked; #[test] fn should_decode_trade() { @@ -113,13 +112,18 @@ mod tests { #[sje(object)] #[allow(dead_code)] struct Kline { - #[sje(rename = "e")] event_type: String, - #[sje(rename = "E")] event_time: u64, - #[sje(rename = "s")] symbol: String, - #[sje(rename = "k")] kline: String, + #[sje(rename = "e")] + event_type: String, + #[sje(rename = "E")] + event_time: u64, + #[sje(rename = "s")] + symbol: String, + #[sje(rename = "k")] + kline: String, } - let kline_msg = KlineDecoder::decode(br#"{"e":"kline","E":1705085313000,"s":"BNBUSDT","k":"simplified"}"#).unwrap(); + let kline_msg = + KlineDecoder::decode(br#"{"e":"kline","E":1705085313000,"s":"BNBUSDT","k":"simplified"}"#).unwrap(); assert_eq!("kline", kline_msg.event_type()); assert_eq!("BNBUSDT", kline_msg.symbol()); assert_eq!("simplified", kline_msg.kline()); @@ -131,15 +135,24 @@ mod tests { #[sje(object)] #[allow(dead_code)] struct Ticker24hr { - #[sje(rename = "e")] event_type: String, - #[sje(rename = "E")] event_time: u64, - #[sje(rename = "s")] symbol: String, - #[sje(rename = "c")] close_price: String, - #[sje(rename = "h")] high_price: String, - #[sje(rename = "l")] low_price: String, + #[sje(rename = "e")] + event_type: String, + #[sje(rename = "E")] + event_time: u64, + #[sje(rename = "s")] + symbol: String, + #[sje(rename = "c")] + close_price: String, + #[sje(rename = "h")] + high_price: String, + #[sje(rename = "l")] + low_price: String, } - let ticker_msg = Ticker24hrDecoder::decode(br#"{"e":"24hrTicker","E":1705085314000,"s":"BNBUSDT","c":"350.50","h":"355.00","l":"345.00"}"#).unwrap(); + let ticker_msg = Ticker24hrDecoder::decode( + br#"{"e":"24hrTicker","E":1705085314000,"s":"BNBUSDT","c":"350.50","h":"355.00","l":"345.00"}"#, + ) + .unwrap(); assert_eq!("24hrTicker", ticker_msg.event_type()); assert_eq!("BNBUSDT", ticker_msg.symbol()); assert_eq!("350.50", ticker_msg.close_price()); @@ -147,4 +160,3 @@ mod tests { assert_eq!("345.00", ticker_msg.low_price()); } } - diff --git a/sje/tests/iter.rs b/sje/tests/iter.rs index 618a3ff..73b7980 100644 --- a/sje/tests/iter.rs +++ b/sje/tests/iter.rs @@ -115,7 +115,6 @@ mod tests { assert_eq!(None, asks.next()); } - #[test] fn should_handle_empty_bids_or_asks() { // Empty bids @@ -154,7 +153,4 @@ mod tests { assert!(update.bids.is_empty()); assert!(update.asks.is_empty()); } - - - }