diff options
author | Casper <casperneo@uchicago.edu> | 2020-10-19 11:40:03 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-19 11:40:03 -0700 |
commit | 9fa1d27059a69149856c6e003da8c9723fec7506 (patch) | |
tree | da75d0032019f07250ebbd9ae1db0c9d82d17667 /src/idl_gen_rust.cpp | |
parent | a402b3abaea6490d8aad1fe90d8bafe2a6396fe8 (diff) | |
download | flatbuffers-9fa1d27059a69149856c6e003da8c9723fec7506.tar.gz flatbuffers-9fa1d27059a69149856c6e003da8c9723fec7506.tar.bz2 flatbuffers-9fa1d27059a69149856c6e003da8c9723fec7506.zip |
Rework enums in rust. (#6098)
* Rework enums in rust.
They're now a unit struct, rather than an enum. This is a
backwards incompatible change but the previous version had UB
and was also backwards incompatible so...
* Update and test sample rust flatbuffers
* Use bitflags crate to properly support rust enums.
Previously, the bitflags attribute was just ignored. This is a breaking change
as the bitflgs API is not like a normal rust enum (duh).
* variant_name() -> Option<_>
* repr transparent
* Reexport bitflags from flatbuffers
* Make bitflags constants CamelCase, matching normal enums
* Deprecate c-style associated enum constants
Co-authored-by: Casper Neo <cneo@google.com>
Diffstat (limited to 'src/idl_gen_rust.cpp')
-rw-r--r-- | src/idl_gen_rust.cpp | 278 |
1 files changed, 169 insertions, 109 deletions
diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp index 4e157827..ad46bac6 100644 --- a/src/idl_gen_rust.cpp +++ b/src/idl_gen_rust.cpp @@ -175,6 +175,14 @@ std::string AddUnwrapIfRequired(std::string s, bool required) { } } +bool IsBitFlagsEnum(const EnumDef &enum_def) { + return enum_def.attributes.Lookup("bit_flags") != nullptr; +} +bool IsBitFlagsEnum(const FieldDef &field) { + EnumDef* ed = field.value.type.enum_def; + return ed && IsBitFlagsEnum(*ed); +} + namespace rust { class RustGenerator : public BaseGenerator { @@ -215,7 +223,10 @@ class RustGenerator : public BaseGenerator { // the future. as a result, we proactively block these out as reserved // words. "follow", "push", "size", "alignment", "to_little_endian", - "from_little_endian", nullptr + "from_little_endian", nullptr, + + // used by Enum constants + "ENUM_MAX", "ENUM_MIN", "ENUM_VALUES", }; for (auto kw = keywords; *kw; kw++) keywords_.insert(*kw); } @@ -508,11 +519,28 @@ class RustGenerator : public BaseGenerator { } } - std::string GetEnumValUse(const EnumDef &enum_def, + std::string GetEnumValue(const EnumDef &enum_def, const EnumVal &enum_val) const { return Name(enum_def) + "::" + Name(enum_val); } + // 1 suffix since old C++ can't figure out the overload. + void ForAllEnumValues1(const EnumDef &enum_def, + std::function<void(const EnumVal&)> cb) { + for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) { + const auto &ev = **it; + code_.SetValue("VARIANT", Name(ev)); + code_.SetValue("VALUE", enum_def.ToString(ev)); + cb(ev); + } + } + void ForAllEnumValues(const EnumDef &enum_def, std::function<void()> cb) { + std::function<void(const EnumVal&)> wrapped = [&](const EnumVal& unused) { + (void) unused; + cb(); + }; + ForAllEnumValues1(enum_def, wrapped); + } // Generate an enum declaration, // an enum string lookup table, // an enum match function, @@ -520,132 +548,162 @@ class RustGenerator : public BaseGenerator { void GenEnum(const EnumDef &enum_def) { code_.SetValue("ENUM_NAME", Name(enum_def)); code_.SetValue("BASE_TYPE", GetEnumTypeForDecl(enum_def.underlying_type)); - - GenComment(enum_def.doc_comment); - code_ += "#[allow(non_camel_case_types)]"; - code_ += "#[repr({{BASE_TYPE}})]"; - code_ += - "#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]"; - code_ += "pub enum " + Name(enum_def) + " {"; - - for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) { - const auto &ev = **it; - - GenComment(ev.doc_comment, " "); - code_.SetValue("KEY", Name(ev)); - code_.SetValue("VALUE", enum_def.ToString(ev)); - code_ += " {{KEY}} = {{VALUE}},"; - } + code_.SetValue("ENUM_NAME_SNAKE", MakeSnakeCase(Name(enum_def))); + code_.SetValue("ENUM_NAME_CAPS", MakeUpper(MakeSnakeCase(Name(enum_def)))); const EnumVal *minv = enum_def.MinValue(); const EnumVal *maxv = enum_def.MaxValue(); FLATBUFFERS_ASSERT(minv && maxv); + code_.SetValue("ENUM_MIN_BASE_VALUE", enum_def.ToString(*minv)); + code_.SetValue("ENUM_MAX_BASE_VALUE", enum_def.ToString(*maxv)); + + if (IsBitFlagsEnum(enum_def)) { + // Defer to the convenient and canonical bitflags crate. We declare it in a + // module to #allow camel case constants in a smaller scope. This matches + // Flatbuffers c-modeled enums where variants are associated constants but + // in camel case. + code_ += "#[allow(non_upper_case_globals)]"; + code_ += "mod bitflags_{{ENUM_NAME_SNAKE}} {"; + code_ += " flatbuffers::bitflags::bitflags! {"; + GenComment(enum_def.doc_comment, " "); + code_ += " pub struct {{ENUM_NAME}}: {{BASE_TYPE}} {"; + ForAllEnumValues1(enum_def, [&](const EnumVal &ev){ + this->GenComment(ev.doc_comment, " "); + code_ += " const {{VARIANT}} = {{VALUE}};"; + }); + code_ += " }"; + code_ += " }"; + code_ += "}"; + code_ += "pub use self::bitflags_{{ENUM_NAME_SNAKE}}::{{ENUM_NAME}};"; + code_ += ""; + + // Generate Follow and Push so we can serialize and stuff. + code_ += "impl<'a> flatbuffers::Follow<'a> for {{ENUM_NAME}} {"; + code_ += " type Inner = Self;"; + code_ += " #[inline]"; + code_ += " fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {"; + code_ += " let bits = flatbuffers::read_scalar_at::<{{BASE_TYPE}}>(buf, loc);"; + code_ += " unsafe { Self::from_bits_unchecked(bits) }"; + code_ += " }"; + code_ += "}"; + code_ += ""; + code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {"; + code_ += " type Output = {{ENUM_NAME}};"; + code_ += " #[inline]"; + code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {"; + code_ += " flatbuffers::emplace_scalar::<{{BASE_TYPE}}>" + "(dst, self.bits());"; + code_ += " }"; + code_ += "}"; + code_ += ""; + code_ += "impl flatbuffers::EndianScalar for {{ENUM_NAME}} {"; + code_ += " #[inline]"; + code_ += " fn to_little_endian(self) -> Self {"; + code_ += " let bits = {{BASE_TYPE}}::to_le(self.bits());"; + code_ += " unsafe { Self::from_bits_unchecked(bits) }"; + code_ += " }"; + code_ += " #[inline]"; + code_ += " fn from_little_endian(self) -> Self {"; + code_ += " let bits = {{BASE_TYPE}}::from_le(self.bits());"; + code_ += " unsafe { Self::from_bits_unchecked(bits) }"; + code_ += " }"; + code_ += "}"; + code_ += ""; + return; + } + // Deprecated associated constants; + code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants" + " instead. This will no longer be generated in 2021.\")]"; + code_ += "pub const ENUM_MIN_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}}" + " = {{ENUM_MIN_BASE_VALUE}};"; + code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants" + " instead. This will no longer be generated in 2021.\")]"; + code_ += "pub const ENUM_MAX_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}}" + " = {{ENUM_MAX_BASE_VALUE}};"; + auto num_fields = NumToString(enum_def.size()); + code_ += "#[deprecated(since = \"1.13\", note = \"Use associated constants" + " instead. This will no longer be generated in 2021.\")]"; + code_ += "#[allow(non_camel_case_types)]"; + code_ += "pub const ENUM_VALUES_{{ENUM_NAME_CAPS}}: [{{ENUM_NAME}}; " + + num_fields + "] = ["; + ForAllEnumValues1(enum_def, [&](const EnumVal &ev){ + code_ += " " + GetEnumValue(enum_def, ev) + ","; + }); + code_ += "];"; code_ += ""; - code_ += "}"; + + GenComment(enum_def.doc_comment); + code_ += + "#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]"; + code_ += "#[repr(transparent)]"; + code_ += "pub struct {{ENUM_NAME}}(pub {{BASE_TYPE}});"; + code_ += "#[allow(non_upper_case_globals)]"; + code_ += "impl {{ENUM_NAME}} {"; + ForAllEnumValues1(enum_def, [&](const EnumVal &ev){ + this->GenComment(ev.doc_comment, " "); + code_ += " pub const {{VARIANT}}: Self = Self({{VALUE}});"; + }); code_ += ""; + // Generate Associated constants + code_ += " pub const ENUM_MIN: {{BASE_TYPE}} = {{ENUM_MIN_BASE_VALUE}};"; + code_ += " pub const ENUM_MAX: {{BASE_TYPE}} = {{ENUM_MAX_BASE_VALUE}};"; + code_ += " pub const ENUM_VALUES: &'static [Self] = &["; + ForAllEnumValues(enum_def, [&](){ + code_ += " Self::{{VARIANT}},"; + }); + code_ += " ];"; + code_ += " /// Returns the variant's name or \"\" if unknown."; + code_ += " pub fn variant_name(self) -> Option<&'static str> {"; + code_ += " match self {"; + ForAllEnumValues(enum_def, [&](){ + code_ += " Self::{{VARIANT}} => Some(\"{{VARIANT}}\"),"; + }); + code_ += " _ => None,"; + code_ += " }"; + code_ += " }"; + code_ += "}"; - code_.SetValue("ENUM_NAME", Name(enum_def)); - code_.SetValue("ENUM_NAME_SNAKE", MakeSnakeCase(Name(enum_def))); - code_.SetValue("ENUM_NAME_CAPS", MakeUpper(MakeSnakeCase(Name(enum_def)))); - code_.SetValue("ENUM_MIN_BASE_VALUE", enum_def.ToString(*minv)); - code_.SetValue("ENUM_MAX_BASE_VALUE", enum_def.ToString(*maxv)); + // Generate Debug. Unknown variants are printed like "<UNKNOWN 42>". + code_ += "impl std::fmt::Debug for {{ENUM_NAME}} {"; + code_ += " fn fmt(&self, f: &mut std::fmt::Formatter) ->" + " std::fmt::Result {"; + code_ += " if let Some(name) = self.variant_name() {"; + code_ += " f.write_str(name)"; + code_ += " } else {"; + code_ += " f.write_fmt(format_args!(\"<UNKNOWN {:?}>\", self.0))"; + code_ += " }"; + code_ += " }"; + code_ += "}"; - // Generate enum constants, and impls for Follow, EndianScalar, and Push. - code_ += "pub const ENUM_MIN_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}} = \\"; - code_ += "{{ENUM_MIN_BASE_VALUE}};"; - code_ += "pub const ENUM_MAX_{{ENUM_NAME_CAPS}}: {{BASE_TYPE}} = \\"; - code_ += "{{ENUM_MAX_BASE_VALUE}};"; - code_ += ""; + // Generate Follow and Push so we can serialize and stuff. code_ += "impl<'a> flatbuffers::Follow<'a> for {{ENUM_NAME}} {"; code_ += " type Inner = Self;"; code_ += " #[inline]"; code_ += " fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {"; - code_ += " flatbuffers::read_scalar_at::<Self>(buf, loc)"; + code_ += " Self(flatbuffers::read_scalar_at::<{{BASE_TYPE}}>(buf, loc))"; code_ += " }"; code_ += "}"; code_ += ""; + code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {"; + code_ += " type Output = {{ENUM_NAME}};"; + code_ += " #[inline]"; + code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {"; + code_ += " flatbuffers::emplace_scalar::<{{BASE_TYPE}}>" + "(dst, self.0);"; + code_ += " }"; + code_ += "}"; + code_ += ""; code_ += "impl flatbuffers::EndianScalar for {{ENUM_NAME}} {"; code_ += " #[inline]"; code_ += " fn to_little_endian(self) -> Self {"; - code_ += " let n = {{BASE_TYPE}}::to_le(self as {{BASE_TYPE}});"; - code_ += " let p = &n as *const {{BASE_TYPE}} as *const {{ENUM_NAME}};"; - code_ += " unsafe { *p }"; + code_ += " Self({{BASE_TYPE}}::to_le(self.0))"; code_ += " }"; code_ += " #[inline]"; code_ += " fn from_little_endian(self) -> Self {"; - code_ += " let n = {{BASE_TYPE}}::from_le(self as {{BASE_TYPE}});"; - code_ += " let p = &n as *const {{BASE_TYPE}} as *const {{ENUM_NAME}};"; - code_ += " unsafe { *p }"; + code_ += " Self({{BASE_TYPE}}::from_le(self.0))"; code_ += " }"; code_ += "}"; code_ += ""; - code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {"; - code_ += " type Output = {{ENUM_NAME}};"; - code_ += " #[inline]"; - code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {"; - code_ += - " flatbuffers::emplace_scalar::<{{ENUM_NAME}}>" - "(dst, *self);"; - code_ += " }"; - code_ += "}"; - code_ += ""; - - // Generate an array of all enumeration values. - auto num_fields = NumToString(enum_def.size()); - code_ += "#[allow(non_camel_case_types)]"; - code_ += "pub const ENUM_VALUES_{{ENUM_NAME_CAPS}}: [{{ENUM_NAME}}; " + - num_fields + "] = ["; - for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) { - const auto &ev = **it; - auto value = GetEnumValUse(enum_def, ev); - auto suffix = *it != enum_def.Vals().back() ? "," : ""; - code_ += " " + value + suffix; - } - code_ += "];"; - code_ += ""; - - // Generate a string table for enum values. - // Problem is, if values are very sparse that could generate really big - // tables. Ideally in that case we generate a map lookup instead, but for - // the moment we simply don't output a table at all. - auto range = enum_def.Distance(); - // Average distance between values above which we consider a table - // "too sparse". Change at will. - static const uint64_t kMaxSparseness = 5; - if (range / static_cast<uint64_t>(enum_def.size()) < kMaxSparseness) { - code_ += "#[allow(non_camel_case_types)]"; - code_ += "pub const ENUM_NAMES_{{ENUM_NAME_CAPS}}: [&str; " + - NumToString(range + 1) + "] = ["; - - auto val = enum_def.Vals().front(); - for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); - ++it) { - auto ev = *it; - for (auto k = enum_def.Distance(val, ev); k > 1; --k) { - code_ += " \"\","; - } - val = ev; - auto suffix = *it != enum_def.Vals().back() ? "," : ""; - code_ += " \"" + Name(*ev) + "\"" + suffix; - } - code_ += "];"; - code_ += ""; - - code_ += - "pub fn enum_name_{{ENUM_NAME_SNAKE}}(e: {{ENUM_NAME}}) -> " - "&'static str {"; - - code_ += " let index = e as {{BASE_TYPE}}\\"; - if (enum_def.MinValue()->IsNonZero()) { - auto vals = GetEnumValUse(enum_def, *enum_def.MinValue()); - code_ += " - " + vals + " as {{BASE_TYPE}}\\"; - } - code_ += ";"; - - code_ += " ENUM_NAMES_{{ENUM_NAME_CAPS}}[index as usize]"; - code_ += "}"; - code_ += ""; - } if (enum_def.is_union) { // Generate tyoesafe offset(s) for unions @@ -677,7 +735,7 @@ class RustGenerator : public BaseGenerator { auto ev = field.value.type.enum_def->FindByValue(field.value.constant); assert(ev); return WrapInNameSpace(field.value.type.enum_def->defined_namespace, - GetEnumValUse(*field.value.type.enum_def, *ev)); + GetEnumValue(*field.value.type.enum_def, *ev)); } // All pointer-ish types have a default value of None, because they are @@ -1027,9 +1085,8 @@ class RustGenerator : public BaseGenerator { } case ftUnionKey: case ftEnumKey: { - const auto underlying_typname = GetTypeBasic(type); //<- never used - const auto typname = WrapInNameSpace(*type.enum_def); - const auto default_value = GetDefaultScalarValue(field); + const std::string typname = WrapInNameSpace(*type.enum_def); + const std::string default_value = GetDefaultScalarValue(field); if (field.optional) { return "self._tab.get::<" + typname + ">(" + offset_name + ", None)"; } else { @@ -1302,7 +1359,7 @@ class RustGenerator : public BaseGenerator { code_.SetValue( "U_ELEMENT_ENUM_TYPE", - WrapInNameSpace(u->defined_namespace, GetEnumValUse(*u, ev))); + WrapInNameSpace(u->defined_namespace, GetEnumValue(*u, ev))); code_.SetValue("U_ELEMENT_TABLE_TYPE", table_init_type); code_.SetValue("U_ELEMENT_NAME", MakeSnakeCase(Name(ev))); @@ -1763,6 +1820,9 @@ class RustGenerator : public BaseGenerator { } void GenNamespaceImports(const int white_spaces) { + if (white_spaces == 0) { + code_ += "#![allow(unused_imports, dead_code)]"; + } std::string indent = std::string(white_spaces, ' '); code_ += ""; if (!parser_.opts.generate_all) { |