From fe5cd31f6df838d0a83c6b0bd307f20f1155be2e Mon Sep 17 00:00:00 2001 From: Casper Meijn Date: Thu, 23 Nov 2023 14:25:24 +0100 Subject: [PATCH] feat: derive Copy trait for messages where possible Rust primitive types can be copied by simply copying the bits. Rust structs can also have this property by deriving the Copy trait. Automatically derive Copy for: - messages that only have fields with primitive types - the Rust enum for one-of fields - messages whose field type are messages that also implement Copy Generated code for Protobuf enums already derives Copy. --- prost-build/src/code_generator.rs | 15 ++++- .../_expected_field_attributes.rs | 4 +- .../_expected_field_attributes_formatted.rs | 4 +- prost-build/src/message_graph.rs | 59 +++++++++++++++++-- prost-types/src/protobuf.rs | 8 +-- tests/src/build.rs | 4 ++ tests/src/derive_copy.proto | 51 ++++++++++++++++ tests/src/derive_copy.rs | 21 +++++++ tests/src/lib.rs | 2 + 9 files changed, 154 insertions(+), 14 deletions(-) create mode 100644 tests/src/derive_copy.proto create mode 100644 tests/src/derive_copy.rs diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index ecd21852a..5dfb8245a 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -182,7 +182,12 @@ impl<'a> CodeGenerator<'a> { self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Message)]\n", + "#[derive(Clone, {}PartialEq, {}::Message)]\n", + if self.message_graph.can_message_derive_copy(&fq_message_name) { + "Copy, " + } else { + "" + }, self.config.prost_path.as_deref().unwrap_or("::prost") )); self.append_skip_debug(&fq_message_name); @@ -592,8 +597,14 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); + + let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| { + self.message_graph + .can_field_derive_copy(fq_message_name, field) + }); self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Oneof)]\n", + "#[derive(Clone, {}PartialEq, {}::Oneof)]\n", + if can_oneof_derive_copy { "Copy, " } else { "" }, self.config.prost_path.as_deref().unwrap_or("::prost") )); self.append_skip_debug(&fq_message_name); diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index 95fb05d86..f58bbb0ba 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -23,12 +23,12 @@ pub struct Foo { pub foo: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Qux { } diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index f1eaee751..0aabf753f 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -23,11 +23,11 @@ pub struct Foo { pub foo: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Qux {} diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index ac0ad1523..9cc40f975 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -4,7 +4,10 @@ use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; use petgraph::Graph; -use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; +use prost_types::{ + field_descriptor_proto::{Label, Type}, + DescriptorProto, FieldDescriptorProto, FileDescriptorProto, +}; /// `MessageGraph` builds a graph of messages whose edges correspond to nesting. /// The goal is to recognize when message types are recursively nested, so @@ -12,6 +15,7 @@ use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; pub struct MessageGraph { index: HashMap, graph: Graph, + messages: HashMap, } impl MessageGraph { @@ -21,6 +25,7 @@ impl MessageGraph { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(), + messages: HashMap::new(), }; for file in files { @@ -41,6 +46,7 @@ impl MessageGraph { let MessageGraph { ref mut index, ref mut graph, + .. } = *self; assert_eq!(b'.', msg_name.as_bytes()[0]); *index @@ -58,13 +64,12 @@ impl MessageGraph { let msg_index = self.get_or_insert_index(msg_name.clone()); for field in &msg.field { - if field.r#type() == field_descriptor_proto::Type::Message - && field.label() != field_descriptor_proto::Label::Repeated - { + if field.r#type() == Type::Message && field.label() != Label::Repeated { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); self.graph.add_edge(msg_index, field_index, ()); } } + self.messages.insert(msg_name.clone(), msg.clone()); for msg in &msg.nested_type { self.add_message(&msg_name, msg); @@ -84,4 +89,50 @@ impl MessageGraph { has_path_connecting(&self.graph, outer, inner, None) } + + /// Returns `true` if this message can automatically derive Copy trait. + pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { + assert_eq!(".", &fq_message_name[..1]); + let msg = self.messages.get(fq_message_name).unwrap(); + msg.field + .iter() + .all(|field| self.can_field_derive_copy(fq_message_name, field)) + } + + /// Returns `true` if the type of this field allows deriving the Copy trait. + pub fn can_field_derive_copy( + &self, + fq_message_name: &str, + field: &FieldDescriptorProto, + ) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + if field.label() == Label::Repeated { + false + } else if field.r#type() == Type::Message { + if self.is_nested(field.type_name(), fq_message_name) { + false + } else { + self.can_message_derive_copy(field.type_name()) + } + } else { + matches!( + field.r#type(), + Type::Float + | Type::Double + | Type::Int32 + | Type::Int64 + | Type::Uint32 + | Type::Uint64 + | Type::Sint32 + | Type::Sint64 + | Type::Fixed32 + | Type::Fixed64 + | Type::Sfixed32 + | Type::Sfixed64 + | Type::Bool + | Type::Enum + ) + } + } } diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index edc1361be..34de0ec88 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -94,7 +94,7 @@ pub mod descriptor_proto { /// fields or extension ranges in the same message. Reserved ranges may /// not overlap. #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -360,7 +360,7 @@ pub mod enum_descriptor_proto { /// is inclusive such that it can appropriately represent the entire int32 /// domain. #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct EnumReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -1853,7 +1853,7 @@ pub struct Mixin { /// be expressed in JSON format as "3.000000001s", and 3 seconds and 1 /// microsecond should be expressed in JSON format as "3.000001s". #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -2293,7 +2293,7 @@ impl NullValue { /// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use /// the Joda Time's [`ISODateTimeFormat.dateTime()`]() to obtain a formatter capable of generating timestamps in this format. #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch /// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to diff --git a/tests/src/build.rs b/tests/src/build.rs index e2b95e6c4..85e67d5ba 100644 --- a/tests/src/build.rs +++ b/tests/src/build.rs @@ -87,6 +87,10 @@ fn main() { .compile_protos(&[src.join("deprecated_field.proto")], includes) .unwrap(); + config + .compile_protos(&[src.join("derive_copy.proto")], includes) + .unwrap(); + config .compile_protos(&[src.join("default_string_escape.proto")], includes) .unwrap(); diff --git a/tests/src/derive_copy.proto b/tests/src/derive_copy.proto new file mode 100644 index 000000000..d2a472bf8 --- /dev/null +++ b/tests/src/derive_copy.proto @@ -0,0 +1,51 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +package derive_copy; + +message EmptyMsg {} + +message IntegerMsg { + int32 field1 = 1; + int64 field2 = 2; + uint32 field3 = 3; + uint64 field4 = 4; + sint32 field5 = 5; + sint64 field6 = 6; + fixed32 field7 = 7; + fixed64 field8 = 8; + sfixed32 field9 = 9; + sfixed64 field10 = 10; +} + +message FloatMsg { + double field1 = 1; + float field2 = 2; +} + +message BoolMsg { bool field1 = 1; } + +enum AnEnum { + A = 0; + B = 1; +}; + +message EnumMsg { AnEnum field1 = 1; } + +message OneOfMsg { + oneof data { + int32 field1 = 1; + int64 field2 = 2; + } +} + +message ComposedMsg { + IntegerMsg field1 = 1; + EnumMsg field2 = 2; + OneOfMsg field3 = 3; +} + +message WellKnownMsg { + google.protobuf.Timestamp timestamp = 1; +} diff --git a/tests/src/derive_copy.rs b/tests/src/derive_copy.rs new file mode 100644 index 000000000..33b4fc84f --- /dev/null +++ b/tests/src/derive_copy.rs @@ -0,0 +1,21 @@ +include!(concat!(env!("OUT_DIR"), "/derive_copy.rs")); + +trait TestCopyIsImplemented: Copy {} + +impl TestCopyIsImplemented for EmptyMsg {} + +impl TestCopyIsImplemented for IntegerMsg {} + +impl TestCopyIsImplemented for FloatMsg {} + +impl TestCopyIsImplemented for BoolMsg {} + +impl TestCopyIsImplemented for AnEnum {} + +impl TestCopyIsImplemented for EnumMsg {} + +impl TestCopyIsImplemented for OneOfMsg {} + +impl TestCopyIsImplemented for ComposedMsg {} + +impl TestCopyIsImplemented for WellKnownMsg {} diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 7d3d94867..48a150efc 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -37,6 +37,8 @@ mod debug; #[cfg(test)] mod deprecated_field; #[cfg(test)] +mod derive_copy; +#[cfg(test)] mod generic_derive; #[cfg(test)] mod message_encoding;