diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index a50e96e4b..41f0a8a93 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -189,7 +189,12 @@ impl<'a> CodeGenerator<'a> { self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Message)]\n", + "#[derive(Clone, {}PartialEq, {}::Message)]\n", + if self.message_graph.can_message_derive_copy(&fq_message_name) { + "Copy, " + } else { + "" + }, self.config.prost_path.as_deref().unwrap_or("::prost") )); self.append_skip_debug(&fq_message_name); @@ -597,8 +602,14 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); + + let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| { + self.message_graph + .can_field_derive_copy(fq_message_name, field) + }); self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Oneof)]\n", + "#[derive(Clone, {}PartialEq, {}::Oneof)]\n", + if can_oneof_derive_copy { "Copy, " } else { "" }, self.config.prost_path.as_deref().unwrap_or("::prost") )); self.append_skip_debug(&fq_message_name); diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index 04860e63d..0b0433719 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -22,12 +22,12 @@ pub struct Foo { pub foo: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Qux { } diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index 8c329f902..d585c286a 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -22,11 +22,11 @@ pub struct Foo { pub foo: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Qux {} diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index ac0ad1523..9cc40f975 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -4,7 +4,10 @@ use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; use petgraph::Graph; -use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; +use prost_types::{ + field_descriptor_proto::{Label, Type}, + DescriptorProto, FieldDescriptorProto, FileDescriptorProto, +}; /// `MessageGraph` builds a graph of messages whose edges correspond to nesting. /// The goal is to recognize when message types are recursively nested, so @@ -12,6 +15,7 @@ use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; pub struct MessageGraph { index: HashMap, graph: Graph, + messages: HashMap, } impl MessageGraph { @@ -21,6 +25,7 @@ impl MessageGraph { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(), + messages: HashMap::new(), }; for file in files { @@ -41,6 +46,7 @@ impl MessageGraph { let MessageGraph { ref mut index, ref mut graph, + .. } = *self; assert_eq!(b'.', msg_name.as_bytes()[0]); *index @@ -58,13 +64,12 @@ impl MessageGraph { let msg_index = self.get_or_insert_index(msg_name.clone()); for field in &msg.field { - if field.r#type() == field_descriptor_proto::Type::Message - && field.label() != field_descriptor_proto::Label::Repeated - { + if field.r#type() == Type::Message && field.label() != Label::Repeated { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); self.graph.add_edge(msg_index, field_index, ()); } } + self.messages.insert(msg_name.clone(), msg.clone()); for msg in &msg.nested_type { self.add_message(&msg_name, msg); @@ -84,4 +89,50 @@ impl MessageGraph { has_path_connecting(&self.graph, outer, inner, None) } + + /// Returns `true` if this message can automatically derive Copy trait. + pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { + assert_eq!(".", &fq_message_name[..1]); + let msg = self.messages.get(fq_message_name).unwrap(); + msg.field + .iter() + .all(|field| self.can_field_derive_copy(fq_message_name, field)) + } + + /// Returns `true` if the type of this field allows deriving the Copy trait. + pub fn can_field_derive_copy( + &self, + fq_message_name: &str, + field: &FieldDescriptorProto, + ) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + if field.label() == Label::Repeated { + false + } else if field.r#type() == Type::Message { + if self.is_nested(field.type_name(), fq_message_name) { + false + } else { + self.can_message_derive_copy(field.type_name()) + } + } else { + matches!( + field.r#type(), + Type::Float + | Type::Double + | Type::Int32 + | Type::Int64 + | Type::Uint32 + | Type::Uint64 + | Type::Sint32 + | Type::Sint64 + | Type::Fixed32 + | Type::Fixed64 + | Type::Sfixed32 + | Type::Sfixed64 + | Type::Bool + | Type::Enum + ) + } + } } diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index fcbe430df..e47e16cea 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -93,7 +93,7 @@ pub mod descriptor_proto { /// fields or extension ranges in the same message. Reserved ranges may /// not overlap. #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -359,7 +359,7 @@ pub mod enum_descriptor_proto { /// is inclusive such that it can appropriately represent the entire int32 /// domain. #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct EnumReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -1852,7 +1852,7 @@ pub struct Mixin { /// be expressed in JSON format as "3.000000001s", and 3 seconds and 1 /// microsecond should be expressed in JSON format as "3.000001s". #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -2292,7 +2292,7 @@ impl NullValue { /// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use /// the Joda Time's [`ISODateTimeFormat.dateTime()`]() to obtain a formatter capable of generating timestamps in this format. #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch /// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to diff --git a/tests/src/build.rs b/tests/src/build.rs index aec3f6458..660063c67 100644 --- a/tests/src/build.rs +++ b/tests/src/build.rs @@ -87,6 +87,10 @@ fn main() { .compile_protos(&[src.join("deprecated_field.proto")], includes) .unwrap(); + config + .compile_protos(&[src.join("derive_copy.proto")], includes) + .unwrap(); + config .compile_protos(&[src.join("default_string_escape.proto")], includes) .unwrap(); diff --git a/tests/src/derive_copy.proto b/tests/src/derive_copy.proto new file mode 100644 index 000000000..d2a472bf8 --- /dev/null +++ b/tests/src/derive_copy.proto @@ -0,0 +1,51 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +package derive_copy; + +message EmptyMsg {} + +message IntegerMsg { + int32 field1 = 1; + int64 field2 = 2; + uint32 field3 = 3; + uint64 field4 = 4; + sint32 field5 = 5; + sint64 field6 = 6; + fixed32 field7 = 7; + fixed64 field8 = 8; + sfixed32 field9 = 9; + sfixed64 field10 = 10; +} + +message FloatMsg { + double field1 = 1; + float field2 = 2; +} + +message BoolMsg { bool field1 = 1; } + +enum AnEnum { + A = 0; + B = 1; +}; + +message EnumMsg { AnEnum field1 = 1; } + +message OneOfMsg { + oneof data { + int32 field1 = 1; + int64 field2 = 2; + } +} + +message ComposedMsg { + IntegerMsg field1 = 1; + EnumMsg field2 = 2; + OneOfMsg field3 = 3; +} + +message WellKnownMsg { + google.protobuf.Timestamp timestamp = 1; +} diff --git a/tests/src/derive_copy.rs b/tests/src/derive_copy.rs new file mode 100644 index 000000000..33b4fc84f --- /dev/null +++ b/tests/src/derive_copy.rs @@ -0,0 +1,21 @@ +include!(concat!(env!("OUT_DIR"), "/derive_copy.rs")); + +trait TestCopyIsImplemented: Copy {} + +impl TestCopyIsImplemented for EmptyMsg {} + +impl TestCopyIsImplemented for IntegerMsg {} + +impl TestCopyIsImplemented for FloatMsg {} + +impl TestCopyIsImplemented for BoolMsg {} + +impl TestCopyIsImplemented for AnEnum {} + +impl TestCopyIsImplemented for EnumMsg {} + +impl TestCopyIsImplemented for OneOfMsg {} + +impl TestCopyIsImplemented for ComposedMsg {} + +impl TestCopyIsImplemented for WellKnownMsg {} diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 954c6b367..83aab2ca8 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -37,6 +37,8 @@ mod debug; #[cfg(test)] mod deprecated_field; #[cfg(test)] +mod derive_copy; +#[cfg(test)] mod generic_derive; #[cfg(test)] mod message_encoding;