diff --git a/examples/assignment.no b/examples/assignment.no index b82be55cc..6d966a131 100644 --- a/examples/assignment.no +++ b/examples/assignment.no @@ -1,5 +1,5 @@ struct Thing { - xx: Field, + pub xx: Field, } fn try_to_mutate(thing: Thing) { diff --git a/examples/hint.no b/examples/hint.no index e8ac1bc2d..bea559acf 100644 --- a/examples/hint.no +++ b/examples/hint.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } hint fn mul(lhs: Field, rhs: Field) -> Field { diff --git a/examples/types.no b/examples/types.no index 7aa792d6d..d219d3f3c 100644 --- a/examples/types.no +++ b/examples/types.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } fn main(pub xx: Field, pub yy: Field) { diff --git a/examples/types_array.no b/examples/types_array.no index ef08c51bf..e50830164 100644 --- a/examples/types_array.no +++ b/examples/types_array.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } fn main(pub xx: Field, pub yy: Field) { diff --git a/examples/types_array_output.no b/examples/types_array_output.no index 43c75493e..1693be0c7 100644 --- a/examples/types_array_output.no +++ b/examples/types_array_output.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } fn main(pub xx: Field, pub yy: Field) -> [Thing; 2] { diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index 4bdf45581..dc9f9e3f8 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -706,7 +706,7 @@ impl IRWriter { // find range of field let mut start = 0; let mut len = 0; - for (field, field_typ) in &struct_info.fields { + for (field, field_typ, _attribute) in &struct_info.fields { if field == &rhs.value { len = self.size_of(field_typ); break; diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 8380707ee..70d3a42e8 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -320,7 +320,7 @@ impl CircuitWriter { .clone(); let mut offset = 0; - for (_field_name, field_typ) in &struct_info.fields { + for (_field_name, field_typ, _attribute) in &struct_info.fields { let len = self.size_of(field_typ); let range = offset..(offset + len); self.constrain_inputs_to_main(&input[range], field_typ, span)?; @@ -492,7 +492,7 @@ impl CircuitWriter { // find range of field let mut start = 0; let mut len = 0; - for (field, field_typ) in &struct_info.fields { + for (field, field_typ, _attribute) in &struct_info.fields { if field == &rhs.value { len = self.size_of(field_typ); break; diff --git a/src/error.rs b/src/error.rs index e39b50636..8aa982104 100644 --- a/src/error.rs +++ b/src/error.rs @@ -369,6 +369,10 @@ pub enum ErrorKind { #[error("division by zero")] DivisionByZero, + #[error("cannot access private field `{1}` of struct `{0}` from outside its methods.")] + PrivateFieldAccess(String, String), + #[error("Not enough variables provided to fill placeholders in the formatted string")] InsufficientVariables, + } diff --git a/src/inputs.rs b/src/inputs.rs index d4dbaee5f..04f90ddf0 100644 --- a/src/inputs.rs +++ b/src/inputs.rs @@ -141,7 +141,7 @@ impl CompiledCircuit { // parse each field let mut res = vec![]; - for (field_name, field_ty) in fields { + for (field_name, field_ty, _attribute) in fields { let value = map.remove(field_name).ok_or_else(|| { ParsingError::MissingStructFieldIdent(field_name.to_string()) })?; diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 31e40a3da..09e134602 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -446,7 +446,7 @@ impl Mast { let mut sum = 0; - for (_, t) in &struct_info.fields { + for (_, t, _) in &struct_info.fields { sum += self.size_of(t); } @@ -549,8 +549,8 @@ fn monomorphize_expr( let typ = struct_info .fields .iter() - .find(|(name, _)| name == &rhs.value) - .map(|(_, typ)| typ.clone()); + .find(|(name, _, _)| name == &rhs.value) + .map(|(_, typ, _)| typ.clone()); let mexpr = expr.to_mast( ctx, diff --git a/src/name_resolution/context.rs b/src/name_resolution/context.rs index f06338772..849fe53c9 100644 --- a/src/name_resolution/context.rs +++ b/src/name_resolution/context.rs @@ -151,7 +151,7 @@ impl NameResCtx { self.resolve(module, true)?; // we resolve the fully-qualified types of the fields - for (_field_name, field_typ) in fields { + for (_field_name, field_typ, _attribute) in fields { self.resolve_typ_kind(&mut field_typ.kind)?; } diff --git a/src/negative_tests.rs b/src/negative_tests.rs index dd557cf71..f630f845f 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -700,7 +700,7 @@ fn test_nonhint_call_with_unsafe() { fn test_no_cst_struct_field_prop() { let code = r#" struct Thing { - val: Field, + pub val: Field, } fn gen(const LEN: Field) -> [Field; LEN] { @@ -725,7 +725,7 @@ fn test_no_cst_struct_field_prop() { fn test_mut_cst_struct_field_prop() { let code = r#" struct Thing { - val: Field, + pub val: Field, } fn gen(const LEN: Field) -> [Field; LEN] { @@ -747,3 +747,24 @@ fn test_mut_cst_struct_field_prop() { ErrorKind::ArgumentTypeMismatch(..) )); } + +#[test] +fn test_private_field_access() { + let code = r#" + struct Room { + pub beds: Field, // public + size: Field // private + } + + fn main(pub beds: Field) { + let room = Room {beds: beds, size: 10}; + room.size = 5; // not allowed + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::PrivateFieldAccess(..) + )); +} diff --git a/src/parser/structs.rs b/src/parser/structs.rs index 6bfed47a1..9e5a7a328 100644 --- a/src/parser/structs.rs +++ b/src/parser/structs.rs @@ -3,12 +3,12 @@ use serde::{Deserialize, Serialize}; use crate::{ constants::Span, error::{ErrorKind, Result}, - lexer::{Token, TokenKind, Tokens}, + lexer::{Keyword, Token, TokenKind, Tokens}, syntax::is_type, }; use super::{ - types::{Ident, ModulePath, Ty, TyKind}, + types::{Attribute, AttributeKind, Ident, ModulePath, Ty, TyKind}, Error, ParserCtx, }; @@ -17,7 +17,7 @@ pub struct StructDef { //pub attribute: Attribute, pub module: ModulePath, // name resolution pub name: CustomType, - pub fields: Vec<(Ident, Ty)>, + pub fields: Vec<(Ident, Ty, Option)>, pub span: Span, } @@ -55,6 +55,26 @@ impl StructDef { tokens.bump(ctx); break; } + + // check for pub keyword + // struct Foo { pub a: Field, b: Field } + // ^ + let attribute = if matches!( + tokens.peek(), + Some(Token { + kind: TokenKind::Keyword(Keyword::Pub), + .. + }) + ) { + let token = tokens.bump(ctx).unwrap(); + Some(Attribute { + kind: AttributeKind::Pub, + span: token.span, + }) + } else { + None + }; + // struct Foo { a: Field, b: Field } // ^ let field_name = Ident::parse(ctx, tokens)?; @@ -67,7 +87,7 @@ impl StructDef { // ^^^^^ let field_ty = Ty::parse(ctx, tokens)?; span = span.merge_with(field_ty.span); - fields.push((field_name, field_ty)); + fields.push((field_name, field_ty, attribute)); // struct Foo { a: Field, b: Field } // ^ ^ diff --git a/src/stdlib/native/int/lib.no b/src/stdlib/native/int/lib.no index 315c252cd..648a80e77 100644 --- a/src/stdlib/native/int/lib.no +++ b/src/stdlib/native/int/lib.no @@ -291,4 +291,22 @@ fn Uint32.mod(self, rhs: Uint32) -> Uint32 { fn Uint64.mod(self, rhs: Uint64) -> Uint64 { let res = self.divmod(rhs); return res[1]; -} \ No newline at end of file +} + +// implement to field +fn Uint8.to_field(self) -> Field { + return self.inner; +} + +fn Uint16.to_field(self) -> Field { + return self.inner; +} + +fn Uint32.to_field(self) -> Field { + return self.inner; +} + +fn Uint64.to_field(self) -> Field { + return self.inner; +} + \ No newline at end of file diff --git a/src/tests/modules.rs b/src/tests/modules.rs index 6517cd4b8..e4b18b35c 100644 --- a/src/tests/modules.rs +++ b/src/tests/modules.rs @@ -31,7 +31,7 @@ use mimoo::liblib; // test a library's type that links to its own type struct Inner { - inner: Field, + pub inner: Field, } struct Lib { diff --git a/src/tests/stdlib/uints/mod.rs b/src/tests/stdlib/uints/mod.rs index efd1ede83..474c017bf 100644 --- a/src/tests/stdlib/uints/mod.rs +++ b/src/tests/stdlib/uints/mod.rs @@ -14,7 +14,7 @@ fn main(pub lhs: Field, rhs: Field) -> Field { let res = lhs_u.{opr}(rhs_u); - return res.inner; + return res.to_field(); } "#; diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 28bd32141..70185ad59 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -11,8 +11,8 @@ use crate::{ imports::FnKind, parser::{ types::{ - is_numeric, FnSig, ForLoopArgument, FunctionDef, ModulePath, Stmt, StmtKind, Symbolic, - Ty, TyKind, + is_numeric, Attribute, AttributeKind, FnSig, ForLoopArgument, FuncOrMethod, + FunctionDef, Stmt, StmtKind, Symbolic, Ty, TyKind, ModulePath, }, CustomType, Expr, ExprKind, Op2, }, @@ -58,7 +58,7 @@ impl FnInfo { #[derive(Deserialize, Serialize, Default, Debug, Clone)] pub struct StructInfo { pub name: String, - pub fields: Vec<(String, TyKind)>, + pub fields: Vec<(String, TyKind, Option)>, pub methods: HashMap, } @@ -119,14 +119,36 @@ impl TypeChecker { .expect("this struct is not defined, or you're trying to access a field of a struct defined in a third-party library (TODO: better error)"); // find field type - let res = struct_info + if let Some((_, field_typ, attribute)) = struct_info .fields .iter() - .find(|(name, _)| name == &rhs.value) - .map(|(_, typ)| typ.clone()); + .find(|(field_name, _, _)| field_name == &rhs.value) + { + // check for the pub attribute + let is_public = matches!( + attribute, + &Some(Attribute { + kind: AttributeKind::Pub, + .. + }) + ); + + // check if we're inside a method of the same struct + let in_method = matches!( + typed_fn_env.current_fn_kind(), + FuncOrMethod::Method(method_struct) if method_struct.name == struct_name + ); - if let Some(res) = res { - Some(ExprTyInfo::new(lhs_node.var_name, res)) + if is_public || in_method { + // allow access + Some(ExprTyInfo::new(lhs_node.var_name, field_typ.clone())) + } else { + // block access + Err(self.error( + ErrorKind::PrivateFieldAccess(struct_name.clone(), rhs.value.clone()), + expr.span, + ))? + } } else { return Err(self.error( ErrorKind::UndefinedField(struct_info.name.clone(), rhs.value.clone()), diff --git a/src/type_checker/fn_env.rs b/src/type_checker/fn_env.rs index a8540cfd6..3c6b8eb7a 100644 --- a/src/type_checker/fn_env.rs +++ b/src/type_checker/fn_env.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::{ constants::Span, error::{Error, ErrorKind, Result}, - parser::types::TyKind, + parser::types::{FuncOrMethod, TyKind}, }; /// Some type information on local variables that we want to track in the [TypedFnEnv] environment. @@ -39,7 +39,7 @@ impl TypeInfo { } /// The environment we use to type check functions. -#[derive(Default, Debug, Clone)] +#[derive(Debug, Clone)] pub struct TypedFnEnv { /// The current nesting level. /// Starting at 0 (top level), and increasing as we go into a block. @@ -55,12 +55,21 @@ pub struct TypedFnEnv { /// Determines if forloop variables are allowed to be accessed. forbid_forloop_scope: bool, + + /// The kind of function we're currently type checking + current_fn_kind: FuncOrMethod, } impl TypedFnEnv { - /// Creates a new TypeEnv - pub fn new() -> Self { - Self::default() + /// Creates a new TypeEnv with the given function kind + pub fn new(fn_kind: &FuncOrMethod) -> Self { + Self { + current_scope: 0, + vars: HashMap::new(), + forloop_scopes: Vec::new(), + forbid_forloop_scope: false, + current_fn_kind: fn_kind.clone(), + } } /// Enters a scoped block. @@ -182,4 +191,8 @@ impl TypedFnEnv { Ok(None) } } + + pub fn current_fn_kind(&self) -> &FuncOrMethod { + &self.current_fn_kind + } } diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 2b405c477..7982b080b 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -298,8 +298,8 @@ impl TypeChecker { let fields: Vec<_> = fields .iter() .map(|field| { - let (name, typ) = field; - (name.value.clone(), typ.kind.clone()) + let (name, typ, attribute) = field; + (name.value.clone(), typ.kind.clone(), attribute.clone()) }) .collect(); @@ -329,7 +329,7 @@ impl TypeChecker { // `fn main() { ... }` RootKind::FunctionDef(function) => { // create a new typed fn environment to type check the function - let mut typed_fn_env = TypedFnEnv::default(); + let mut typed_fn_env = TypedFnEnv::new(&function.sig.kind); // if we're expecting a library, this should not be the main function let is_main = function.is_main();