From b4d163cce4c1d04eb7ebc25a504710200ba36e8c Mon Sep 17 00:00:00 2001 From: Tom Anderson Date: Wed, 17 Jul 2024 23:30:46 +1000 Subject: [PATCH] implement `TypeCheckCtx` --- src/compile_pass.rs | 29 ++++- src/ctx/symbol_map.rs | 1 - src/main.rs | 9 +- src/repr/ast/typed.rs | 7 +- src/repr/identifier/function.rs | 3 + src/repr/identifier/mod.rs | 3 + src/repr/ir/function.rs | 7 +- src/repr/ir/triple/mod.rs | 4 +- src/repr/mod.rs | 1 + src/stage/codegen/llvm/mod.rs | 21 +-- src/stage/lower_ir/ctx.rs | 15 +-- src/stage/lower_ir/lowering/mod.rs | 9 +- src/stage/type_check/ctx.rs | 142 +++++++++++++++++++++ src/stage/type_check/expression/block.rs | 31 +++-- src/stage/type_check/expression/call.rs | 15 +-- src/stage/type_check/expression/ident.rs | 62 ++++----- src/stage/type_check/expression/if_else.rs | 8 +- src/stage/type_check/expression/infix.rs | 25 ++-- src/stage/type_check/expression/mod.rs | 18 ++- src/stage/type_check/function.rs | 20 ++- src/stage/type_check/mod.rs | 33 +---- src/stage/type_check/program.rs | 20 ++- src/stage/type_check/statement.rs | 84 ++++++++---- tests/programs.rs | 9 +- 24 files changed, 380 insertions(+), 196 deletions(-) delete mode 100644 src/ctx/symbol_map.rs create mode 100644 src/repr/identifier/function.rs create mode 100644 src/repr/identifier/mod.rs create mode 100644 src/stage/type_check/ctx.rs diff --git a/src/compile_pass.rs b/src/compile_pass.rs index 856ed5a..20b8635 100644 --- a/src/compile_pass.rs +++ b/src/compile_pass.rs @@ -1,11 +1,22 @@ +use std::collections::HashMap; + +use index_vec::IndexVec; + use crate::{ - stage::parse::ParseCtx, + repr::identifier::FunctionIdx, + stage::{ + parse::ParseCtx, + type_check::{FunctionSignature, TypeCheckCtx}, + }, util::symbol_map::{interner_symbol_map::*, SymbolMap}, }; #[derive(Default)] pub struct CompilePass { symbols: InternerSymbolMap, + + function_signatures: IndexVec, + function_symbols: HashMap, } impl SymbolMap for CompilePass { @@ -28,3 +39,19 @@ impl SymbolMap for CompilePass { } impl ParseCtx for CompilePass {} + +impl TypeCheckCtx for CompilePass { + fn register_function(&mut self, symbol: Symbol, signature: FunctionSignature) -> FunctionIdx { + let idx = self.function_signatures.push(signature); + self.function_symbols.insert(symbol, idx); + idx + } + + fn get_function(&self, idx: FunctionIdx) -> FunctionSignature { + self.function_signatures[idx].clone() + } + + fn lookup_function_symbol(&self, symbol: Symbol) -> Option { + self.function_symbols.get(&symbol).cloned() + } +} diff --git a/src/ctx/symbol_map.rs b/src/ctx/symbol_map.rs deleted file mode 100644 index 6d87067..0000000 --- a/src/ctx/symbol_map.rs +++ /dev/null @@ -1 +0,0 @@ -use crate::util::symbol_map::InternerSymbolMap; diff --git a/src/main.rs b/src/main.rs index c8569e2..7775ed3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,7 @@ fn main() -> int { } }; - let program = match program.ty_solve() { + let program = match program.ty_solve(&mut ctx) { Ok(program) => program, Err(e) => { eprintln!("{e}"); @@ -34,14 +34,11 @@ fn main() -> int { } }; - let ir_ctx = ir::lower(program); - let _main = ir_ctx - .function_for_name("main") - .expect("main function to exist"); + let main = program.main.name; + let ir_ctx = ir::lower(program); let ctx = inkwell::context::Context::create(); - let main = ir_ctx.symbol_map.get("main").unwrap(); let function_ids = ir_ctx.functions.keys().cloned().collect::>(); let mut llvm_pass = Pass::new(&ctx, ir_ctx); function_ids.into_iter().for_each(|function| { diff --git a/src/repr/ast/typed.rs b/src/repr/ast/typed.rs index ccc79e4..49ee0a6 100644 --- a/src/repr/ast/typed.rs +++ b/src/repr/ast/typed.rs @@ -1,6 +1,6 @@ use crate::{ - generate_ast, repr::ty::Ty, stage::lower_ir::FunctionIdx, - util::symbol_map::interner_symbol_map::Symbol, + generate_ast, + repr::{identifier::FunctionIdx, ty::Ty}, }; #[derive(Clone, Debug)] @@ -11,6 +11,5 @@ pub struct TyInfo { generate_ast! { TyInfo: TyInfo, - // TODO: Something else for this - FnIdentifier: Symbol + FnIdentifier: FunctionIdx } diff --git a/src/repr/identifier/function.rs b/src/repr/identifier/function.rs new file mode 100644 index 0000000..32f6c87 --- /dev/null +++ b/src/repr/identifier/function.rs @@ -0,0 +1,3 @@ +use index_vec::define_index_type; + +define_index_type! {pub struct FunctionIdx = usize;} diff --git a/src/repr/identifier/mod.rs b/src/repr/identifier/mod.rs new file mode 100644 index 0000000..2d74470 --- /dev/null +++ b/src/repr/identifier/mod.rs @@ -0,0 +1,3 @@ +mod function; + +pub use function::FunctionIdx; diff --git a/src/repr/ir/function.rs b/src/repr/ir/function.rs index 4b94a2a..6137d6c 100644 --- a/src/repr/ir/function.rs +++ b/src/repr/ir/function.rs @@ -3,8 +3,8 @@ use std::collections::HashSet; use index_vec::IndexVec; use crate::{ - repr::ast::typed as ast, - stage::{lower_ir::FunctionIdx, type_check::FunctionSignature}, + repr::{ast::typed as ast, identifier::FunctionIdx}, + stage::type_check::FunctionSignature, util::symbol_map::interner_symbol_map::Symbol, }; @@ -16,8 +16,7 @@ index_vec::define_index_type! { #[derive(Debug, Clone)] pub struct Function { - // TODO: This should be something else - pub symbol: Symbol, + pub symbol: FunctionIdx, pub signature: FunctionSignature, pub basic_blocks: IndexVec, pub scope: HashSet, diff --git a/src/repr/ir/triple/mod.rs b/src/repr/ir/triple/mod.rs index 5292c76..3894906 100644 --- a/src/repr/ir/triple/mod.rs +++ b/src/repr/ir/triple/mod.rs @@ -1,6 +1,6 @@ use index_vec::define_index_type; -use crate::{stage::lower_ir::FunctionIdx, util::symbol_map::interner_symbol_map::Symbol}; +use crate::{repr::identifier::FunctionIdx, util::symbol_map::interner_symbol_map::Symbol}; use super::{BasicBlockIdx, Value}; @@ -25,7 +25,7 @@ pub enum Triple { /// Jump to the corresponding basic block. Jump(BasicBlockIdx), /// Call the corresponding function. - Call(Symbol), + Call(FunctionIdx), /// Return with the provided value. Return(Value), /// Assign some symbol to some value. diff --git a/src/repr/mod.rs b/src/repr/mod.rs index 4035d42..188c8f8 100644 --- a/src/repr/mod.rs +++ b/src/repr/mod.rs @@ -1,4 +1,5 @@ pub mod ast; +pub mod identifier; pub mod ir; pub mod token; pub mod ty; diff --git a/src/stage/codegen/llvm/mod.rs b/src/stage/codegen/llvm/mod.rs index 8030e12..0e194ef 100644 --- a/src/stage/codegen/llvm/mod.rs +++ b/src/stage/codegen/llvm/mod.rs @@ -12,8 +12,11 @@ use inkwell::{ }; use crate::{ - repr::ir::{BinaryOp, Triple, UnaryOp, Value}, - stage::lower_ir::{FunctionIdx, IRCtx}, + repr::{ + identifier::FunctionIdx, + ir::{BinaryOp, Triple, UnaryOp, Value}, + }, + stage::lower_ir::IRCtx, util::symbol_map::interner_symbol_map::Symbol, }; @@ -25,8 +28,8 @@ pub struct Pass<'ctx> { module: Module<'ctx>, symbols: HashMap>, - basic_blocks: HashMap<(Symbol, BasicBlockIdx), inkwell::basic_block::BasicBlock<'ctx>>, - pub function_values: HashMap>, + basic_blocks: HashMap<(FunctionIdx, BasicBlockIdx), inkwell::basic_block::BasicBlock<'ctx>>, + pub function_values: HashMap>, } impl<'ctx> Pass<'ctx> { @@ -60,23 +63,23 @@ impl<'ctx> Pass<'ctx> { } /// Compile the provided function, returning the LLVM handle to it. - pub fn compile(&mut self, function_id: Symbol) -> FunctionValue<'ctx> { + pub fn compile(&mut self, function_idx: FunctionIdx) -> FunctionValue<'ctx> { let function = self .ir_ctx .functions - .get(&function_id) + .get(&function_idx) .expect("function to exist"); let builder = self.llvm_ctx.create_builder(); let fn_value = self .function_values - .get(&function_id) + .get(&function_idx) .expect("function to exist") .to_owned(); // BUG: This won't work with multiple functions - let entry_bb = (function_id, BasicBlockIdx::new(0)); + let entry_bb = (function_idx, BasicBlockIdx::new(0)); let entry = *self.basic_blocks.entry(entry_bb).or_insert_with(|| { self.llvm_ctx .append_basic_block(fn_value, &format!("bb_{:?}_{:?}", entry_bb.0, entry_bb.1)) @@ -108,7 +111,7 @@ impl<'ctx> Pass<'ctx> { fn compile_basic_block( &mut self, function: &FunctionValue<'ctx>, - basic_block_id: (Symbol, BasicBlockIdx), + basic_block_id: (FunctionIdx, BasicBlockIdx), ) { let bb = *self.basic_blocks.entry(basic_block_id).or_insert_with(|| { self.llvm_ctx diff --git a/src/stage/lower_ir/ctx.rs b/src/stage/lower_ir/ctx.rs index 62ee17a..d9e553e 100644 --- a/src/stage/lower_ir/ctx.rs +++ b/src/stage/lower_ir/ctx.rs @@ -1,18 +1,14 @@ use std::collections::HashMap; -use index_vec::{define_index_type, IndexVec}; - use crate::{ - repr::ir::Function, - util::symbol_map::interner_symbol_map::{InternerSymbolMap, Symbol}, + repr::{identifier::FunctionIdx, ir::Function}, + util::symbol_map::interner_symbol_map::InternerSymbolMap, }; -define_index_type! {pub struct FunctionIdx = usize;} - #[derive(Default, Clone, Debug)] pub struct IRCtx { /// Map of function symbol to the basic block entry point - pub functions: HashMap, + pub functions: HashMap, pub symbol_map: InternerSymbolMap, } @@ -23,9 +19,4 @@ impl IRCtx { ..Default::default() } } - - pub fn function_for_name(&self, s: &str) -> Option { - let symbol = self.symbol_map.get(s)?; - self.functions.get(&symbol).cloned() - } } diff --git a/src/stage/lower_ir/lowering/mod.rs b/src/stage/lower_ir/lowering/mod.rs index 1314445..b6b8902 100644 --- a/src/stage/lower_ir/lowering/mod.rs +++ b/src/stage/lower_ir/lowering/mod.rs @@ -1,9 +1,6 @@ -use crate::{ - repr::{ast::typed as ast, ir::*}, - util::symbol_map::interner_symbol_map::Symbol, -}; +use crate::repr::{ast::typed as ast, identifier::FunctionIdx, ir::*}; -use super::{FunctionIdx, IRCtx}; +use super::IRCtx; pub fn lower(program: ast::Program) -> IRCtx { let mut ir = IRCtx::new(program.symbols); @@ -39,7 +36,7 @@ impl FunctionLoweringCtx<'_> { } } -fn lower_function(ir_ctx: &mut IRCtx, function: ast::Function) -> Symbol { +fn lower_function(ir_ctx: &mut IRCtx, function: ast::Function) -> FunctionIdx { let mut repr_function = Function::new(&function); // Insert entry basic block diff --git a/src/stage/type_check/ctx.rs b/src/stage/type_check/ctx.rs new file mode 100644 index 0000000..97a70c4 --- /dev/null +++ b/src/stage/type_check/ctx.rs @@ -0,0 +1,142 @@ +use index_vec::{define_index_type, IndexVec}; + +use crate::{ + repr::{identifier::FunctionIdx, ty::Ty}, + util::symbol_map::interner_symbol_map::Symbol, +}; + +use super::FunctionSignature; + +pub trait TypeCheckCtx { + /// Register a function's signature and associated symbol, to produce a unique identifier for the function. + fn register_function(&mut self, symbol: Symbol, signature: FunctionSignature) -> FunctionIdx; + + /// Get the signature associated with a function identifier. + fn get_function(&self, idx: FunctionIdx) -> FunctionSignature; + + /// Attempt to look up a symbol, returning the associated function's identifier if it exists. + fn lookup_function_symbol(&self, symbol: Symbol) -> Option; +} + +#[cfg(test)] +mockall::mock! { + pub TypeCheckCtx {} + + impl TypeCheckCtx for TypeCheckCtx { + fn register_function(&mut self, symbol: Symbol, signature: FunctionSignature) -> FunctionIdx; + fn get_function(&self, idx: FunctionIdx) -> FunctionSignature; + fn lookup_function_symbol(&self, symbol: Symbol) -> Option; + } +} + +define_index_type! {pub struct ScopeIdx = usize;} +define_index_type! {pub struct BindingIdx = usize;} + +/// A binding within a specific scope. +pub struct ScopedBinding(ScopeIdx, BindingIdx); + +pub struct ScopePart { + /// Indicates that this scope (and potentially a descendant) is active. + active: bool, + + /// All of the bindings present within this scope. + bindings: IndexVec, +} + +impl ScopePart { + /// Create a new scope part, and automatically mark it as active. + pub fn new() -> Self { + Self { + active: true, + bindings: IndexVec::new(), + } + } + + pub fn exit(&mut self) { + self.active = false; + } + + pub fn add(&mut self, symbol: Symbol, ty: Ty) -> BindingIdx { + self.bindings.push((symbol, ty)) + } +} + +pub struct Scope { + scopes: IndexVec, +} + +impl Scope { + /// Create a new instance, and automatically enter the base scope. + pub fn new() -> Self { + let mut scope = Self { + scopes: IndexVec::new(), + }; + + // Automatically enter the first scope + scope.enter(); + + scope + } + + /// Enter a new scope. + pub fn enter(&mut self) -> ScopeIdx { + self.scopes.push(ScopePart::new()) + } + + /// Exit the most recent scope. + pub fn leave(&mut self) { + let Some(scope) = self.scopes.last_mut() else { + return; + }; + + scope.exit(); + } + + /// Register a symbol and associated type, and produce a unique binding for it. + pub fn register(&mut self, symbol: Symbol, ty: Ty) -> ScopedBinding { + // Fetch the currently active scope + let active_scope_idx = self.active_scope(); + let active_scope = &mut self.scopes[active_scope_idx]; + + // Register the binding type and symbol + let binding_idx = active_scope.add(symbol, ty); + + ScopedBinding(active_scope_idx, binding_idx) + } + + /// Attempt to retrieve a binding and type for the provided symbol. Will only search for items that are in scope. + pub fn resolve(&mut self, symbol: Symbol) -> Option<(ScopedBinding, Ty)> { + // Run through all possible scopes + self.scopes + .iter_enumerated() + // Only include active scopes + .filter(|(_, scope)| scope.active) + // Extract all of the bindings from the scopes + .flat_map(|(scope_idx, scope)| { + scope + .bindings + .iter_enumerated() + // Only consider bindings that match the symbol + .filter(|(_, (test_symbol, _))| *test_symbol == symbol) + // Generate the scoped binding representation to track which scope the binding originated from + .map(move |(binding_idx, (_, ty))| (ScopedBinding(scope_idx, binding_idx), *ty)) + }) + .next_back() + } + + /// Find the currently activated scope identifier. + fn active_scope(&self) -> ScopeIdx { + self.scopes + .iter_enumerated() + .rev() + .find(|(_, scope)| scope.active) + .map(|(idx, _)| idx) + .expect("should always be at least one scope active") + } +} + +impl Default for Scope { + fn default() -> Self { + Self::new() + } +} diff --git a/src/stage/type_check/expression/block.rs b/src/stage/type_check/expression/block.rs index 9b030ff..2f5a9cc 100644 --- a/src/stage/type_check/expression/block.rs +++ b/src/stage/type_check/expression/block.rs @@ -1,11 +1,17 @@ +use ctx::{Scope, TypeCheckCtx}; + use super::*; impl parse_ast::Block { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { let statements = self .statements .into_iter() - .map(|statement| statement.ty_solve(ctx)) + .map(|statement| statement.ty_solve(ctx, scope)) .collect::, _>>()?; let ty_info = TyInfo::try_from(( @@ -43,15 +49,10 @@ mod test { use crate::{ repr::{ast::untyped::*, ty::Ty}, + stage::type_check::ctx::{MockTypeCheckCtx, Scope}, util::source::Span, }; - use super::expression::{FnCtx, TyError, TyInfo}; - - fn run(b: Block) -> Result { - Ok(b.ty_solve(&mut FnCtx::mock())?.ty_info) - } - #[test] fn ty_check_block() { // { @@ -76,7 +77,10 @@ mod test { Span::default(), ); - let ty_info = run(b).unwrap(); + let ty_info = b + .ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()) + .unwrap() + .ty_info; assert_eq!(ty_info.ty, Ty::Unit); assert_eq!(ty_info.return_ty, Some(Ty::Int)); @@ -108,7 +112,9 @@ mod test { Span::default(), ); - assert!(run(b).is_err()); + let result = b.ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()); + + assert!(result.is_err()); } #[test] @@ -133,7 +139,10 @@ mod test { Span::default(), ); - let ty_info = run(b).unwrap(); + let ty_info = b + .ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()) + .unwrap() + .ty_info; assert_eq!(ty_info.ty, Ty::Unit); assert_eq!(ty_info.return_ty, None); diff --git a/src/stage/type_check/expression/call.rs b/src/stage/type_check/expression/call.rs index e843115..a3068e7 100644 --- a/src/stage/type_check/expression/call.rs +++ b/src/stage/type_check/expression/call.rs @@ -1,21 +1,19 @@ use super::*; impl parse_ast::Call { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve(self, ctx: &mut impl TypeCheckCtx, scope: &mut Scope) -> Result { // Determine the types of all the arguments let args = self .args .into_iter() - .map(|arg| arg.ty_solve(ctx)) + .map(|arg| arg.ty_solve(ctx, scope)) .collect::, _>>()?; - let ty_ctx = ctx.ty_ctx.borrow(); - // Compare the arguments to the function types - let signature = ty_ctx - .function_signatures - .get(&self.name) + let function_idx = ctx + .lookup_function_symbol(self.name) .ok_or(TyError::SymbolNotFound(self.name))?; + let signature = ctx.get_function(function_idx); if args.len() != signature.arguments.len() { // TODO: Make new type error for when the function call has too many arguments @@ -39,8 +37,7 @@ impl parse_ast::Call { // Ensure all the return types from the arguments are correct args.iter().map(|arg| arg.get_ty_info().return_ty), ))?, - // TODO: Resolve a symbol into a function - name: todo!(), //self.name, + name: function_idx, args, span: self.span, }) diff --git a/src/stage/type_check/expression/ident.rs b/src/stage/type_check/expression/ident.rs index 7ffe705..8101312 100644 --- a/src/stage/type_check/expression/ident.rs +++ b/src/stage/type_check/expression/ident.rs @@ -1,13 +1,16 @@ use super::*; impl parse_ast::Ident { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve( + self, + _ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { Ok(Ident { ty_info: TyInfo { - ty: ctx - .scope - .get(&self.name) - .cloned() + ty: scope + .resolve(self.name) + .map(|(_, ty)| ty) .ok_or(TyError::SymbolNotFound(self.name))?, return_ty: None, }, @@ -19,31 +22,32 @@ impl parse_ast::Ident { #[cfg(test)] mod test_ident { - use string_interner::Symbol; - use crate::{repr::ast::untyped::Ident, util::source::Span}; - - use super::expression::{FnCtx, Ty, TyError, TyInfo}; + use crate::{ + repr::ast::untyped::Ident, + stage::type_check::ctx::{MockTypeCheckCtx, Scope}, + util::source::Span, + }; - fn run(i: Ident, ident: bool) -> Result { - let mut fn_ctx = FnCtx::mock(); - if ident { - fn_ctx - .scope - .insert(Symbol::try_from_usize(0).unwrap(), Ty::Int); - } - - Ok(i.ty_solve(&mut fn_ctx)?.ty_info) - } + use super::expression::Ty; #[test] fn ident_present() { - let ty_info = run( - Ident::new(Symbol::try_from_usize(0).unwrap(), Span::default()), - true, - ) - .unwrap(); + // Set up a reference symbol + let symbol = Symbol::try_from_usize(0).unwrap(); + + // Create a scope and add the symbol to it + let mut scope = Scope::new(); + scope.register(symbol, Ty::Int); + + let i = Ident::new(symbol, Span::default()); + + // Run the type solve + let ty_info = i + .ty_solve(&mut MockTypeCheckCtx::new(), &mut scope) + .unwrap() + .ty_info; assert_eq!(ty_info.ty, Ty::Int); assert_eq!(ty_info.return_ty, None); @@ -51,10 +55,10 @@ mod test_ident { #[test] fn ident_infer_missing() { - assert!(run( - Ident::new(Symbol::try_from_usize(0).unwrap(), Span::default()), - false, - ) - .is_err()); + let i = Ident::new(Symbol::try_from_usize(0).unwrap(), Span::default()); + + let result = i.ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()); + + assert!(result.is_err()); } } diff --git a/src/stage/type_check/expression/if_else.rs b/src/stage/type_check/expression/if_else.rs index a0e0c35..1c57e96 100644 --- a/src/stage/type_check/expression/if_else.rs +++ b/src/stage/type_check/expression/if_else.rs @@ -1,18 +1,18 @@ use super::*; impl parse_ast::If { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve(self, ctx: &mut impl TypeCheckCtx, scope: &mut Scope) -> Result { // Make sure the condition is correctly typed - let condition = self.condition.ty_solve(ctx)?; + let condition = self.condition.ty_solve(ctx, scope)?; let condition_ty = condition.get_ty_info(); if !matches!(condition_ty.ty, Ty::Boolean) { return Err(TyError::Mismatch(Ty::Boolean, condition_ty.ty)); } - let success = self.success.ty_solve(ctx)?; + let success = self.success.ty_solve(ctx, scope)?; let otherwise = self .otherwise - .map(|otherwise| otherwise.ty_solve(ctx)) + .map(|otherwise| otherwise.ty_solve(ctx, scope)) .transpose()?; let ty_info = TyInfo::try_from(( diff --git a/src/stage/type_check/expression/infix.rs b/src/stage/type_check/expression/infix.rs index 58ec1d5..9837af6 100644 --- a/src/stage/type_check/expression/infix.rs +++ b/src/stage/type_check/expression/infix.rs @@ -14,9 +14,13 @@ impl InfixOperation { } impl parse_ast::Infix { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { - let left = self.left.ty_solve(ctx)?; - let right = self.right.ty_solve(ctx)?; + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { + let left = self.left.ty_solve(ctx, scope)?; + let right = self.right.ty_solve(ctx, scope)?; let left_ty_info = left.get_ty_info(); let right_ty_info = right.get_ty_info(); @@ -42,15 +46,10 @@ impl parse_ast::Infix { mod test_infix { use crate::{ repr::{ast::untyped::*, ty::Ty}, + stage::type_check::ctx::{MockTypeCheckCtx, Scope}, util::source::Span, }; - use super::expression::{FnCtx, TyError, TyInfo}; - - fn run(i: Infix) -> Result { - Ok(i.ty_solve(&mut FnCtx::mock())?.ty_info) - } - #[test] fn infix_same() { // 0 + 0 @@ -61,7 +60,10 @@ mod test_infix { Span::default(), ); - let ty_info = run(infix).unwrap(); + let ty_info = infix + .ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()) + .unwrap() + .ty_info; assert_eq!(ty_info.ty, Ty::Int); assert_eq!(ty_info.return_ty, None); } @@ -75,6 +77,7 @@ mod test_infix { Span::default(), ); - assert!(run(infix).is_err()); + let result = infix.ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()); + assert!(result.is_err()); } } diff --git a/src/stage/type_check/expression/mod.rs b/src/stage/type_check/expression/mod.rs index 14388ce..4a92705 100644 --- a/src/stage/type_check/expression/mod.rs +++ b/src/stage/type_check/expression/mod.rs @@ -1,3 +1,5 @@ +use ctx::{Scope, TypeCheckCtx}; + use super::*; mod block; @@ -9,15 +11,19 @@ mod infix; mod integer; impl parse_ast::Expression { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { Ok(match self { - parse_ast::Expression::Infix(e) => Expression::Infix(e.ty_solve(ctx)?), + parse_ast::Expression::Infix(e) => Expression::Infix(e.ty_solve(ctx, scope)?), parse_ast::Expression::Integer(e) => Expression::Integer(e.ty_solve()?), parse_ast::Expression::Boolean(e) => Expression::Boolean(e.ty_solve()?), - parse_ast::Expression::Ident(e) => Expression::Ident(e.ty_solve(ctx)?), - parse_ast::Expression::Block(e) => Expression::Block(e.ty_solve(ctx)?), - parse_ast::Expression::If(e) => Expression::If(e.ty_solve(ctx)?), - parse_ast::Expression::Call(e) => Expression::Call(e.ty_solve(ctx)?), + parse_ast::Expression::Ident(e) => Expression::Ident(e.ty_solve(ctx, scope)?), + parse_ast::Expression::Block(e) => Expression::Block(e.ty_solve(ctx, scope)?), + parse_ast::Expression::If(e) => Expression::If(e.ty_solve(ctx, scope)?), + parse_ast::Expression::Call(e) => Expression::Call(e.ty_solve(ctx, scope)?), }) } } diff --git a/src/stage/type_check/function.rs b/src/stage/type_check/function.rs index e287af7..1affd58 100644 --- a/src/stage/type_check/function.rs +++ b/src/stage/type_check/function.rs @@ -1,16 +1,22 @@ +use ctx::{Scope, TypeCheckCtx}; + use super::*; impl parse_ast::Function { - pub fn ty_solve(self, ctx: Rc>) -> Result { - // TODO: Register this symbol to get a function identifier for it - let identifier = self.name; + pub fn ty_solve(self, ctx: &mut impl TypeCheckCtx) -> Result { + let identifier = ctx + .lookup_function_symbol(self.name) + .expect("function must already be registered"); - // Set up a fn ctx just for this function - let mut ctx = FnCtx::new(ctx); + // Create the scope for this function + let mut scope = Scope::new(); - // TODO: Need to insert parameters into scope + // Add all of the function's parameters into the scope so they're accessible + self.parameters.iter().for_each(|(symbol, ty)| { + scope.register(*symbol, *ty); + }); - let body = self.body.ty_solve(&mut ctx)?; + let body = self.body.ty_solve(ctx, &mut scope)?; // If the body contains any return statements, they must match the annotated return statement if let Some(return_ty) = body.ty_info.return_ty { diff --git a/src/stage/type_check/mod.rs b/src/stage/type_check/mod.rs index 4c48e63..9d309cc 100644 --- a/src/stage/type_check/mod.rs +++ b/src/stage/type_check/mod.rs @@ -1,10 +1,10 @@ +mod ctx; mod expression; mod function; mod program; mod statement; use itertools::Itertools; -use std::{cell::RefCell, collections::HashMap, rc::Rc}; use crate::repr::ty::Ty; use crate::{ @@ -12,6 +12,8 @@ use crate::{ util::symbol_map::interner_symbol_map::Symbol, }; +pub use ctx::TypeCheckCtx; + #[derive(Clone, Debug)] pub struct FunctionSignature { arguments: Vec, @@ -27,35 +29,6 @@ impl From<&base_ast::Function> for F } } -#[derive(Default)] -pub struct TyCtx { - function_signatures: HashMap, -} - -impl TyCtx { - pub fn mock() -> Self { - Self::default() - } -} - -pub struct FnCtx { - ty_ctx: Rc>, - scope: HashMap, -} - -impl FnCtx { - pub fn new(ty_ctx: Rc>) -> Self { - Self { - ty_ctx, - scope: HashMap::new(), - } - } - - pub fn mock() -> Self { - Self::new(Rc::new(RefCell::new(TyCtx::mock()))) - } -} - #[derive(Debug, thiserror::Error)] pub enum TyError { #[error("mismatched types: {0:?} and {1:?}")] diff --git a/src/stage/type_check/program.rs b/src/stage/type_check/program.rs index 827efff..ac46106 100644 --- a/src/stage/type_check/program.rs +++ b/src/stage/type_check/program.rs @@ -1,30 +1,28 @@ +use ctx::TypeCheckCtx; + use super::*; impl parse_ast::Program { - pub fn ty_solve(self) -> Result { + pub fn ty_solve(self, ctx: &mut impl TypeCheckCtx) -> Result { // Main function must return int if self.main.return_ty != Ty::Int { return Err(TyError::Mismatch(Ty::Int, self.main.return_ty)); } - let mut ctx = TyCtx::default(); + ctx.register_function(self.main.name, FunctionSignature::from(&self.main)); // Pre-register all functions - ctx.function_signatures.extend( - self.functions - .iter() - .map(|function| (function.name, FunctionSignature::from(function))), - ); - - let ctx = Rc::new(RefCell::new(ctx)); + self.functions.iter().for_each(|function| { + ctx.register_function(function.name, FunctionSignature::from(function)); + }); // Make sure the type of the function is correct - let main = self.main.ty_solve(Rc::clone(&ctx))?; + let main = self.main.ty_solve(ctx)?; let functions = self .functions .into_iter() - .map(|function| function.ty_solve(Rc::clone(&ctx))) + .map(|function| function.ty_solve(ctx)) .collect::, _>>()?; Ok(Program { diff --git a/src/stage/type_check/statement.rs b/src/stage/type_check/statement.rs index fd1012e..07a004d 100644 --- a/src/stage/type_check/statement.rs +++ b/src/stage/type_check/statement.rs @@ -1,19 +1,29 @@ +use ctx::{Scope, TypeCheckCtx}; + use super::*; impl parse_ast::Statement { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { Ok(match self { - parse_ast::Statement::Return(s) => Statement::Return(s.ty_solve(ctx)?), - parse_ast::Statement::Let(s) => Statement::Let(s.ty_solve(ctx)?), - parse_ast::Statement::Expression(s) => Statement::Expression(s.ty_solve(ctx)?), + parse_ast::Statement::Return(s) => Statement::Return(s.ty_solve(ctx, scope)?), + parse_ast::Statement::Let(s) => Statement::Let(s.ty_solve(ctx, scope)?), + parse_ast::Statement::Expression(s) => Statement::Expression(s.ty_solve(ctx, scope)?), }) } } impl parse_ast::LetStatement { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { // Work out what the type of the value is - let value = self.value.ty_solve(ctx)?; + let value = self.value.ty_solve(ctx, scope)?; // Make sure the value type matches what the statement was annotated with if let Some(ty) = self.ty_info { @@ -24,7 +34,7 @@ impl parse_ast::LetStatement { } // Record the type - ctx.scope.insert(self.name, value.get_ty_info().ty); + scope.register(self.name, value.get_ty_info().ty); Ok(LetStatement { ty_info: TyInfo { @@ -39,8 +49,12 @@ impl parse_ast::LetStatement { } impl parse_ast::ReturnStatement { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { - let value = self.value.ty_solve(ctx)?; + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { + let value = self.value.ty_solve(ctx, scope)?; Ok(ReturnStatement { ty_info: TyInfo::try_from(( @@ -54,8 +68,12 @@ impl parse_ast::ReturnStatement { } impl parse_ast::ExpressionStatement { - pub fn ty_solve(self, ctx: &mut FnCtx) -> Result { - let expression = self.expression.ty_solve(ctx)?; + pub fn ty_solve( + self, + ctx: &mut impl TypeCheckCtx, + scope: &mut Scope, + ) -> Result { + let expression = self.expression.ty_solve(ctx, scope)?; // Expression statement has same type as the underlying expression let mut ty_info = expression.get_ty_info().clone(); @@ -74,29 +92,27 @@ impl parse_ast::ExpressionStatement { #[cfg(test)] mod test_statement { - use std::collections::HashMap; use string_interner::Symbol as _; use crate::{ - repr::ast::untyped::*, util::source::Span, util::symbol_map::interner_symbol_map::Symbol, + repr::ast::untyped::*, + stage::type_check::ctx::{MockTypeCheckCtx, Scope}, + util::{source::Span, symbol_map::interner_symbol_map::Symbol}, }; - use super::{FnCtx, Ty, TyError, TyInfo}; - - fn run(s: Statement) -> Result<(TyInfo, HashMap), TyError> { - let mut fn_ctx = FnCtx::mock(); - let ty = s.ty_solve(&mut fn_ctx)?.get_ty_info().clone(); - - Ok((ty, fn_ctx.scope.clone())) - } + use super::Ty; #[test] fn return_statement() { // return 0; let s = Statement::_return(Expression::integer(0, Span::default()), Span::default()); - let (ty_info, _) = run(s).unwrap(); + let ty_info = s + .ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()) + .unwrap() + .get_ty_info() + .clone(); assert_eq!(ty_info.ty, Ty::Unit); assert_eq!(ty_info.return_ty, Some(Ty::Int)); @@ -111,13 +127,19 @@ mod test_statement { Span::default(), ); - let (ty_info, symbols) = run(s).unwrap(); + let mut scope = Scope::new(); + + let ty_info = s + .ty_solve(&mut MockTypeCheckCtx::new(), &mut scope) + .unwrap() + .get_ty_info() + .clone(); assert_eq!(ty_info.ty, Ty::Unit); assert_eq!(ty_info.return_ty, None); assert_eq!( - symbols.get(&Symbol::try_from_usize(0).unwrap()).cloned(), - Some(Ty::Int) + scope.resolve(Symbol::try_from_usize(0).unwrap()).unwrap().1, + Ty::Int ); } @@ -130,7 +152,11 @@ mod test_statement { Span::default(), ); - let (ty_info, _) = run(s).unwrap(); + let ty_info = s + .ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()) + .unwrap() + .get_ty_info() + .clone(); assert_eq!(ty_info.ty, Ty::Unit); assert_eq!(ty_info.return_ty, None); @@ -145,7 +171,11 @@ mod test_statement { Span::default(), ); - let (ty_info, _) = run(s).unwrap(); + let ty_info = s + .ty_solve(&mut MockTypeCheckCtx::new(), &mut Scope::new()) + .unwrap() + .get_ty_info() + .clone(); assert_eq!(ty_info.ty, Ty::Int); assert_eq!(ty_info.return_ty, None); diff --git a/tests/programs.rs b/tests/programs.rs index dceaea0..de5968d 100644 --- a/tests/programs.rs +++ b/tests/programs.rs @@ -109,7 +109,7 @@ fn programs(#[case] expected: i64, #[case] source: &'static str) { } }; - let program = match program.ty_solve() { + let program = match program.ty_solve(&mut ctx) { Ok(program) => program, Err(e) => { eprintln!("{e}"); @@ -117,12 +117,9 @@ fn programs(#[case] expected: i64, #[case] source: &'static str) { } }; - let ir_ctx = lower_ir::lower(program); - let main = ir_ctx - .symbol_map - .get("main") - .expect("main function to exist"); + let main = program.main.name; + let ir_ctx = lower_ir::lower(program); let ctx = inkwell::context::Context::create(); let mut llvm_pass = Pass::new(&ctx, ir_ctx);