Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement impl specialization #3087

Merged
merged 7 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,12 @@ fn collect_impls(
let module = &mut def_maps.get_mut(&crate_id).unwrap().modules[type_module.0];

for (_, method_id, method) in &unresolved.functions {
let result = module.declare_function(method.name_ident().clone(), *method_id);

if let Err((first_def, second_def)) = result {
let error = DefCollectorErrorKind::Duplicate {
typ: DuplicateType::Function,
first_def,
second_def,
};
errors.push((error.into(), unresolved.file_id));
// If this method was already declared, remove it from the module so it cannot
// be accessed with the `TypeName::method` syntax. We'll check later whether the
// object types in each method overlap or not. If they do, we issue an error.
// If not, that is specialization which is allowed.
if module.declare_function(method.name_ident().clone(), *method_id).is_err() {
module.remove_function(method.name_ident());
}
}
// Prohibit defining impls for primitive types if we're not in the stdlib
Expand Down
5 changes: 5 additions & 0 deletions compiler/noirc_frontend/src/hir/def_map/item_scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,9 @@ impl ItemScope {
pub fn values(&self) -> &HashMap<Ident, (ModuleDefId, Visibility)> {
&self.values
}

pub fn remove_definition(&mut self, name: &Ident) {
self.types.remove(name);
self.values.remove(name);
}
}
5 changes: 5 additions & 0 deletions compiler/noirc_frontend/src/hir/def_map/module_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ impl ModuleData {
self.declare(name, id.into())
}

pub fn remove_function(&mut self, name: &Ident) {
self.scope.remove_definition(name);
self.definitions.remove_definition(name);
}

pub fn declare_global(&mut self, name: Ident, id: StmtId) -> Result<(), (Ident, Ident)> {
self.declare(name, id.into())
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,8 @@ impl<'interner> TypeChecker<'interner> {
) -> Option<HirMethodReference> {
match object_type {
Type::Struct(typ, _args) => {
match self.interner.lookup_method(typ.borrow().id, method_name) {
let id = typ.borrow().id;
vezenovm marked this conversation as resolved.
Show resolved Hide resolved
match self.interner.lookup_method(object_type, id, method_name, false) {
Some(method_id) => Some(HirMethodReference::FuncId(method_id)),
None => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ impl Type {

/// `try_unify` is a bit of a misnomer since although errors are not committed,
/// any unified bindings are on success.
fn try_unify(&self, other: &Type) -> Result<(), UnificationError> {
pub fn try_unify(&self, other: &Type) -> Result<(), UnificationError> {
use Type::*;
use TypeVariableKind as Kind;

Expand Down Expand Up @@ -995,7 +995,7 @@ impl Type {
/// Instantiate this type, replacing any type variables it is quantified
/// over with fresh type variables. If this type is not a Type::Forall,
/// it is unchanged.
pub fn instantiate(&self, interner: &mut NodeInterner) -> (Type, TypeBindings) {
pub fn instantiate(&self, interner: &NodeInterner) -> (Type, TypeBindings) {
match self {
Type::Forall(typevars, typ) => {
let replacements = typevars
Expand Down
92 changes: 71 additions & 21 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ pub struct NodeInterner {
// Each trait definition is possibly shared across multiple type nodes.
// It is also mutated through the RefCell during name resolution to append
// methods from impls to the type.
//
// TODO: We may be able to remove the Shared wrapper once traits are no longer types.
// We'd just lookup their methods as needed through the NodeInterner.
traits: HashMap<TraitId, Trait>,

// Trait implementation map
Expand All @@ -108,10 +105,15 @@ pub struct NodeInterner {

globals: HashMap<StmtId, GlobalInfo>, // NOTE: currently only used for checking repeat globals and restricting their scope to a module

next_type_variable_id: usize,
next_type_variable_id: std::cell::Cell<usize>,

/// A map from a struct type and method name to a function id for the method.
struct_methods: HashMap<(StructId, String), FuncId>,
/// This can resolve to potentially multiple methods if the same method name is
/// specialized for different generics on the same type. E.g. for `Struct<T>`, we
/// may have both `impl Struct<u32> { fn foo(){} }` and `impl Struct<u8> { fn foo(){} }`.
/// If this happens, the returned Vec will have 2 entries and we'll need to further
/// disambiguate them by checking the type of each function.
struct_methods: HashMap<(StructId, String), Vec<FuncId>>,

/// Methods on primitive types defined in the stdlib.
primitive_methods: HashMap<(TypeMethodKey, String), FuncId>,
Expand Down Expand Up @@ -381,7 +383,7 @@ impl Default for NodeInterner {
trait_implementations: HashMap::new(),
instantiation_bindings: HashMap::new(),
field_indices: HashMap::new(),
next_type_variable_id: 0,
next_type_variable_id: std::cell::Cell::new(0),
globals: HashMap::new(),
struct_methods: HashMap::new(),
primitive_methods: HashMap::new(),
Expand Down Expand Up @@ -829,13 +831,13 @@ impl NodeInterner {
*old = Node::Expression(new);
}

pub fn next_type_variable_id(&mut self) -> TypeVariableId {
let id = self.next_type_variable_id;
self.next_type_variable_id += 1;
pub fn next_type_variable_id(&self) -> TypeVariableId {
let id = self.next_type_variable_id.get();
self.next_type_variable_id.set(id + 1);
TypeVariableId(id)
}

pub fn next_type_variable(&mut self) -> Type {
pub fn next_type_variable(&self) -> Type {
Type::type_variable(self.next_type_variable_id())
}

Expand Down Expand Up @@ -863,9 +865,10 @@ impl NodeInterner {
self.function_definition_ids[&function]
}

/// Add a method to a type.
/// This will panic for non-struct types currently as we do not support methods
/// for primitives. We could allow this in the future however.
/// Adds a non-trait method to a type.
///
/// Returns `Some(duplicate)` if a matching method was already defined.
/// Returns `None` otherwise.
pub fn add_method(
&mut self,
self_type: &Type,
Expand All @@ -874,8 +877,15 @@ impl NodeInterner {
) -> Option<FuncId> {
match self_type {
Type::Struct(struct_type, _generics) => {
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, method_id)
let id = struct_type.borrow().id;

if let Some(existing) = self.lookup_method(self_type, id, &method_name, true) {
return Some(existing);
}

let key = (id, method_name);
self.struct_methods.entry(key).or_default().push(method_id);
None
}
Type::Error => None,

Expand All @@ -899,11 +909,10 @@ impl NodeInterner {
) -> bool {
self.trait_implementations.insert(key.clone(), trait_impl.clone());
match &key.typ {
Type::Struct(struct_type, _generics) => {
Type::Struct(..) => {
for func_id in &trait_impl.borrow().methods {
let method_name = self.function_name(func_id).to_owned();
let key = (struct_type.borrow().id, method_name);
self.struct_methods.insert(key, *func_id);
self.add_method(&key.typ, method_name, *func_id);
}
true
}
Expand Down Expand Up @@ -938,9 +947,50 @@ impl NodeInterner {
}
}

/// Search by name for a method on the given struct
pub fn lookup_method(&self, id: StructId, method_name: &str) -> Option<FuncId> {
self.struct_methods.get(&(id, method_name.to_owned())).copied()
/// Search by name for a method on the given struct.
///
/// If `check_type` is true, this will force `lookup_method` to check the type
/// of each candidate instead of returning only the first candidate if there is exactly one.
/// This is generally only desired when declaring new methods to check if they overlap any
/// existing methods.
///
/// Another detail is that this method does not handle auto-dereferencing through `&mut T`.
/// So if an object is of type `self : &mut T` but a method only accepts `self: T` (or
/// vice-versa), the call will not be selected. If this is ever implemented into this method,
/// we can remove the `methods.len() == 1` check and the `check_type` early return.
pub fn lookup_method(
&self,
typ: &Type,
id: StructId,
method_name: &str,
check_type: bool,
) -> Option<FuncId> {
let methods = self.struct_methods.get(&(id, method_name.to_owned()))?;

// If there is only one method, just return it immediately.
// It will still be typechecked later.
if !check_type && methods.len() == 1 {
return Some(methods[0]);
}

// When adding methods we always check they do not overlap, so there should be
// at most 1 matching method in this list.
for method in methods {
match self.function_meta(method).typ.instantiate(self).0 {
Type::Function(args, _, _) => {
if let Some(object) = args.get(0) {
// TODO #3089: This is dangerous! try_unify may commit type bindings even on failure
if object.try_unify(typ).is_ok() {
return Some(*method);
}
}
}
Type::Error => (),
other => unreachable!("Expected function type, found {other}"),
}
}

None
}

/// Looks up a given method name on the given primitive type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "specialization"
type = "bin"
authors = [""]
compiler_version = "0.16.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
struct Foo<T> {}

impl Foo<u32> {
fn foo(_self: Self) -> Field { 1 }
}

impl Foo<u64> {
fn foo(_self: Self) -> Field { 2 }
}

fn main() {
let f1: Foo<u32> = Foo{};
let f2: Foo<u64> = Foo{};
assert(f1.foo() + f2.foo() == 3);
}
Loading