Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow a trait to be implemented multiple times for the same struct #3292

Merged
merged 4 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 39 additions & 53 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::dc_mod::collect_defs;
use super::errors::{DefCollectorErrorKind, DuplicateType};
use crate::graph::CrateId;
use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleDefId, ModuleId};
use crate::hir::def_map::{CrateDefMap, LocalModuleId, ModuleData, ModuleDefId, ModuleId};
use crate::hir::resolution::errors::ResolverError;
use crate::hir::resolution::import::PathResolutionError;
use crate::hir::resolution::path_resolver::PathResolver;
Expand Down Expand Up @@ -126,7 +126,7 @@
pub enum CompilationError {
ParseError(ParserError),
DefinitionError(DefCollectorErrorKind),
ResolveError(ResolverError),
ResolverError(ResolverError),
TypeError(TypeCheckError),
}

Expand All @@ -135,7 +135,7 @@
match value {
CompilationError::ParseError(error) => error.into(),
CompilationError::DefinitionError(error) => error.into(),
CompilationError::ResolveError(error) => error.into(),
CompilationError::ResolverError(error) => error.into(),
CompilationError::TypeError(error) => error.into(),
}
}
Expand All @@ -155,7 +155,7 @@

impl From<ResolverError> for CompilationError {
fn from(value: ResolverError) -> Self {
CompilationError::ResolveError(value)
CompilationError::ResolverError(value)
}
}
impl From<TypeCheckError> for CompilationError {
Expand Down Expand Up @@ -296,12 +296,6 @@
// globals will need to reference the struct type they're initialized to to ensure they are valid.
resolved_globals.extend(resolve_globals(context, other_globals, crate_id));

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
// done before resolution since we need to be able to resolve the type of the
// impl since that determines the module we should collect into.
errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls));

// Bind trait impls to their trait. Collect trait functions, that have a
// default implementation, which hasn't been overridden.
errors.extend(collect_trait_impls(
Expand All @@ -310,6 +304,15 @@
&mut def_collector.collected_traits_impls,
));

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
// done before resolution since we need to be able to resolve the type of the
// impl since that determines the module we should collect into.
//
// These are resolved after trait impls so that struct methods are chosen
// over trait methods if there are name conflicts.
errors.extend(collect_impls(context, crate_id, &def_collector.collected_impls));

// Lower each function in the crate. This is now possible since imports have been resolved
let file_func_ids = resolve_free_functions(
&mut context.def_interner,
Expand Down Expand Up @@ -377,7 +380,6 @@

if let Some(struct_type) = get_struct_type(&typ) {
let struct_type = struct_type.borrow();
let type_module = struct_type.id.local_module_id();

// `impl`s are only allowed on types defined within the current crate
if struct_type.id.krate() != crate_id {
Expand All @@ -391,7 +393,7 @@
// Grab the module defined by the struct type. Note that impls are a case
// where the module the methods are added to is not the same as the module
// they are resolved in.
let module = &mut def_maps.get_mut(&crate_id).unwrap().modules[type_module.0];
let module = get_module_mut(def_maps, struct_type.id.module_id());

for (_, method_id, method) in &unresolved.functions {
// If this method was already declared, remove it from the module so it cannot
Expand All @@ -413,6 +415,13 @@
errors
}

fn get_module_mut(
def_maps: &mut BTreeMap<CrateId, CrateDefMap>,
module: ModuleId,
) -> &mut ModuleData {
&mut def_maps.get_mut(&module.krate).unwrap().modules[module.local_id.0]
}

fn collect_trait_impl_methods(
interner: &mut NodeInterner,
def_maps: &BTreeMap<CrateId, CrateDefMap>,
Expand Down Expand Up @@ -494,25 +503,6 @@
errors
}

fn add_method_to_struct_namespace(
current_def_map: &mut CrateDefMap,
struct_type: &Shared<StructType>,
func_id: FuncId,
name_ident: &Ident,
trait_id: TraitId,
) -> Result<(), DefCollectorErrorKind> {
let struct_type = struct_type.borrow();
let type_module = struct_type.id.local_module_id();
let module = &mut current_def_map.modules[type_module.0];
module.declare_trait_function(name_ident.clone(), func_id, trait_id).map_err(
|(first_def, second_def)| DefCollectorErrorKind::Duplicate {
typ: DuplicateType::TraitImplementation,
first_def,
second_def,
},
)
}

fn collect_trait_impl(
context: &mut Context,
crate_id: CrateId,
Expand All @@ -535,28 +525,24 @@
if let Some(trait_id) = trait_impl.trait_id {
errors
.extend(collect_trait_impl_methods(interner, def_maps, crate_id, trait_id, trait_impl));
for (_, func_id, ast) in &trait_impl.methods.functions {
let file = def_maps[&crate_id].file_id(trait_impl.module_id);

let path_resolver = StandardPathResolver::new(module);
let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file);
resolver.add_generics(&ast.def.generics);
let typ = resolver.resolve_type(unresolved_type.clone());

if let Some(struct_type) = get_struct_type(&typ) {
errors.extend(take_errors(trait_impl.file_id, resolver));
let current_def_map = def_maps.get_mut(&struct_type.borrow().id.krate()).unwrap();
match add_method_to_struct_namespace(
current_def_map,
struct_type,
*func_id,
ast.name_ident(),
trait_id,
) {
Ok(()) => {}
Err(err) => {
errors.push((err.into(), trait_impl.file_id));
}
let path_resolver = StandardPathResolver::new(module);
let file = def_maps[&crate_id].file_id(trait_impl.module_id);
let mut resolver = Resolver::new(interner, &path_resolver, def_maps, file);
let typ = resolver.resolve_type(unresolved_type);
errors.extend(take_errors(trait_impl.file_id, resolver));

if let Some(struct_type) = get_struct_type(&typ) {
let struct_type = struct_type.borrow();
let module = get_module_mut(def_maps, struct_type.id.module_id());

for (_, method_id, method) in &trait_impl.methods.functions {
// If this method was already declared, remove it from the module so it cannot
// be accessed with the `TypeName::method` syntax. We'll check later whether the
// object types in each method overlap or not. If they do, we issue an error.
// If not, that is specialization which is allowed.
if module.declare_function(method.name_ident().clone(), *method_id).is_err() {
module.remove_function(method.name_ident());
}
}
}
Expand Down Expand Up @@ -841,7 +827,7 @@
}

fn take_errors(file_id: FileId, resolver: Resolver<'_>) -> Vec<(CompilationError, FileId)> {
resolver.take_errors().iter().cloned().map(|e| (e.into(), file_id)).collect()
vecmap(resolver.take_errors(), |e| (e.into(), file_id))
}

/// Create the mappings from TypeId -> TraitType
Expand Down Expand Up @@ -1040,7 +1026,7 @@
methods
}

// TODO(vitkov): Move this out of here and into type_check

Check warning on line 1029 in compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (vitkov)
fn check_methods_signatures(
resolver: &mut Resolver,
impl_methods: &Vec<(FileId, FuncId)>,
Expand Down
15 changes: 13 additions & 2 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::{
graph::CrateId,
hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait},
node_interner::{TraitId, TypeAliasId},
node_interner::{FunctionModifiers, TraitId, TypeAliasId},
parser::{SortedModule, SortedSubModule},
FunctionDefinition, Ident, LetStatement, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl,
NoirTypeAlias, TraitImplItem, TraitItem, TypeImpl,
Expand Down Expand Up @@ -378,11 +378,22 @@
body,
} => {
let func_id = context.def_interner.push_empty_fn();
let modifiers = FunctionModifiers {
name: name.to_string(),
visibility: crate::FunctionVisibility::Public,
// TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629

Check warning on line 384 in compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (Maddiaa)
attributes: crate::token::Attributes::empty(),
is_unconstrained: false,
contract_function_type: None,
is_internal: None,
};

context.def_interner.push_function_definition(func_id, modifiers, id.0);

match self.def_collector.def_map.modules[id.0.local_id.0]
.declare_function(name.clone(), func_id)
{
Ok(()) => {
// TODO(Maddiaa): Investigate trait implementations with attributes see: https://github.com/noir-lang/noir/issues/2629
if let Some(body) = body {
let impl_method =
NoirFunction::normal(FunctionDefinition::normal(
Expand Down Expand Up @@ -426,7 +437,7 @@
}
}
TraitItem::Type { name } => {
// TODO(nickysn or alexvitkov): implement context.def_interner.push_empty_type_alias and get an id, instead of using TypeAliasId::dummy_id()

Check warning on line 440 in compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (nickysn)

Check warning on line 440 in compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (alexvitkov)
if let Err((first_def, second_def)) = self.def_collector.def_map.modules
[id.0.local_id.0]
.declare_type_alias(name.clone(), TypeAliasId::dummy_id())
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,10 @@ impl NodeInterner {
#[cfg(test)]
pub fn push_test_function_definition(&mut self, name: String) -> FuncId {
let id = self.push_fn(HirFunction::empty());
let modifiers = FunctionModifiers::new();
let mut modifiers = FunctionModifiers::new();
modifiers.name = name;
let module = ModuleId::dummy_id();
self.push_function_definition(name, id, modifiers, module);
self.push_function_definition(id, modifiers, module);
id
}

Expand All @@ -631,7 +632,6 @@ impl NodeInterner {
module: ModuleId,
) -> DefinitionId {
use ContractFunctionType::*;
let name = function.name.0.contents.clone();

// We're filling in contract_function_type and is_internal now, but these will be verified
// later during name resolution.
Expand All @@ -643,16 +643,16 @@ impl NodeInterner {
contract_function_type: Some(if function.is_open { Open } else { Secret }),
is_internal: Some(function.is_internal),
};
self.push_function_definition(name, id, modifiers, module)
self.push_function_definition(id, modifiers, module)
}

pub fn push_function_definition(
&mut self,
name: String,
func: FuncId,
modifiers: FunctionModifiers,
module: ModuleId,
) -> DefinitionId {
let name = modifiers.name.clone();
self.function_modifiers.insert(func, modifiers);
self.function_modules.insert(func, module);
self.push_definition(name, false, DefinitionKind::Function(func))
Expand Down
28 changes: 17 additions & 11 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ mod test {

for (err, _file_id) in errors {
match &err {
CompilationError::ResolveError(ResolverError::PathResolutionError(
CompilationError::ResolverError(ResolverError::PathResolutionError(
PathResolutionError::Unresolved(ident),
)) => {
assert_eq!(ident, "NotAType");
Expand Down Expand Up @@ -533,19 +533,24 @@ mod test {
}
}

fn main() {
}
fn main() {}
";
let errors = get_program_errors(src);
assert!(!has_parser_error(&errors));
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
assert!(errors.len() == 2, "Expected 2 errors, got: {:?}", errors);
for (err, _file_id) in errors {
match &err {
CompilationError::DefinitionError(
DefCollectorErrorKind::TraitImplNotAllowedFor { trait_path, span: _ },
) => {
assert_eq!(trait_path.as_string(), "Default");
}
CompilationError::ResolverError(ResolverError::Expected {
expected, got, ..
}) => {
assert_eq!(expected, "type");
assert_eq!(got, "function");
}
_ => {
panic!("No other errors are expected! Found = {:?}", err);
}
Expand Down Expand Up @@ -810,7 +815,7 @@ mod test {
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
// It should be regarding the unused variable
match &errors[0].0 {
CompilationError::ResolveError(ResolverError::UnusedVariable { ident }) => {
CompilationError::ResolverError(ResolverError::UnusedVariable { ident }) => {
assert_eq!(&ident.0.contents, "y");
}
_ => unreachable!("we should only have an unused var error"),
Expand All @@ -829,7 +834,7 @@ mod test {
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
// It should be regarding the unresolved var `z` (Maybe change to undeclared and special case)
match &errors[0].0 {
CompilationError::ResolveError(ResolverError::VariableNotDeclared {
CompilationError::ResolverError(ResolverError::VariableNotDeclared {
name,
span: _,
}) => assert_eq!(name, "z"),
Expand All @@ -848,7 +853,7 @@ mod test {
assert!(errors.len() == 1, "Expected 1 error, got: {:?}", errors);
for (compilation_error, _file_id) in errors {
match compilation_error {
CompilationError::ResolveError(err) => {
CompilationError::ResolverError(err) => {
match err {
ResolverError::PathResolutionError(PathResolutionError::Unresolved(
name,
Expand Down Expand Up @@ -892,7 +897,7 @@ mod test {
// `foo::bar` does not exist
for (compilation_error, _file_id) in errors {
match compilation_error {
CompilationError::ResolveError(err) => {
CompilationError::ResolverError(err) => {
match err {
ResolverError::UnusedVariable { ident } => {
assert_eq!(&ident.0.contents, "z");
Expand Down Expand Up @@ -1069,12 +1074,13 @@ mod test {

for (err, _file_id) in errors {
match &err {
CompilationError::ResolveError(ResolverError::VariableNotDeclared {
name, ..
CompilationError::ResolverError(ResolverError::VariableNotDeclared {
name,
..
}) => {
assert_eq!(name, "i");
}
CompilationError::ResolveError(ResolverError::NumericConstantInFormatString {
CompilationError::ResolverError(ResolverError::NumericConstantInFormatString {
name,
..
}) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "trait_generics"
type = "bin"
authors = [""]
compiler_version = "0.10.5"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

struct Empty<T> {}

trait Foo {
fn foo(self) -> u32;
}

impl Foo for Empty<u32> {
fn foo(_self: Self) -> u32 { 32 }
}

impl Foo for Empty<u64> {
fn foo(_self: Self) -> u32 { 64 }
}

fn main() {
let x: Empty<u32> = Empty {};
let y: Empty<u64> = Empty {};
let z = Empty {};

assert(x.foo() == 32);
assert(y.foo() == 64);

// Types matching multiple impls will currently choose
// the first matching one instead of erroring
assert(z.foo() == 32);
}
Loading