Skip to content

Commit

Permalink
Merge pull request #617 from AleoHQ/const-args
Browse files Browse the repository at this point in the history
[Feature] Const Function Arguments
  • Loading branch information
collinc97 authored Feb 5, 2021
2 parents 2e121b2 + 18ebf0c commit 770f660
Show file tree
Hide file tree
Showing 19 changed files with 142 additions and 66 deletions.
30 changes: 19 additions & 11 deletions asg/src/expression/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
Span,
Type,
};
pub use leo_ast::BinaryOperation;
pub use leo_ast::{BinaryOperation, Node as AstNode};

use std::{
cell::RefCell,
Expand Down Expand Up @@ -195,25 +195,33 @@ impl FromAst<leo_ast::CallExpression> for CallExpression {
));
}
}
if value.arguments.len() != function.argument_types.len() {
if value.arguments.len() != function.arguments.len() {
return Err(AsgConvertError::unexpected_call_argument_count(
function.argument_types.len(),
function.arguments.len(),
value.arguments.len(),
&value.span,
));
}

let arguments = value
.arguments
.iter()
.zip(function.arguments.iter())
.map(|(expr, argument)| {
let argument = argument.borrow();
let converted =
Arc::<Expression>::from_ast(scope, expr, Some(argument.type_.clone().strong().partial()))?;
if argument.const_ && !converted.is_consty() {
return Err(AsgConvertError::unexpected_nonconst(&expr.span()));
}
Ok(converted)
})
.collect::<Result<Vec<_>, AsgConvertError>>()?;

Ok(CallExpression {
parent: RefCell::new(None),
span: Some(value.span.clone()),
arguments: value
.arguments
.iter()
.zip(function.argument_types.iter())
.map(|(expr, argument)| {
Arc::<Expression>::from_ast(scope, expr, Some(argument.clone().strong().partial()))
})
.collect::<Result<Vec<_>, AsgConvertError>>()?,
arguments,
function,
target,
})
Expand Down
5 changes: 2 additions & 3 deletions asg/src/expression/variable_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::{
Statement,
Type,
Variable,
VariableDeclaration,
};

use std::{
Expand Down Expand Up @@ -62,7 +61,7 @@ impl ExpressionNode for VariableRef {
fn enforce_parents(&self, _expr: &Arc<Expression>) {}

fn get_type(&self) -> Option<Type> {
Some(self.variable.borrow().type_.clone())
Some(self.variable.borrow().type_.clone().strong())
}

fn is_mut_ref(&self) -> bool {
Expand Down Expand Up @@ -104,7 +103,7 @@ impl ExpressionNode for VariableRef {

fn is_consty(&self) -> bool {
let variable = self.variable.borrow();
if variable.declaration == VariableDeclaration::IterationDefinition {
if variable.const_ {
return true;
}
if variable.mutable || variable.assignments.len() != 1 {
Expand Down
3 changes: 2 additions & 1 deletion asg/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ impl Input {
container: Arc::new(RefCell::new(crate::InnerVariable {
id: uuid::Uuid::new_v4(),
name: Identifier::new("input".to_string()),
type_: Type::Circuit(container_circuit),
type_: Type::Circuit(container_circuit).weak(),
mutable: false,
const_: false,
declaration: crate::VariableDeclaration::Input,
references: vec![],
assignments: vec![],
Expand Down
65 changes: 30 additions & 35 deletions asg/src/program/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub struct Function {
pub name: RefCell<Identifier>,
pub output: WeakType,
pub has_input: bool,
pub argument_types: Vec<WeakType>,
pub arguments: Vec<Variable>,
pub circuit: RefCell<Option<Weak<Circuit>>>,
pub body: RefCell<Weak<FunctionBody>>,
pub qualifier: FunctionQualifier,
Expand All @@ -71,7 +71,6 @@ impl Eq for Function {}
pub struct FunctionBody {
pub span: Option<Span>,
pub function: Arc<Function>,
pub arguments: Vec<Variable>,
pub body: Arc<Statement>,
pub scope: Scope,
}
Expand All @@ -94,7 +93,7 @@ impl Function {
let mut qualifier = FunctionQualifier::Static;
let mut has_input = false;

let mut argument_types = vec![];
let mut arguments = vec![];
{
for input in value.input.iter() {
match input {
Expand All @@ -107,8 +106,24 @@ impl Function {
FunctionInput::MutSelfKeyword(_) => {
qualifier = FunctionQualifier::MutSelfRef;
}
FunctionInput::Variable(leo_ast::FunctionInputVariable { type_, .. }) => {
argument_types.push(scope.borrow().resolve_ast_type(&type_)?.into());
FunctionInput::Variable(leo_ast::FunctionInputVariable {
identifier,
mutable,
const_,
type_,
span: _span,
}) => {
let variable = Arc::new(RefCell::new(crate::InnerVariable {
id: Uuid::new_v4(),
name: identifier.clone(),
type_: scope.borrow().resolve_ast_type(&type_)?.weak(),
mutable: *mutable,
const_: *const_,
declaration: crate::VariableDeclaration::Parameter,
references: vec![],
assignments: vec![],
}));
arguments.push(variable.clone());
}
}
}
Expand All @@ -121,7 +136,7 @@ impl Function {
name: RefCell::new(value.identifier.clone()),
output: output.into(),
has_input,
argument_types,
arguments,
circuit: RefCell::new(None),
body: RefCell::new(Weak::new()),
qualifier,
Expand All @@ -136,47 +151,26 @@ impl FunctionBody {
function: Arc<Function>,
) -> Result<FunctionBody, AsgConvertError> {
let new_scope = InnerScope::make_subscope(scope);
let mut arguments = vec![];
{
let mut scope_borrow = new_scope.borrow_mut();
if function.qualifier != FunctionQualifier::Static {
let circuit = function.circuit.borrow();
let self_variable = Arc::new(RefCell::new(crate::InnerVariable {
id: Uuid::new_v4(),
name: Identifier::new("self".to_string()),
type_: Type::Circuit(circuit.as_ref().unwrap().upgrade().unwrap()),
type_: WeakType::Circuit(circuit.as_ref().unwrap().clone()),
mutable: function.qualifier == FunctionQualifier::MutSelfRef,
const_: false,
declaration: crate::VariableDeclaration::Parameter,
references: vec![],
assignments: vec![],
}));
scope_borrow.variables.insert("self".to_string(), self_variable);
}
scope_borrow.function = Some(function.clone());
for input in value.input.iter() {
match input {
FunctionInput::InputKeyword(_) => {}
FunctionInput::SelfKeyword(_) => {}
FunctionInput::MutSelfKeyword(_) => {}
FunctionInput::Variable(leo_ast::FunctionInputVariable {
identifier,
mutable,
type_,
span: _span,
}) => {
let variable = Arc::new(RefCell::new(crate::InnerVariable {
id: Uuid::new_v4(),
name: identifier.clone(),
type_: scope_borrow.resolve_ast_type(&type_)?,
mutable: *mutable,
declaration: crate::VariableDeclaration::Parameter,
references: vec![],
assignments: vec![],
}));
arguments.push(variable.clone());
scope_borrow.variables.insert(identifier.name.clone(), variable);
}
}
for argument in function.arguments.iter() {
let name = argument.borrow().name.name.clone();
scope_borrow.variables.insert(name, argument.clone());
}
}
let main_block = BlockStatement::from_ast(&new_scope, &value.block, None)?;
Expand All @@ -200,7 +194,6 @@ impl FunctionBody {
Ok(FunctionBody {
span: Some(value.span.clone()),
function,
arguments,
body: Arc::new(Statement::Block(main_block)),
scope: new_scope,
})
Expand All @@ -211,14 +204,16 @@ impl Into<leo_ast::Function> for &Function {
fn into(self) -> leo_ast::Function {
let (input, body, span) = match self.body.borrow().upgrade() {
Some(body) => (
body.arguments
body.function
.arguments
.iter()
.map(|variable| {
let variable = variable.borrow();
leo_ast::FunctionInput::Variable(leo_ast::FunctionInputVariable {
identifier: variable.name.clone(),
mutable: variable.mutable,
type_: (&variable.type_).into(),
const_: variable.const_,
type_: (&variable.type_.clone().strong()).into(),
span: Span::default(),
})
})
Expand Down
2 changes: 1 addition & 1 deletion asg/src/statement/assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl FromAst<leo_ast::AssignStatement> for Arc<Statement> {
if !variable.borrow().mutable {
return Err(AsgConvertError::immutable_assignment(&name, &statement.span));
}
let mut target_type: Option<PartialType> = Some(variable.borrow().type_.clone().into());
let mut target_type: Option<PartialType> = Some(variable.borrow().type_.clone().strong().into());

let mut target_accesses = vec![];
for access in statement.assignee.accesses.iter() {
Expand Down
6 changes: 4 additions & 2 deletions asg/src/statement/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ impl FromAst<leo_ast::DefinitionStatement> for Arc<Statement> {
id: uuid::Uuid::new_v4(),
name: variable.identifier.clone(),
type_: type_
.ok_or_else(|| AsgConvertError::unresolved_type(&variable.identifier.name, &statement.span))?,
.ok_or_else(|| AsgConvertError::unresolved_type(&variable.identifier.name, &statement.span))?
.weak(),
mutable: variable.mutable,
const_: false,
declaration: crate::VariableDeclaration::Definition,
references: vec![],
assignments: vec![],
Expand Down Expand Up @@ -145,7 +147,7 @@ impl Into<leo_ast::DefinitionStatement> for &DefinitionStatement {
span: variable.name.span.clone(),
});
if type_.is_none() {
type_ = Some((&variable.type_).into());
type_ = Some((&variable.type_.clone().strong()).into());
}
}

Expand Down
4 changes: 3 additions & 1 deletion asg/src/statement/iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ impl FromAst<leo_ast::IterationStatement> for Arc<Statement> {
name: statement.variable.clone(),
type_: start
.get_type()
.ok_or_else(|| AsgConvertError::unresolved_type(&statement.variable.name, &statement.span))?,
.ok_or_else(|| AsgConvertError::unresolved_type(&statement.variable.name, &statement.span))?
.weak(),
mutable: false,
const_: true,
declaration: crate::VariableDeclaration::IterationDefinition,
references: vec![],
assignments: vec![],
Expand Down
4 changes: 4 additions & 0 deletions asg/src/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ impl Type {
self.into()
}

pub fn weak(self) -> WeakType {
self.into()
}

pub fn is_unit(&self) -> bool {
matches!(self, Type::Tuple(t) if t.is_empty())
}
Expand Down
5 changes: 3 additions & 2 deletions asg/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::{Expression, Statement, Type};
use crate::{Expression, Statement, WeakType};
use leo_ast::Identifier;

use std::{
Expand All @@ -37,8 +37,9 @@ pub enum VariableDeclaration {
pub struct InnerVariable {
pub id: Uuid,
pub name: Identifier,
pub type_: Type,
pub type_: WeakType,
pub mutable: bool,
pub const_: bool, // only function arguments, const var definitions NOT included
pub declaration: VariableDeclaration,
pub references: Vec<Weak<Expression>>, // all Expression::VariableRef or panic
pub assignments: Vec<Weak<Statement>>, // all Statement::Assign or panic -- must be 1 if not mutable, or 0 if declaration == input | parameter
Expand Down
60 changes: 60 additions & 0 deletions asg/tests/pass/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,66 @@ fn test_iteration() {
load_asg(program_string).unwrap();
}

#[test]
fn test_const_args() {
let program_string = r#"
function one(const value: u32) -> u32 {
return value + 1
}
function main() {
let mut a = 0u32;
for i in 0..10 {
a += one(i);
}
console.assert(a == 20u32);
}
"#;
load_asg(program_string).unwrap();
}

#[test]
fn test_const_args_used() {
let program_string = r#"
function index(arr: [u8; 3], const value: u32) -> u8 {
return arr[value]
}
function main() {
let mut a = 0u8;
let arr = [1u8, 2, 3];
for i in 0..3 {
a += index(arr, i);
}
console.assert(a == 6u8);
}
"#;
load_asg(program_string).unwrap();
}

#[test]
fn test_const_args_fail() {
let program_string = r#"
function index(arr: [u8; 3], const value: u32) -> u8 {
return arr[value]
}
function main(x_value: u32) {
let mut a = 0u8;
let arr = [1u8, 2, 3];
a += index(arr, x_value);
console.assert(a == 1u8);
}
"#;
load_asg(program_string).err().unwrap();
}

#[test]
fn test_iteration_repeated() {
let program_string = include_str!("iteration_repeated.leo");
Expand Down
1 change: 0 additions & 1 deletion asg/tests/pass/mutability/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ fn test_function_input_mut() {
}

#[test]
#[ignore]
fn test_swap() {
let program_string = include_str!("swap.leo");
load_asg(program_string).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion asg/tests/pass/mutability/swap.leo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Swap two elements of an array.
function swap(mut a: [u32; 2], i: u32, j: u32) -> [u32; 2] {
function swap(mut a: [u32; 2], const i: u32, const j: u32) -> [u32; 2] {
let t = a[i];
a[i] = a[j];
a[j] = t;
Expand Down
Loading

0 comments on commit 770f660

Please sign in to comment.