diff options
Diffstat (limited to 'src/traits.rs')
-rw-r--r-- | src/traits.rs | 772 |
1 files changed, 772 insertions, 0 deletions
diff --git a/src/traits.rs b/src/traits.rs new file mode 100644 index 0000000..0f4e0a9 --- /dev/null +++ b/src/traits.rs @@ -0,0 +1,772 @@ +#![allow(unused_imports)] +use proc_macro2::{Ident, Span, TokenStream, TokenTree}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::{ + parse::{Parse, ParseStream, Parser}, + punctuated::Punctuated, + spanned::Spanned, + Result, *, +}; + +macro_rules! bail { + ($msg:expr $(,)?) => { + return Err(Error::new(Span::call_site(), &$msg[..])) + }; + + ( $msg:expr => $span_to_blame:expr $(,)? ) => { + return Err(Error::new_spanned(&$span_to_blame, $msg)) + }; +} + +pub trait Derivable { + fn ident(input: &DeriveInput) -> Result<syn::Path>; + fn implies_trait() -> Option<TokenStream> { + None + } + fn asserts(_input: &DeriveInput) -> Result<TokenStream> { + Ok(quote!()) + } + fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> { + Ok(()) + } + fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { + Ok((quote!(), quote!())) + } + fn requires_where_clause() -> bool { + true + } +} + +pub struct Pod; + +impl Derivable for Pod { + fn ident(_: &DeriveInput) -> Result<syn::Path> { + Ok(syn::parse_quote!(::bytemuck::Pod)) + } + + fn asserts(input: &DeriveInput) -> Result<TokenStream> { + let repr = get_repr(&input.attrs)?; + + let completly_packed = + repr.packed == Some(1) || repr.repr == Repr::Transparent; + + if !completly_packed && !input.generics.params.is_empty() { + bail!("\ + Pod requires cannot be derived for non-packed types containing \ + generic parameters because the padding requirements can't be verified \ + for generic non-packed structs\ + " => input.generics.params.first().unwrap()); + } + + match &input.data { + Data::Struct(_) => { + let assert_no_padding = if !completly_packed { + Some(generate_assert_no_padding(input)?) + } else { + None + }; + let assert_fields_are_pod = + generate_fields_are_trait(input, Self::ident(input)?)?; + + Ok(quote!( + #assert_no_padding + #assert_fields_are_pod + )) + } + Data::Enum(_) => bail!("Deriving Pod is not supported for enums"), + Data::Union(_) => bail!("Deriving Pod is not supported for unions"), + } + } + + fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> { + let repr = get_repr(attributes)?; + match repr.repr { + Repr::C => Ok(()), + Repr::Transparent => Ok(()), + _ => { + bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]") + } + } + } +} + +pub struct AnyBitPattern; + +impl Derivable for AnyBitPattern { + fn ident(_: &DeriveInput) -> Result<syn::Path> { + Ok(syn::parse_quote!(::bytemuck::AnyBitPattern)) + } + + fn implies_trait() -> Option<TokenStream> { + Some(quote!(::bytemuck::Zeroable)) + } + + fn asserts(input: &DeriveInput) -> Result<TokenStream> { + match &input.data { + Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern` + Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?), + Data::Enum(_) => { + bail!("Deriving AnyBitPattern is not supported for enums") + } + } + } +} + +pub struct Zeroable; + +impl Derivable for Zeroable { + fn ident(_: &DeriveInput) -> Result<syn::Path> { + Ok(syn::parse_quote!(::bytemuck::Zeroable)) + } + + fn asserts(input: &DeriveInput) -> Result<TokenStream> { + match &input.data { + Data::Union(_) => Ok(quote!()), // unions are always `Zeroable` + Data::Struct(_) => generate_fields_are_trait(input, Self::ident(input)?), + Data::Enum(_) => bail!("Deriving Zeroable is not supported for enums"), + } + } +} + +pub struct NoUninit; + +impl Derivable for NoUninit { + fn ident(_: &DeriveInput) -> Result<syn::Path> { + Ok(syn::parse_quote!(::bytemuck::NoUninit)) + } + + fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { + let repr = get_repr(attributes)?; + match ty { + Data::Struct(_) => match repr.repr { + Repr::C | Repr::Transparent => Ok(()), + _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"), + }, + Data::Enum(_) => if repr.repr.is_integer() { + Ok(()) + } else { + bail!("NoUninit requires the enum to be an explicit #[repr(Int)]") + }, + Data::Union(_) => bail!("NoUninit can only be derived on enums and structs") + } + } + + fn asserts(input: &DeriveInput) -> Result<TokenStream> { + if !input.generics.params.is_empty() { + bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs"); + } + + match &input.data { + Data::Struct(DataStruct { .. }) => { + let assert_no_padding = generate_assert_no_padding(&input)?; + let assert_fields_are_no_padding = + generate_fields_are_trait(&input, Self::ident(input)?)?; + + Ok(quote!( + #assert_no_padding + #assert_fields_are_no_padding + )) + } + Data::Enum(DataEnum { variants, .. }) => { + if variants.iter().any(|variant| !variant.fields.is_empty()) { + bail!("Only fieldless enums are supported for NoUninit") + } else { + Ok(quote!()) + } + } + Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */ + } + } + + fn trait_impl(_input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { + Ok((quote!(), quote!())) + } +} + +pub struct CheckedBitPattern; + +impl Derivable for CheckedBitPattern { + fn ident(_: &DeriveInput) -> Result<syn::Path> { + Ok(syn::parse_quote!(::bytemuck::CheckedBitPattern)) + } + + fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> { + let repr = get_repr(attributes)?; + match ty { + Data::Struct(_) => match repr.repr { + Repr::C | Repr::Transparent => Ok(()), + _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"), + }, + Data::Enum(_) => if repr.repr.is_integer() { + Ok(()) + } else { + bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]") + }, + Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs") + } + } + + fn asserts(input: &DeriveInput) -> Result<TokenStream> { + if !input.generics.params.is_empty() { + bail!("CheckedBitPattern cannot be derived for structs containing generic parameters"); + } + + match &input.data { + Data::Struct(DataStruct { .. }) => { + let assert_fields_are_maybe_pod = + generate_fields_are_trait(&input, Self::ident(input)?)?; + + Ok(assert_fields_are_maybe_pod) + } + Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed + * OK by NoUninit */ + Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ + } + } + + fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { + match &input.data { + Data::Struct(DataStruct { fields, .. }) => { + generate_checked_bit_pattern_struct(&input.ident, fields, &input.attrs) + } + Data::Enum(_) => generate_checked_bit_pattern_enum(input), + Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */ + } + } +} + +pub struct TransparentWrapper; + +impl TransparentWrapper { + fn get_wrapper_type( + attributes: &[Attribute], fields: &Fields, + ) -> Option<TokenStream> { + let transparent_param = get_simple_attr(attributes, "transparent"); + transparent_param.map(|ident| ident.to_token_stream()).or_else(|| { + let mut types = get_field_types(&fields); + let first_type = types.next(); + if let Some(_) = types.next() { + // can't guess param type if there is more than one field + return None; + } else { + first_type.map(|ty| ty.to_token_stream()) + } + }) + } +} + +impl Derivable for TransparentWrapper { + fn ident(input: &DeriveInput) -> Result<syn::Path> { + let fields = get_struct_fields(input)?; + + let ty = match Self::get_wrapper_type(&input.attrs, &fields) { + Some(ty) => ty, + None => bail!( + "\ + when deriving TransparentWrapper for a struct with more than one field \ + you need to specify the transparent field using #[transparent(T)]\ + " + ), + }; + + Ok(syn::parse_quote!(::bytemuck::TransparentWrapper<#ty>)) + } + + fn asserts(input: &DeriveInput) -> Result<TokenStream> { + let fields = get_struct_fields(input)?; + let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) { + Some(wrapped_type) => wrapped_type.to_string(), + None => unreachable!(), /* other code will already reject this derive */ + }; + let mut wrapped_fields = fields + .iter() + .filter(|field| field.ty.to_token_stream().to_string() == wrapped_type); + if let None = wrapped_fields.next() { + bail!("TransparentWrapper must have one field of the wrapped type"); + }; + if let Some(_) = wrapped_fields.next() { + bail!("TransparentWrapper can only have one field of the wrapped type") + } else { + Ok(quote!()) + } + } + + fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> { + let repr = get_repr(attributes)?; + + match repr.repr { + Repr::Transparent => Ok(()), + _ => { + bail!( + "TransparentWrapper requires the struct to be #[repr(transparent)]" + ) + } + } + } + + fn requires_where_clause() -> bool { + false + } +} + +pub struct Contiguous; + +impl Derivable for Contiguous { + fn ident(_: &DeriveInput) -> Result<syn::Path> { + Ok(syn::parse_quote!(::bytemuck::Contiguous)) + } + + fn trait_impl(input: &DeriveInput) -> Result<(TokenStream, TokenStream)> { + let repr = get_repr(&input.attrs)?; + + let integer_ty = if let Some(integer_ty) = repr.repr.as_integer_type() { + integer_ty + } else { + bail!("Contiguous requires the enum to be #[repr(Int)]"); + }; + + let variants = get_enum_variants(input)?; + let mut variants_with_discriminator = + VariantDiscriminantIterator::new(variants); + + let (min, max, count) = variants_with_discriminator.try_fold( + (i64::max_value(), i64::min_value(), 0), + |(min, max, count), res| { + let discriminator = res?; + Ok::<_, Error>(( + i64::min(min, discriminator), + i64::max(max, discriminator), + count + 1, + )) + }, + )?; + + if max - min != count - 1 { + bail! { + "Contiguous requires the enum discriminants to be contiguous", + } + } + + let min_lit = LitInt::new(&format!("{}", min), input.span()); + let max_lit = LitInt::new(&format!("{}", max), input.span()); + + Ok(( + quote!(), + quote! { + type Int = #integer_ty; + const MIN_VALUE: #integer_ty = #min_lit; + const MAX_VALUE: #integer_ty = #max_lit; + }, + )) + } +} + +fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> { + if let Data::Struct(DataStruct { fields, .. }) = &input.data { + Ok(fields) + } else { + bail!("deriving this trait is only supported for structs") + } +} + +fn get_fields(input: &DeriveInput) -> Result<Fields> { + match &input.data { + Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()), + Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())), + Data::Enum(_) => bail!("deriving this trait is not supported for enums"), + } +} + +fn get_enum_variants<'a>( + input: &'a DeriveInput, +) -> Result<impl Iterator<Item = &'a Variant> + 'a> { + if let Data::Enum(DataEnum { variants, .. }) = &input.data { + Ok(variants.iter()) + } else { + bail!("deriving this trait is only supported for enums") + } +} + +fn get_field_types<'a>( + fields: &'a Fields, +) -> impl Iterator<Item = &'a Type> + 'a { + fields.iter().map(|field| &field.ty) +} + +fn generate_checked_bit_pattern_struct( + input_ident: &Ident, fields: &Fields, attrs: &[Attribute], +) -> Result<(TokenStream, TokenStream)> { + let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span()); + + let repr = get_repr(attrs)?; + + let field_names = fields + .iter() + .enumerate() + .map(|(i, field)| { + field.ident.clone().unwrap_or_else(|| { + Ident::new(&format!("field{}", i), input_ident.span()) + }) + }) + .collect::<Vec<_>>(); + let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>(); + + let field_name = &field_names[..]; + let field_ty = &field_tys[..]; + + let derive_dbg = + quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]); + + Ok(( + quote! { + #repr + #[derive(Clone, Copy, ::bytemuck::AnyBitPattern)] + #derive_dbg + pub struct #bits_ty { + #(#field_name: <#field_ty as ::bytemuck::CheckedBitPattern>::Bits,)* + } + }, + quote! { + type Bits = #bits_ty; + + #[inline] + #[allow(clippy::double_comparisons)] + fn is_valid_bit_pattern(bits: &#bits_ty) -> bool { + #(<#field_ty as ::bytemuck::CheckedBitPattern>::is_valid_bit_pattern(&bits.#field_name) && )* true + } + }, + )) +} + +fn generate_checked_bit_pattern_enum( + input: &DeriveInput, +) -> Result<(TokenStream, TokenStream)> { + let span = input.span(); + let mut variants_with_discriminant = + VariantDiscriminantIterator::new(get_enum_variants(input)?); + + let (min, max, count) = variants_with_discriminant.try_fold( + (i64::max_value(), i64::min_value(), 0), + |(min, max, count), res| { + let discriminant = res?; + Ok::<_, Error>(( + i64::min(min, discriminant), + i64::max(max, discriminant), + count + 1, + )) + }, + )?; + + let check = if count == 0 { + quote_spanned!(span => false) + } else if max - min == count - 1 { + // contiguous range + let min_lit = LitInt::new(&format!("{}", min), span); + let max_lit = LitInt::new(&format!("{}", max), span); + + quote!(*bits >= #min_lit && *bits <= #max_lit) + } else { + // not contiguous range, check for each + let variant_lits = + VariantDiscriminantIterator::new(get_enum_variants(input)?) + .map(|res| { + let variant = res?; + Ok(LitInt::new(&format!("{}", variant), span)) + }) + .collect::<Result<Vec<_>>>()?; + + // count is at least 1 + let first = &variant_lits[0]; + let rest = &variant_lits[1..]; + + quote!(matches!(*bits, #first #(| #rest )*)) + }; + + let repr = get_repr(&input.attrs)?; + let integer_ty = repr.repr.as_integer_type().unwrap(); // should be checked in attr check already + Ok(( + quote!(), + quote! { + type Bits = #integer_ty; + + #[inline] + #[allow(clippy::double_comparisons)] + fn is_valid_bit_pattern(bits: &Self::Bits) -> bool { + #check + } + }, + )) +} + +/// Check that a struct has no padding by asserting that the size of the struct +/// is equal to the sum of the size of it's fields +fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> { + let struct_type = &input.ident; + let span = input.ident.span(); + let fields = get_fields(input)?; + + let mut field_types = get_field_types(&fields); + let size_sum = if let Some(first) = field_types.next() { + let size_first = quote_spanned!(span => ::core::mem::size_of::<#first>()); + let size_rest = + quote_spanned!(span => #( + ::core::mem::size_of::<#field_types>() )*); + + quote_spanned!(span => #size_first#size_rest) + } else { + quote_spanned!(span => 0) + }; + + Ok(quote_spanned! {span => const _: fn() = || { + struct TypeWithoutPadding([u8; #size_sum]); + let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>; + };}) +} + +/// Check that all fields implement a given trait +fn generate_fields_are_trait( + input: &DeriveInput, trait_: syn::Path, +) -> Result<TokenStream> { + let (impl_generics, _ty_generics, where_clause) = + input.generics.split_for_impl(); + let fields = get_fields(input)?; + let span = input.span(); + let field_types = get_field_types(&fields); + Ok(quote_spanned! {span => #(const _: fn() = || { + #[allow(clippy::missing_const_for_fn)] + fn check #impl_generics () #where_clause { + fn assert_impl<T: #trait_>() {} + assert_impl::<#field_types>(); + } + };)* + }) +} + +fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> { + match tokens.into_iter().next() { + Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()), + Some(TokenTree::Ident(ident)) => Some(ident), + _ => None, + } +} + +/// get a simple #[foo(bar)] attribute, returning "bar" +fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> { + for attr in attributes { + if let (AttrStyle::Outer, Some(outer_ident), Some(inner_ident)) = ( + &attr.style, + attr.path.get_ident(), + get_ident_from_stream(attr.tokens.clone()), + ) { + if outer_ident.to_string() == attr_name { + return Some(inner_ident); + } + } + } + + None +} + +fn get_repr(attributes: &[Attribute]) -> Result<Representation> { + attributes + .iter() + .filter_map(|attr| { + if attr.path.is_ident("repr") { + Some(attr.parse_args::<Representation>()) + } else { + None + } + }) + .try_fold(Representation::default(), |a, b| { + let b = b?; + Ok(Representation { + repr: match (a.repr, b.repr) { + (a, Repr::Rust) => a, + (Repr::Rust, b) => b, + _ => bail!("conflicting representation hints"), + }, + packed: match (a.packed, b.packed) { + (a, None) => a, + (None, b) => b, + _ => bail!("conflicting representation hints"), + }, + align: match (a.align, b.align) { + (a, None) => a, + (None, b) => b, + _ => bail!("conflicting representation hints"), + }, + }) + }) +} + +mk_repr! { + U8 => u8, + I8 => i8, + U16 => u16, + I16 => i16, + U32 => u32, + I32 => i32, + U64 => u64, + I64 => i64, + I128 => i128, + U128 => u128, +} +// where +macro_rules! mk_repr {( + $( + $Xn:ident => $xn:ident + ),* $(,)? +) => ( + #[derive(Clone, Copy, PartialEq)] + enum Repr { + Rust, + C, + Transparent, + $($Xn),* + } + + impl Repr { + fn is_integer(self) -> bool { + match self { + Repr::Rust | Repr::C | Repr::Transparent => false, + _ => true, + } + } + + fn as_integer_type(self) -> Option<TokenStream> { + match self { + Repr::Rust | Repr::C | Repr::Transparent => None, + $( + Repr::$Xn => Some(quote! { ::core::primitive::$xn }), + )* + } + } + } + + #[derive(Clone, Copy)] + struct Representation { + packed: Option<u32>, + align: Option<u32>, + repr: Repr, + } + + impl Default for Representation { + fn default() -> Self { + Self { packed: None, align: None, repr: Repr::Rust } + } + } + + impl Parse for Representation { + fn parse(input: ParseStream<'_>) -> Result<Representation> { + let mut ret = Representation::default(); + while !input.is_empty() { + let keyword = input.parse::<Ident>()?; + // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`) + let keyword_str = keyword.to_string(); + let new_repr = match keyword_str.as_str() { + "C" => Repr::C, + "transparent" => Repr::Transparent, + "packed" => { + ret.packed = Some(if input.peek(token::Paren) { + let contents; parenthesized!(contents in input); + LitInt::base10_parse::<u32>(&contents.parse()?)? + } else { + 1 + }); + let _: Option<Token![,]> = input.parse()?; + continue; + }, + "align" => { + let contents; parenthesized!(contents in input); + ret.align = Some(LitInt::base10_parse::<u32>(&contents.parse()?)?); + let _: Option<Token![,]> = input.parse()?; + continue; + }, + $( + stringify!($xn) => Repr::$Xn, + )* + _ => return Err(input.error("unrecognized representation hint")) + }; + if ::core::mem::replace(&mut ret.repr, new_repr) != Repr::Rust { + input.error("duplicate representation hint"); + } + let _: Option<Token![,]> = input.parse()?; + } + Ok(ret) + } + } + + impl ToTokens for Representation { + fn to_tokens(&self, tokens: &mut TokenStream) { + let repr = match self.repr { + Repr::Rust => None, + Repr::C => Some(quote!(C)), + Repr::Transparent => Some(quote!(transparent)), + $( + Repr::$Xn => Some(quote!($xn)), + )* + }; + let packed = self.packed.map(|p| quote!(packed(#p))); + let comma = if packed.is_some() && repr.is_some() { + Some(quote!(,)) + } else { + None + }; + tokens.extend(quote!( + #[repr( #repr #comma #packed )] + )); + } + } +)} +use mk_repr; + +struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> { + inner: I, + last_value: i64, +} + +impl<'a, I: Iterator<Item = &'a Variant> + 'a> + VariantDiscriminantIterator<'a, I> +{ + fn new(inner: I) -> Self { + VariantDiscriminantIterator { inner, last_value: -1 } + } +} + +impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator + for VariantDiscriminantIterator<'a, I> +{ + type Item = Result<i64>; + + fn next(&mut self) -> Option<Self::Item> { + let variant = self.inner.next()?; + if !variant.fields.is_empty() { + return Some(Err(Error::new_spanned( + &variant.fields, + "Only fieldless enums are supported", + ))); + } + + if let Some((_, discriminant)) = &variant.discriminant { + let discriminant_value = match parse_int_expr(discriminant) { + Ok(value) => value, + Err(e) => return Some(Err(e)), + }; + self.last_value = discriminant_value; + } else { + self.last_value += 1; + } + + Some(Ok(self.last_value)) + } +} + +fn parse_int_expr(expr: &Expr) -> Result<i64> { + match expr { + Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => { + parse_int_expr(expr).map(|int| -int) + } + Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(), + Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()), + _ => bail!("Not an integer expression"), + } +} |