diff --git a/asg/src/expression/call.rs b/asg/src/expression/call.rs index d96845655a..306301576c 100644 --- a/asg/src/expression/call.rs +++ b/asg/src/expression/call.rs @@ -29,7 +29,7 @@ use crate::{ Span, Type, }; -pub use leo_ast::BinaryOperation; +pub use leo_ast::{BinaryOperation, Node as AstNode}; use std::{ cell::RefCell, @@ -195,25 +195,33 @@ impl FromAst for CallExpression { )); } } - if value.arguments.len() != function.argument_types.len() { + if value.arguments.len() != function.arguments.len() { return Err(AsgConvertError::unexpected_call_argument_count( - function.argument_types.len(), + function.arguments.len(), value.arguments.len(), &value.span, )); } + let arguments = value + .arguments + .iter() + .zip(function.arguments.iter()) + .map(|(expr, argument)| { + let argument = argument.borrow(); + let converted = + Arc::::from_ast(scope, expr, Some(argument.type_.clone().strong().partial()))?; + if argument.const_ && !converted.is_consty() { + return Err(AsgConvertError::unexpected_nonconst(&expr.span())); + } + Ok(converted) + }) + .collect::, AsgConvertError>>()?; + Ok(CallExpression { parent: RefCell::new(None), span: Some(value.span.clone()), - arguments: value - .arguments - .iter() - .zip(function.argument_types.iter()) - .map(|(expr, argument)| { - Arc::::from_ast(scope, expr, Some(argument.clone().strong().partial())) - }) - .collect::, AsgConvertError>>()?, + arguments, function, target, }) diff --git a/asg/src/expression/variable_ref.rs b/asg/src/expression/variable_ref.rs index 7815429c6e..9bfdf6da3f 100644 --- a/asg/src/expression/variable_ref.rs +++ b/asg/src/expression/variable_ref.rs @@ -29,7 +29,6 @@ use crate::{ Statement, Type, Variable, - VariableDeclaration, }; use std::{ @@ -62,7 +61,7 @@ impl ExpressionNode for VariableRef { fn enforce_parents(&self, _expr: &Arc) {} fn get_type(&self) -> Option { - Some(self.variable.borrow().type_.clone()) + Some(self.variable.borrow().type_.clone().strong()) } fn is_mut_ref(&self) -> bool { @@ -104,7 +103,7 @@ impl ExpressionNode for VariableRef { fn is_consty(&self) -> bool { let variable = self.variable.borrow(); - if variable.declaration == VariableDeclaration::IterationDefinition { + if variable.const_ { return true; } if variable.mutable || variable.assignments.len() != 1 { diff --git a/asg/src/input.rs b/asg/src/input.rs index 3ee220da4d..b814927f2a 100644 --- a/asg/src/input.rs +++ b/asg/src/input.rs @@ -127,8 +127,9 @@ impl Input { container: Arc::new(RefCell::new(crate::InnerVariable { id: uuid::Uuid::new_v4(), name: Identifier::new("input".to_string()), - type_: Type::Circuit(container_circuit), + type_: Type::Circuit(container_circuit).weak(), mutable: false, + const_: false, declaration: crate::VariableDeclaration::Input, references: vec![], assignments: vec![], diff --git a/asg/src/program/function.rs b/asg/src/program/function.rs index 00f4cfae2d..daff1741f6 100644 --- a/asg/src/program/function.rs +++ b/asg/src/program/function.rs @@ -51,7 +51,7 @@ pub struct Function { pub name: RefCell, pub output: WeakType, pub has_input: bool, - pub argument_types: Vec, + pub arguments: Vec, pub circuit: RefCell>>, pub body: RefCell>, pub qualifier: FunctionQualifier, @@ -71,7 +71,6 @@ impl Eq for Function {} pub struct FunctionBody { pub span: Option, pub function: Arc, - pub arguments: Vec, pub body: Arc, pub scope: Scope, } @@ -94,7 +93,7 @@ impl Function { let mut qualifier = FunctionQualifier::Static; let mut has_input = false; - let mut argument_types = vec![]; + let mut arguments = vec![]; { for input in value.input.iter() { match input { @@ -107,8 +106,24 @@ impl Function { FunctionInput::MutSelfKeyword(_) => { qualifier = FunctionQualifier::MutSelfRef; } - FunctionInput::Variable(leo_ast::FunctionInputVariable { type_, .. }) => { - argument_types.push(scope.borrow().resolve_ast_type(&type_)?.into()); + FunctionInput::Variable(leo_ast::FunctionInputVariable { + identifier, + mutable, + const_, + type_, + span: _span, + }) => { + let variable = Arc::new(RefCell::new(crate::InnerVariable { + id: Uuid::new_v4(), + name: identifier.clone(), + type_: scope.borrow().resolve_ast_type(&type_)?.weak(), + mutable: *mutable, + const_: *const_, + declaration: crate::VariableDeclaration::Parameter, + references: vec![], + assignments: vec![], + })); + arguments.push(variable.clone()); } } } @@ -121,7 +136,7 @@ impl Function { name: RefCell::new(value.identifier.clone()), output: output.into(), has_input, - argument_types, + arguments, circuit: RefCell::new(None), body: RefCell::new(Weak::new()), qualifier, @@ -136,7 +151,6 @@ impl FunctionBody { function: Arc, ) -> Result { let new_scope = InnerScope::make_subscope(scope); - let mut arguments = vec![]; { let mut scope_borrow = new_scope.borrow_mut(); if function.qualifier != FunctionQualifier::Static { @@ -144,8 +158,9 @@ impl FunctionBody { let self_variable = Arc::new(RefCell::new(crate::InnerVariable { id: Uuid::new_v4(), name: Identifier::new("self".to_string()), - type_: Type::Circuit(circuit.as_ref().unwrap().upgrade().unwrap()), + type_: WeakType::Circuit(circuit.as_ref().unwrap().clone()), mutable: function.qualifier == FunctionQualifier::MutSelfRef, + const_: false, declaration: crate::VariableDeclaration::Parameter, references: vec![], assignments: vec![], @@ -153,30 +168,9 @@ impl FunctionBody { scope_borrow.variables.insert("self".to_string(), self_variable); } scope_borrow.function = Some(function.clone()); - for input in value.input.iter() { - match input { - FunctionInput::InputKeyword(_) => {} - FunctionInput::SelfKeyword(_) => {} - FunctionInput::MutSelfKeyword(_) => {} - FunctionInput::Variable(leo_ast::FunctionInputVariable { - identifier, - mutable, - type_, - span: _span, - }) => { - let variable = Arc::new(RefCell::new(crate::InnerVariable { - id: Uuid::new_v4(), - name: identifier.clone(), - type_: scope_borrow.resolve_ast_type(&type_)?, - mutable: *mutable, - declaration: crate::VariableDeclaration::Parameter, - references: vec![], - assignments: vec![], - })); - arguments.push(variable.clone()); - scope_borrow.variables.insert(identifier.name.clone(), variable); - } - } + for argument in function.arguments.iter() { + let name = argument.borrow().name.name.clone(); + scope_borrow.variables.insert(name, argument.clone()); } } let main_block = BlockStatement::from_ast(&new_scope, &value.block, None)?; @@ -200,7 +194,6 @@ impl FunctionBody { Ok(FunctionBody { span: Some(value.span.clone()), function, - arguments, body: Arc::new(Statement::Block(main_block)), scope: new_scope, }) @@ -211,14 +204,16 @@ impl Into for &Function { fn into(self) -> leo_ast::Function { let (input, body, span) = match self.body.borrow().upgrade() { Some(body) => ( - body.arguments + body.function + .arguments .iter() .map(|variable| { let variable = variable.borrow(); leo_ast::FunctionInput::Variable(leo_ast::FunctionInputVariable { identifier: variable.name.clone(), mutable: variable.mutable, - type_: (&variable.type_).into(), + const_: variable.const_, + type_: (&variable.type_.clone().strong()).into(), span: Span::default(), }) }) diff --git a/asg/src/statement/assign.rs b/asg/src/statement/assign.rs index 473989ffe5..10d7be7dc2 100644 --- a/asg/src/statement/assign.rs +++ b/asg/src/statement/assign.rs @@ -94,7 +94,7 @@ impl FromAst for Arc { if !variable.borrow().mutable { return Err(AsgConvertError::immutable_assignment(&name, &statement.span)); } - let mut target_type: Option = Some(variable.borrow().type_.clone().into()); + let mut target_type: Option = Some(variable.borrow().type_.clone().strong().into()); let mut target_accesses = vec![]; for access in statement.assignee.accesses.iter() { diff --git a/asg/src/statement/definition.rs b/asg/src/statement/definition.rs index 656a6ec9b8..438b26b29c 100644 --- a/asg/src/statement/definition.rs +++ b/asg/src/statement/definition.rs @@ -99,8 +99,10 @@ impl FromAst for Arc { id: uuid::Uuid::new_v4(), name: variable.identifier.clone(), type_: type_ - .ok_or_else(|| AsgConvertError::unresolved_type(&variable.identifier.name, &statement.span))?, + .ok_or_else(|| AsgConvertError::unresolved_type(&variable.identifier.name, &statement.span))? + .weak(), mutable: variable.mutable, + const_: false, declaration: crate::VariableDeclaration::Definition, references: vec![], assignments: vec![], @@ -145,7 +147,7 @@ impl Into for &DefinitionStatement { span: variable.name.span.clone(), }); if type_.is_none() { - type_ = Some((&variable.type_).into()); + type_ = Some((&variable.type_.clone().strong()).into()); } } diff --git a/asg/src/statement/iteration.rs b/asg/src/statement/iteration.rs index 13896249a1..bd19f0ab68 100644 --- a/asg/src/statement/iteration.rs +++ b/asg/src/statement/iteration.rs @@ -65,8 +65,10 @@ impl FromAst for Arc { name: statement.variable.clone(), type_: start .get_type() - .ok_or_else(|| AsgConvertError::unresolved_type(&statement.variable.name, &statement.span))?, + .ok_or_else(|| AsgConvertError::unresolved_type(&statement.variable.name, &statement.span))? + .weak(), mutable: false, + const_: true, declaration: crate::VariableDeclaration::IterationDefinition, references: vec![], assignments: vec![], diff --git a/asg/src/type_.rs b/asg/src/type_.rs index 426a93f063..078d9a6cf9 100644 --- a/asg/src/type_.rs +++ b/asg/src/type_.rs @@ -157,6 +157,10 @@ impl Type { self.into() } + pub fn weak(self) -> WeakType { + self.into() + } + pub fn is_unit(&self) -> bool { matches!(self, Type::Tuple(t) if t.is_empty()) } diff --git a/asg/src/variable.rs b/asg/src/variable.rs index b2914db73f..d519da53de 100644 --- a/asg/src/variable.rs +++ b/asg/src/variable.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{Expression, Statement, Type}; +use crate::{Expression, Statement, WeakType}; use leo_ast::Identifier; use std::{ @@ -37,8 +37,9 @@ pub enum VariableDeclaration { pub struct InnerVariable { pub id: Uuid, pub name: Identifier, - pub type_: Type, + pub type_: WeakType, pub mutable: bool, + pub const_: bool, // only function arguments, const var definitions NOT included pub declaration: VariableDeclaration, pub references: Vec>, // all Expression::VariableRef or panic pub assignments: Vec>, // all Statement::Assign or panic -- must be 1 if not mutable, or 0 if declaration == input | parameter diff --git a/asg/tests/pass/function/mod.rs b/asg/tests/pass/function/mod.rs index 7965c39420..abdbd93b7c 100644 --- a/asg/tests/pass/function/mod.rs +++ b/asg/tests/pass/function/mod.rs @@ -28,6 +28,66 @@ fn test_iteration() { load_asg(program_string).unwrap(); } +#[test] +fn test_const_args() { + let program_string = r#" + function one(const value: u32) -> u32 { + return value + 1 + } + + function main() { + let mut a = 0u32; + + for i in 0..10 { + a += one(i); + } + + console.assert(a == 20u32); + } + "#; + load_asg(program_string).unwrap(); +} + +#[test] +fn test_const_args_used() { + let program_string = r#" + function index(arr: [u8; 3], const value: u32) -> u8 { + return arr[value] + } + + function main() { + let mut a = 0u8; + let arr = [1u8, 2, 3]; + + for i in 0..3 { + a += index(arr, i); + } + + console.assert(a == 6u8); + } + "#; + load_asg(program_string).unwrap(); +} + +#[test] +fn test_const_args_fail() { + let program_string = r#" + function index(arr: [u8; 3], const value: u32) -> u8 { + return arr[value] + } + + function main(x_value: u32) { + let mut a = 0u8; + let arr = [1u8, 2, 3]; + + a += index(arr, x_value); + + console.assert(a == 1u8); + } + "#; + load_asg(program_string).err().unwrap(); +} + #[test] fn test_iteration_repeated() { let program_string = include_str!("iteration_repeated.leo"); diff --git a/asg/tests/pass/mutability/mod.rs b/asg/tests/pass/mutability/mod.rs index 6244be256f..5da14da0ae 100644 --- a/asg/tests/pass/mutability/mod.rs +++ b/asg/tests/pass/mutability/mod.rs @@ -65,7 +65,6 @@ fn test_function_input_mut() { } #[test] -#[ignore] fn test_swap() { let program_string = include_str!("swap.leo"); load_asg(program_string).unwrap(); diff --git a/asg/tests/pass/mutability/swap.leo b/asg/tests/pass/mutability/swap.leo index 8234a3cfb7..0a209149ee 100644 --- a/asg/tests/pass/mutability/swap.leo +++ b/asg/tests/pass/mutability/swap.leo @@ -1,5 +1,5 @@ // Swap two elements of an array. -function swap(mut a: [u32; 2], i: u32, j: u32) -> [u32; 2] { +function swap(mut a: [u32; 2], const i: u32, const j: u32) -> [u32; 2] { let t = a[i]; a[i] = a[j]; a[j] = t; diff --git a/ast/src/functions/input/function_input.rs b/ast/src/functions/input/function_input.rs index 755892d9bd..8cece828ad 100644 --- a/ast/src/functions/input/function_input.rs +++ b/ast/src/functions/input/function_input.rs @@ -23,6 +23,7 @@ use std::fmt; #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct FunctionInputVariable { pub identifier: Identifier, + pub const_: bool, pub mutable: bool, pub type_: Type, pub span: Span, @@ -32,6 +33,7 @@ impl<'ast> From> for FunctionInputVariable { fn from(parameter: GrammarFunctionInput<'ast>) -> Self { FunctionInputVariable { identifier: Identifier::from(parameter.identifier), + const_: parameter.const_.is_some(), mutable: parameter.mutable.is_some(), type_: Type::from(parameter.type_), span: Span::from(parameter.span), @@ -42,6 +44,9 @@ impl<'ast> From> for FunctionInputVariable { impl FunctionInputVariable { fn format(&self, f: &mut fmt::Formatter) -> fmt::Result { // mut var: bool + if self.const_ { + write!(f, "const ")?; + } if self.mutable { write!(f, "mut ")?; } diff --git a/compiler/src/function/function.rs b/compiler/src/function/function.rs index 6df2c0a444..efa2ebf29b 100644 --- a/compiler/src/function/function.rs +++ b/compiler/src/function/function.rs @@ -52,7 +52,7 @@ impl> ConstrainedProgram { None }; - if function.arguments.len() != arguments.len() { + if function.function.arguments.len() != arguments.len() { return Err(FunctionError::input_not_found( "arguments length invalid".to_string(), function.span.clone().unwrap_or_default(), @@ -60,7 +60,7 @@ impl> ConstrainedProgram { } // Store input values as new variables in resolved program - for (variable, input_expression) in function.arguments.iter().zip(arguments.iter()) { + for (variable, input_expression) in function.function.arguments.iter().zip(arguments.iter()) { let input_value = self.enforce_expression(cs, input_expression)?; let variable = variable.borrow(); diff --git a/compiler/src/function/main_function.rs b/compiler/src/function/main_function.rs index f1003f8c6c..e858db1b73 100644 --- a/compiler/src/function/main_function.rs +++ b/compiler/src/function/main_function.rs @@ -64,7 +64,7 @@ impl> ConstrainedProgram { let mut arguments = vec![]; - for input_variable in function.arguments.iter() { + for input_variable in function.function.arguments.iter() { { let input_variable = input_variable.borrow(); let name = input_variable.name.name.clone(); @@ -73,7 +73,7 @@ impl> ConstrainedProgram { })?; let input_value = self.allocate_main_function_input( cs, - &input_variable.type_, + &input_variable.type_.clone().strong(), &name, input_option, &function.span.clone().unwrap_or_default(), diff --git a/compiler/tests/mutability/mod.rs b/compiler/tests/mutability/mod.rs index 8acd3f9e51..6c1555c870 100644 --- a/compiler/tests/mutability/mod.rs +++ b/compiler/tests/mutability/mod.rs @@ -149,7 +149,6 @@ fn test_function_input_mut() { } #[test] -#[ignore] fn test_swap() { let program_string = include_str!("swap.leo"); let program = parse_program(program_string).unwrap(); diff --git a/compiler/tests/mutability/swap.leo b/compiler/tests/mutability/swap.leo index 8234a3cfb7..0a209149ee 100644 --- a/compiler/tests/mutability/swap.leo +++ b/compiler/tests/mutability/swap.leo @@ -1,5 +1,5 @@ // Swap two elements of an array. -function swap(mut a: [u32; 2], i: u32, j: u32) -> [u32; 2] { +function swap(mut a: [u32; 2], const i: u32, const j: u32) -> [u32; 2] { let t = a[i]; a[i] = a[j]; a[j] = t; diff --git a/grammar/src/functions/input/function_input.rs b/grammar/src/functions/input/function_input.rs index efce2a17db..3cac32b798 100644 --- a/grammar/src/functions/input/function_input.rs +++ b/grammar/src/functions/input/function_input.rs @@ -16,7 +16,7 @@ use crate::{ ast::Rule, - common::{Identifier, Mutable}, + common::{Const, Identifier, Mutable}, types::Type, SpanDef, }; @@ -28,6 +28,7 @@ use serde::Serialize; #[derive(Clone, Debug, FromPest, PartialEq, Serialize)] #[pest_ast(rule(Rule::function_input))] pub struct FunctionInput<'ast> { + pub const_: Option, pub mutable: Option, pub identifier: Identifier<'ast>, pub type_: Type<'ast>, diff --git a/grammar/src/leo.pest b/grammar/src/leo.pest index 8b8eabd4e3..1d955ab44b 100644 --- a/grammar/src/leo.pest +++ b/grammar/src/leo.pest @@ -436,7 +436,7 @@ statement_return = { "return " ~ expression} function = { "function " ~ identifier ~ input_tuple ~ ("->" ~ type_)? ~ block } // Declared in functions/input/function_input.rs -function_input = { mutable? ~ identifier ~ ":" ~ type_ } +function_input = { const_? ~ mutable? ~ identifier ~ ":" ~ type_ } // Declared in functions/input/input_keyword.rs input_keyword = { "input" }