From cd1678f52af5ccbf9b1bbd92dd6da4e3c9459b47 Mon Sep 17 00:00:00 2001 From: Tom Anderson Date: Thu, 18 Jul 2024 17:53:29 +1000 Subject: [PATCH] create and implement `IRCtx` --- src/.DS_Store | Bin 6148 -> 0 bytes src/compile_pass.rs | 90 ++++++++++++++++++++- src/main.rs | 19 +++-- src/stage/codegen/llvm/mod.rs | 73 ++++++++--------- src/stage/lower_ir/ctx.rs | 86 ++++++++++++++++---- src/stage/lower_ir/lowering/mod.rs | 122 ++++++++++------------------- tests/programs.rs | 20 +++-- 7 files changed, 266 insertions(+), 144 deletions(-) delete mode 100644 src/.DS_Store diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index 22e0357a90efe843c9f49c982e2bee3cc71d7467..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~J!%6%427R!7lt%jx}3%b$PET#pTHLgIN-)O5R#B`j-IEV#+%e(2~Qxsk!Hp2 ze#OoTfbBnCo`4C!hVF`;hnX4k70!6a*XQZ&c745FSn(?G5HT}WCd~G0TOtA?AOa#F z0wS;=0(pvK^SGdA(xZrg2rPque;*p%wU>_6_;fJD2tZx49L9Cb64d4eYA+qBtkA5c z2dh?#F~sZ9PHlBvFCD4b4y)nA>dxj<49&6~)|k+&hbV}E2+RnqGN1hXKhl4j|7R^q zML-1p839`#PKP~Ts?OFQujlpWs``4+sc|{O-A@1$KZ>{XFzz>BP, function_symbols: HashMap, + + ir_functions: HashMap, } impl SymbolMap for CompilePass { @@ -55,3 +58,88 @@ impl TypeCheckCtx for CompilePass { self.function_symbols.get(&symbol).cloned() } } + +impl IRCtx for CompilePass { + type FunctionBuilder = FunctionBuilder; + + fn register_function(&mut self, idx: FunctionIdx, function: ir::Function) { + self.ir_functions.insert(idx, function); + } + + fn all_functions(&self) -> Vec<(FunctionIdx, ir::Function)> { + self.ir_functions + .iter() + .map(|(idx, function)| (*idx, function.clone())) + .collect() + } +} + +pub struct FunctionBuilder { + idx: FunctionIdx, + signature: FunctionSignature, + + basic_blocks: IndexVec, + current_basic_block: ir::BasicBlockIdx, + + scope: Vec<(Symbol, Ty)>, +} + +impl FunctionBuilderTrait for FunctionBuilder { + fn new(function: &Function) -> Self { + let mut basic_blocks = IndexVec::new(); + let current_basic_block = basic_blocks.push(ir::BasicBlock::default()); + + Self { + idx: function.name, + signature: FunctionSignature::from(function), + basic_blocks, + current_basic_block, + scope: Vec::new(), + } + } + + fn register_scoped(&mut self, symbol: Symbol, ty: Ty) { + self.scope.push((symbol, ty)); + } + + fn add_triple(&mut self, triple: ir::Triple) -> ir::TripleRef { + ir::TripleRef { + basic_block: self.current_basic_block, + triple: self.basic_blocks[self.current_basic_block] + .triples + .push(triple), + } + } + + fn current_bb(&self) -> ir::BasicBlockIdx { + self.current_basic_block + } + + fn goto_bb(&mut self, bb: ir::BasicBlockIdx) { + assert!( + bb < self.basic_blocks.len_idx(), + "can only goto basic block if it exists" + ); + self.current_basic_block = bb; + } + + fn push_bb(&mut self) -> ir::BasicBlockIdx { + let idx = self.basic_blocks.push(ir::BasicBlock::default()); + + self.current_basic_block = idx; + + idx + } + + fn build>(self, ctx: &mut Ctx) { + ctx.register_function( + self.idx, + ir::Function { + symbol: self.idx, + signature: self.signature, + basic_blocks: self.basic_blocks, + scope: self.scope.into_iter().map(|(symbol, _)| symbol).collect(), + }, + ) + } +} diff --git a/src/main.rs b/src/main.rs index 7775ed3..0291f2d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,11 @@ use lumina::{ compile_pass::CompilePass, - stage::{codegen::llvm::Pass, lex::Lexer, lower_ir as ir, parse::parse}, + stage::{ + codegen::llvm::Pass, + lex::Lexer, + lower_ir::{self as ir, IRCtx}, + parse::parse, + }, util::source::Source, }; @@ -36,11 +41,15 @@ fn main() -> int { let main = program.main.name; - let ir_ctx = ir::lower(program); - let ctx = inkwell::context::Context::create(); + ir::lower(&mut ctx, program); + let llvm_ctx = inkwell::context::Context::create(); - let function_ids = ir_ctx.functions.keys().cloned().collect::>(); - let mut llvm_pass = Pass::new(&ctx, ir_ctx); + let function_ids = ctx + .all_functions() + .iter() + .map(|(idx, _)| *idx) + .collect::>(); + let mut llvm_pass = Pass::new(&llvm_ctx, ctx); function_ids.into_iter().for_each(|function| { llvm_pass.compile(function); }); diff --git a/src/stage/codegen/llvm/mod.rs b/src/stage/codegen/llvm/mod.rs index 0e194ef..0b4bf1c 100644 --- a/src/stage/codegen/llvm/mod.rs +++ b/src/stage/codegen/llvm/mod.rs @@ -14,17 +14,15 @@ use inkwell::{ use crate::{ repr::{ identifier::FunctionIdx, - ir::{BinaryOp, Triple, UnaryOp, Value}, + ir::{BasicBlockIdx, BinaryOp, ConstantValue, Triple, TripleIdx, UnaryOp, Value}, }, stage::lower_ir::IRCtx, util::symbol_map::interner_symbol_map::Symbol, }; -use crate::repr::ir::{BasicBlockIdx, ConstantValue, TripleIdx}; - -pub struct Pass<'ctx> { +pub struct Pass<'ctx, I> { llvm_ctx: &'ctx LLVMContext, - ir_ctx: IRCtx, + ir_ctx: I, module: Module<'ctx>, symbols: HashMap>, @@ -32,27 +30,32 @@ pub struct Pass<'ctx> { pub function_values: HashMap>, } -impl<'ctx> Pass<'ctx> { - pub fn new(llvm_ctx: &'ctx LLVMContext, ir_ctx: IRCtx) -> Self { +impl<'ctx, I> Pass<'ctx, I> +where + I: IRCtx, +{ + pub fn new(llvm_ctx: &'ctx LLVMContext, ir_ctx: I) -> Self { let module = llvm_ctx.create_module("module"); Self { - function_values: HashMap::from_iter(ir_ctx.functions.iter().map(|(idx, function)| { - (*idx, { - // Forward-declare all the functions - // TODO: Pick appropriate return type depending on signature - let fn_type = llvm_ctx.i64_type().fn_type(&[], false); - let fn_value = module.add_function( - // TODO: Determine function name from identifier - // ir_ctx.symbol_map.resolve(function.symbol).unwrap(), - "my function", - fn_type, - None, - ); - - fn_value - }) - })), + function_values: HashMap::from_iter(ir_ctx.all_functions().iter().map( + |(idx, _function)| { + (*idx, { + // Forward-declare all the functions + // TODO: Pick appropriate return type depending on signature + let fn_type = llvm_ctx.i64_type().fn_type(&[], false); + let fn_value = module.add_function( + // TODO: Determine function name from identifier + // ir_ctx.symbol_map.resolve(function.symbol).unwrap(), + "my function", + fn_type, + None, + ); + + fn_value + }) + }, + )), llvm_ctx, ir_ctx, module, @@ -66,9 +69,12 @@ impl<'ctx> Pass<'ctx> { pub fn compile(&mut self, function_idx: FunctionIdx) -> FunctionValue<'ctx> { let function = self .ir_ctx - .functions - .get(&function_idx) - .expect("function to exist"); + .all_functions() + .iter() + .find(|(idx, _)| idx == &function_idx) + .expect("function to exist") + .1 + .clone(); let builder = self.llvm_ctx.create_builder(); @@ -92,13 +98,7 @@ impl<'ctx> Pass<'ctx> { ( symbol.to_owned(), builder - .build_alloca( - self.llvm_ctx.i64_type(), - self.ir_ctx - .symbol_map - .resolve(*symbol) - .expect("symbol to exist in map"), - ) + .build_alloca(self.llvm_ctx.i64_type(), "todo: work out symbol name") .unwrap(), ) })); @@ -123,9 +123,12 @@ impl<'ctx> Pass<'ctx> { let basic_block = self .ir_ctx - .functions - .get(&basic_block_id.0) + .all_functions() + .iter() + .find(|(idx, _)| idx == &basic_block_id.0) .expect("function to exist") + .1 + .clone() .basic_blocks .get(basic_block_id.1) .expect("requested basic block must exist") diff --git a/src/stage/lower_ir/ctx.rs b/src/stage/lower_ir/ctx.rs index d9e553e..1e752b9 100644 --- a/src/stage/lower_ir/ctx.rs +++ b/src/stage/lower_ir/ctx.rs @@ -1,22 +1,78 @@ -use std::collections::HashMap; - use crate::{ - repr::{identifier::FunctionIdx, ir::Function}, - util::symbol_map::interner_symbol_map::InternerSymbolMap, + repr::{ + ast::typed as ast, + identifier::FunctionIdx, + ir::{BasicBlockIdx, Function, Triple, TripleRef}, + ty::Ty, + }, + util::symbol_map::interner_symbol_map::Symbol, }; -#[derive(Default, Clone, Debug)] -pub struct IRCtx { - /// Map of function symbol to the basic block entry point - pub functions: HashMap, - pub symbol_map: InternerSymbolMap, +pub trait IRCtx { + type FunctionBuilder: FunctionBuilder; + + /// Register an IR function implementation against it's identifier. + fn register_function(&mut self, idx: FunctionIdx, function: Function); + + /// Prepare and return a new builder for the provided AST function representation. + fn new_builder(&self, function: &ast::Function) -> Self::FunctionBuilder { + Self::FunctionBuilder::new(function) + } + + // TODO: Probably get rid of this + fn all_functions(&self) -> Vec<(FunctionIdx, Function)>; +} + +#[cfg(test)] +mockall::mock! { + pub IRCtx {} + + impl IRCtx for IRCtx { + type FunctionBuilder = MockFunctionBuilder; + + fn register_function(&mut self, idx: FunctionIdx, function: Function); + fn new_builder(&self, function: &ast::Function) -> MockFunctionBuilder; + fn all_functions(&self) -> Vec<(FunctionIdx, Function)>; + } } -impl IRCtx { - pub fn new(symbol_map: InternerSymbolMap) -> Self { - Self { - symbol_map, - ..Default::default() - } +/// A stateful representation of a function that is being constructed. A function consists of basic +/// blocks, and the builder is always 'located' at a basic block. +pub trait FunctionBuilder { + /// Initialise a new builder with the provided function, positioned at the entry point. + fn new(function: &ast::Function) -> Self; + + /// Register the provided symbol with the given type into the function scope. + fn register_scoped(&mut self, symbol: Symbol, ty: Ty); + + /// Add a triple to the current basic block. + fn add_triple(&mut self, triple: Triple) -> TripleRef; + + /// Get the current basic block. + fn current_bb(&self) -> BasicBlockIdx; + + /// Go to a specific basic block. + fn goto_bb(&mut self, bb: BasicBlockIdx); + + /// Create a new basic block, switch to it, and return its identifier. + fn push_bb(&mut self) -> BasicBlockIdx; + + /// Consume the builder, and register the built function against the context. + fn build>(self, ctx: &mut Ctx); +} + +#[cfg(test)] +mockall::mock! { + pub FunctionBuilder {} + + impl FunctionBuilder for FunctionBuilder { + fn new(function: &ast::Function) -> Self; + fn register_scoped(&mut self, symbol: Symbol, ty: Ty); + fn add_triple(&mut self, triple: Triple) -> TripleRef; + fn current_bb(&self) -> BasicBlockIdx; + fn goto_bb(&mut self, bb: BasicBlockIdx) ; + fn push_bb(&mut self) -> BasicBlockIdx; + #[mockall::concretize] + fn build>(self, ctx: &mut Ctx); } } diff --git a/src/stage/lower_ir/lowering/mod.rs b/src/stage/lower_ir/lowering/mod.rs index b6b8902..d8b98fa 100644 --- a/src/stage/lower_ir/lowering/mod.rs +++ b/src/stage/lower_ir/lowering/mod.rs @@ -1,71 +1,33 @@ -use crate::repr::{ast::typed as ast, identifier::FunctionIdx, ir::*}; +use crate::repr::{ast::typed as ast, ir::*}; -use super::IRCtx; - -pub fn lower(program: ast::Program) -> IRCtx { - let mut ir = IRCtx::new(program.symbols); +use super::{FunctionBuilder, IRCtx}; +pub fn lower(ctx: &mut impl IRCtx, program: ast::Program) { // Fill up the functions in the IR for function in program.functions { - lower_function(&mut ir, function); + lower_function(ctx, function); } - lower_function(&mut ir, program.main); - - ir -} - -struct FunctionLoweringCtx<'ctx> { - ir_ctx: &'ctx mut IRCtx, - current_bb: BasicBlockIdx, - function: Function, -} - -impl FunctionLoweringCtx<'_> { - fn add_triple(&mut self, triple: Triple) -> TripleRef { - TripleRef { - basic_block: self.current_bb, - triple: self - .function - .basic_blocks - .get_mut(self.current_bb) - .expect("current basic block must exist") - .triples - .push(triple), - } - } + lower_function(ctx, program.main); } -fn lower_function(ir_ctx: &mut IRCtx, function: ast::Function) -> FunctionIdx { - let mut repr_function = Function::new(&function); - - // Insert entry basic block - assert!( - repr_function.basic_blocks.is_empty(), - "entry basic block should be first in function" - ); - let entry = repr_function.basic_blocks.push(BasicBlock::default()); +fn lower_function(ctx: &mut impl IRCtx, function: ast::Function) { + // Create a new function builder, which will already be positioned at the entry point. + let mut builder = ctx.new_builder(&function); // Perform the lowering - let repr_function = { - let mut ctx = FunctionLoweringCtx { - ir_ctx, - current_bb: entry, - function: repr_function, - }; - - lower_block(&mut ctx, &function.body); + lower_block(ctx, &mut builder, &function.body); - ctx.function - }; - - ir_ctx.functions.insert(function.name, repr_function); - - function.name + // Consume the builder + builder.build(ctx); } /// Lower an AST block into the current function context. -fn lower_block(ctx: &mut FunctionLoweringCtx, block: &ast::Block) -> Value { +fn lower_block( + ctx: &mut impl IRCtx, + builder: &mut impl FunctionBuilder, + block: &ast::Block, +) -> Value { assert!( !block.statements.is_empty(), "block must have statements within it" @@ -82,21 +44,17 @@ fn lower_block(ctx: &mut FunctionLoweringCtx, block: &ast::Block) -> Value { { match statement { ast::Statement::Return(ast::ReturnStatement { value, .. }) => { - let value = lower_expression(ctx, value); - ctx.add_triple(Triple::Return(value)); + let value = lower_expression(ctx, builder, value); + builder.add_triple(Triple::Return(value)); } ast::Statement::Let(ast::LetStatement { name, value, .. }) => { - assert!( - // Insert function name into scope - ctx.function.scope.insert(*name), - "cannot redeclare variable" - ); - - let value = lower_expression(ctx, value); - ctx.add_triple(Triple::Assign(*name, value)); + builder.register_scoped(*name, value.get_ty_info().ty); + + let value = lower_expression(ctx, builder, value); + builder.add_triple(Triple::Assign(*name, value)); } ast::Statement::Expression(ast::ExpressionStatement { expression, .. }) => { - let result = lower_expression(ctx, expression); + let result = lower_expression(ctx, builder, expression); // Implicit return // TODO: Check for semi-colon @@ -111,7 +69,11 @@ fn lower_block(ctx: &mut FunctionLoweringCtx, block: &ast::Block) -> Value { Value::Unit } -fn lower_expression(ctx: &mut FunctionLoweringCtx, expression: &ast::Expression) -> Value { +fn lower_expression( + ctx: &mut impl IRCtx, + builder: &mut impl FunctionBuilder, + expression: &ast::Expression, +) -> Value { match expression { ast::Expression::Infix(ast::Infix { left, @@ -119,47 +81,45 @@ fn lower_expression(ctx: &mut FunctionLoweringCtx, expression: &ast::Expression) right, .. }) => { - let lhs = lower_expression(ctx, left); - let rhs = lower_expression(ctx, right); + let lhs = lower_expression(ctx, builder, left); + let rhs = lower_expression(ctx, builder, right); let op = BinaryOp::from(operation); - Value::Triple(ctx.add_triple(Triple::BinaryOp { lhs, rhs, op })) + Value::Triple(builder.add_triple(Triple::BinaryOp { lhs, rhs, op })) } ast::Expression::Integer(integer) => Value::integer(integer.value), ast::Expression::Boolean(boolean) => Value::boolean(boolean.value), ast::Expression::Ident(ast::Ident { name, .. }) => Value::Name(*name), - ast::Expression::Block(block) => lower_block(ctx, block), + ast::Expression::Block(block) => lower_block(ctx, builder, block), ast::Expression::If(ast::If { condition, success, otherwise, .. }) => { - let condition = lower_expression(ctx, condition); + let condition = lower_expression(ctx, builder, condition); - let here = ctx.current_bb; + let original_bb = builder.current_bb(); // Lower success block into newly created basic block - let success_bb = ctx.function.basic_blocks.push(BasicBlock::default()); - ctx.current_bb = success_bb; - let success_value = lower_block(ctx, success); + let success_bb = builder.push_bb(); + let success_value = lower_block(ctx, builder, success); // Lower the otherwise block, if it exists let (otherwise_bb, otherwise_value) = otherwise .as_ref() .map(|otherwise| { - let otherwise_bb = ctx.function.basic_blocks.push(BasicBlock::default()); - ctx.current_bb = otherwise_bb; - let otherwise_value = lower_block(ctx, otherwise); + let otherwise_bb = builder.push_bb(); + let otherwise_value = lower_block(ctx, builder, otherwise); (otherwise_bb, otherwise_value) }) .expect("else branch to have value"); // Revert back to original location - ctx.current_bb = here; + builder.goto_bb(original_bb); - Value::Triple(ctx.add_triple(Triple::Switch { + Value::Triple(builder.add_triple(Triple::Switch { value: condition, default: (success_bb, success_value), branches: vec![(Value::integer(0), otherwise_bb, otherwise_value)], @@ -167,7 +127,7 @@ fn lower_expression(ctx: &mut FunctionLoweringCtx, expression: &ast::Expression) } ast::Expression::Call(call) => { let idx = call.name; - Value::Triple(ctx.add_triple(Triple::Call(idx))) + Value::Triple(builder.add_triple(Triple::Call(idx))) } } } diff --git a/tests/programs.rs b/tests/programs.rs index de5968d..76b3c2f 100644 --- a/tests/programs.rs +++ b/tests/programs.rs @@ -1,6 +1,6 @@ use lumina::{ compile_pass::CompilePass, - stage::{lex::Lexer, lower_ir}, + stage::{codegen::llvm::Pass, lex::Lexer, lower_ir, lower_ir::IRCtx, parse::parse}, util::source::Source, }; use rstest::rstest; @@ -95,8 +95,6 @@ fn main() -> int { }"# )] fn programs(#[case] expected: i64, #[case] source: &'static str) { - use lumina::stage::{codegen::llvm::Pass, parse::parse}; - let source = Source::new(source); let mut ctx = CompilePass::default(); @@ -119,11 +117,19 @@ fn programs(#[case] expected: i64, #[case] source: &'static str) { let main = program.main.name; - let ir_ctx = lower_ir::lower(program); - let ctx = inkwell::context::Context::create(); + lower_ir::lower(&mut ctx, program); + let llvm_ctx = inkwell::context::Context::create(); - let mut llvm_pass = Pass::new(&ctx, ir_ctx); - let main = llvm_pass.compile(main); + let function_ids = ctx + .all_functions() + .iter() + .map(|(idx, _)| *idx) + .collect::>(); + let mut llvm_pass = Pass::new(&llvm_ctx, ctx); + function_ids.into_iter().for_each(|function| { + llvm_pass.compile(function); + }); + let main = *llvm_pass.function_values.get(&main).unwrap(); let result = llvm_pass.jit(main);