Skip to content

Commit

Permalink
feat: derive Copy trait for messages where possible (#950)
Browse files Browse the repository at this point in the history
* feat: derive Copy trait for messages where possible

Rust primitive types can be copied by simply copying the bits. Rust structs can also have this property by deriving the Copy trait.

Automatically derive Copy for:
- messages that only have fields with primitive types
- the Rust enum for one-of fields
- messages whose field type are messages that also implement Copy

Generated code for Protobuf enums already derives Copy.

* fix: Remove clone call when copy is implemented

Clippy reports: warning: using `clone` on type `Timestamp` which implements the `Copy` trait
  • Loading branch information
caspermeijn authored May 20, 2024
1 parent d42c85e commit 85c698a
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 24 deletions.
15 changes: 13 additions & 2 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,12 @@ impl<'a> CodeGenerator<'a> {
self.buf
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");
self.buf.push_str(&format!(
"#[derive(Clone, PartialEq, {}::Message)]\n",
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
if self.message_graph.can_message_derive_copy(&fq_message_name) {
"Copy, "
} else {
""
},
prost_path(self.config)
));
self.append_skip_debug(&fq_message_name);
Expand Down Expand Up @@ -613,8 +618,14 @@ impl<'a> CodeGenerator<'a> {
self.push_indent();
self.buf
.push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n");

let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| {
self.message_graph
.can_field_derive_copy(fq_message_name, field)
});
self.buf.push_str(&format!(
"#[derive(Clone, PartialEq, {}::Oneof)]\n",
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
if can_oneof_derive_copy { "Copy, " } else { "" },
prost_path(self.config)
));
self.append_skip_debug(fq_message_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ pub struct Foo {
pub foo: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag="1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Qux {
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ pub struct Foo {
pub foo: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Qux {}
59 changes: 55 additions & 4 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
use petgraph::Graph;

use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto};
use prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
};

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
pub struct MessageGraph {
index: HashMap<String, NodeIndex>,
graph: Graph<String, ()>,
messages: HashMap<String, DescriptorProto>,
}

impl MessageGraph {
Expand All @@ -21,6 +25,7 @@ impl MessageGraph {
let mut msg_graph = MessageGraph {
index: HashMap::new(),
graph: Graph::new(),
messages: HashMap::new(),
};

for file in files {
Expand All @@ -41,6 +46,7 @@ impl MessageGraph {
let MessageGraph {
ref mut index,
ref mut graph,
..
} = *self;
assert_eq!(b'.', msg_name.as_bytes()[0]);
*index
Expand All @@ -58,13 +64,12 @@ impl MessageGraph {
let msg_index = self.get_or_insert_index(msg_name.clone());

for field in &msg.field {
if field.r#type() == field_descriptor_proto::Type::Message
&& field.label() != field_descriptor_proto::Label::Repeated
{
if field.r#type() == Type::Message && field.label() != Label::Repeated {
let field_index = self.get_or_insert_index(field.type_name.clone().unwrap());
self.graph.add_edge(msg_index, field_index, ());
}
}
self.messages.insert(msg_name.clone(), msg.clone());

for msg in &msg.nested_type {
self.add_message(&msg_name, msg);
Expand All @@ -84,4 +89,50 @@ impl MessageGraph {

has_path_connecting(&self.graph, outer, inner, None)
}

/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
let msg = self.messages.get(fq_message_name).unwrap();
msg.field
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Copy trait.
pub fn can_field_derive_copy(
&self,
fq_message_name: &str,
field: &FieldDescriptorProto,
) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.label() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
if self.is_nested(field.type_name(), fq_message_name) {
false
} else {
self.can_message_derive_copy(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Float
| Type::Double
| Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
)
}
}
}
2 changes: 1 addition & 1 deletion prost-types/src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ mod tests {
};
assert_eq!(
expected,
format!("{}", DateTime::from(timestamp.clone())),
format!("{}", DateTime::from(timestamp)),
"timestamp: {:?}",
timestamp
);
Expand Down
6 changes: 3 additions & 3 deletions prost-types/src/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl TryFrom<Duration> for time::Duration {

impl fmt::Display for Duration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = self.clone();
let mut d = *self;
d.normalize();
if self.seconds < 0 && self.nanos < 0 {
write!(f, "-")?;
Expand Down Expand Up @@ -193,7 +193,7 @@ mod tests {
Ok(duration) => duration,
Err(_) => return Err(TestCaseError::reject("duration out of range")),
};
prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration);
prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration);

if std_duration != time::Duration::default() {
let neg_prost_duration = Duration {
Expand All @@ -220,7 +220,7 @@ mod tests {
Ok(duration) => duration,
Err(_) => return Err(TestCaseError::reject("duration out of range")),
};
prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration);
prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration);

if std_duration != time::Duration::default() {
let neg_prost_duration = Duration {
Expand Down
8 changes: 4 additions & 4 deletions prost-types/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub mod descriptor_proto {
/// fields or extension ranges in the same message. Reserved ranges may
/// not overlap.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct ReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -360,7 +360,7 @@ pub mod enum_descriptor_proto {
/// is inclusive such that it can appropriately represent the entire int32
/// domain.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct EnumReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -1853,7 +1853,7 @@ pub struct Mixin {
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
/// microsecond should be expressed in JSON format as "3.000001s".
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Duration {
/// Signed seconds of the span of time. Must be from -315,576,000,000
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
Expand Down Expand Up @@ -2293,7 +2293,7 @@ impl NullValue {
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
pub struct Timestamp {
/// Represents seconds of UTC time since Unix epoch
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
Expand Down
11 changes: 5 additions & 6 deletions prost-types/src/timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl Timestamp {
///
/// [1]: https://github.com/google/protobuf/blob/v3.3.2/src/google/protobuf/util/time_util.cc#L59-L77
pub fn try_normalize(mut self) -> Result<Timestamp, Timestamp> {
let before = self.clone();
let before = self;
self.normalize();
// If the seconds value has changed, and is either i64::MIN or i64::MAX, then the timestamp
// normalization overflowed.
Expand Down Expand Up @@ -201,7 +201,7 @@ impl TryFrom<Timestamp> for std::time::SystemTime {
type Error = TimestampError;

fn try_from(mut timestamp: Timestamp) -> Result<std::time::SystemTime, Self::Error> {
let orig_timestamp = timestamp.clone();
let orig_timestamp = timestamp;
timestamp.normalize();

let system_time = if timestamp.seconds >= 0 {
Expand All @@ -211,8 +211,7 @@ impl TryFrom<Timestamp> for std::time::SystemTime {
timestamp
.seconds
.checked_neg()
.ok_or_else(|| TimestampError::OutOfSystemRange(timestamp.clone()))?
as u64,
.ok_or(TimestampError::OutOfSystemRange(timestamp))? as u64,
))
};

Expand All @@ -234,7 +233,7 @@ impl FromStr for Timestamp {

impl fmt::Display for Timestamp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
datetime::DateTime::from(self.clone()).fmt(f)
datetime::DateTime::from(*self).fmt(f)
}
}
#[cfg(test)]
Expand Down Expand Up @@ -262,7 +261,7 @@ mod tests {
) {
let mut timestamp = Timestamp { seconds, nanos };
timestamp.normalize();
if let Ok(system_time) = SystemTime::try_from(timestamp.clone()) {
if let Ok(system_time) = SystemTime::try_from(timestamp) {
prop_assert_eq!(Timestamp::from(system_time), timestamp);
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ fn main() {
.compile_protos(&[src.join("deprecated_field.proto")], includes)
.unwrap();

config
.compile_protos(&[src.join("derive_copy.proto")], includes)
.unwrap();

config
.compile_protos(&[src.join("default_string_escape.proto")], includes)
.unwrap();
Expand Down
51 changes: 51 additions & 0 deletions tests/src/derive_copy.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
syntax = "proto3";

import "google/protobuf/timestamp.proto";

package derive_copy;

message EmptyMsg {}

message IntegerMsg {
int32 field1 = 1;
int64 field2 = 2;
uint32 field3 = 3;
uint64 field4 = 4;
sint32 field5 = 5;
sint64 field6 = 6;
fixed32 field7 = 7;
fixed64 field8 = 8;
sfixed32 field9 = 9;
sfixed64 field10 = 10;
}

message FloatMsg {
double field1 = 1;
float field2 = 2;
}

message BoolMsg { bool field1 = 1; }

enum AnEnum {
A = 0;
B = 1;
};

message EnumMsg { AnEnum field1 = 1; }

message OneOfMsg {
oneof data {
int32 field1 = 1;
int64 field2 = 2;
}
}

message ComposedMsg {
IntegerMsg field1 = 1;
EnumMsg field2 = 2;
OneOfMsg field3 = 3;
}

message WellKnownMsg {
google.protobuf.Timestamp timestamp = 1;
}
21 changes: 21 additions & 0 deletions tests/src/derive_copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
include!(concat!(env!("OUT_DIR"), "/derive_copy.rs"));

trait TestCopyIsImplemented: Copy {}

impl TestCopyIsImplemented for EmptyMsg {}

impl TestCopyIsImplemented for IntegerMsg {}

impl TestCopyIsImplemented for FloatMsg {}

impl TestCopyIsImplemented for BoolMsg {}

impl TestCopyIsImplemented for AnEnum {}

impl TestCopyIsImplemented for EnumMsg {}

impl TestCopyIsImplemented for OneOfMsg {}

impl TestCopyIsImplemented for ComposedMsg {}

impl TestCopyIsImplemented for WellKnownMsg {}
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ mod debug;
#[cfg(test)]
mod deprecated_field;
#[cfg(test)]
mod derive_copy;
#[cfg(test)]
mod enum_keyword_variant;
#[cfg(test)]
mod generic_derive;
Expand Down

0 comments on commit 85c698a

Please sign in to comment.