diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 827d6168be..98d972d6cd 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -46,6 +46,7 @@ libdd-dogstatsd-client @DataDog/apm-common-components-core libdd-library-config*/ @DataDog/apm-sdk-capabilities libdd-log*/ @DataDog/apm-common-components-core libdd-profiling*/ @DataDog/libdatadog-profiling +libdd-proto-codec @DataDog/apm-common-components-core libdd-telemetry*/ @DataDog/apm-common-components-core libdd-tinybytes @DataDog/apm-common-components-core libdd-trace-normalization @DataDog/serverless @DataDog/libdatadog-apm diff --git a/Cargo.lock b/Cargo.lock index f5e51d432b..392be4b7d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1055,7 +1055,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8599749b6667e2f0c910c1d0dff6901163ff698a52d5a39720f61b5be4b20d3" dependencies = [ "futures-core", - "prost", + "prost 0.14.3", "prost-types", "tonic", "tonic-prost", @@ -1075,7 +1075,7 @@ dependencies = [ "hdrhistogram", "humantime", "hyper-util", - "prost", + "prost 0.14.3", "prost-types", "serde", "serde_json", @@ -1514,7 +1514,7 @@ dependencies = [ "clap", "libdd-profiling", "libdd-profiling-protobuf", - "prost", + "prost 0.14.3", "sysinfo", ] @@ -3131,7 +3131,7 @@ dependencies = [ name = "libdd-ddsketch" version = "1.0.1" dependencies = [ - "prost", + "prost 0.14.3", "prost-build", "protoc-bin-vendored", ] @@ -3235,7 +3235,7 @@ dependencies = [ "mime", "parking_lot", "proptest", - "prost", + "prost 0.14.3", "rand 0.8.5", "reqwest", "rustc-hash 1.1.0", @@ -3286,7 +3286,18 @@ version = "1.0.0" dependencies = [ "bolero", "libdd-profiling-protobuf", - "prost", + "prost 0.14.3", +] + +[[package]] +name = "libdd-proto-codec" +version = "28.0.2" +dependencies = [ + "arbitrary", + "bytes", + "criterion", + "prost 0.13.5", + "rand 0.9.0", ] [[package]] @@ -3375,7 +3386,7 @@ dependencies = [ name = "libdd-trace-protobuf" version = "1.1.0" dependencies = [ - "prost", + "prost 0.14.3", "prost-build", "protoc-bin-vendored", "serde", @@ -3419,7 +3430,7 @@ dependencies = [ "libdd-trace-normalization", "libdd-trace-protobuf", "libdd-trace-utils", - "prost", + "prost 0.14.3", "rand 0.8.5", "rmp", "rmp-serde", @@ -4415,6 +4426,16 @@ dependencies = [ "unarray", ] +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive 0.13.5", +] + [[package]] name = "prost" version = "0.14.3" @@ -4422,7 +4443,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.14.3", ] [[package]] @@ -4437,13 +4458,26 @@ dependencies = [ "multimap", "petgraph", "prettyplease", - "prost", + "prost 0.14.3", "prost-types", "regex", "syn 2.0.87", "tempfile", ] +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools 0.11.0", + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "prost-derive" version = "0.14.3" @@ -4463,7 +4497,7 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" dependencies = [ - "prost", + "prost 0.14.3", ] [[package]] @@ -6209,7 +6243,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66bd50ad6ce1252d87ef024b3d64fe4c3cf54a86fb9ef4c631fdd0ded7aeaa67" dependencies = [ "bytes", - "prost", + "prost 0.14.3", "tonic", ] diff --git a/Cargo.toml b/Cargo.toml index 23cd1b87fc..f28f51bc8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ members = [ "libdd-tinybytes", "libdd-dogstatsd-client", "libdd-log", - "libdd-log-ffi", + "libdd-log-ffi", "libdd-proto-codec", ] # https://doc.rust-lang.org/cargo/reference/resolver.html diff --git a/libdd-proto-codec/Cargo.toml b/libdd-proto-codec/Cargo.toml new file mode 100644 index 0000000000..ec6d4a855c --- /dev/null +++ b/libdd-proto-codec/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "libdd-proto-codec" +rust-version.workspace = true +edition.workspace = true +version.workspace = true +license.workspace = true +authors.workspace = true + +[dependencies] +bytes = { version = "1.4" } + +[dev-dependencies] +prost = { version = "0.13", features = ["derive"] } +arbitrary = { version = "1", features = ["derive"] } +rand = "0.9" +criterion = { version = "0.5", features = ["html_reports"] } diff --git a/libdd-proto-codec/src/constants.rs b/libdd-proto-codec/src/constants.rs new file mode 100644 index 0000000000..b6510be460 --- /dev/null +++ b/libdd-proto-codec/src/constants.rs @@ -0,0 +1,33 @@ +use core::hash; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, hash::Hash, PartialOrd, Ord)] +pub enum WireType { + Varint = 0, + Fixed64 = 1, + LengthDelimited = 2, + StartGroup = 3, // Deprecated in proto3, but still used in proto2. + EndGroup = 4, // Deprecated in proto3, but still used in proto2. + Fixed32 = 5, +} + +impl WireType { + #[inline] + #[allow(unused)] + pub(crate) const fn from_u32(value: u32) -> Option { + match value { + 0 => Some(WireType::Varint), + 1 => Some(WireType::Fixed64), + 2 => Some(WireType::LengthDelimited), + 3 => Some(WireType::StartGroup), + 4 => Some(WireType::EndGroup), + 5 => Some(WireType::Fixed32), + _ => None, + } + } + + #[inline] + pub(crate) const fn to_u32(self) -> u32 { + self as u32 + } +} diff --git a/libdd-proto-codec/src/encoder.rs b/libdd-proto-codec/src/encoder.rs new file mode 100644 index 0000000000..dcbe0df67f --- /dev/null +++ b/libdd-proto-codec/src/encoder.rs @@ -0,0 +1,511 @@ +use crate::constants::WireType; +use alloc::vec::Vec; +use core::ops::DerefMut; + +pub const MAP_KEY_FIELD_NUM: u32 = 1; +pub const MAP_VALUE_FIELD_NUM: u32 = 2; + +pub trait BufMut: DerefMut { + fn put_u8(&mut self, v: u8); + fn put_slice(&mut self, slice: &[u8]); + fn truncate(&mut self, new_len: usize); +} + +impl BufMut for Vec { + fn put_u8(&mut self, v: u8) { + self.push(v); + } + + fn put_slice(&mut self, slice: &[u8]) { + self.extend_from_slice(slice); + } + + fn truncate(&mut self, new_len: usize) { + self.truncate(new_len); + } +} + +#[derive(Default)] +pub struct TopLevelEncoder { + data: B, +} + +impl TopLevelEncoder { + pub fn encoder(&mut self) -> Encoder<'_, B> { + Encoder { + data: &mut self.data, + } + } + + pub fn finish(self) -> B { + self.data + } +} + +pub struct NestedEncoder<'a, B: BufMut> { + tag_position: usize, + size_position: usize, + write_empty: bool, + encoder: Encoder<'a, B>, +} + +impl NestedEncoder<'_, B> { + pub fn encoder(&mut self) -> Encoder<'_, B> { + Encoder { + data: self.encoder.data, + } + } +} + +impl Drop for NestedEncoder<'_, B> { + fn drop(&mut self) { + let size = self.encoder.data.len() - self.size_position - 5; + if !self.write_empty && size == 0 { + // If the message is empty, we need to remove the tag and size + self.encoder.data.truncate(self.tag_position); + return; + } + + let size_placeholder: &mut [u8; 5] = (&mut self.encoder.data + [self.size_position..self.size_position + 5]) + .try_into() + .unwrap(); + write_varint_max(size as u64, size_placeholder); + } +} + +trait ScalarEncoder { + type Input; + const WIRE_TYPE: WireType; + + fn encode(input: Self::Input, data: &mut B); +} + +#[inline(always)] +const fn append_varint u64>(f: F) -> impl FnOnce(T, &mut B) { + move |input: T, data: &mut B| { + let v = f(input); + encode_varint(v, data) + } +} + +macro_rules! impl_scalar_encode { + ($ty:ident, $input_ty:ty, $write_fn:expr, $wire_ty:expr) => { + struct $ty; + impl ScalarEncoder for $ty { + type Input = $input_ty; + const WIRE_TYPE: WireType = $wire_ty; + + #[inline(always)] + fn encode(input: Self::Input, data: &mut B) { + $write_fn(input, data); + } + } + }; +} + +macro_rules! impl_scalar_encode_varint { + ($ty:ident, $input_ty:ty, $to_varint_fn:expr) => { + impl_scalar_encode!( + $ty, + $input_ty, + append_varint($to_varint_fn), + WireType::Varint + ); + }; +} + +impl_scalar_encode_varint!(UInt64Encoder, u64, |v| v); +impl_scalar_encode_varint!(UInt32Encoder, u32, |v| v as u64); +impl_scalar_encode_varint!(Int64Encoder, i64, |v| v as u64); +impl_scalar_encode_varint!(Int32Encoder, i32, |v| v as u64); +impl_scalar_encode_varint!(SInt64Encoder, i64, |v| ((v << 1) ^ (v >> 63)) as u64); +impl_scalar_encode_varint!(SInt32Encoder, i32, |v| ((v << 1) ^ (v >> 31)) as u32 as u64); +impl_scalar_encode_varint!(BoolEncoder, bool, |v| v as u64); +impl_scalar_encode!( + Fixed64Encoder, + u64, + |v: u64, data: &mut B| { + data.put_slice(&v.to_le_bytes()); + }, + WireType::Fixed64 +); +impl_scalar_encode!( + Fixed32Encoder, + u32, + |v: u32, data: &mut B| { + data.put_slice(&v.to_le_bytes()); + }, + WireType::Fixed32 +); +impl_scalar_encode!( + SFixed64Encoder, + i64, + |v: i64, data: &mut B| { + data.put_slice(&v.to_le_bytes()); + }, + WireType::Fixed64 +); +impl_scalar_encode!( + SFixed32Encoder, + i32, + |v: i32, data: &mut B| { + data.put_slice(&v.to_le_bytes()); + }, + WireType::Fixed32 +); +impl_scalar_encode!( + F64Encoder, + f64, + |v: f64, data: &mut B| { + let bits = v.to_bits(); + data.put_slice(&bits.to_le_bytes()); + }, + WireType::Fixed64 +); +impl_scalar_encode!( + F32Encoder, + f32, + |v: f32, data: &mut B| { + let bits = v.to_bits(); + data.put_slice(&bits.to_le_bytes()); + }, + WireType::Fixed32 +); + +struct StringEncoder<'a>(core::marker::PhantomData<&'a ()>); + +impl<'a> ScalarEncoder for StringEncoder<'a> { + type Input = &'a str; + const WIRE_TYPE: WireType = WireType::LengthDelimited; + + #[inline(always)] + fn encode(input: Self::Input, data: &mut B) { + BytesEncoder::encode(input.as_bytes(), data); + } +} + +struct BytesEncoder<'a>(core::marker::PhantomData<&'a ()>); + +impl<'a> ScalarEncoder for BytesEncoder<'a> { + type Input = &'a [u8]; + const WIRE_TYPE: WireType = WireType::LengthDelimited; + + #[inline(always)] + fn encode(input: Self::Input, data: &mut B) { + let size = input.len(); + encode_varint(size as u64, data); + data.put_slice(input); + } +} + +fn encode_packed, B: BufMut>( + values: I, + data: &mut B, +) { + let size_position = data.len(); + data.put_slice(&[0; 5]); // Placeholder for size + for value in values { + E::encode(value, data); + } + let size = data.len() - size_position - 5; + let size_placeholder: &mut [u8; 5] = (&mut data[size_position..size_position + 5]) + .try_into() + .unwrap(); + write_varint_max(size as u64, size_placeholder); +} + +#[derive(Debug)] +pub struct Encoder<'a, B: BufMut> { + data: &'a mut B, +} + +impl Encoder<'_, B> { + /// returns an Encoder for a nested message. + /// + /// ```rust + /// use libdd_proto_codec::encoder::{TopLevelEncoder, Encoder, BufMut}; + /// + /// struct Bar { + /// baz: i32, + /// } + /// + /// fn encode_bar(e: &mut Encoder<'_, B>, bar: &Bar) { + /// e.write_sint32(1, bar.baz); + /// } + /// + /// struct Foo { + /// bar: Bar, + /// } + /// + /// fn encode_foo(e: &mut Encoder<'_, B>, foo: &Foo) { + /// encode_bar(&mut e.write_message(1).encoder(), &foo.bar); + /// } + /// + /// let mut e = TopLevelEncoder::>::default(); + /// encode_foo(&mut e.encoder(), &Foo { bar: Bar { baz: -1 } } ); + /// dbg!(e.finish()); + /// ``` + pub fn write_message(&mut self, field_number: u32) -> NestedEncoder<'_, B> { + let tag_position = self.data.len(); + encode_tagged(field_number, WireType::LengthDelimited, self.data); + let size_position = self.data.len(); + self.data.put_slice(&[0; 5]); // Placeholder for size + NestedEncoder { + tag_position, + write_empty: false, + size_position, + encoder: Encoder { data: self.data }, + } + } + + /// returns an Encoder for a nested message. + /// + /// If the nested message has a zero value (all fields are default or missing) + /// it will still be encoded into the buffer + pub fn write_message_opt(&mut self, field_number: u32) -> NestedEncoder<'_, B> { + encode_tagged(field_number, WireType::LengthDelimited, self.data); + let size_position = self.data.len(); + self.data.put_slice(&[0; 5]); // Placeholder for size + NestedEncoder { + // not needed + tag_position: 0, + write_empty: true, + size_position, + encoder: Encoder { data: self.data }, + } + } + + /// returns an Encoder for a nested message with repeated annotation + /// + /// ```rust + /// use libdd_proto_codec::encoder::{TopLevelEncoder, Encoder, BufMut}; + /// + /// struct Bar { + /// baz: i32, + /// } + /// + /// fn encode_bar(e: &mut Encoder<'_, B>, bar: &Bar) { + /// e.write_sint32(1, bar.baz); + /// } + /// + /// struct Foo { + /// bars: Vec, + /// } + /// + /// fn encode_foo(e: &mut Encoder<'_, B>, foo: &Foo) { + /// for bar in &foo.bars { + /// encode_bar(&mut e.write_message(1).encoder(), &bar); + /// } + /// } + /// + /// let mut e = TopLevelEncoder::>::default(); + /// encode_foo(&mut e.encoder(), &Foo { bars: vec![Bar { baz: -1 }, Bar { baz: 0 }] } ); + /// dbg!(e.finish()); + /// ``` + pub fn write_message_repeated(&mut self, field_number: u32) -> NestedEncoder<'_, B> { + self.write_message_opt(field_number) + } + + pub fn write_strings_repeated<'b, I: IntoIterator>( + &mut self, + field_number: u32, + v: I, + ) { + for value in v { + self.write_string_repeated(field_number, value); + } + } + + pub fn write_bytess_repeated<'b, I: IntoIterator>( + &mut self, + field_number: u32, + v: I, + ) { + for value in v { + self.write_bytes_repeated(field_number, value); + } + } + + /// returns a helper to encode protobufs maps + /// + /// ``` + /// use libdd_proto_codec::encoder::{Encoder, MapEncoder, BufMut, MAP_KEY_FIELD_NUM, MAP_VALUE_FIELD_NUM}; + /// + /// // message Example { + /// // map field = 3; + /// //} + /// + /// struct Example { + /// field: Vec<(String, i64)>, + /// } + /// fn encode_example(e: &mut Encoder<'_, Vec>, example: &Example) { + /// let map_encoder = e.write_map(3); + /// } + /// + /// fn encode_string_i64_map<'a, I: IntoIterator>(mut e: MapEncoder<'_, Vec>, map: I) { + /// for (k, v) in map { + /// encode_string_i64_map_entry(&mut e.write_map_entry() + /// .encoder(), k, *v); + /// } + /// } + /// + /// fn encode_string_i64_map_entry(e: &mut Encoder<'_, B>, key: &str, value: i64) { + /// e.write_string(MAP_KEY_FIELD_NUM, key); + /// e.write_int64(MAP_VALUE_FIELD_NUM, value); + /// } + /// ``` + pub fn write_map(&mut self, field_number: u32) -> MapEncoder<'_, B> { + MapEncoder { + data: self.data, + field_number, + } + } +} + +pub struct MapEncoder<'a, B: BufMut> { + data: &'a mut B, + field_number: u32, +} + +impl MapEncoder<'_, B> { + pub fn write_map_entry(&mut self) -> NestedEncoder<'_, B> { + let tag_position = self.data.len(); + encode_tagged(self.field_number, WireType::LengthDelimited, self.data); + let size_position = self.data.len(); + self.data.put_slice(&[0; 5]); // Placeholder for size + NestedEncoder { + tag_position, + write_empty: true, + size_position, + encoder: Encoder { data: self.data }, + } + } +} + +macro_rules! impl_scalar { + ($($fn_name:ident, $opt_fn_name:ident, $repeated_fn_name:ident, $($repeated_iter_fn_name:ident)?, $input_ty:ty, $encoder:ty,)*) => { + impl Encoder<'_, B> { + $( + pub fn $fn_name(&mut self, field_number: u32, v: $input_ty) { + if v == <$input_ty> :: default() { + return; + } + encode_tagged(field_number, <$encoder>::WIRE_TYPE, self.data); + <$encoder>::encode(v, self.data); + } + + pub fn $opt_fn_name(&mut self, field_number: u32, v: $input_ty) { + encode_tagged(field_number, <$encoder>::WIRE_TYPE, self.data); + <$encoder>::encode(v, self.data); + } + + pub fn $repeated_fn_name(&mut self, field_number: u32, v: $input_ty) { + self.$opt_fn_name(field_number, v); + } + + $( + pub fn $repeated_iter_fn_name>(&mut self, field_number: u32, v: I) { + for value in v { + self.$repeated_fn_name(field_number, value); + } + } + )? + )* + } + }; +} + +macro_rules! impl_packed { + ($($fn_name:ident, $input_ty:ty, $encoder:ty,)*) => { + impl Encoder<'_, B> { + $( + pub fn $fn_name>(&mut self, field_number: u32, v: I) { + let mut v = v.into_iter().peekable(); + if v.peek().is_none() { + return; + } + encode_tagged(field_number, WireType::LengthDelimited, self.data); + encode_packed::<$encoder, _, B>(v, self.data); + } + )* + } + }; +} + +impl_scalar! { + write_uint64, write_uint64_opt, write_uint64_repeated, write_uint64s_repeated, u64, UInt64Encoder, + write_uint32, write_uint32_opt, write_uint32_repeated, write_uint32s_repeated, u32, UInt32Encoder, + write_int64, write_int64_opt, write_int64_repeated, write_int64s_repeated, i64, Int64Encoder, + write_int32, write_int32_opt, write_int32_repeated, write_int32s_repeated, i32, Int32Encoder, + write_sint64, write_sint64_opt, write_sint64_repeated, write_sint64s_repeated, i64, SInt64Encoder, + write_sint32, write_sint32_opt, write_sint32_repeated, write_sint32s_repeated, i32, SInt32Encoder, + write_fixed64, write_fixed64_opt, write_fixed64_repeated, write_fixed64s_repeated, u64, Fixed64Encoder, + write_fixed32, write_fixed32_opt, write_fixed32_repeated, write_fixed32s_repeated, u32, Fixed32Encoder, + write_sfixed64, write_sfixed64_opt, write_sfixed64_repeated, write_sfixed64s_repeated, i64, SFixed64Encoder, + write_sfixed32, write_sfixed32_opt, write_sfixed32_repeated, write_sfixed32s_repeated, i32, SFixed32Encoder, + write_bool, write_bool_opt, write_bool_repeated, write_bools_repeated, bool, BoolEncoder, + write_f64, write_f64_opt, write_f64_repeated, write_f64s_repeated, f64, F64Encoder, + write_f32, write_f32_opt, write_f32_repeated, write_f32s_repeated, f32, F32Encoder, + write_string, write_string_opt, write_string_repeated, , &str, StringEncoder<'_>, + write_bytes, write_bytes_opt, write_bytes_repeated, , &[u8], BytesEncoder<'_>, +} + +impl_packed! { + write_uint64_packed, u64, UInt64Encoder, + write_uint32_packed, u32, UInt32Encoder, + write_int64_packed, i64, Int64Encoder, + write_int32_packed, i32, Int32Encoder, + write_sint64_packed, i64, SInt64Encoder, + write_sint32_packed, i32, SInt32Encoder, + write_fixed64_packed, u64, Fixed64Encoder, + write_fixed32_packed, u32, Fixed32Encoder, + write_sfixed64_packed, i64, SFixed64Encoder, + write_sfixed32_packed, i32, SFixed32Encoder, + write_bool_packed, bool, BoolEncoder, + write_f64_packed, f64, F64Encoder, + write_f32_packed, f32, F32Encoder, +} + +#[test] +fn test_encoding() { + let mut data = vec![]; + let mut encoder = Encoder { data: &mut data }; + + encoder.write_message(1).encoder().write_uint32(1, 2); + encoder.write_uint32(2, 3); + assert_eq!(data, &[10, 130, 128, 128, 128, 0, 8, 2, 16, 3]) +} + +#[inline] +fn encode_tagged(field_number: u32, wire_type: WireType, data: &mut B) { + let tag = (field_number << 3) | wire_type.to_u32(); + encode_varint(tag as u64, data); +} + +#[inline] +fn write_varint_max(mut v: u64, buf: &mut [u8; 5]) { + for (i, item) in buf.iter_mut().enumerate() { + *item = (v & 0x7F) as u8; + v >>= 7; + if i != 4 { + *item |= 0x80; + } + } +} + +/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. +/// The buffer must have enough remaining space (maximum 10 bytes). +#[inline] +fn encode_varint(mut value: u64, buf: &mut B) { + // Varints are never more than 10 bytes + for _ in 0..10 { + if value < 0x80 { + buf.put_u8(value as u8); + break; + } else { + buf.put_u8(((value & 0x7F) | 0x80) as u8); + value >>= 7; + } + } +} diff --git a/libdd-proto-codec/src/lib.rs b/libdd-proto-codec/src/lib.rs new file mode 100644 index 0000000000..fec0bf8dd6 --- /dev/null +++ b/libdd-proto-codec/src/lib.rs @@ -0,0 +1,7 @@ +extern crate alloc; + +pub mod constants; +pub mod encoder; + +#[cfg(test)] +mod tests; diff --git a/libdd-proto-codec/src/tests.rs b/libdd-proto-codec/src/tests.rs new file mode 100644 index 0000000000..3c9ae2e11c --- /dev/null +++ b/libdd-proto-codec/src/tests.rs @@ -0,0 +1,144 @@ +use arbitrary::Arbitrary; +use prost::Message; + +use crate::encoder::{self, MAP_KEY_FIELD_NUM, MAP_VALUE_FIELD_NUM}; + +#[derive(PartialEq, prost::Message, arbitrary::Arbitrary)] +struct Bar { + #[prost(message, tag = "1", required)] + foo: Foo, + #[prost(string, repeated, tag = "2")] + string_repeated_field: Vec, + #[prost(sfixed32, repeated, tag = "3")] + i32_repeated_field: Vec, + #[prost(string, tag = "4")] + string_field: String, + #[prost(map = "string, sint64", tag = "5")] + map_field: std::collections::HashMap, +} + +#[derive(prost::Message, arbitrary::Arbitrary)] +struct Foo { + #[prost(uint64, tag = "1")] + u64_field: u64, + #[prost(uint32, tag = "2")] + u32_field: u32, + #[prost(int64, tag = "3")] + i64_field: i64, + #[prost(int32, tag = "4")] + i32_field: i32, + #[prost(sint64, tag = "5")] + si64_field: i64, + #[prost(sint32, tag = "6")] + si32_field: i32, + #[prost(bool, tag = "7")] + bool_field: bool, + #[prost(double, tag = "8")] + f64_field: f64, + #[prost(float, tag = "9")] + f32_field: f32, + #[prost(sint64, repeated, packed, tag = "10")] + packed_si64_packed_field: Vec, + #[prost(string, tag = "11")] + string_field: String, +} + +impl PartialEq for Foo { + fn eq(&self, other: &Self) -> bool { + self.u64_field == other.u64_field + && self.u32_field == other.u32_field + && self.i64_field == other.i64_field + && self.i32_field == other.i32_field + && self.si64_field == other.si64_field + && self.si32_field == other.si32_field + && self.bool_field == other.bool_field + && self.f64_field.total_cmp(&other.f64_field).is_eq() + && self.f32_field.total_cmp(&other.f32_field).is_eq() + && self.packed_si64_packed_field == other.packed_si64_packed_field + && self.string_field == other.string_field + } +} + +fn manual_encode_bar(e: &mut encoder::Encoder<'_, B>, bar: &Bar) { + manual_encode_foo(&mut e.write_message(1).encoder(), &bar.foo); + e.write_strings_repeated(2, bar.string_repeated_field.iter().map(|s| s.as_str())); + e.write_sfixed32_packed(3, bar.i32_repeated_field.iter().copied()); + e.write_string(4, &bar.string_field); + let mut map_enc = e.write_map(5); + for (k, v) in &bar.map_field { + let mut entry = map_enc.write_map_entry(); + let mut entry_enc = entry.encoder(); + entry_enc.write_string(MAP_KEY_FIELD_NUM, k); + entry_enc.write_sint64(MAP_VALUE_FIELD_NUM, *v); + } +} + +fn manual_bar_top_level_encoder(bar: &Bar) -> Vec { + let mut encoder = encoder::TopLevelEncoder::default(); + manual_encode_bar(&mut encoder.encoder(), bar); + encoder.finish() +} + +fn manual_encode_foo(e: &mut encoder::Encoder<'_, B>, foo: &Foo) { + e.write_uint64(1, foo.u64_field); + e.write_uint32(2, foo.u32_field); + e.write_int64(3, foo.i64_field); + e.write_int32(4, foo.i32_field); + e.write_sint64(5, foo.si64_field); + e.write_sint32(6, foo.si32_field); + e.write_bool(7, foo.bool_field); + e.write_f64(8, foo.f64_field); + e.write_f32(9, foo.f32_field); + e.write_sint64_packed(10, foo.packed_si64_packed_field.iter().copied()); + e.write_string(11, &foo.string_field); +} + +fn manual_foo_top_level_encoder(foo: &Foo) -> Vec { + let mut encoder = encoder::TopLevelEncoder::default(); + manual_encode_foo(&mut encoder.encoder(), foo); + encoder.finish() +} + +#[test] +fn test_roundtrip_bar() { + for _ in 0..100 { + let l = rand::random_range(0..255_usize); + let input: Vec = (0..l).map(|_| rand::random()).collect(); + let input_bar: Bar = Bar::arbitrary(&mut arbitrary::Unstructured::new(&input)).unwrap(); + + test_roundtrip_bar_inner(&input_bar); + } +} + +fn test_roundtrip_bar_inner(input_bar: &Bar) { + let manual_encoded = manual_bar_top_level_encoder(input_bar); + let prost_encoded = input_bar.encode_to_vec(); + + let prost_decoded_prost_encoded = Bar::decode(&*prost_encoded).unwrap(); + let prost_decoded_manual_encoded = Bar::decode(&*manual_encoded).unwrap(); + + assert_eq!(&prost_decoded_prost_encoded, input_bar); + assert_eq!(&prost_decoded_manual_encoded, input_bar); +} + +#[test] +fn test_roundtrip_foo() { + for _ in 0..100 { + let l = rand::random_range(0..255_usize); + let input: Vec = (0..l).map(|_| rand::random()).collect(); + let input_foo: Foo = Foo::arbitrary(&mut arbitrary::Unstructured::new(&input)).unwrap(); + test_roundtrip_foo_inner(&input_foo); + } +} + +fn test_roundtrip_foo_inner(input_foo: &Foo) { + let manual_encoded = manual_foo_top_level_encoder(input_foo); + let prost_encoded = input_foo.encode_to_vec(); + + let prost_decoded_prost_encoded = Foo::decode(&*prost_encoded).unwrap(); + + let prost_decoded_manual_encoded = Foo::decode(&*manual_encoded).unwrap(); + + assert_eq!(&prost_decoded_prost_encoded, input_foo); + assert_eq!(&prost_decoded_manual_encoded, input_foo); +}