From fb977f4cc28a3f5c913b963a27b553cf845bb9d2 Mon Sep 17 00:00:00 2001 From: Donough Liu Date: Fri, 20 Sep 2024 14:09:42 +0800 Subject: [PATCH] fix(prost-build): Remove `derived(Copy)` on boxed fields (#1157) * fix(prost-build): Remove `derived(Copy)` on boxed fields * Add regression test --- prost-build/src/config.rs | 2 +- .../_expected_field_attributes.rs | 2 +- .../_expected_field_attributes_formatted.rs | 2 +- prost-build/src/message_graph.rs | 28 ++++++++++++++++--- prost-build/src/path.rs | 2 +- tests/src/boxed_field.proto | 10 +++++++ tests/src/build.rs | 5 ++++ tests/src/lib.rs | 4 +++ 8 files changed, 47 insertions(+), 8 deletions(-) create mode 100644 tests/src/boxed_field.proto diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 965d79c7b..c59358246 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -1080,7 +1080,7 @@ impl Config { let mut modules = HashMap::new(); let mut packages = HashMap::new(); - let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1)); + let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1), self.boxed.clone()); let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types) .map_err(|error| Error::new(ErrorKind::InvalidInput, error))?; diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index dc420692c..bf1e8c517 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -19,7 +19,7 @@ pub struct Foo { #[prost(string, tag="1")] pub foo: ::prost::alloc::string::String, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index b1955b955..c130aad2e 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -19,7 +19,7 @@ pub struct Foo { #[prost(string, tag = "1")] pub foo: ::prost::alloc::string::String, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index 1f02ef352..e2bcad918 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -9,6 +9,8 @@ use prost_types::{ DescriptorProto, FieldDescriptorProto, FileDescriptorProto, }; +use crate::path::PathMap; + /// `MessageGraph` builds a graph of messages whose edges correspond to nesting. /// The goal is to recognize when message types are recursively nested, so /// that fields can be boxed when necessary. @@ -16,14 +18,19 @@ pub struct MessageGraph { index: HashMap, graph: Graph, messages: HashMap, + boxed: PathMap<()>, } impl MessageGraph { - pub fn new<'a>(files: impl Iterator) -> MessageGraph { + pub(crate) fn new<'a>( + files: impl Iterator, + boxed: PathMap<()>, + ) -> MessageGraph { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(), messages: HashMap::new(), + boxed, }; for file in files { @@ -74,6 +81,11 @@ impl MessageGraph { } } + /// Try get a message descriptor from current message graph + pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> { + self.messages.get(message) + } + /// Returns true if message type `inner` is nested in message type `outer`. pub fn is_nested(&self, outer: &str, inner: &str) -> bool { let outer = match self.index.get(outer) { @@ -91,8 +103,9 @@ impl MessageGraph { /// Returns `true` if this message can automatically derive Copy trait. pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { assert_eq!(".", &fq_message_name[..1]); - let msg = self.messages.get(fq_message_name).unwrap(); - msg.field + self.get_message(fq_message_name) + .unwrap() + .field .iter() .all(|field| self.can_field_derive_copy(fq_message_name, field)) } @@ -105,10 +118,17 @@ impl MessageGraph { ) -> bool { assert_eq!(".", &fq_message_name[..1]); + // repeated field cannot derive Copy if field.label() == Label::Repeated { false } else if field.r#type() == Type::Message { - if self.is_nested(field.type_name(), fq_message_name) { + // nested and boxed messages cannot derive Copy + if self.is_nested(field.type_name(), fq_message_name) + || self + .boxed + .get_first_field(fq_message_name, field.name()) + .is_some() + { false } else { self.can_message_derive_copy(field.type_name()) diff --git a/prost-build/src/path.rs b/prost-build/src/path.rs index f6897005d..2c2d8e242 100644 --- a/prost-build/src/path.rs +++ b/prost-build/src/path.rs @@ -3,7 +3,7 @@ use std::iter; /// Maps a fully-qualified Protobuf path to a value using path matchers. -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub(crate) struct PathMap { // insertion order might actually matter (to avoid warning about legacy-derive-helpers) // see: https://doc.rust-lang.org/rustc/lints/listing/warn-by-default.html#legacy-derive-helpers diff --git a/tests/src/boxed_field.proto b/tests/src/boxed_field.proto new file mode 100644 index 000000000..17f543d92 --- /dev/null +++ b/tests/src/boxed_field.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package boxed_field; + +message Foo { + Bar bar = 1; +} + +message Bar { +} diff --git a/tests/src/build.rs b/tests/src/build.rs index 796157980..b707fb270 100644 --- a/tests/src/build.rs +++ b/tests/src/build.rs @@ -171,6 +171,11 @@ fn main() { .compile_protos(&[src.join("type_names.proto")], includes) .unwrap(); + prost_build::Config::new() + .boxed("Foo.bar") + .compile_protos(&[src.join("boxed_field.proto")], includes) + .unwrap(); + // Check that attempting to compile a .proto without a package declaration does not result in an error. config .compile_protos(&[src.join("no_package.proto")], includes) diff --git a/tests/src/lib.rs b/tests/src/lib.rs index bdeb90352..614baa969 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -140,6 +140,10 @@ pub mod invalid { } } +pub mod boxed_field { + include!(concat!(env!("OUT_DIR"), "/boxed_field.rs")); +} + pub mod default_string_escape { include!(concat!(env!("OUT_DIR"), "/default_string_escape.rs")); }