Skip to content

Commit

Permalink
feat: check argument count and types on attribute function callback (#…
Browse files Browse the repository at this point in the history
…5921)

# Description

## Problem

Resolves #5903

## Summary

Now the compiler will check that attribute function callbacks have at
least one argument, that that argument's type matches the corresponding
type, and that remaining arguments also match the types given.

Also previously errors on these callbacks were shown on the function
that had the attribute, instead of on the attribute, likely because
attributes didn't have a Span attached to them: this PR adds that too.

## Additional Context

The error message is still a bit strange because if you have code like
this:

```rust
#[attr]
fn foo() {}

fn attr() {}

fn main() {}
```

You get this:

```
error: Expected 0 arguments, but 1 was provided
  ┌─ src/main.nr:1:3
  │
1 │ #[attr]
  │   ---- Too many arguments
```

which kind of makes sense, because 1 implicit argument was provided but
0 are expected in the callback, but maybe the error should point out
that the callback actually needs one argument. Let me know if you think
we should improve the error message here... but at least it doesn't
error anymore.

Oh, I remember why I didn't improve that error message: the error should
likely be on the callback function, but it should point out that the
error happens because of a given attribute, so we need two different
locations for the error, which I think we currently doesn't support.

## Documentation

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
asterite authored Sep 4, 2024
1 parent d2caa5b commit 91f693d
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 63 deletions.
4 changes: 2 additions & 2 deletions aztec_macros/src/utils/ast_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ pub fn check_trait_method_implemented(trait_impl: &NoirTraitImpl, method_name: &

/// Checks if an attribute is a custom attribute with a specific name
pub fn is_custom_attribute(attr: &SecondaryAttribute, attribute_name: &str) -> bool {
if let SecondaryAttribute::Custom(custom_attr) = attr {
custom_attr.as_str() == attribute_name
if let SecondaryAttribute::Custom(custom_attribute) = attr {
custom_attribute.contents.as_str() == attribute_name
} else {
false
}
Expand Down
10 changes: 7 additions & 3 deletions compiler/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,13 @@ fn compile_contract_inner(
.attributes
.secondary
.iter()
.filter_map(
|attr| if let SecondaryAttribute::Custom(tag) = attr { Some(tag) } else { None },
)
.filter_map(|attr| {
if let SecondaryAttribute::Custom(attribute) = attr {
Some(&attribute.contents)
} else {
None
}
})
.cloned()
.collect();

Expand Down
4 changes: 4 additions & 0 deletions compiler/noirc_errors/src/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ impl Span {
let other_distance = other.end() - other.start();
self_distance < other_distance
}

pub fn shift_by(&self, offset: u32) -> Span {
Self::from(self.start() + offset..self.end() + offset)
}
}

impl From<Span> for Range<usize> {
Expand Down
94 changes: 71 additions & 23 deletions compiler/noirc_frontend/src/elaborator/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
},
node_interner::{DefinitionKind, DependencyId, FuncId, TraitId},
parser::{self, TopLevelStatement},
Type, TypeBindings,
Type, TypeBindings, UnificationError,
};

use super::{Elaborator, FunctionContext, ResolverMeta};
Expand Down Expand Up @@ -96,10 +96,14 @@ impl<'context> Elaborator<'context> {
generated_items: &mut CollectedItems,
) {
for attribute in attributes {
if let SecondaryAttribute::Custom(name) = attribute {
if let Err(error) =
self.run_comptime_attribute_on_item(name, item.clone(), span, generated_items)
{
if let SecondaryAttribute::Custom(attribute) = attribute {
if let Err(error) = self.run_comptime_attribute_on_item(
&attribute.contents,
item.clone(),
span,
attribute.contents_span,
generated_items,
) {
self.errors.push(error);
}
}
Expand All @@ -111,10 +115,11 @@ impl<'context> Elaborator<'context> {
attribute: &str,
item: Value,
span: Span,
attribute_span: Span,
generated_items: &mut CollectedItems,
) -> Result<(), (CompilationError, FileId)> {
let location = Location::new(span, self.file);
let Some((function, arguments)) = Self::parse_attribute(attribute, self.file)? else {
let location = Location::new(attribute_span, self.file);
let Some((function, arguments)) = Self::parse_attribute(attribute, location)? else {
// Do not issue an error if the attribute is unknown
return Ok(());
};
Expand All @@ -141,12 +146,17 @@ impl<'context> Elaborator<'context> {
};

let mut interpreter = self.setup_interpreter();
let mut arguments =
Self::handle_attribute_arguments(&mut interpreter, function, arguments, location)
.map_err(|error| {
let file = error.get_location().file;
(error.into(), file)
})?;
let mut arguments = Self::handle_attribute_arguments(
&mut interpreter,
&item,
function,
arguments,
location,
)
.map_err(|error| {
let file = error.get_location().file;
(error.into(), file)
})?;

arguments.insert(0, (item, location));

Expand All @@ -170,33 +180,62 @@ impl<'context> Elaborator<'context> {
#[allow(clippy::type_complexity)]
pub(crate) fn parse_attribute(
annotation: &str,
file: FileId,
location: Location,
) -> Result<Option<(Expression, Vec<Expression>)>, (CompilationError, FileId)> {
let (tokens, mut lexing_errors) = Lexer::lex(annotation);
if !lexing_errors.is_empty() {
return Err((lexing_errors.swap_remove(0).into(), file));
return Err((lexing_errors.swap_remove(0).into(), location.file));
}

let expression = parser::expression()
.parse(tokens)
.map_err(|mut errors| (errors.swap_remove(0).into(), file))?;
.map_err(|mut errors| (errors.swap_remove(0).into(), location.file))?;

Ok(match expression.kind {
ExpressionKind::Call(call) => Some((*call.func, call.arguments)),
ExpressionKind::Variable(_) => Some((expression, Vec::new())),
_ => None,
})
let (mut func, mut arguments) = match expression.kind {
ExpressionKind::Call(call) => (*call.func, call.arguments),
ExpressionKind::Variable(_) => (expression, Vec::new()),
_ => return Ok(None),
};

func.span = func.span.shift_by(location.span.start());

for argument in &mut arguments {
argument.span = argument.span.shift_by(location.span.start());
}

Ok(Some((func, arguments)))
}

fn handle_attribute_arguments(
interpreter: &mut Interpreter,
item: &Value,
function: FuncId,
arguments: Vec<Expression>,
location: Location,
) -> Result<Vec<(Value, Location)>, InterpreterError> {
let meta = interpreter.elaborator.interner.function_meta(&function);

let mut parameters = vecmap(&meta.parameters.0, |(_, typ, _)| typ.clone());

if parameters.is_empty() {
return Err(InterpreterError::ArgumentCountMismatch {
expected: 0,
actual: arguments.len() + 1,
location,
});
}

let expected_type = item.get_type();
let expected_type = expected_type.as_ref();

if &parameters[0] != expected_type {
return Err(InterpreterError::TypeMismatch {
expected: parameters[0].clone(),
actual: expected_type.clone(),
location,
});
}

// Remove the initial parameter for the comptime item since that is not included
// in `arguments` at this point.
parameters.remove(0);
Expand All @@ -213,6 +252,7 @@ impl<'context> Elaborator<'context> {
let mut varargs = im::Vector::new();

for (i, arg) in arguments.into_iter().enumerate() {
let arg_location = Location::new(arg.span, location.file);
let param_type = parameters.get(i).or(varargs_elem_type).unwrap_or(&Type::Error);

let mut push_arg = |arg| {
Expand All @@ -233,9 +273,17 @@ impl<'context> Elaborator<'context> {
}?;
push_arg(Value::TraitDefinition(trait_id));
} else {
let expr_id = interpreter.elaborator.elaborate_expression(arg).0;
let (expr_id, expr_type) = interpreter.elaborator.elaborate_expression(arg);
push_arg(interpreter.evaluate(expr_id)?);
}

if let Err(UnificationError) = expr_type.unify(param_type) {
return Err(InterpreterError::TypeMismatch {
expected: param_type.clone(),
actual: expr_type,
location: arg_location,
});
}
};
}

if is_varargs {
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
node_interner::{
DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, ReferenceId, TraitId, TypeAliasId,
},
token::CustomAtrribute,
Shared, Type, TypeVariable,
};
use crate::{
Expand Down Expand Up @@ -819,7 +820,7 @@ impl<'context> Elaborator<'context> {
let attributes = func.secondary_attributes().iter();
let attributes =
attributes.filter_map(|secondary_attribute| secondary_attribute.as_custom());
let attributes = attributes.map(|str| str.to_string()).collect();
let attributes: Vec<CustomAtrribute> = attributes.cloned().collect();

let meta = FuncMeta {
name: name_ident,
Expand Down
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{
TraitImplKind, TraitMethodId,
},
Generics, Kind, ResolvedGeneric, Type, TypeBinding, TypeBindings, TypeVariable,
TypeVariableKind,
TypeVariableKind, UnificationError,
};

use super::{lints, Elaborator};
Expand Down Expand Up @@ -713,9 +713,9 @@ impl<'context> Elaborator<'context> {
expected: &Type,
make_error: impl FnOnce() -> TypeCheckError,
) {
let mut errors = Vec::new();
actual.unify(expected, &mut errors, make_error);
self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file)));
if let Err(UnificationError) = actual.unify(expected) {
self.errors.push((make_error().into(), self.file));
}
}

/// Wrapper of Type::unify_with_coercions using self.errors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ fn function_def_has_named_attribute(
let name = name.iter().map(|token| token.to_string()).collect::<Vec<_>>().join("");

for attribute in attributes {
let parse_result = Elaborator::parse_attribute(attribute, location.file);
let parse_result = Elaborator::parse_attribute(&attribute.contents, location);
let Ok(Some((function, _arguments))) = parse_result else {
continue;
};
Expand Down
3 changes: 2 additions & 1 deletion compiler/noirc_frontend/src/hir_def/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::graph::CrateId;
use crate::hir::def_map::LocalModuleId;
use crate::macros_api::{BlockExpression, StructId};
use crate::node_interner::{ExprId, NodeInterner, TraitId, TraitImplId};
use crate::token::CustomAtrribute;
use crate::{ResolvedGeneric, Type};

/// A Hir function is a block expression with a list of statements.
Expand Down Expand Up @@ -166,7 +167,7 @@ pub struct FuncMeta {
pub self_type: Option<Type>,

/// Custom attributes attached to this function.
pub custom_attributes: Vec<String>,
pub custom_attributes: Vec<CustomAtrribute>,
}

#[derive(Debug, Clone)]
Expand Down
18 changes: 5 additions & 13 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1467,21 +1467,13 @@ impl Type {
/// equal to the other type in the process. When comparing types, unification
/// (including try_unify) are almost always preferred over Type::eq as unification
/// will correctly handle generic types.
pub fn unify(
&self,
expected: &Type,
errors: &mut Vec<TypeCheckError>,
make_error: impl FnOnce() -> TypeCheckError,
) {
pub fn unify(&self, expected: &Type) -> Result<(), UnificationError> {
let mut bindings = TypeBindings::new();

match self.try_unify(expected, &mut bindings) {
Ok(()) => {
// Commit any type bindings on success
Self::apply_type_bindings(bindings);
}
Err(UnificationError) => errors.push(make_error()),
}
self.try_unify(expected, &mut bindings).map(|()| {
// Commit any type bindings on success
Self::apply_type_bindings(bindings);
})
}

/// `try_unify` is a bit of a misnomer since although errors are not committed,
Expand Down
19 changes: 14 additions & 5 deletions compiler/noirc_frontend/src/lexer/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,12 @@ impl<'a> Lexer<'a> {
}
self.next_char();

let contents_start = self.position + 1;

let word = self.eat_while(None, |ch| ch != ']');

let contents_end = self.position;

if !self.peek_char_is(']') {
return Err(LexerErrorKind::UnexpectedCharacter {
span: Span::single_char(self.position),
Expand All @@ -308,7 +312,10 @@ impl<'a> Lexer<'a> {

let end = self.position;

let attribute = Attribute::lookup_attribute(&word, Span::inclusive(start, end))?;
let span = Span::inclusive(start, end);
let contents_span = Span::inclusive(contents_start, contents_end);

let attribute = Attribute::lookup_attribute(&word, span, contents_span)?;

Ok(attribute.into_span(start, end))
}
Expand Down Expand Up @@ -682,7 +689,7 @@ mod tests {
use iter_extended::vecmap;

use super::*;
use crate::token::{FunctionAttribute, SecondaryAttribute, TestScope};
use crate::token::{CustomAtrribute, FunctionAttribute, SecondaryAttribute, TestScope};

#[test]
fn test_single_double_char() {
Expand Down Expand Up @@ -810,9 +817,11 @@ mod tests {
let token = lexer.next_token().unwrap();
assert_eq!(
token.token(),
&Token::Attribute(Attribute::Secondary(SecondaryAttribute::Custom(
"custom(hello)".to_string()
)))
&Token::Attribute(Attribute::Secondary(SecondaryAttribute::Custom(CustomAtrribute {
contents: "custom(hello)".to_string(),
span: Span::from(0..16),
contents_span: Span::from(2..15)
})))
);
}

Expand Down
Loading

0 comments on commit 91f693d

Please sign in to comment.