From 7bee1b5342598fd5aec07901fb16271b7d159685 Mon Sep 17 00:00:00 2001 From: Constantine Theocharis Date: Sat, 9 Sep 2023 19:09:29 +0000 Subject: [PATCH 1/4] Rename `FnCallTerm` -> `CallTerm` and `FnRef` -> `Fn` --- compiler/hash-intrinsics/src/intrinsics.rs | 2 +- compiler/hash-lower/src/build/category.rs | 4 ++-- compiler/hash-lower/src/build/into.rs | 6 +++--- compiler/hash-lower/src/build/matches/mod.rs | 2 +- compiler/hash-lower/src/build/place.rs | 4 ++-- compiler/hash-lower/src/build/rvalue.rs | 6 +++--- compiler/hash-lower/src/build/ty.rs | 10 +++++----- .../hash-semantics/src/passes/evaluation/mod.rs | 4 ++-- .../src/passes/resolution/exprs.rs | 14 +++++++------- .../src/passes/resolution/params.rs | 8 ++++---- .../src/passes/resolution/paths.rs | 6 +++--- .../hash-semantics/src/passes/resolution/tys.rs | 8 ++++---- compiler/hash-tir/src/fns.rs | 4 ++-- compiler/hash-tir/src/scopes.rs | 2 +- compiler/hash-tir/src/terms.rs | 10 +++++----- compiler/hash-tir/src/utils/traversing.rs | 16 ++++++++-------- compiler/hash-typecheck/src/inference.rs | 16 ++++++++-------- compiler/hash-typecheck/src/normalisation.rs | 16 ++++++++-------- compiler/hash-typecheck/src/unification.rs | 14 ++++++-------- 19 files changed, 75 insertions(+), 77 deletions(-) diff --git a/compiler/hash-intrinsics/src/intrinsics.rs b/compiler/hash-intrinsics/src/intrinsics.rs index a4f86b1bc..6f4ec3109 100644 --- a/compiler/hash-intrinsics/src/intrinsics.rs +++ b/compiler/hash-intrinsics/src/intrinsics.rs @@ -738,7 +738,7 @@ impl DefinedIntrinsics { "print_fn_directives", FnTy::builder().params(params).return_ty(ret).build(), |_, args| { - if let Term::FnRef(fn_def) = *args[1].value() { + if let Term::Fn(fn_def) = *args[1].value() { attr_store().map_with_default(fn_def.node_id_or_default(), |attrs| { stream_less_writeln!("{:?}", attrs); }); diff --git a/compiler/hash-lower/src/build/category.rs b/compiler/hash-lower/src/build/category.rs index fe9e18a2a..990afc6b2 100644 --- a/compiler/hash-lower/src/build/category.rs +++ b/compiler/hash-lower/src/build/category.rs @@ -55,9 +55,9 @@ impl Category { | Term::LoopControl(..) | Term::Block(_) | Term::Ctor(_) - | Term::FnRef(_) + | Term::Fn(_) | Term::Match(..) - | Term::FnCall(..) => Category::RValue(RValueKind::Into), + | Term::Call(..) => Category::RValue(RValueKind::Into), Term::Tuple(_) | Term::Decl(_) diff --git a/compiler/hash-lower/src/build/into.rs b/compiler/hash-lower/src/build/into.rs index a807b3b8a..4b40beddb 100644 --- a/compiler/hash-lower/src/build/into.rs +++ b/compiler/hash-lower/src/build/into.rs @@ -23,7 +23,7 @@ use hash_tir::{ control::{LoopControlTerm, ReturnTerm}, data::CtorTerm, environment::env::AccessToEnv, - fns::FnCallTerm, + fns::CallTerm, node::NodesId, params::ParamIndex, refs::{self, RefTerm}, @@ -129,7 +129,7 @@ impl<'tcx> BodyBuilder<'tcx> { self.constructor_into_dest(destination, block, ctor, adt, span) } } - Term::FnCall(ref fn_term @ FnCallTerm { subject, args, .. }) => { + Term::Call(ref fn_term @ CallTerm { subject, args, .. }) => { match self.classify_fn_call_term(fn_term) { FnCallTermKind::Call(_) => { // Get the type of the function into or to to get the @@ -336,7 +336,7 @@ impl<'tcx> BodyBuilder<'tcx> { | Ty::RefTy(_) | Ty::Universe | Term::Hole(_) - | Term::FnRef(_) => block.unit(), + | Term::Fn(_) => block.unit(), }; block_and diff --git a/compiler/hash-lower/src/build/matches/mod.rs b/compiler/hash-lower/src/build/matches/mod.rs index 27e0f0908..b49301be0 100644 --- a/compiler/hash-lower/src/build/matches/mod.rs +++ b/compiler/hash-lower/src/build/matches/mod.rs @@ -107,7 +107,7 @@ impl<'tcx> BodyBuilder<'tcx> { // If this is a `&&`, we can create a `then-else` block sequence // that respects the short-circuiting behaviour of `&&`. - if let Term::FnCall(ref fn_call) = *expr.value() { + if let Term::Call(ref fn_call) = *expr.value() { if let FnCallTermKind::LogicalBinOp(LogicalBinOp::And, lhs, rhs) = self.classify_fn_call_term(fn_call) { diff --git a/compiler/hash-lower/src/build/place.rs b/compiler/hash-lower/src/build/place.rs index 544089218..462345f0e 100644 --- a/compiler/hash-lower/src/build/place.rs +++ b/compiler/hash-lower/src/build/place.rs @@ -144,9 +144,9 @@ impl<'tcx> BodyBuilder<'tcx> { Term::Tuple(_) | Term::Lit(_) | Term::Array(_) - | Term::FnCall(_) + | Term::Call(_) | Term::Ctor(_) - | Term::FnRef(_) + | Term::Fn(_) | Term::Block(_) | Term::Loop(_) | Term::LoopControl(_) diff --git a/compiler/hash-lower/src/build/rvalue.rs b/compiler/hash-lower/src/build/rvalue.rs index bbd57a806..fd4116b5c 100644 --- a/compiler/hash-lower/src/build/rvalue.rs +++ b/compiler/hash-lower/src/build/rvalue.rs @@ -38,7 +38,7 @@ impl<'tcx> BodyBuilder<'tcx> { let value = self.as_constant(lit).into(); block.and(value) } - ref fn_call_term @ Term::FnCall(fn_call) => { + ref fn_call_term @ Term::Call(fn_call) => { match self.classify_fn_call_term(&fn_call) { FnCallTermKind::BinaryOp(op, lhs, rhs) => { let lhs = unpack!(block = self.as_operand(block, lhs, Mutability::Mutable)); @@ -113,7 +113,7 @@ impl<'tcx> BodyBuilder<'tcx> { ref term @ (Term::Array(_) | Term::Tuple(_) | Term::Ctor(_) - | Term::FnRef(_) + | Term::Fn(_) | Term::Block(_) | Term::Var(_) | Term::Loop(_) @@ -172,7 +172,7 @@ impl<'tcx> BodyBuilder<'tcx> { // If the item is a reference to a function, i.e. the subject of a call, then // we emit a constant that refers to the function. - if let Term::FnRef(def_id) = *term { + if let Term::Fn(def_id) = *term { let ty_id = self.ty_id_from_tir_fn_def(def_id); // If this is a function type, we emit a ZST to represent the operand diff --git a/compiler/hash-lower/src/build/ty.rs b/compiler/hash-lower/src/build/ty.rs index 9a2f70d5a..38e9a7d13 100644 --- a/compiler/hash-lower/src/build/ty.rs +++ b/compiler/hash-lower/src/build/ty.rs @@ -20,7 +20,7 @@ use hash_tir::{ atom_info::ItemInAtomInfo, data::DataTy, environment::env::AccessToEnv, - fns::{FnCallTerm, FnDefId}, + fns::{CallTerm, FnDefId}, lits::{Lit, LitPat}, pats::PatId, terms::{Term, TermId, TyId}, @@ -37,7 +37,7 @@ use super::BodyBuilder; pub enum FnCallTermKind { /// A function call, the term doesn't change and should just be /// handled as a function call. - Call(FnCallTerm), + Call(CallTerm), /// A cast intrinsic operation, we perform a cast from the type of the /// first term into the desired second [IrTyId]. @@ -90,11 +90,11 @@ impl<'tcx> BodyBuilder<'tcx> { /// Function which is used to classify a [FnCallTerm] into a /// [FnCallTermKind]. - pub(crate) fn classify_fn_call_term(&self, term: &FnCallTerm) -> FnCallTermKind { - let FnCallTerm { subject, args, .. } = term; + pub(crate) fn classify_fn_call_term(&self, term: &CallTerm) -> FnCallTermKind { + let CallTerm { subject, args, .. } = term; match *subject.value() { - Term::FnRef(fn_def) => { + Term::Fn(fn_def) => { // Check if the fn_def is a `un_op` intrinsic if fn_def == self.intrinsics().un_op() { let (op, subject) = diff --git a/compiler/hash-semantics/src/passes/evaluation/mod.rs b/compiler/hash-semantics/src/passes/evaluation/mod.rs index 74b921ea8..6901ab55f 100644 --- a/compiler/hash-semantics/src/passes/evaluation/mod.rs +++ b/compiler/hash-semantics/src/passes/evaluation/mod.rs @@ -10,7 +10,7 @@ use hash_storage::store::statics::SequenceStoreValue; use hash_tir::{ args::Arg, environment::env::AccessToEnv, - fns::FnCallTerm, + fns::CallTerm, node::{Node, NodeId}, terms::{Term, TermId}, utils::common::dump_tir, @@ -54,7 +54,7 @@ impl EvaluationPass<'_> { match def { Some(def) => { let call_term = Term::from( - FnCallTerm { + CallTerm { subject: Term::from(def, def.origin()), implicit: false, args: Node::create_at(Node::::empty_seq(), def.origin()), diff --git a/compiler/hash-semantics/src/passes/resolution/exprs.rs b/compiler/hash-semantics/src/passes/resolution/exprs.rs index 54c3a9873..f75d0c148 100644 --- a/compiler/hash-semantics/src/passes/resolution/exprs.rs +++ b/compiler/hash-semantics/src/passes/resolution/exprs.rs @@ -25,7 +25,7 @@ use hash_tir::{ control::{LoopControlTerm, LoopTerm, MatchCase, MatchTerm, ReturnTerm}, data::DataTy, environment::env::AccessToEnv, - fns::{FnBody, FnCallTerm, FnDefId}, + fns::{CallTerm, FnBody, FnDefId}, lits::{CharLit, FloatLit, IntLit, Lit, StrLit}, node::{Node, NodeOrigin}, params::ParamIndex, @@ -345,7 +345,7 @@ impl<'tc> ResolutionPass<'tc> { ResolvedAstPathComponent::Terminal(terminal) => match terminal { TerminalResolvedPathComponent::FnDef(fn_def_id) => { // Reference to a function definition - Ok(Term::from(Term::FnRef(*fn_def_id), origin)) + Ok(Term::from(Term::Fn(*fn_def_id), origin)) } TerminalResolvedPathComponent::CtorPat(_) => { panic_on_span!( @@ -360,7 +360,7 @@ impl<'tc> ResolutionPass<'tc> { } TerminalResolvedPathComponent::FnCall(fn_call_term) => { // Function call - Ok(Term::from(Term::FnCall(**fn_call_term), origin)) + Ok(Term::from(Term::Call(**fn_call_term), origin)) } TerminalResolvedPathComponent::Var(bound_var) => { // Bound variable @@ -399,7 +399,7 @@ impl<'tc> ResolutionPass<'tc> { match (subject, args) { (Some(subject), Some(args)) => Ok(Term::from( - Term::FnCall(FnCallTerm { subject, args, implicit: false }), + Term::Call(CallTerm { subject, args, implicit: false }), NodeOrigin::Given(node.id()), )), _ => Err(SemanticError::Signal), @@ -888,7 +888,7 @@ impl<'tc> ResolutionPass<'tc> { // If all ok, create a fn ref term match (params, return_ty, return_value) { (Some(_), None | Some(Some(_)), Some(_)) => { - Ok(Term::from(Term::FnRef(fn_def_id), NodeOrigin::Given(node_id))) + Ok(Term::from(Term::Fn(fn_def_id), NodeOrigin::Given(node_id))) } _ => Err(SemanticError::Signal), } @@ -1011,7 +1011,7 @@ impl<'tc> ResolutionPass<'tc> { // Invoke the intrinsic function Ok(Term::from( - FnCallTerm { + CallTerm { subject: Term::from(intrinsic_fn_def, origin), args: Arg::seq_positional( [typeof_lhs, self.create_term_from_integer_lit(bin_op_num), lhs, rhs], @@ -1044,7 +1044,7 @@ impl<'tc> ResolutionPass<'tc> { // Invoke the intrinsic function Ok(Term::from( - FnCallTerm { + CallTerm { subject: Term::from(intrinsic_fn_def, origin), args: Arg::seq_positional( [typeof_a, self.create_term_from_integer_lit(op_num), a], diff --git a/compiler/hash-semantics/src/passes/resolution/params.rs b/compiler/hash-semantics/src/passes/resolution/params.rs index 81165849c..a9c51a6ac 100644 --- a/compiler/hash-semantics/src/passes/resolution/params.rs +++ b/compiler/hash-semantics/src/passes/resolution/params.rs @@ -8,7 +8,7 @@ use hash_storage::store::{ }; use hash_tir::{ args::{ArgsId, PatArgsId}, - fns::FnCallTerm, + fns::CallTerm, node::{Node, NodeOrigin}, params::{Param, ParamId, ParamsId, SomeParamsOrArgsId}, pats::Spread, @@ -292,7 +292,7 @@ impl<'tc> ResolutionPass<'tc> { subject: TermId, args: &[AstArgGroup], original_node_id: AstNodeId, - ) -> SemanticResult { + ) -> SemanticResult { debug_assert!(!args.is_empty()); let mut current_subject = subject; for arg_group in args { @@ -303,7 +303,7 @@ impl<'tc> ResolutionPass<'tc> { // Here we are trying to call a function with term arguments. // Apply the arguments to the current subject and continue. current_subject = Term::from( - Term::FnCall(FnCallTerm { + Term::Call(CallTerm { subject: current_subject, args, implicit: matches!(arg_group, AstArgGroup::ImplicitArgs(_)), @@ -321,7 +321,7 @@ impl<'tc> ResolutionPass<'tc> { } } match *current_subject.value() { - Term::FnCall(call) => Ok(call), + Term::Call(call) => Ok(call), _ => unreachable!(), } } diff --git a/compiler/hash-semantics/src/passes/resolution/paths.rs b/compiler/hash-semantics/src/passes/resolution/paths.rs index e229cbdf2..585fb9f57 100644 --- a/compiler/hash-semantics/src/passes/resolution/paths.rs +++ b/compiler/hash-semantics/src/passes/resolution/paths.rs @@ -27,7 +27,7 @@ use hash_storage::store::statics::{SequenceStoreValue, StoreId}; use hash_tir::{ args::{Arg, ArgsId}, data::{CtorPat, CtorTerm, DataDefId}, - fns::{FnCallTerm, FnDefId}, + fns::{CallTerm, FnDefId}, mods::{ModDefId, ModMemberValue}, node::{Node, NodeId, NodeOrigin}, symbols::SymbolId, @@ -121,7 +121,7 @@ pub enum TerminalResolvedPathComponent { /// A data constructor term. CtorTerm(Node), /// A function call term. - FnCall(Node), + FnCall(Node), /// A variable bound in the current context. Var(SymbolId), } @@ -302,7 +302,7 @@ impl<'tc> ResolutionPass<'tc> { )), args => { let resultant_term = self.wrap_term_in_fn_call_from_ast_args( - Term::from(Term::FnRef(fn_def_id), component.origin()), + Term::from(Term::Fn(fn_def_id), component.origin()), args, component.node_id, )?; diff --git a/compiler/hash-semantics/src/passes/resolution/tys.rs b/compiler/hash-semantics/src/passes/resolution/tys.rs index a9e531eb3..2a9efabe8 100644 --- a/compiler/hash-semantics/src/passes/resolution/tys.rs +++ b/compiler/hash-semantics/src/passes/resolution/tys.rs @@ -14,7 +14,7 @@ use hash_tir::{ args::{Arg, ArgsId}, data::DataTy, environment::env::AccessToEnv, - fns::FnCallTerm, + fns::CallTerm, node::{Node, NodeOrigin}, params::ParamIndex, primitives::primitives, @@ -136,7 +136,7 @@ impl<'tc> ResolutionPass<'tc> { }, ResolvedAstPathComponent::Terminal(terminal) => match terminal { TerminalResolvedPathComponent::FnDef(fn_def_id) => { - Ok(Term::from(Term::FnRef(*fn_def_id), origin)) + Ok(Term::from(Term::Fn(*fn_def_id), origin)) } TerminalResolvedPathComponent::CtorPat(_) => { panic_on_span!( @@ -149,7 +149,7 @@ impl<'tc> ResolutionPass<'tc> { Ok(Term::from(Term::Ctor(**ctor_term), origin)) } TerminalResolvedPathComponent::FnCall(fn_call_term) => { - Ok(Term::from(Term::FnCall(**fn_call_term), origin)) + Ok(Term::from(Term::Call(**fn_call_term), origin)) } TerminalResolvedPathComponent::Var(bound_var) => { Ok(Ty::from(Ty::Var(*bound_var), origin)) @@ -218,7 +218,7 @@ impl<'tc> ResolutionPass<'tc> { match (subject, args) { (Some(subject), Some(args)) => Ok(Term::from( - Term::FnCall(FnCallTerm { subject, args, implicit: true }), + Term::Call(CallTerm { subject, args, implicit: true }), NodeOrigin::Given(node.id()), )), _ => Err(SemanticError::Signal), diff --git a/compiler/hash-tir/src/fns.rs b/compiler/hash-tir/src/fns.rs index dbfe7c2ef..2c8d6e203 100644 --- a/compiler/hash-tir/src/fns.rs +++ b/compiler/hash-tir/src/fns.rs @@ -106,7 +106,7 @@ tir_node_single_store!(FnDef); /// A function call. #[derive(Debug, Clone, Copy)] -pub struct FnCallTerm { +pub struct CallTerm { /// The function being called /// /// This could be a function definition, a value of a function type, or a @@ -201,7 +201,7 @@ impl Display for FnDefId { } } -impl Display for FnCallTerm { +impl Display for CallTerm { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.subject)?; diff --git a/compiler/hash-tir/src/scopes.rs b/compiler/hash-tir/src/scopes.rs index 54a08f3fa..7c00ceec0 100644 --- a/compiler/hash-tir/src/scopes.rs +++ b/compiler/hash-tir/src/scopes.rs @@ -177,7 +177,7 @@ impl fmt::Display for DeclTerm { match *term_id.value() { // If a function is being declared, print the body, otherwise just // its name. - Term::FnRef(fn_def_id) + Term::Fn(fn_def_id) if fn_def_id.map(|def| def.name == binding_pat.name) => { fn_def_id.to_string() diff --git a/compiler/hash-tir/src/terms.rs b/compiler/hash-tir/src/terms.rs index 3fc906485..16729b120 100644 --- a/compiler/hash-tir/src/terms.rs +++ b/compiler/hash-tir/src/terms.rs @@ -17,7 +17,7 @@ use crate::{ control::{LoopControlTerm, LoopTerm, MatchTerm, ReturnTerm}, data::{CtorTerm, DataDefId, DataTy}, environment::stores::tir_stores, - fns::{FnCallTerm, FnDefId, FnTy}, + fns::{CallTerm, FnDefId, FnTy}, lits::LitId, node::{Node, NodeId, NodeOrigin}, params::Param, @@ -68,8 +68,8 @@ pub enum Term { Ctor(CtorTerm), // Functions - FnCall(FnCallTerm), - FnRef(FnDefId), + Call(CallTerm), + Fn(FnDefId), // Loops Loop(LoopTerm), @@ -233,8 +233,8 @@ impl fmt::Display for Term { Term::Tuple(tuple_term) => write!(f, "{}", tuple_term), Term::Lit(lit) => write!(f, "{}", *lit.value()), Term::Ctor(ctor_term) => write!(f, "{}", ctor_term), - Term::FnCall(fn_call_term) => write!(f, "{}", fn_call_term), - Term::FnRef(fn_def_id) => write!( + Term::Call(fn_call_term) => write!(f, "{}", fn_call_term), + Term::Fn(fn_def_id) => write!( f, "{}", if fn_def_id.map(|fn_def| fn_def.name.map(|sym| sym.name.is_none())) { diff --git a/compiler/hash-tir/src/utils/traversing.rs b/compiler/hash-tir/src/utils/traversing.rs index 4c5ce8ca6..9c3e4ea74 100644 --- a/compiler/hash-tir/src/utils/traversing.rs +++ b/compiler/hash-tir/src/utils/traversing.rs @@ -17,7 +17,7 @@ use crate::{ control::{IfPat, LoopTerm, MatchCase, MatchTerm, OrPat, ReturnTerm}, data::{CtorDefId, CtorPat, CtorTerm, DataDefCtors, DataDefId, DataTy, PrimitiveCtorInfo}, environment::env::Env, - fns::{FnBody, FnCallTerm, FnDef, FnDefId, FnTy}, + fns::{CallTerm, FnBody, FnDef, FnDefId, FnTy}, mods::{ModDefId, ModMemberId, ModMemberValue}, node::{HasAstNodeId, Node, NodeId, NodeOrigin, NodesId}, params::{Param, ParamsId}, @@ -118,7 +118,7 @@ impl TraversingUtils { let result = match f(term_id.into())? { ControlFlow::Break(atom) => match atom { Atom::Term(t) => Ok(t), - Atom::FnDef(fn_def_id) => Ok(Node::create_at(Term::FnRef(fn_def_id), origin)), + Atom::FnDef(fn_def_id) => Ok(Node::create_at(Term::Fn(fn_def_id), origin)), Atom::Pat(_) => unreachable!("cannot use a pattern as a term"), }, ControlFlow::Continue(()) => match *term_id.value() { @@ -136,17 +136,17 @@ impl TraversingUtils { let ctor_args = self.fmap_args(ctor_term.ctor_args, f)?; Ok(Term::from(CtorTerm { ctor: ctor_term.ctor, data_args, ctor_args }, origin)) } - Term::FnCall(fn_call_term) => { + Term::Call(fn_call_term) => { let subject = self.fmap_term(fn_call_term.subject, f)?; let args = self.fmap_args(fn_call_term.args, f)?; Ok(Term::from( - FnCallTerm { args, subject, implicit: fn_call_term.implicit }, + CallTerm { args, subject, implicit: fn_call_term.implicit }, origin, )) } - Term::FnRef(fn_def_id) => { + Term::Fn(fn_def_id) => { let fn_def_id = self.fmap_fn_def(fn_def_id, f)?; - Ok(Term::from(Term::FnRef(fn_def_id), origin)) + Ok(Term::from(Term::Fn(fn_def_id), origin)) } Term::Block(block_term) => { let statements = self.fmap_term_list(block_term.statements, f)?; @@ -485,11 +485,11 @@ impl TraversingUtils { self.visit_args(ctor_term.data_args, f)?; self.visit_args(ctor_term.ctor_args, f) } - Term::FnCall(fn_call_term) => { + Term::Call(fn_call_term) => { self.visit_term(fn_call_term.subject, f)?; self.visit_args(fn_call_term.args, f) } - Term::FnRef(fn_def_id) => self.visit_fn_def(fn_def_id, f), + Term::Fn(fn_def_id) => self.visit_fn_def(fn_def_id, f), Term::Block(block_term) => { self.visit_term_list(block_term.statements, f)?; self.visit_term(block_term.return_value, f) diff --git a/compiler/hash-typecheck/src/inference.rs b/compiler/hash-typecheck/src/inference.rs index 336033369..e0eb23839 100644 --- a/compiler/hash-typecheck/src/inference.rs +++ b/compiler/hash-typecheck/src/inference.rs @@ -26,7 +26,7 @@ use hash_tir::{ context::ScopeKind, control::{IfPat, LoopControlTerm, LoopTerm, MatchTerm, OrPat, ReturnTerm}, data::{CtorDefId, CtorPat, CtorTerm, DataDefCtors, DataDefId, DataTy, PrimitiveCtorInfo}, - fns::{FnBody, FnCallTerm, FnDefId, FnTy}, + fns::{CallTerm, FnBody, FnDefId, FnTy}, lits::{Lit, LitId}, mods::{ModDefId, ModMemberId, ModMemberValue}, node::{HasAstNodeId, Node, NodeId, NodeOrigin, NodesId}, @@ -854,7 +854,7 @@ impl InferenceOps<'_, T> { /// Infer the type of a function call. pub fn infer_fn_call_term( &self, - fn_call_term: &FnCallTerm, + fn_call_term: &CallTerm, annotation_ty: TyId, original_term_id: TermId, ) -> TcResult<()> { @@ -869,7 +869,7 @@ impl InferenceOps<'_, T> { if let Ty::FnTy(_) = *fn_ty.return_ty.value() && fn_ty.implicit && !fn_call_term.implicit { let applied_args = Arg::seq_from_params_as_holes(fn_ty.params); let copied_subject = Term::inherited_from(fn_call_term.subject, *fn_call_term.subject.value()); - let new_subject = FnCallTerm { + let new_subject = CallTerm { args: applied_args, subject: copied_subject, implicit: fn_ty.implicit, @@ -948,8 +948,8 @@ impl InferenceOps<'_, T> { // @@MissingOrigin // Maybe it is better to check this manually? let call_term = Node::create_at( - Term::FnCall(FnCallTerm { - subject: Node::create_at(Term::FnRef(fn_def_id), NodeOrigin::Generated), + Term::Call(CallTerm { + subject: Node::create_at(Term::Fn(fn_def_id), NodeOrigin::Generated), implicit: false, args: Node::create_at(Node::::empty_seq(), NodeOrigin::Generated), }), @@ -1011,7 +1011,7 @@ impl InferenceOps<'_, T> { self.infer_params(fn_def.ty.params, || { self.infer_term(fn_def.ty.return_ty, Ty::universe_of(fn_def.ty.return_ty))?; if let FnBody::Defined(fn_body) = fn_def.body { - if let Term::FnRef(immediate_body_fn) = *fn_body.value() { + if let Term::Fn(immediate_body_fn) = *fn_body.value() { self.infer_fn_def( immediate_body_fn, Ty::hole_for(fn_body), @@ -1530,10 +1530,10 @@ impl InferenceOps<'_, T> { Term::Lit(lit_term) => self.infer_lit(lit_term, annotation_ty)?, Term::Array(prim_term) => self.infer_array_term(&prim_term, annotation_ty)?, Term::Ctor(ctor_term) => self.infer_ctor_term(&ctor_term, annotation_ty, term_id)?, - Term::FnCall(fn_call_term) => { + Term::Call(fn_call_term) => { self.infer_fn_call_term(&fn_call_term, annotation_ty, term_id)? } - Term::FnRef(fn_def_id) => { + Term::Fn(fn_def_id) => { self.infer_fn_def(fn_def_id, annotation_ty, term_id, FnInferMode::Body)? } Term::Var(var_term) => self.infer_var(var_term, annotation_ty)?, diff --git a/compiler/hash-typecheck/src/normalisation.rs b/compiler/hash-typecheck/src/normalisation.rs index 6bc82260c..d7fe5a078 100644 --- a/compiler/hash-typecheck/src/normalisation.rs +++ b/compiler/hash-typecheck/src/normalisation.rs @@ -15,7 +15,7 @@ use hash_tir::{ casting::CastTerm, context::ScopeKind, control::{LoopControlTerm, LoopTerm, MatchTerm, ReturnTerm}, - fns::{FnBody, FnCallTerm, FnDefId}, + fns::{CallTerm, FnBody, FnDefId}, holes::Hole, lits::{Lit, LitPat}, node::{Node, NodeId, NodesId}, @@ -232,7 +232,7 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { pub fn maybe_to_fn_def(&self, atom: Atom) -> Option { match atom { Atom::Term(term) => match *term.value() { - Term::FnRef(fn_def_id) => Some(fn_def_id), + Term::Fn(fn_def_id) => Some(fn_def_id), _ => None, }, Atom::FnDef(fn_def_id) => Some(fn_def_id), @@ -286,7 +286,7 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { match atom { Atom::Term(term) => match *term.value() { // Never has effects - Term::Hole(_) | Term::FnRef(_) => Ok(ControlFlow::Break(())), + Term::Hole(_) | Term::Fn(_) => Ok(ControlFlow::Break(())), // These have effects if their constituents do Term::Lit(_) @@ -308,7 +308,7 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { | Term::FnTy(_) | Term::Block(_) => Ok(ControlFlow::Continue(())), - Term::FnCall(fn_call) => { + Term::Call(fn_call) => { // Get its inferred type and check if it is pure match self.try_get_inferred_ty(fn_call.subject) { Some(fn_ty) => { @@ -721,14 +721,14 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { } /// Evaluate a function call. - fn eval_fn_call(&self, mut fn_call: Node) -> AtomEvaluation { + fn eval_fn_call(&self, mut fn_call: Node) -> AtomEvaluation { let st = eval_state(); fn_call.subject = self.to_term(self.eval_and_record(fn_call.subject.into(), &st)?); fn_call.args = st.update_from_evaluation(fn_call.args, self.eval_args(fn_call.args))?; // Beta-reduce: - if let Term::FnRef(fn_def_id) = *fn_call.subject.value() { + if let Term::Fn(fn_def_id) = *fn_call.subject.value() { let fn_def = fn_def_id.value(); if (fn_def.ty.pure || matches!(self.mode.get(), NormalisationMode::Full)) && self.try_get_inferred_ty(fn_def_id).is_some() @@ -844,7 +844,7 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { Term::TypeOf(term) => ctrl_map(self.eval_type_of(term)), Term::Unsafe(unsafe_expr) => ctrl_map(self.eval_unsafe(unsafe_expr)), Term::Match(match_term) => ctrl_map(self.eval_match(match_term)), - Term::FnCall(fn_call) => { + Term::Call(fn_call) => { ctrl_map(self.eval_fn_call(term.origin().with_data(fn_call))) } Term::Cast(cast_term) => ctrl_map(self.eval_cast(cast_term)), @@ -857,7 +857,7 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { // Introduction forms: Term::Ref(_) - | Term::FnRef(_) + | Term::Fn(_) | Term::Lit(_) | Term::Array(_) | Term::Tuple(_) diff --git a/compiler/hash-typecheck/src/unification.rs b/compiler/hash-typecheck/src/unification.rs index ced87b04a..1e839a81e 100644 --- a/compiler/hash-typecheck/src/unification.rs +++ b/compiler/hash-typecheck/src/unification.rs @@ -7,7 +7,7 @@ use hash_tir::{ args::ArgsId, context::ScopeKind, data::DataDefCtors, - fns::{FnCallTerm, FnTy}, + fns::{CallTerm, FnTy}, holes::Hole, lits::Lit, params::ParamsId, @@ -260,7 +260,7 @@ impl<'tc, T: AccessToTypechecking> UnificationOps<'tc, T> { } } - pub fn unify_fn_calls(&self, src: FnCallTerm, target: FnCallTerm) -> TcResult<()> { + pub fn unify_fn_calls(&self, src: CallTerm, target: CallTerm) -> TcResult<()> { self.unify_terms(src.subject, target.subject)?; self.unify_args(src.args, target.args)?; Ok(()) @@ -350,13 +350,11 @@ impl<'tc, T: AccessToTypechecking> UnificationOps<'tc, T> { } (Term::Ref(_), _) | (_, Term::Ref(_)) => self.mismatching_atoms(src_id, target_id), - (Term::FnCall(c1), Term::FnCall(c2)) => self.unify_fn_calls(c1, c2), - (Term::FnCall(_), _) | (_, Term::FnCall(_)) => { - self.mismatching_atoms(src_id, target_id) - } + (Term::Call(c1), Term::Call(c2)) => self.unify_fn_calls(c1, c2), + (Term::Call(_), _) | (_, Term::Call(_)) => self.mismatching_atoms(src_id, target_id), - (Term::FnRef(f1), Term::FnRef(f2)) if f1 == f2 => Ok(()), - (Term::FnRef(_), _) | (_, Term::FnRef(_)) => self.mismatching_atoms(src_id, target_id), + (Term::Fn(f1), Term::Fn(f2)) if f1 == f2 => Ok(()), + (Term::Fn(_), _) | (_, Term::Fn(_)) => self.mismatching_atoms(src_id, target_id), // @@Todo: rest _ => self.mismatching_atoms(src_id, target_id), From 558001edcf7f9f62f63cb8091f1c4a8fd0268d0f Mon Sep 17 00:00:00 2001 From: Constantine Theocharis Date: Mon, 11 Sep 2023 10:16:02 +0000 Subject: [PATCH 2/4] Make `Decl` a special part of blocks rather than a general term --- compiler/hash-lower/src/build/block.rs | 18 +-- compiler/hash-lower/src/build/category.rs | 1 - compiler/hash-lower/src/build/into.rs | 4 - .../src/build/matches/declarations.rs | 4 +- compiler/hash-lower/src/build/place.rs | 1 - compiler/hash-lower/src/build/rvalue.rs | 1 - .../src/environment/ast_info.rs | 4 +- .../src/passes/discovery/defs.rs | 21 ++-- .../src/passes/resolution/exprs.rs | 117 ++++++++---------- compiler/hash-storage/src/store/statics.rs | 6 +- compiler/hash-tir/src/args.rs | 2 +- compiler/hash-tir/src/context.rs | 50 +++++--- compiler/hash-tir/src/control.rs | 15 +-- compiler/hash-tir/src/data.rs | 2 +- compiler/hash-tir/src/environment/stores.rs | 4 +- compiler/hash-tir/src/lits.rs | 35 +++--- compiler/hash-tir/src/mods.rs | 2 +- compiler/hash-tir/src/params.rs | 9 +- compiler/hash-tir/src/scopes.rs | 72 ++++++++--- compiler/hash-tir/src/terms.rs | 8 +- compiler/hash-tir/src/utils/traversing.rs | 96 ++++++++------ compiler/hash-typecheck/src/inference.rs | 63 +++++----- compiler/hash-typecheck/src/normalisation.rs | 81 ++++++------ compiler/hash-typecheck/src/substitution.rs | 4 +- 24 files changed, 329 insertions(+), 291 deletions(-) diff --git a/compiler/hash-lower/src/build/block.rs b/compiler/hash-lower/src/build/block.rs index f92b4dbbe..14b681bf2 100644 --- a/compiler/hash-lower/src/build/block.rs +++ b/compiler/hash-lower/src/build/block.rs @@ -12,7 +12,8 @@ use hash_storage::store::{statics::StoreId, TrivialSequenceStoreKey}; use hash_tir::{ context::{Context, ScopeKind}, control::{LoopTerm, MatchTerm}, - scopes::BlockTerm, + node::HasAstNodeId, + scopes::{BlockStatement, BlockTerm}, terms::{Term, TermId}, }; @@ -30,7 +31,7 @@ impl<'tcx> BodyBuilder<'tcx> { match *block_term.value() { Term::Block(ref body) => self.body_block_into_dest(place, block, body), - Term::Loop(LoopTerm { block: ref body }) => { + Term::Loop(LoopTerm { inner }) => { // Begin the loop block by connecting the previous block // and terminating it with a `goto` instruction to this block let loop_body = self.control_flow_graph.start_new_block(); @@ -43,8 +44,7 @@ impl<'tcx> BodyBuilder<'tcx> { // always going to be `()` let tmp_place = this.make_tmp_unit(); - let body_block_end = - unpack!(this.body_block_into_dest(tmp_place, loop_body, body)); + let body_block_end = unpack!(this.term_into_dest(tmp_place, loop_body, inner)); // In the situation that we have the final statement in the loop, this // block should go back to the start of the loop... @@ -68,7 +68,7 @@ impl<'tcx> BodyBuilder<'tcx> { mut block: BasicBlock, body: &BlockTerm, ) -> BlockAnd<()> { - let BlockTerm { stack_id, statements, return_value } = body; + let BlockTerm { stack_id, statements, expr } = body; if self.reached_terminator { return block.unit(); @@ -84,11 +84,11 @@ impl<'tcx> BodyBuilder<'tcx> { // We need to handle declarations here specifically, otherwise // in order to not have to create a temporary for the declaration // which doesn't make sense because we are just declaring a local(s) - Term::Decl(decl) => { - let span = this.span_of_term(statement); + BlockStatement::Decl(decl) => { + let span = statement.node_id_or_default(); unpack!(block = this.lower_declaration(block, &decl, span)); } - _ => { + BlockStatement::Expr(statement) => { // @@Investigate: do we need to deal with the temporary here? unpack!( block = this.term_into_temp(block, statement, Mutability::Immutable) @@ -100,7 +100,7 @@ impl<'tcx> BodyBuilder<'tcx> { // If this block has an expression, we need to deal with it since // it might change the destination of this block. if !this.reached_terminator { - unpack!(block = this.term_into_dest(place, block, *return_value)); + unpack!(block = this.term_into_dest(place, block, *expr)); } }); diff --git a/compiler/hash-lower/src/build/category.rs b/compiler/hash-lower/src/build/category.rs index 990afc6b2..d5d6ded2a 100644 --- a/compiler/hash-lower/src/build/category.rs +++ b/compiler/hash-lower/src/build/category.rs @@ -60,7 +60,6 @@ impl Category { | Term::Call(..) => Category::RValue(RValueKind::Into), Term::Tuple(_) - | Term::Decl(_) | Term::Assign(_) | Term::Array(_) | Term::Cast(_) diff --git a/compiler/hash-lower/src/build/into.rs b/compiler/hash-lower/src/build/into.rs index 4b40beddb..91416e298 100644 --- a/compiler/hash-lower/src/build/into.rs +++ b/compiler/hash-lower/src/build/into.rs @@ -286,10 +286,6 @@ impl<'tcx> BodyBuilder<'tcx> { self.control_flow_graph.goto(block, return_block, span); self.control_flow_graph.start_new_block().unit() } - // For declarations, we have to perform some bookkeeping in regards - // to locals..., but this expression should never return any value - // so we should just return a unit block here - Term::Decl(ref decl) => self.lower_declaration(block, decl, span), Term::Assign(assign_term) => { // Deal with the actual assignment block = unpack!(self.lower_assign_term(block, assign_term, span)); diff --git a/compiler/hash-lower/src/build/matches/declarations.rs b/compiler/hash-lower/src/build/matches/declarations.rs index 1e158a4b6..373cbfa56 100644 --- a/compiler/hash-lower/src/build/matches/declarations.rs +++ b/compiler/hash-lower/src/build/matches/declarations.rs @@ -15,7 +15,7 @@ use hash_tir::{ environment::env::AccessToEnv, node::NodesId, pats::{Pat, PatId}, - scopes::{BindingPat, DeclTerm}, + scopes::{BindingPat, Decl}, symbols::SymbolId, terms::TermId, tuples::TuplePat, @@ -46,7 +46,7 @@ impl<'tcx> BodyBuilder<'tcx> { pub(crate) fn lower_declaration( &mut self, mut block: BasicBlock, - decl: &DeclTerm, + decl: &Decl, decl_origin: AstNodeId, ) -> BlockAnd<()> { if let Some(value) = &decl.value { diff --git a/compiler/hash-lower/src/build/place.rs b/compiler/hash-lower/src/build/place.rs index 462345f0e..ba8ac1b85 100644 --- a/compiler/hash-lower/src/build/place.rs +++ b/compiler/hash-lower/src/build/place.rs @@ -152,7 +152,6 @@ impl<'tcx> BodyBuilder<'tcx> { | Term::LoopControl(_) | Term::Match(_) | Term::Return(_) - | Term::Decl(_) | Term::Assign(_) | Term::Unsafe(_) | Term::Cast(_) diff --git a/compiler/hash-lower/src/build/rvalue.rs b/compiler/hash-lower/src/build/rvalue.rs index fd4116b5c..c10a7574c 100644 --- a/compiler/hash-lower/src/build/rvalue.rs +++ b/compiler/hash-lower/src/build/rvalue.rs @@ -120,7 +120,6 @@ impl<'tcx> BodyBuilder<'tcx> { | Term::LoopControl(_) | Term::Match(_) | Term::Return(_) - | Term::Decl(_) | Term::Assign(_) | Term::Unsafe(_) | Term::Access(_) diff --git a/compiler/hash-semantics/src/environment/ast_info.rs b/compiler/hash-semantics/src/environment/ast_info.rs index 4c77f4cb2..0b22c9e84 100644 --- a/compiler/hash-semantics/src/environment/ast_info.rs +++ b/compiler/hash-semantics/src/environment/ast_info.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, fmt::Debug, hash::Hash}; use hash_ast::ast::AstNodeId; use hash_tir::{ args::{ArgId, ArgsSeqId, PatArgId, PatArgsSeqId}, - context::Decl, + context::ContextMember, data::{CtorDefId, CtorDefsSeqId, DataDefId}, fns::FnDefId, mods::{ModDefId, ModMemberId, ModMembersSeqId}, @@ -161,7 +161,7 @@ ast_info! { fn_defs: AstMap, stacks: AstMap, - stack_members: AstMap, + stack_members: AstMap, terms: AstMap, tys: AstMap, diff --git a/compiler/hash-semantics/src/passes/discovery/defs.rs b/compiler/hash-semantics/src/passes/discovery/defs.rs index 2b3520ea5..675780d64 100644 --- a/compiler/hash-semantics/src/passes/discovery/defs.rs +++ b/compiler/hash-semantics/src/passes/discovery/defs.rs @@ -9,7 +9,7 @@ use hash_storage::store::{ DefaultPartialStore, PartialStore, SequenceStoreKey, StoreKey, }; use hash_tir::{ - context::Decl, + context::ContextMember, data::{CtorDef, CtorDefData, CtorDefId, DataDefCtors, DataDefId}, environment::env::AccessToEnv, fns::FnDefId, @@ -75,7 +75,7 @@ pub(super) enum ItemId { /// contain local definitions. #[derive(Debug, Copy, Clone, From)] enum StackMemberOrModMember { - StackMember(Decl), + StackMember(ContextMember), ModMember(ModMember), } @@ -509,14 +509,15 @@ impl<'tc> DiscoveryPass<'tc> { pub(super) fn add_stack_members_in_pat_to_buf( &self, node: AstNodeRef, - buf: &mut SmallVec<[(AstNodeId, Decl); 3]>, + buf: &mut SmallVec<[(AstNodeId, ContextMember); 3]>, ) { let register_spread_pat = - |spread: &AstNode, buf: &mut SmallVec<[(AstNodeId, Decl); 3]>| { + |spread: &AstNode, + buf: &mut SmallVec<[(AstNodeId, ContextMember); 3]>| { if let Some(name) = &spread.name { buf.push(( name.id(), - Decl { + ContextMember { name: SymbolId::from_name(name.ident, NodeOrigin::Given(name.id())), ty: None, value: None, @@ -529,7 +530,7 @@ impl<'tc> DiscoveryPass<'tc> { ast::Pat::Binding(binding) => { buf.push(( node.id(), - Decl { + ContextMember { name: SymbolId::from_name( binding.name.ident, NodeOrigin::Given(binding.name.id()), @@ -584,7 +585,7 @@ impl<'tc> DiscoveryPass<'tc> { ast::Pat::If(if_pat) => self.add_stack_members_in_pat_to_buf(if_pat.pat.ast_ref(), buf), ast::Pat::Wild(_) => buf.push(( node.id(), - Decl { + ContextMember { name: SymbolId::fresh(NodeOrigin::Given(node.id())), // is_mutable: false, ty: None, @@ -622,8 +623,10 @@ impl<'tc> DiscoveryPass<'tc> { (Some(declaration_name), ast::Pat::Binding(binding_pat)) if declaration_name.borrow().name == Some(binding_pat.name.ident) => { - found_members - .push((node.id(), Decl { name: declaration_name, ty: None, value: None })) + found_members.push(( + node.id(), + ContextMember { name: declaration_name, ty: None, value: None }, + )) } _ => self.add_stack_members_in_pat_to_buf(node, &mut found_members), } diff --git a/compiler/hash-semantics/src/passes/resolution/exprs.rs b/compiler/hash-semantics/src/passes/resolution/exprs.rs index f75d0c148..f125ff08a 100644 --- a/compiler/hash-semantics/src/passes/resolution/exprs.rs +++ b/compiler/hash-semantics/src/passes/resolution/exprs.rs @@ -27,12 +27,11 @@ use hash_tir::{ environment::env::AccessToEnv, fns::{CallTerm, FnBody, FnDefId}, lits::{CharLit, FloatLit, IntLit, Lit, StrLit}, - node::{Node, NodeOrigin}, + node::{Node, NodeId, NodeOrigin}, params::ParamIndex, primitives::primitives, refs::{DerefTerm, RefKind, RefTerm}, - scopes::{AssignTerm, BlockTerm, DeclTerm, Stack}, - term_as_variant, + scopes::{AssignTerm, BlockStatement, BlockTerm, Decl}, terms::{Term, TermId, Ty, TyOfTerm, UnsafeTerm}, tuples::TupleTerm, }; @@ -139,9 +138,6 @@ impl<'tc> ResolutionPass<'tc> { ast::Expr::Macro(invocation) => { self.make_term_from_ast_macro_invocation_expr(node.with_body(invocation))? } - ast::Expr::Declaration(declaration) => { - self.make_term_from_ast_stack_declaration(node.with_body(declaration))? - } ast::Expr::Ref(ref_expr) => { self.make_term_from_ast_ref_expr(node.with_body(ref_expr))? } @@ -201,6 +197,7 @@ impl<'tc> ResolutionPass<'tc> { // No-ops (not supported or handled earlier): ast::Expr::TraitDef(_) | ast::Expr::MergeDeclaration(_) + | ast::Expr::Declaration(_) | ast::Expr::ImplDef(_) | ast::Expr::TraitImpl(_) => Term::unit(NodeOrigin::Given(node.id())), @@ -445,10 +442,10 @@ impl<'tc> ResolutionPass<'tc> { } /// Make a term from an [`ast::Declaration`] in non-constant scope. - fn make_term_from_ast_stack_declaration( + fn make_decl_from_ast_declaration( &self, node: AstNodeRef, - ) -> SemanticResult { + ) -> SemanticResult> { self.scoping().register_declaration(node); // Pattern @@ -471,10 +468,9 @@ impl<'tc> ResolutionPass<'tc> { }; match (pat, ty, value) { - (Some(pat), Some(ty), Some(value)) => Ok(Term::from( - Term::Decl(DeclTerm { bind_pat: pat, ty, value }), - NodeOrigin::Given(node.id()), - )), + (Some(pat), Some(ty), Some(value)) => { + Ok(Node::at(Decl { bind_pat: pat, ty, value }, NodeOrigin::Given(node.id()))) + } _ => { // If pat had an error, then we can't make a term, and the // error will have been added already. @@ -706,54 +702,63 @@ impl<'tc> ResolutionPass<'tc> { self.scoping().add_mod_members(local_mod_def); } - // Traverse the statements and the end expression - let statements = node + // Traverse the statements: + let mut statements = node .statements .iter() .filter(|statement| !mod_member_ids.contains(&statement.id())) .filter_map(|statement| { if let ast::Expr::Declaration(declaration) = statement.body() { + // Handle declarations using `BlockStatement::Decl` self.scoping().register_declaration(node.with_body(declaration)); + let decl = + self.try_or_add_error(self.make_decl_from_ast_declaration( + statement.with_body(declaration), + ))?; + Some(decl.with_data(BlockStatement::Decl(decl.data))) + } else { + // Everything else is `BlockStatement::Expr` + let expr = self.try_or_add_error( + self.make_term_from_ast_expr(statement.ast_ref()), + )?; + Some(Node::at(BlockStatement::Expr(expr), expr.origin())) } - self.try_or_add_error(self.make_term_from_ast_expr(statement.ast_ref())) }) .collect_vec(); - let expr = node.expr.as_ref().and_then(|expr| { - if mod_member_ids.contains(&expr.id()) { - None - } else { - Some({ - if let ast::Expr::Declaration(declaration) = expr.body() { - self.scoping().register_declaration(node.with_body(declaration)); - } + // If an expression is given, use it as the returning expression, and otherwise + // use a unit `()` as the returning expression. + let total_origin = NodeOrigin::Given(node.id()); + let empty_expr = || Term::unit(total_origin); + let expr = match node.expr.as_ref() { + Some(expr) => { + if mod_member_ids.contains(&expr.id()) { + Some(empty_expr()) + } else if let ast::Expr::Declaration(declaration) = expr.body() { + self.try_or_add_error( + self.make_decl_from_ast_declaration(expr.with_body(declaration)), + ) + .map(|decl| { + statements.push(decl.with_data(BlockStatement::Decl(decl.data))); + empty_expr() + }) + } else { self.try_or_add_error(self.make_term_from_ast_expr(expr.ast_ref())) - }) + } } - }); + None => Some(empty_expr()), + }; // If all ok, create a block term - match ( - expr, - statements.len() - == (node.statements.len().saturating_sub(mod_member_ids.len())), - ) { - (Some(Some(expr)), true) => { - let statements = - Node::create_at(TermId::seq(statements), NodeOrigin::Given(node.id())); - Ok(Term::from( - Term::Block(BlockTerm { statements, return_value: expr, stack_id }), - NodeOrigin::Given(node.id()), - )) - } - (None, true) => { + match expr { + Some(expr) => { let statements = - Node::create_at(TermId::seq(statements), NodeOrigin::Given(node.id())); - let return_value = Term::unit(NodeOrigin::Given(node.id())); - Ok(Term::from( - Term::Block(BlockTerm { statements, return_value, stack_id }), + Node::create_at(Node::seq(statements), NodeOrigin::Given(node.id())); + let result = Term::from( + Term::Block(BlockTerm { statements, expr, stack_id }), NodeOrigin::Given(node.id()), - )) + ); + Ok(result) } _ => Err(SemanticError::Signal), } @@ -772,28 +777,8 @@ impl<'tc> ResolutionPass<'tc> { &self, node: AstNodeRef, ) -> SemanticResult { - let inner = match node.contents.body() { - ast::Block::Body(body_block) => { - self.make_term_from_ast_body_block(node.contents.with_body(body_block))? - } - inner => Term::from( - BlockTerm { - return_value: self.make_term_from_ast_block(node.contents.with_body(inner))?, - statements: Node::create_at( - TermId::empty_seq(), - NodeOrigin::Given(node.contents.id()), - ), - stack_id: Stack::empty(NodeOrigin::Given(node.contents.id())), - }, - NodeOrigin::Given(node.contents.id()), - ), - }; - - let block = term_as_variant!(self, inner.value(), Block); - Ok(Term::from( - Term::Loop(LoopTerm { block: inner.value().with_data(block) }), - NodeOrigin::Given(node.id()), - )) + let inner = self.make_term_from_ast_block(node.contents.ast_ref())?; + Ok(Term::from(Term::Loop(LoopTerm { inner }), NodeOrigin::Given(node.id()))) } /// Make a term from an [`ast::Block`]. diff --git a/compiler/hash-storage/src/store/statics.rs b/compiler/hash-storage/src/store/statics.rs index 9ebd21f34..89f1524b1 100644 --- a/compiler/hash-storage/src/store/statics.rs +++ b/compiler/hash-storage/src/store/statics.rs @@ -542,17 +542,17 @@ macro_rules! static_sequence_store_direct { } fn map(self, f: impl FnOnce(&Self::ValueRef) -> R) -> R { - use $crate::store::Store; + use $crate::store::SequenceStore; $store_source.$store_name().map_fast(self.0, |v| f(&v[self.1])) } fn modify(self, f: impl FnOnce(&mut Self::ValueRef) -> R) -> R { - use $crate::store::Store; + use $crate::store::SequenceStore; $store_source.$store_name().modify_fast(self.0, |v| f(&mut v[self.1])) } fn set(self, value: Self::Value) { - use $crate::store::Store; + use $crate::store::SequenceStore; $store_source.$store_name().set_at_index(self.0, self.1, value); } } diff --git a/compiler/hash-tir/src/args.rs b/compiler/hash-tir/src/args.rs index 4ba9b442e..516e9f8f9 100644 --- a/compiler/hash-tir/src/args.rs +++ b/compiler/hash-tir/src/args.rs @@ -5,7 +5,7 @@ use std::{fmt::Debug, option::Option}; use hash_storage::store::{ statics::{SequenceStoreValue, SingleStoreValue, StoreId}, - SequenceStore, SequenceStoreKey, TrivialSequenceStoreKey, + SequenceStoreKey, TrivialSequenceStoreKey, }; use hash_utils::{derive_more::From, itertools::Itertools}; diff --git a/compiler/hash-tir/src/context.rs b/compiler/hash-tir/src/context.rs index adaedb00b..c232ce907 100644 --- a/compiler/hash-tir/src/context.rs +++ b/compiler/hash-tir/src/context.rs @@ -30,7 +30,7 @@ use crate::{ /// A binding that contains a type and optional value. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct Decl { +pub struct ContextMember { pub name: SymbolId, pub ty: Option, pub value: Option, @@ -70,7 +70,7 @@ pub struct Scope { /// The kind of the scope. pub kind: ScopeKind, /// The bindings of the scope - pub decls: RefCell>, + pub decls: RefCell>, } impl Scope { @@ -80,19 +80,23 @@ impl Scope { } /// Add a binding to the scope. - pub fn add_decl(&self, decl: Decl) { + pub fn add_decl(&self, decl: ContextMember) { self.decls.borrow_mut().insert(decl.name, decl); } /// Get the decl corresponding to the given symbol. - pub fn get_decl(&self, symbol: SymbolId) -> Option { + pub fn get_decl(&self, symbol: SymbolId) -> Option { self.decls.borrow().get(&symbol).copied() } /// Set an existing decl kind of the given symbol. /// /// Returns `true` if the decl was found and updated, `false` otherwise. - pub fn set_existing_decl(&self, symbol: SymbolId, f: &impl Fn(Decl) -> Decl) -> bool { + pub fn set_existing_decl( + &self, + symbol: SymbolId, + f: &impl Fn(ContextMember) -> ContextMember, + ) -> bool { if let Some(old) = self.get_decl(symbol) { self.decls.borrow_mut().insert(symbol, f(old)); true @@ -164,25 +168,25 @@ impl Context { /// Add a new decl to the current scope context. pub fn add_decl(&self, name: SymbolId, ty: Option, value: Option) { - self.get_current_scope_ref().add_decl(Decl { name, ty, value }) + self.get_current_scope_ref().add_decl(ContextMember { name, ty, value }) } /// Get a decl from the context, reading all accessible scopes. - pub fn try_get_decl(&self, name: SymbolId) -> Option { + pub fn try_get_decl(&self, name: SymbolId) -> Option { self.scopes.borrow().iter().rev().find_map(|scope| scope.get_decl(name)) } /// Get a decl from the context, reading all accessible scopes. /// /// Panics if the decl doesn't exist. - pub fn get_decl(&self, name: SymbolId) -> Decl { + pub fn get_decl(&self, name: SymbolId) -> ContextMember { self.try_get_decl(name) .unwrap_or_else(|| panic!("cannot find a declaration with name {}", name)) } /// Modify a decl in the context, with a function that takes the current /// decl kind and returns the new decl kind. - pub fn modify_decl_with(&self, name: SymbolId, f: impl Fn(Decl) -> Decl) { + pub fn modify_decl_with(&self, name: SymbolId, f: impl Fn(ContextMember) -> ContextMember) { let _ = self .scopes .borrow() @@ -193,7 +197,7 @@ impl Context { } /// Modify a decl in the context. - pub fn modify_decl(&self, decl: Decl) { + pub fn modify_decl(&self, decl: ContextMember) { self.modify_decl_with(decl.name, |_| decl); } @@ -254,7 +258,7 @@ impl Context { pub fn try_for_decls_of_scope_rev( &self, scope_index: usize, - mut f: impl FnMut(&Decl) -> Result<(), E>, + mut f: impl FnMut(&ContextMember) -> Result<(), E>, ) -> Result<(), E> { self.scopes.borrow()[scope_index] .decls @@ -269,7 +273,7 @@ impl Context { pub fn try_for_decls_of_scope( &self, scope_index: usize, - mut f: impl FnMut(&Decl) -> Result<(), E>, + mut f: impl FnMut(&ContextMember) -> Result<(), E>, ) -> Result<(), E> { self.scopes.borrow()[scope_index].decls.borrow().iter().try_for_each(|(_, decl)| f(decl)) } @@ -282,7 +286,7 @@ impl Context { /// Iterate over all the decls in the context for the scope with the /// given index (reversed). - pub fn for_decls_of_scope_rev(&self, scope_index: usize, mut f: impl FnMut(&Decl)) { + pub fn for_decls_of_scope_rev(&self, scope_index: usize, mut f: impl FnMut(&ContextMember)) { let _ = self.try_for_decls_of_scope_rev(scope_index, |decl| -> Result<(), Infallible> { f(decl); Ok(()) @@ -291,7 +295,7 @@ impl Context { /// Iterate over all the decls in the context for the scope with the /// given index. - pub fn for_decls_of_scope(&self, scope_index: usize, mut f: impl FnMut(&Decl)) { + pub fn for_decls_of_scope(&self, scope_index: usize, mut f: impl FnMut(&ContextMember)) { let _ = self.try_for_decls_of_scope(scope_index, |decl| -> Result<(), Infallible> { f(decl); Ok(()) @@ -340,12 +344,20 @@ impl Context { /// Add a typing binding to the closest stack scope. pub fn add_assignment_to_closest_stack(&self, name: SymbolId, ty: TyId, value: TermId) { - self.get_closest_stack_scope_ref().add_decl(Decl { name, ty: Some(ty), value: Some(value) }) + self.get_closest_stack_scope_ref().add_decl(ContextMember { + name, + ty: Some(ty), + value: Some(value), + }) } /// Add a typing binding to the closest stack scope. pub fn add_typing_to_closest_stack(&self, name: SymbolId, ty: TyId) { - self.get_closest_stack_scope_ref().add_decl(Decl { name, ty: Some(ty), value: None }) + self.get_closest_stack_scope_ref().add_decl(ContextMember { + name, + ty: Some(ty), + value: None, + }) } /// Add a typing binding. @@ -361,13 +373,13 @@ impl Context { /// Modify the type of an assignment binding. pub fn modify_typing(&self, name: SymbolId, new_ty: TyId) { let current_value = self.try_get_decl_value(name); - self.modify_decl(Decl { name, ty: Some(new_ty), value: current_value }) + self.modify_decl(ContextMember { name, ty: Some(new_ty), value: current_value }) } /// Modify the value of an assignment binding. pub fn modify_assignment(&self, name: SymbolId, new_value: TermId) { let current_ty = self.try_get_decl_ty(name); - self.modify_decl(Decl { name, ty: current_ty, value: Some(new_value) }) + self.modify_decl(ContextMember { name, ty: current_ty, value: Some(new_value) }) } /// Add parameter bindings from the given parameters. @@ -529,7 +541,7 @@ impl fmt::Display for EqualityJudgement { } } -impl fmt::Display for Decl { +impl fmt::Display for ContextMember { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let ty_or_unknown = { if let Some(ty) = self.ty { diff --git a/compiler/hash-tir/src/control.rs b/compiler/hash-tir/src/control.rs index 23a68bf3d..0d25902f9 100644 --- a/compiler/hash-tir/src/control.rs +++ b/compiler/hash-tir/src/control.rs @@ -4,9 +4,7 @@ use core::fmt; use std::fmt::Debug; use hash_ast::ast::MatchOrigin; -use hash_storage::store::{ - statics::StoreId, SequenceStore, SequenceStoreKey, TrivialSequenceStoreKey, -}; +use hash_storage::store::{statics::StoreId, SequenceStoreKey, TrivialSequenceStoreKey}; use textwrap::indent; use super::{ @@ -14,20 +12,17 @@ use super::{ scopes::StackId, terms::Term, }; -use crate::{ - environment::stores::tir_stores, node::Node, scopes::BlockTerm, terms::TermId, - tir_node_sequence_store_direct, -}; +use crate::{environment::stores::tir_stores, terms::TermId, tir_node_sequence_store_direct}; /// A loop term. /// -/// Contains a block. +/// Contains an inner term which should produce side-effects. /// /// The type of a loop is `void`, unless it can be proven to never terminate (in /// which case it is `never`). #[derive(Debug, Clone, Copy)] pub struct LoopTerm { - pub block: Node, + pub inner: TermId, } /// A match term. @@ -106,7 +101,7 @@ pub struct OrPat { impl fmt::Display for LoopTerm { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "loop {}", &*self.block) + write!(f, "loop {}", self.inner) } } diff --git a/compiler/hash-tir/src/data.rs b/compiler/hash-tir/src/data.rs index 803a3e4ce..bb0a35fa5 100644 --- a/compiler/hash-tir/src/data.rs +++ b/compiler/hash-tir/src/data.rs @@ -7,7 +7,7 @@ use hash_storage::{ get, store::{ statics::{SequenceStoreValue, SingleStoreValue, StoreId}, - SequenceStore, SequenceStoreKey, TrivialSequenceStoreKey, + SequenceStoreKey, TrivialSequenceStoreKey, }, }; use hash_utils::itertools::Itertools; diff --git a/compiler/hash-tir/src/environment/stores.rs b/compiler/hash-tir/src/environment/stores.rs index cce5b8b24..9330a2775 100644 --- a/compiler/hash-tir/src/environment/stores.rs +++ b/compiler/hash-tir/src/environment/stores.rs @@ -15,7 +15,7 @@ use crate::{ mods::{ModDefStore, ModMembersSeqStore, ModMembersStore}, params::{ParamsSeqStore, ParamsStore}, pats::{PatListSeqStore, PatListStore, PatStore}, - scopes::StackStore, + scopes::{BlockStatementsSeqStore, BlockStatementsStore, StackStore}, symbols::SymbolStore, terms::{TermListSeqStore, TermListStore, TermStore}, }; @@ -50,6 +50,8 @@ stores! { term_list_seq: TermListSeqStore, match_cases: MatchCasesStore, match_cases_seq: MatchCasesSeqStore, + block_statements: BlockStatementsStore, + block_statements_seq: BlockStatementsSeqStore, atom_info: AtomInfoStore, } diff --git a/compiler/hash-tir/src/lits.rs b/compiler/hash-tir/src/lits.rs index 34f1add7f..ee2534cc7 100644 --- a/compiler/hash-tir/src/lits.rs +++ b/compiler/hash-tir/src/lits.rs @@ -324,20 +324,27 @@ impl Display for LitPat { // and `MAX` for these situations since it is easier for the // user to understand the problem. Lit::Int(lit) => { - let value = lit.value.value(); - let kind = value.map(|constant| constant.ty()); - - // ##Hack: we don't use size since it is never invoked because of - // integer constant don't store usize values. - let dummy_size = Size::ZERO; - let value = value.map(|constant| constant.value.as_u128()); - - if kind.numeric_min(dummy_size) == value { - write!(f, "{kind}::MIN") - } else if kind.numeric_max(dummy_size) == value { - write!(f, "{kind}::MAX") - } else { - write!(f, "{lit}") + match lit.value { + LitValue::Raw(_) => { + // Defer to display impl for `Lit` below. + write!(f, "{lit}") + } + LitValue::Value(value) => { + let kind = value.map(|constant| constant.ty()); + + // ##Hack: we don't use size since it is never invoked because of + // integer constant don't store usize values. + let dummy_size = Size::ZERO; + let value = value.map(|constant| constant.value.as_u128()); + + if kind.numeric_min(dummy_size) == value { + write!(f, "{kind}::MIN") + } else if kind.numeric_max(dummy_size) == value { + write!(f, "{kind}::MAX") + } else { + write!(f, "{lit}") + } + } } } Lit::Str(lit) => write!(f, "{lit}"), diff --git a/compiler/hash-tir/src/mods.rs b/compiler/hash-tir/src/mods.rs index 535253558..c9cc82b4d 100644 --- a/compiler/hash-tir/src/mods.rs +++ b/compiler/hash-tir/src/mods.rs @@ -5,7 +5,7 @@ use std::{fmt::Display, path::Path}; use hash_source::{identifier::Identifier, SourceId}; use hash_storage::{ get, - store::{statics::StoreId, SequenceStore, Store, StoreKey, TrivialSequenceStoreKey}, + store::{statics::StoreId, Store, StoreKey, TrivialSequenceStoreKey}, }; use textwrap::indent; use utility_types::omit; diff --git a/compiler/hash-tir/src/params.rs b/compiler/hash-tir/src/params.rs index d69aa9a17..d35340ca7 100644 --- a/compiler/hash-tir/src/params.rs +++ b/compiler/hash-tir/src/params.rs @@ -5,7 +5,7 @@ use std::{fmt::Debug, option::Option}; use hash_source::identifier::Identifier; use hash_storage::store::{ statics::{SequenceStoreValue, SingleStoreValue, StoreId}, - SequenceStore, SequenceStoreKey, TrivialSequenceStoreKey, + SequenceStoreKey, TrivialSequenceStoreKey, }; use hash_utils::{derive_more::From, itertools::Itertools}; @@ -258,8 +258,11 @@ impl fmt::Display for Param { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "{}: {}{}", - self.name, + "{}{}{}", + match self.name.value().name { + Some(name) => format!("{}: ", name), + None => "".to_string(), + }, self.ty, if let Some(default) = self.default { format!(" = {}", default) diff --git a/compiler/hash-tir/src/scopes.rs b/compiler/hash-tir/src/scopes.rs index 7c00ceec0..d96fac259 100644 --- a/compiler/hash-tir/src/scopes.rs +++ b/compiler/hash-tir/src/scopes.rs @@ -18,14 +18,14 @@ use utility_types::omit; use super::{pats::Pat, terms::Term}; use crate::{ - context::Decl, + context::ContextMember, environment::stores::tir_stores, mods::ModDefId, node::{Node, NodeOrigin}, pats::PatId, symbols::SymbolId, - terms::{TermId, TermListId, TyId}, - tir_node_single_store, + terms::{TermId, TyId}, + tir_node_sequence_store_direct, tir_node_single_store, }; /// A binding pattern, which is essentially a declaration left-hand side. @@ -76,7 +76,7 @@ impl StackIndices { /// Depending on the `bind_pat` used, this can be used to declare a single or /// multiple variables. #[derive(Debug, Clone, Copy)] -pub struct DeclTerm { +pub struct Decl { pub bind_pat: PatId, pub ty: TyId, pub value: Option, @@ -102,7 +102,7 @@ pub struct StackMember { /// A stack, which is a list of stack members. #[derive(Debug, Clone)] pub struct Stack { - pub members: Vec, + pub members: Vec, /// Local module definition containing members that are defined in this /// stack. pub local_mod_def: Option, @@ -122,10 +122,10 @@ pub struct StackMemberId(pub StackId, pub usize); impl SingleStoreId for StackMemberId {} impl StoreId for StackMemberId { - type Value = Decl; - type ValueRef = Decl; - type ValueBorrow = MappedRwLockReadGuard<'static, Decl>; - type ValueBorrowMut = MappedRwLockWriteGuard<'static, Decl>; + type Value = ContextMember; + type ValueRef = ContextMember; + type ValueBorrow = MappedRwLockReadGuard<'static, ContextMember>; + type ValueBorrowMut = MappedRwLockWriteGuard<'static, ContextMember>; fn borrow(self) -> Self::ValueBorrow { MappedRwLockReadGuard::map(self.0.borrow(), |stack| &stack.members[self.1]) @@ -158,17 +158,28 @@ impl StoreId for StackMemberId { #[derive(Debug, Clone, Copy)] pub struct BlockTerm { pub stack_id: StackId, // The associated stack ID for this block. - pub statements: TermListId, - pub return_value: TermId, + pub statements: BlockStatementsId, + pub expr: TermId, } +/// A statement in a block. +/// +/// This is either an expression, or a declaration. +#[derive(Debug, Clone, Copy)] +pub enum BlockStatement { + Decl(Decl), + Expr(TermId), +} + +tir_node_sequence_store_direct!(BlockStatement); + impl fmt::Display for BindingPat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}{}", if self.is_mutable { "mut " } else { "" }, self.name) } } -impl fmt::Display for DeclTerm { +impl fmt::Display for Decl { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let value = match self.value { Some(term_id) => { @@ -239,6 +250,34 @@ impl fmt::Display for StackId { } } +impl fmt::Display for BlockStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BlockStatement::Decl(d) => { + write!(f, "{}", d) + } + BlockStatement::Expr(e) => { + write!(f, "{}", e) + } + } + } +} + +impl fmt::Display for BlockStatementId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", *self.value()) + } +} + +impl fmt::Display for BlockStatementsId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for term in self.iter() { + writeln!(f, "{};", term)?; + } + Ok(()) + } +} + impl fmt::Display for BlockTerm { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "{{")?; @@ -251,13 +290,8 @@ impl fmt::Display for BlockTerm { write!(f, "{}", indent(&members, " "))?; } - for term in self.statements.iter() { - let term = term.to_string(); - writeln!(f, "{};", indent(&term, " "))?; - } - - let return_value = (self.return_value).to_string(); - writeln!(f, "{}", indent(&return_value, " "))?; + write!(f, "{}", indent(&self.statements.to_string(), " "))?; + writeln!(f, "{}", indent(&self.expr.to_string(), " "))?; write!(f, "}}") } diff --git a/compiler/hash-tir/src/terms.rs b/compiler/hash-tir/src/terms.rs index 16729b120..0d8d4ad07 100644 --- a/compiler/hash-tir/src/terms.rs +++ b/compiler/hash-tir/src/terms.rs @@ -23,7 +23,7 @@ use crate::{ params::Param, primitives::primitives, refs::{DerefTerm, RefTerm, RefTy}, - scopes::{AssignTerm, BlockTerm, DeclTerm}, + scopes::{AssignTerm, BlockTerm}, tir_node_sequence_store_indirect, tir_node_single_store, tuples::{TupleTerm, TupleTy}, utils::traversing::Atom, @@ -79,8 +79,7 @@ pub enum Term { Match(MatchTerm), Return(ReturnTerm), - // Declarations and assignments - Decl(DeclTerm), + // Assignments Assign(AssignTerm), // Unsafe @@ -251,9 +250,6 @@ impl fmt::Display for Term { } Term::Match(match_term) => write!(f, "{}", match_term), Term::Return(return_term) => write!(f, "{}", return_term), - Term::Decl(decl_stack_member_term) => { - write!(f, "{}", decl_stack_member_term) - } Term::Assign(assign_term) => write!(f, "{}", assign_term), Term::Unsafe(unsafe_term) => write!(f, "{}", unsafe_term), Term::Access(access_term) => write!(f, "{}", access_term), diff --git a/compiler/hash-tir/src/utils/traversing.rs b/compiler/hash-tir/src/utils/traversing.rs index 9c3e4ea74..c790e943d 100644 --- a/compiler/hash-tir/src/utils/traversing.rs +++ b/compiler/hash-tir/src/utils/traversing.rs @@ -23,7 +23,7 @@ use crate::{ params::{Param, ParamsId}, pats::{Pat, PatId, PatListId}, refs::{DerefTerm, RefTerm, RefTy}, - scopes::{AssignTerm, BlockTerm, DeclTerm}, + scopes::{AssignTerm, BlockStatement, BlockStatementsId, BlockTerm, Decl}, terms::{Term, TermId, TermListId, Ty, TyOfTerm, UnsafeTerm}, tuples::{TuplePat, TupleTerm, TupleTy}, }; @@ -149,30 +149,17 @@ impl TraversingUtils { Ok(Term::from(Term::Fn(fn_def_id), origin)) } Term::Block(block_term) => { - let statements = self.fmap_term_list(block_term.statements, f)?; - let return_value = self.fmap_term(block_term.return_value, f)?; + let statements = self.fmap_block_statements(block_term.statements, f)?; + let expr = self.fmap_term(block_term.expr, f)?; Ok(Term::from( - BlockTerm { statements, return_value, stack_id: block_term.stack_id }, + BlockTerm { statements, stack_id: block_term.stack_id, expr }, origin, )) } Term::Var(var_term) => Ok(Term::from(var_term, origin)), Term::Loop(loop_term) => { - let statements = self.fmap_term_list(loop_term.block.statements, f)?; - let return_value = self.fmap_term(loop_term.block.return_value, f)?; - Ok(Term::from( - LoopTerm { - block: Node::at( - BlockTerm { - statements, - return_value, - stack_id: loop_term.block.stack_id, - }, - loop_term.block.origin, - ), - }, - origin, - )) + let inner = self.fmap_term(loop_term.inner, f)?; + Ok(Term::from(LoopTerm { inner }, origin)) } Term::LoopControl(loop_control_term) => Ok(Term::from(loop_control_term, origin)), Term::Match(match_term) => { @@ -207,13 +194,6 @@ impl TraversingUtils { let expression = self.fmap_term(return_term.expression, f)?; Ok(Term::from(ReturnTerm { expression }, origin)) } - Term::Decl(decl_stack_member_term) => { - let bind_pat = self.fmap_pat(decl_stack_member_term.bind_pat, f)?; - let ty = self.fmap_term(decl_stack_member_term.ty, f)?; - let value = - decl_stack_member_term.value.map(|v| self.fmap_term(v, f)).transpose()?; - Ok(Term::from(DeclTerm { ty, bind_pat, value }, origin)) - } Term::Assign(assign_term) => { let subject = self.fmap_term(assign_term.subject, f)?; let value = self.fmap_term(assign_term.value, f)?; @@ -336,6 +316,34 @@ impl TraversingUtils { Ok(result) } + pub fn fmap_block_statements>( + &self, + block_statements: BlockStatementsId, + f: F, + ) -> Result { + let mut new_list = Vec::with_capacity(block_statements.len()); + for statement in block_statements.elements().value() { + match *statement { + BlockStatement::Decl(decl) => { + let bind_pat = self.fmap_pat(decl.bind_pat, f)?; + let ty = self.fmap_term(decl.ty, f)?; + let value = decl.value.map(|v| self.fmap_term(v, f)).transpose()?; + new_list.push(Node::at( + BlockStatement::Decl(Decl { ty, bind_pat, value }), + statement.origin, + )); + } + BlockStatement::Expr(expr) => { + new_list.push(Node::at( + BlockStatement::Expr(self.fmap_term(expr, f)?), + statement.origin, + )); + } + } + } + Ok(Node::create_at(Node::seq(new_list), block_statements.origin())) + } + pub fn fmap_term_list>( &self, term_list: TermListId, @@ -491,14 +499,11 @@ impl TraversingUtils { } Term::Fn(fn_def_id) => self.visit_fn_def(fn_def_id, f), Term::Block(block_term) => { - self.visit_term_list(block_term.statements, f)?; - self.visit_term(block_term.return_value, f) + self.visit_block_statements(block_term.statements, f)?; + self.visit_term(block_term.expr, f) } Term::Var(_) => Ok(()), - Term::Loop(loop_term) => { - self.visit_term_list(loop_term.block.statements, f)?; - self.visit_term(loop_term.block.return_value, f) - } + Term::Loop(loop_term) => self.visit_term(loop_term.inner, f), Term::LoopControl(_) => Ok(()), Term::Match(match_term) => { self.visit_term(match_term.subject, f)?; @@ -509,13 +514,6 @@ impl TraversingUtils { Ok(()) } Term::Return(return_term) => self.visit_term(return_term.expression, f), - Term::Decl(decl_stack_member_term) => { - self.visit_pat(decl_stack_member_term.bind_pat, f)?; - self.visit_term(decl_stack_member_term.ty, f)?; - let (Some(()) | None) = - decl_stack_member_term.value.map(|v| self.visit_term(v, f)).transpose()?; - Ok(()) - } Term::Assign(assign_term) => { self.visit_term(assign_term.subject, f)?; self.visit_term(assign_term.value, f) @@ -611,6 +609,26 @@ impl TraversingUtils { Ok(()) } + pub fn visit_block_statements>( + &self, + block_statements: BlockStatementsId, + f: &mut F, + ) -> Result<(), E> { + for statement in block_statements.elements().value() { + match *statement { + BlockStatement::Decl(decl) => { + self.visit_pat(decl.bind_pat, f)?; + self.visit_term(decl.ty, f)?; + decl.value.map(|v| self.visit_term(v, f)).transpose()?; + } + BlockStatement::Expr(expr) => { + self.visit_term(expr, f)?; + } + } + } + Ok(()) + } + pub fn visit_pat_list>( &self, pat_list_id: PatListId, diff --git a/compiler/hash-typecheck/src/inference.rs b/compiler/hash-typecheck/src/inference.rs index e0eb23839..dcea29c0a 100644 --- a/compiler/hash-typecheck/src/inference.rs +++ b/compiler/hash-typecheck/src/inference.rs @@ -34,7 +34,7 @@ use hash_tir::{ pats::{Pat, PatId, PatListId, RangePat, Spread}, primitives::primitives, refs::{DerefTerm, RefTerm, RefTy}, - scopes::{AssignTerm, BlockTerm, DeclTerm}, + scopes::{AssignTerm, BlockStatement, BlockTerm}, sub::Sub, symbols::SymbolId, term_as_variant, @@ -1051,6 +1051,7 @@ impl InferenceOps<'_, T> { self.check_by_unify(fn_ty_id, annotation_ty)?; self.register_atom_inference(fn_def_id, fn_def_id, fn_def.ty); + Ok(()) } @@ -1143,11 +1144,7 @@ impl InferenceOps<'_, T> { original_term_id: TermId, ) -> TcResult<()> { // Forward to the inner term. - self.infer_block_term( - &loop_term.block, - Ty::hole(loop_term.block.origin), - original_term_id, - )?; + self.infer_term(loop_term.inner, Ty::hole(loop_term.inner.origin().inferred()))?; let loop_term = Ty::expect_is(original_term_id, Ty::unit_ty(original_term_id.origin().inferred())); self.check_by_unify(loop_term, annotation_ty)?; @@ -1175,11 +1172,30 @@ impl InferenceOps<'_, T> { let mut diverges = false; for statement in block_term.statements.iter() { - let statement_ty = Ty::hole_for(statement); - self.infer_term(statement, statement_ty)?; + let ty_to_check_divergence = match *statement.value() { + BlockStatement::Decl(decl) => { + self.check_ty(decl.ty)?; + if let Some(value) = decl.value { + self.infer_term(value, decl.ty)?; + }; + self.infer_pat(decl.bind_pat, decl.ty, decl.value)?; + + // Check that the binding pattern of the declaration is irrefutable. + let eck = self.exhaustiveness_checker(decl.bind_pat); + eck.is_pat_irrefutable(&[decl.bind_pat], decl.ty, None); + self.append_exhaustiveness_diagnostics(eck); + + decl.ty + } + BlockStatement::Expr(expr) => { + let statement_ty = Ty::hole_for(expr); + self.infer_term(expr, statement_ty)?; + statement_ty + } + }; // If the statement diverges, we can already exit - if self.uni_ops().is_uninhabitable(statement_ty)? { + if self.uni_ops().is_uninhabitable(ty_to_check_divergence)? { diverges = true; } } @@ -1193,13 +1209,13 @@ impl InferenceOps<'_, T> { } _ => { // Infer the return value - let return_value_ty = Ty::hole_for(block_term.return_value); - self.infer_term(block_term.return_value, return_value_ty)?; + let return_value_ty = Ty::hole_for(block_term.expr); + self.infer_term(block_term.expr, return_value_ty)?; } } } else { // Infer the return value - self.infer_term(block_term.return_value, annotation_ty)?; + self.infer_term(block_term.expr, annotation_ty)?; }; let sub = self.sub_ops().create_sub_from_current_scope(); @@ -1274,28 +1290,6 @@ impl InferenceOps<'_, T> { Ok(()) } - /// Infer a stack declaration term, and return its type. - pub fn infer_decl_term( - &self, - decl_term: &DeclTerm, - annotation_ty: TyId, - original_term_id: TermId, - ) -> TcResult<()> { - self.check_ty(decl_term.ty)?; - if let Some(value) = decl_term.value { - self.infer_term(value, decl_term.ty)?; - }; - self.infer_pat(decl_term.bind_pat, decl_term.ty, decl_term.value)?; - self.check_by_unify(Ty::unit_ty(original_term_id.origin().inferred()), annotation_ty)?; - - // Check that the binding pattern of the declaration is irrefutable. - let eck = self.exhaustiveness_checker(decl_term.bind_pat); - eck.is_pat_irrefutable(&[decl_term.bind_pat], decl_term.ty, None); - self.append_exhaustiveness_diagnostics(eck); - - Ok(()) - } - /// Infer an access term. pub fn infer_access_term( &self, @@ -1554,7 +1548,6 @@ impl InferenceOps<'_, T> { } Term::Ref(ref_term) => self.infer_ref_term(&ref_term, annotation_ty, term_id)?, Term::Cast(cast_term) => self.infer_cast_term(cast_term, annotation_ty)?, - Term::Decl(decl_term) => self.infer_decl_term(&decl_term, annotation_ty, term_id)?, Term::Access(access_term) => { self.infer_access_term(&access_term, annotation_ty, term_id)? } diff --git a/compiler/hash-typecheck/src/normalisation.rs b/compiler/hash-typecheck/src/normalisation.rs index d7fe5a078..e049cd046 100644 --- a/compiler/hash-typecheck/src/normalisation.rs +++ b/compiler/hash-typecheck/src/normalisation.rs @@ -22,7 +22,7 @@ use hash_tir::{ params::ParamIndex, pats::{Pat, PatId, PatListId, RangePat, Spread}, refs::DerefTerm, - scopes::{AssignTerm, BlockTerm, DeclTerm}, + scopes::{AssignTerm, BlockStatement, BlockTerm}, symbols::SymbolId, terms::{Term, TermId, TermListId, Ty, TyId, TyOfTerm, UnsafeTerm}, tuples::TupleTerm, @@ -294,7 +294,6 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { | Term::Tuple(_) | Term::Var(_) | Term::Match(_) - | Term::Decl(_) | Term::Unsafe(_) | Term::Access(_) | Term::Array(_) @@ -441,11 +440,46 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { let st = eval_state(); for statement in block_term.statements.iter() { - let _ = self.eval_and_record(statement.into(), &st)?; + match *statement.value() { + BlockStatement::Decl(mut decl_term) => { + decl_term.value = decl_term + .value + .map(|v| -> Result<_, Signal> { + Ok(self.to_term(self.eval_nested_and_record(v.into(), &st)?)) + }) + .transpose()?; + + match decl_term.value { + Some(value) => match self.match_value_and_get_binds( + value, + decl_term.bind_pat, + &mut |name, term_id| { + self.context().add_untyped_assignment(name, term_id) + }, + )? { + MatchResult::Successful => { + // All good + } + MatchResult::Failed => { + panic!("Non-exhaustive let-binding: {}", decl_term) + } + MatchResult::Stuck => { + info!("Stuck evaluating let-binding: {}", decl_term); + } + }, + None => { + panic!("Let binding with no value: {}", decl_term) + } + } + } + BlockStatement::Expr(expr) => { + let _ = self.eval_and_record(expr.into(), &st)?; + } + } } let sub = self.sub_ops().create_sub_from_current_scope(); - let result_term = self.eval_and_record(block_term.return_value.into(), &st)?; + let result_term = self.eval_and_record(block_term.expr.into(), &st)?; let subbed_result_term = self.sub_ops().apply_sub_to_atom(result_term, &sub); evaluation_to(subbed_result_term) @@ -643,40 +677,6 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { panic!("Non-exhaustive match: {}", &match_term) } - /// Evaluate a declaration term. - fn eval_decl(&self, mut decl_term: Node) -> AtomEvaluation { - let st = eval_state(); - decl_term.value = decl_term - .value - .map(|v| -> Result<_, Signal> { - Ok(self.to_term(self.eval_nested_and_record(v.into(), &st)?)) - }) - .transpose()?; - - match decl_term.value { - Some(value) => match self.match_value_and_get_binds( - value, - decl_term.bind_pat, - &mut |name, term_id| self.context().add_untyped_assignment(name, term_id), - )? { - MatchResult::Successful => { - // All good - evaluation_to(Term::unit(decl_term.origin.computed())) - } - MatchResult::Failed => { - panic!("Non-exhaustive let-binding: {}", &*decl_term) - } - MatchResult::Stuck => { - info!("Stuck evaluating let-binding: {}", &*decl_term); - evaluation_if(|| Term::from(*decl_term, decl_term.origin.computed()), &st) - } - }, - None => { - panic!("Let binding with no value: {}", &*decl_term) - } - } - } - /// Evaluate a `return` term. fn eval_return(&self, return_term: ReturnTerm) -> Result { let normalised = self.eval(return_term.expression.into())?; @@ -686,7 +686,7 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { /// Evaluate a `loop` term. fn eval_loop(&self, loop_term: Node) -> FullEvaluation { loop { - match self.eval_block(*loop_term.block) { + match self.eval(loop_term.inner.into()) { Ok(_) | Err(Signal::Continue) => continue, Err(Signal::Break) => break, Err(e) => return Err(e), @@ -868,9 +868,6 @@ impl<'tc, T: AccessToTypechecking> NormalisationOps<'tc, T> { Term::Assign(assign_term) => { ctrl_map_full(self.eval_assign(term.origin().with_data(assign_term))) } - Term::Decl(decl_term) => { - ctrl_map(self.eval_decl(term.origin().with_data(decl_term))) - } Term::Return(return_expr) => self.eval_return(return_expr)?, Term::Block(block_term) => ctrl_map(self.eval_block(block_term)), Term::Loop(loop_term) => { diff --git a/compiler/hash-typecheck/src/substitution.rs b/compiler/hash-typecheck/src/substitution.rs index 906890b48..099083f52 100644 --- a/compiler/hash-typecheck/src/substitution.rs +++ b/compiler/hash-typecheck/src/substitution.rs @@ -7,7 +7,7 @@ use hash_tir::{ access::AccessTerm, args::{ArgsId, PatArgsId}, atom_info::ItemInAtomInfo, - context::Decl, + context::ContextMember, fns::FnBody, holes::Hole, mods::ModDefId, @@ -375,7 +375,7 @@ impl<'a, T: AccessToTypechecking> SubstitutionOps<'a, T> { let _current_scope_index = self.context().get_current_scope_index(); match self.context().get_current_scope_ref().get_decl(var) { Some(var) => { - matches!(var, Decl { value: None, .. }) + matches!(var, ContextMember { value: None, .. }) } None => { warn!("Not found var {} in current scope", var); From 166512d519a89069f42250ac5647dbebdc9097b4 Mon Sep 17 00:00:00 2001 From: Alexander Fedotov Date: Mon, 11 Sep 2023 12:37:10 -0400 Subject: [PATCH 3/4] lower: fix loop lowering for new TIR `LoopTerm` structure --- compiler/hash-lower/src/build/block.rs | 1 + compiler/hash-lower/src/build/mod.rs | 9 ++++++--- .../property_access/out_of_bounds_access.stderr | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/compiler/hash-lower/src/build/block.rs b/compiler/hash-lower/src/build/block.rs index 14b681bf2..2378fb08c 100644 --- a/compiler/hash-lower/src/build/block.rs +++ b/compiler/hash-lower/src/build/block.rs @@ -52,6 +52,7 @@ impl<'tcx> BodyBuilder<'tcx> { this.control_flow_graph.goto(body_block_end, loop_body, span); } + this.reached_terminator = false; next_block.unit() }) } diff --git a/compiler/hash-lower/src/build/mod.rs b/compiler/hash-lower/src/build/mod.rs index 471ea8b15..f3cc5ff88 100644 --- a/compiler/hash-lower/src/build/mod.rs +++ b/compiler/hash-lower/src/build/mod.rs @@ -172,9 +172,12 @@ pub(crate) struct BodyBuilder<'tcx> { /// after a block terminator. loop_block_info: Option, - /// If the current [terms::BlockTerm] has reached a terminating statement, - /// i.e. a statement that is typed as `!`. Examples of such statements - /// are `return`, `break`, `continue`, etc. + /// If the lowerer has reached a terminating statement within some block, + /// meaning that further statements do not require to be lowered. + /// + /// A statement that is typed as `!`. Examples of such statements + /// are `return`, `break`, `continue`, or expressions that are of type + /// `!`. reached_terminator: bool, /// A temporary [Place] that is used to throw away results from expressions diff --git a/tests/cases/typecheck/property_access/out_of_bounds_access.stderr b/tests/cases/typecheck/property_access/out_of_bounds_access.stderr index 28bb2b59f..4020583cc 100644 --- a/tests/cases/typecheck/property_access/out_of_bounds_access.stderr +++ b/tests/cases/typecheck/property_access/out_of_bounds_access.stderr @@ -1,6 +1,6 @@ -error[0016]: property `2` not found on type `(s515: i32, s517: i32)` +error[0016]: property `2` not found on type `(i32, i32)` --> $DIR/out_of_bounds_access.hash:4:3 3 | t := (1, 2) 4 | t.2 - | ^ term has type `(s515: i32, s517: i32)`. Property `2` is not present on this type + | ^ term has type `(i32, i32)`. Property `2` is not present on this type 5 | } From 3090e40becc7211606e36bb841156a6ebbd43911 Mon Sep 17 00:00:00 2001 From: Constantine Theocharis Date: Mon, 11 Sep 2023 20:33:57 +0000 Subject: [PATCH 4/4] Fix cargo fmt --- compiler/hash-lower/src/build/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compiler/hash-lower/src/build/mod.rs b/compiler/hash-lower/src/build/mod.rs index f3cc5ff88..176f1d4e7 100644 --- a/compiler/hash-lower/src/build/mod.rs +++ b/compiler/hash-lower/src/build/mod.rs @@ -172,11 +172,11 @@ pub(crate) struct BodyBuilder<'tcx> { /// after a block terminator. loop_block_info: Option, - /// If the lowerer has reached a terminating statement within some block, + /// If the lowerer has reached a terminating statement within some block, /// meaning that further statements do not require to be lowered. - /// + /// /// A statement that is typed as `!`. Examples of such statements - /// are `return`, `break`, `continue`, or expressions that are of type + /// are `return`, `break`, `continue`, or expressions that are of type /// `!`. reached_terminator: bool,