Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

[Unreleased]: https://github.com/trussed-dev/serde-indexed/compare/0.1.1...HEAD

- Add support for `#[serde(with)]` ([#16][])
- Add support for `#[serde(skip)]` ([#14][])
- Add support for generics ([#11][])
- skip_serializing_if no longer incorrectly affects deserialization (fixes [#2][])
Expand All @@ -12,6 +13,7 @@
[#2]: https://github.com/trussed-dev/serde-indexed/issues/2
[#11]: https://github.com/trussed-dev/serde-indexed/pull/11
[#14]: https://github.com/trussed-dev/serde-indexed/pull/14
[#16]: https://github.com/trussed-dev/serde-indexed/pull/16
[#19]: https://github.com/trussed-dev/serde-indexed/pull/19

## [v0.1.1][] (2024-04-03)
Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ syn = "2.0"
heapless = { version = "0.7.16", default-features = false, features = ["serde"] }
hex-literal = "0.4.1"
serde = { version = "1" }
serde-byte-array = "0.1.2"
serde_bytes = { version = "0.11.12", default-features = false }
serde_bytes = { version = "0.11.15" }
serde_cbor = { version = "0.11.0" }
serde_test = "1.0.176"
141 changes: 127 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,65 @@ mod parse;
use parse::Skip;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{format_ident, quote};
use syn::{parse_macro_input, Lifetime, LifetimeParam, TypeParamBound};
use quote::{format_ident, quote, quote_spanned};
use syn::{
parse_macro_input, ImplGenerics, Lifetime, LifetimeParam, TypeGenerics, TypeParamBound,
WhereClause,
};

use crate::parse::Input;

fn serialize_fields(fields: &[parse::Field], offset: usize) -> Vec<proc_macro2::TokenStream> {
fn serialize_fields(
fields: &[parse::Field],
offset: usize,
impl_generics_serialize: ImplGenerics<'_>,
ty_generics_serialize: TypeGenerics<'_>,
ty_generics: &TypeGenerics<'_>,
where_clause: Option<&WhereClause>,
ident: &Ident,
) -> Vec<proc_macro2::TokenStream> {
fields
.iter()
.filter_map(|field| {
let index = field.index + offset;
let member = &field.member;
let serialize_member = match &field.serialize_with {
None => quote!(&self.#member),
Some(f) => {
let ty = &field.ty;
quote!({
struct __InternalSerdeIndexedSerializeWith #impl_generics_serialize {
value: &'__serde_indexed_lifetime #ty,
phantom: ::core::marker::PhantomData<#ident #ty_generics>,
}

impl #impl_generics_serialize serde::Serialize for __InternalSerdeIndexedSerializeWith #ty_generics_serialize #where_clause {
fn serialize<__S>(
&self,
__s: __S,
) -> ::core::result::Result<__S::Ok, __S::Error>
where
__S: serde::Serializer,
{
#f(self.value, __s)
}
}

&__InternalSerdeIndexedSerializeWith { value: &self.#member, phantom: ::core::marker::PhantomData::<#ident #ty_generics> }
})
}
};

// println!("field {:?} index {:?}", &field.label, field.index);
match &field.skip_serializing_if {
Skip::If(path) => Some(quote! {
if !#path(&self.#member) {
map.serialize_entry(&#index, &self.#member)?;
map.serialize_entry(&#index, #serialize_member)?;
}
}),
Skip::Always => None,
Skip::Never => Some(quote! {
map.serialize_entry(&#index, &self.#member)?;
map.serialize_entry(&#index, #serialize_member)?;
}),
}
})
Expand Down Expand Up @@ -85,7 +123,6 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as Input);
let ident = input.ident;
let num_fields = count_serialized_fields(&input.fields);
let serialize_fields = serialize_fields(&input.fields, input.attrs.offset);
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let mut generics_cl = input.generics.clone();
generics_cl.type_params_mut().for_each(|t| {
Expand All @@ -94,6 +131,26 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
});
let (impl_generics, _, _) = generics_cl.split_for_impl();

let mut generics_cl2 = generics_cl.clone();

generics_cl2
.params
.push(syn::GenericParam::Lifetime(LifetimeParam::new(
Lifetime::new("'__serde_indexed_lifetime", Span::call_site()),
)));

let (impl_generics_serialize, ty_generics_serialize, _) = generics_cl2.split_for_impl();

let serialize_fields = serialize_fields(
&input.fields,
input.attrs.offset,
impl_generics_serialize,
ty_generics_serialize,
&ty_generics,
where_clause,
&ident,
);

TokenStream::from(quote! {
#[automatically_derived]
impl #impl_generics serde::Serialize for #ident #ty_generics #where_clause {
Expand All @@ -119,7 +176,8 @@ fn none_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
.filter(|f| !f.skip_serializing_if.is_always())
.map(|field| {
let ident = format_ident!("{}", &field.label);
quote! {
let span = field.original_span;
quote_spanned! { span =>
let mut #ident = None;
}
})
Expand All @@ -132,11 +190,12 @@ fn unwrap_expected_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStre
.map(|field| {
let label = field.label.clone();
let ident = format_ident!("{}", &field.label);
let span = field.original_span;
match field.skip_serializing_if {
Skip::Never => quote! {
let #ident = #ident.ok_or_else(|| serde::de::Error::missing_field(#label))?;
},
Skip::If(_) => quote! {
Skip::If(_) => quote_spanned! { span =>
let #ident = #ident.unwrap_or_default();
},
Skip::Always => quote! {
Expand All @@ -147,20 +206,64 @@ fn unwrap_expected_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStre
.collect()
}

fn match_fields(fields: &[parse::Field], offset: usize) -> Vec<proc_macro2::TokenStream> {
fn match_fields(
fields: &[parse::Field],
offset: usize,
impl_generics_with_de: &ImplGenerics<'_>,
ty_generics: &TypeGenerics<'_>,
ty_generics_with_de: &TypeGenerics<'_>,
where_clause: Option<&WhereClause>,
struct_ident: &Ident,
) -> Vec<proc_macro2::TokenStream> {
fields
.iter()
.filter(|f| !f.skip_serializing_if.is_always())
.map(|field| {
let label = field.label.clone();
let ident = format_ident!("{}", &field.label);
let index = field.index + offset;
quote! {
let span = field.original_span;

let next_value = match &field.deserialize_with {
Some(f) => {
let ty = &field.ty;
quote_spanned!(span => {
struct __InternalSerdeIndexedDeserializeWith #impl_generics_with_de {
value: #ty,
phantom: ::core::marker::PhantomData<#struct_ident #ty_generics>,
lifetime: ::core::marker::PhantomData<&'de ()>,
}
impl #impl_generics_with_de serde::Deserialize<'de> for __InternalSerdeIndexedDeserializeWith #ty_generics_with_de #where_clause {
fn deserialize<__D>(
__deserializer: __D,
) -> Result<Self, __D::Error>
where
__D: serde::Deserializer<'de>,
{

Ok(__InternalSerdeIndexedDeserializeWith {
value: #f(__deserializer)?,
phantom: ::core::marker::PhantomData,
lifetime: ::core::marker::PhantomData,
})
}
}

let __InternalSerdeIndexedDeserializeWith { value, lifetime: _, phantom: _ } = map.next_value()?;
value
}
)
}
None => quote_spanned!(span => map.next_value()?),
};

quote_spanned!{ span =>
#index => {
if #ident.is_some() {
return Err(serde::de::Error::duplicate_field(#label));
}
#ident = Some(map.next_value()?);
let next_value = #next_value;
#ident = Some(next_value);
},
}
})
Expand All @@ -172,7 +275,8 @@ fn all_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
.iter()
.map(|field| {
let ident = format_ident!("{}", &field.label);
quote! {
let span = field.original_span;
quote_spanned! { span =>
#ident
}
})
Expand All @@ -185,7 +289,6 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let ident = input.ident;
let none_fields = none_fields(&input.fields);
let unwrap_expected_fields = unwrap_expected_fields(&input.fields);
let match_fields = match_fields(&input.fields, input.attrs.offset);
let all_fields = all_fields(&input.fields);

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Expand All @@ -212,7 +315,17 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
.push_value(TypeParamBound::Verbatim(quote!(serde::Deserialize<'de>)));
});

let (impl_generics_with_de, _, _) = generics_cl.split_for_impl();
let (impl_generics_with_de, ty_generics_with_de, _) = generics_cl.split_for_impl();

let match_fields = match_fields(
&input.fields,
input.attrs.offset,
&impl_generics_with_de,
&ty_generics,
&ty_generics_with_de,
where_clause,
&ident,
);

let the_loop = if !input.fields.is_empty() {
// NB: In the previous "none_fields", we use the actual struct's
Expand Down
54 changes: 53 additions & 1 deletion src/parse.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use proc_macro2::Span;
use syn::meta::ParseNestedMeta;
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Fields, Generics, Ident, LitInt, LitStr, Token};

pub struct Input {
Expand Down Expand Up @@ -36,6 +37,10 @@ pub struct Field {
pub member: syn::Member,
pub index: usize,
pub skip_serializing_if: Skip,
pub serialize_with: Option<syn::ExprPath>,
pub deserialize_with: Option<syn::ExprPath>,
pub ty: syn::Type,
pub original_span: Span,
}

fn parse_meta(attrs: &mut StructAttrs, meta: ParseNestedMeta) -> Result<()> {
Expand Down Expand Up @@ -115,9 +120,24 @@ fn fields_from_ast(
index += 1;

let mut skip_serializing_if = Skip::Never;
let mut deserialize_with = None;
let mut serialize_with = None;

for attr in &field.attrs {
if attr.path().is_ident("serde") {
attr.parse_nested_meta(|meta| {
let parse_value = |attribute: &mut Option<_>, attribute_name: &str| {
let litstr: LitStr = meta.value()?.parse()?;
let tokens = syn::parse_str(&litstr.value())?;
if attribute.is_some() {
return Err(
meta.error(format!("Multiple attributes for {attribute_name}"))
);
}
*attribute = Some(syn::parse2(tokens)?);
Ok(())
};

if meta.path.is_ident("skip_serializing_if") {
let litstr: LitStr = meta.value()?.parse()?;
let tokens = syn::parse_str(&litstr.value())?;
Expand Down Expand Up @@ -145,9 +165,37 @@ fn fields_from_ast(
.error("Multiple attributes for skip_serializing_if or skip"));
}
skip_serializing_if = Skip::Always;
Ok(())
} else if meta.path.is_ident("deserialize_with") {
parse_value(&mut deserialize_with, "deserialize_with")
} else if meta.path.is_ident("serialize_with") {
parse_value(&mut serialize_with, "serialize_with")
} else if meta.path.is_ident("with") {
let litstr: LitStr = meta.value()?.parse()?;
if serialize_with.is_some() {
return Err(meta.error(
"Using `with` when `serialize_with` is already used"
.to_string(),
));
}
if deserialize_with.is_some() {
return Err(meta.error(
"Using `with` when `deserialize_with` is already used"
.to_string(),
));
}

let serialize_tokens =
syn::parse_str(&format!("{}::serialize", litstr.value()))?;
let deserialize_tokens =
syn::parse_str(&format!("{}::deserialize", litstr.value()))?;

serialize_with = Some(syn::parse2(serialize_tokens)?);
deserialize_with = Some(syn::parse2(deserialize_tokens)?);

Ok(())
} else {
Err(meta.error("Unkown field attribute"))
return Err(meta.error("Unkown field attribute"));
}
})?;
}
Expand All @@ -169,7 +217,11 @@ fn fields_from_ast(
},
index: current_index,
// TODO: make this... more concise? handle errors? the thing with the spans?
ty: field.ty.clone(),
skip_serializing_if,
serialize_with,
deserialize_with,
original_span: field.span(),
})
})
.collect()
Expand Down
Loading