Skip to content

Commit

Permalink
feat: Implement impl specialization (#3087)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Oct 11, 2023
1 parent 35f7a9d commit 44716fa
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 33 deletions.
15 changes: 6 additions & 9 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,12 @@ fn collect_impls(
let module = &mut def_maps.get_mut(&crate_id).unwrap().modules[type_module.0];

for (_, method_id, method) in &unresolved.functions {
let result = module.declare_function(method.name_ident().clone(), *method_id);

if let Err((first_def, second_def)) = result {
let error = DefCollectorErrorKind::Duplicate {
typ: DuplicateType::Function,
first_def,
second_def,
};
errors.push((error.into(), unresolved.file_id));
// If this method was already declared, remove it from the module so it cannot
// be accessed with the `TypeName::method` syntax. We'll check later whether the
// object types in each method overlap or not. If they do, we issue an error.
// If not, that is specialization which is allowed.
if module.declare_function(method.name_ident().clone(), *method_id).is_err() {
module.remove_function(method.name_ident());
}
}
// Prohibit defining impls for primitive types if we're not in the stdlib
Expand Down
5 changes: 5 additions & 0 deletions compiler/noirc_frontend/src/hir/def_map/item_scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,9 @@ impl ItemScope {
pub fn values(&self) -> &HashMap<Ident, (ModuleDefId, Visibility)> {
&self.values
}

pub fn remove_definition(&mut self, name: &Ident) {
self.types.remove(name);
self.values.remove(name);
}
}
5 changes: 5 additions & 0 deletions compiler/noirc_frontend/src/hir/def_map/module_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ impl ModuleData {
self.declare(name, id.into())
}

pub fn remove_function(&mut self, name: &Ident) {
self.scope.remove_definition(name);
self.definitions.remove_definition(name);
}

pub fn declare_global(&mut self, name: Ident, id: StmtId) -> Result<(), (Ident, Ident)> {
self.declare(name, id.into())
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,8 @@ impl<'interner> TypeChecker<'interner> {
) -> Option<HirMethodReference> {
match object_type {
Type::Struct(typ, _args) => {
match self.interner.lookup_method(typ.borrow().id, method_name) {
let id = typ.borrow().id;
match self.interner.lookup_method(object_type, id, method_name, false) {
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ impl Type {

/// `try_unify` is a bit of a misnomer since although errors are not committed,
/// any unified bindings are on success.
fn try_unify(&self, other: &Type) -> Result<(), UnificationError> {
pub fn try_unify(&self, other: &Type) -> Result<(), UnificationError> {
use Type::*;
use TypeVariableKind as Kind;

Expand Down Expand Up @@ -995,7 +995,7 @@ impl Type {
/// Instantiate this type, replacing any type variables it is quantified
/// over with fresh type variables. If this type is not a Type::Forall,
/// it is unchanged.
pub fn instantiate(&self, interner: &mut NodeInterner) -> (Type, TypeBindings) {
pub fn instantiate(&self, interner: &NodeInterner) -> (Type, TypeBindings) {
match self {
Type::Forall(typevars, typ) => {
let replacements = typevars
Expand Down
92 changes: 71 additions & 21 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ pub struct NodeInterner {
// Each trait definition is possibly shared across multiple type nodes.
// It is also mutated through the RefCell during name resolution to append
// methods from impls to the type.
//
// TODO: We may be able to remove the Shared wrapper once traits are no longer types.
// We'd just lookup their methods as needed through the NodeInterner.
traits: HashMap<TraitId, Trait>,

// Trait implementation map
Expand All @@ -108,10 +105,15 @@ pub struct NodeInterner {

globals: HashMap<StmtId, GlobalInfo>, // NOTE: currently only used for checking repeat globals and restricting their scope to a module

next_type_variable_id: usize,
next_type_variable_id: std::cell::Cell<usize>,

/// A map from a struct type and method name to a function id for the method.
struct_methods: HashMap<(StructId, String), FuncId>,
/// This can resolve to potentially multiple methods if the same method name is
/// specialized for different generics on the same type. E.g. for `Struct<T>`, we
/// may have both `impl Struct<u32> { fn foo(){} }` and `impl Struct<u8> { fn foo(){} }`.
/// If this happens, the returned Vec will have 2 entries and we'll need to further
/// disambiguate them by checking the type of each function.
struct_methods: HashMap<(StructId, String), Vec<FuncId>>,

/// Methods on primitive types defined in the stdlib.
primitive_methods: HashMap<(TypeMethodKey, String), FuncId>,
Expand Down Expand Up @@ -381,7 +383,7 @@ impl Default for NodeInterner {
trait_implementations: HashMap::new(),
instantiation_bindings: HashMap::new(),
field_indices: HashMap::new(),
next_type_variable_id: 0,
next_type_variable_id: std::cell::Cell::new(0),
globals: HashMap::new(),
struct_methods: HashMap::new(),
primitive_methods: HashMap::new(),
Expand Down Expand Up @@ -829,13 +831,13 @@ impl NodeInterner {
*old = Node::Expression(new);
}

pub fn next_type_variable_id(&mut self) -> TypeVariableId {
let id = self.next_type_variable_id;
self.next_type_variable_id += 1;
pub fn next_type_variable_id(&self) -> TypeVariableId {
let id = self.next_type_variable_id.get();
self.next_type_variable_id.set(id + 1);
TypeVariableId(id)
}

pub fn next_type_variable(&mut self) -> Type {
pub fn next_type_variable(&self) -> Type {
Type::type_variable(self.next_type_variable_id())
}

Expand Down Expand Up @@ -863,9 +865,10 @@ impl NodeInterner {
self.function_definition_ids[&function]
}

/// Add a method to a type.
/// This will panic for non-struct types currently as we do not support methods
/// for primitives. We could allow this in the future however.
/// Adds a non-trait method to a type.
///
/// Returns `Some(duplicate)` if a matching method was already defined.
/// Returns `None` otherwise.
pub fn add_method(
&mut self,
self_type: &Type,
Expand All @@ -874,8 +877,15 @@ impl NodeInterner {
) -> Option<FuncId> {
match self_type {
Type::Struct(struct_type, _generics) => {
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, method_id)
let id = struct_type.borrow().id;

if let Some(existing) = self.lookup_method(self_type, id, &method_name, true) {
return Some(existing);
}

let key = (id, method_name);
self.struct_methods.entry(key).or_default().push(method_id);
None
}
Type::Error => None,

Expand All @@ -899,11 +909,10 @@ impl NodeInterner {
) -> bool {
self.trait_implementations.insert(key.clone(), trait_impl.clone());
match &key.typ {
Type::Struct(struct_type, _generics) => {
Type::Struct(..) => {
for func_id in &trait_impl.borrow().methods {
let method_name = self.function_name(func_id).to_owned();
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, *func_id);
self.add_method(&key.typ, method_name, *func_id);
}
true
}
Expand Down Expand Up @@ -938,9 +947,50 @@ impl NodeInterner {
}
}

/// Search by name for a method on the given struct
pub fn lookup_method(&self, id: StructId, method_name: &str) -> Option<FuncId> {
self.struct_methods.get(&(id, method_name.to_owned())).copied()
/// Search by name for a method on the given struct.
///
/// If `check_type` is true, this will force `lookup_method` to check the type
/// of each candidate instead of returning only the first candidate if there is exactly one.
/// This is generally only desired when declaring new methods to check if they overlap any
/// existing methods.
///
/// Another detail is that this method does not handle auto-dereferencing through `&mut T`.
/// So if an object is of type `self : &mut T` but a method only accepts `self: T` (or
/// vice-versa), the call will not be selected. If this is ever implemented into this method,
/// we can remove the `methods.len() == 1` check and the `check_type` early return.
pub fn lookup_method(
&self,
typ: &Type,
id: StructId,
method_name: &str,
check_type: bool,
) -> Option<FuncId> {
let methods = self.struct_methods.get(&(id, method_name.to_owned()))?;

// If there is only one method, just return it immediately.
// It will still be typechecked later.
if !check_type && methods.len() == 1 {
return Some(methods[0]);
}

// When adding methods we always check they do not overlap, so there should be
// at most 1 matching method in this list.
for method in methods {
match self.function_meta(method).typ.instantiate(self).0 {
Type::Function(args, _, _) => {
if let Some(object) = args.get(0) {
// TODO #3089: This is dangerous! try_unify may commit type bindings even on failure
if object.try_unify(typ).is_ok() {
return Some(*method);
}
}
}
Type::Error => (),
other => unreachable!("Expected function type, found {other}"),
}
}

None
}

/// Looks up a given method name on the given primitive type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "specialization"
type = "bin"
authors = [""]
compiler_version = "0.16.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
struct Foo<T> {}

impl Foo<u32> {
fn foo(_self: Self) -> Field { 1 }
}

impl Foo<u64> {
fn foo(_self: Self) -> Field { 2 }
}

fn main() {
let f1: Foo<u32> = Foo{};
let f2: Foo<u64> = Foo{};
assert(f1.foo() + f2.foo() == 3);
}

0 comments on commit 44716fa

Please sign in to comment.