Skip to content

Commit

Permalink
ast: introduce Params and replace AstNodes<Params> with it
Browse files Browse the repository at this point in the history
  • Loading branch information
feds01 committed Aug 31, 2023
1 parent a19b420 commit 7ff7ee7
Show file tree
Hide file tree
Showing 14 changed files with 238 additions and 293 deletions.
77 changes: 26 additions & 51 deletions compiler/hash-ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ define_tree! {
#[node]
pub struct TupleTy {
/// inner types of the tuple type
pub entries: Child!(TyParams),
pub entries: Child!(Params),
}

/// Array type, , e.g. `[T]`, `[T; N]`.
Expand All @@ -587,7 +587,7 @@ define_tree! {
#[node]
pub struct FnTy {
/// Any defined parameters for the function type
pub params: Child!(TyParams),
pub params: Child!(Params),

/// Optional return type
pub return_ty: Child!(Ty),
Expand Down Expand Up @@ -1417,7 +1417,7 @@ define_tree! {
pub ty_params: OptionalChild!(TyParams),

/// The fields of the struct, in the form of [Param].
pub fields: Children!(Param),
pub fields: Child!(Params),
}

/// A variant of an enum definition, e.g. `Some(T)`.
Expand All @@ -1428,7 +1428,7 @@ define_tree! {
pub name: Child!(Name),

/// The parameters of the enum variant, if any.
pub fields: Children!(Param),
pub fields: OptionalChild!(Params),

/// The type of the enum variant, if any.
pub ty: OptionalChild!(Ty),
Expand Down Expand Up @@ -1756,72 +1756,63 @@ define_tree! {
}
}

/// A parameter list, e.g. `(a: i32, b := 'c')`.
#[derive(Debug, PartialEq, Clone)]
#[node]
pub struct Params {
/// The parameters.
pub params: Children!(Param),

/// The origin of the type parameters.
pub origin: ParamOrigin,
}

/// This enum describes the origin kind of the subject that a parameter
/// unification occurred on.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParamOrigin {
/// If at the current time, it's not known the origin of the parameter list,
/// the function will default to using this.
Unknown,

/// Parameters came from a tuple
/// Parameters came from a tuple expression or a tuple type.
Tuple,

/// Parameters came from a struct
Struct,

/// Parameters came from a function definition
/// Parameters came from a function definition or a function type.
Fn,

/// Parameters came from a function call
FnCall,

/// Parameters came from an enum variant initialisation
EnumVariant,

/// Array pattern parameters, the parameters are all the same, but it's
/// used to represent the inner terms of the array pattern since spread
/// patterns may become named parameters.
ArrayPat,

/// Module pattern
ModulePat,

/// Constructor pattern, although this is likely to be erased into a
/// [`ParamOrigin::Struct`] or [`ParamOrigin::EnumVariant`] when inspected.
ConstructorPat,
}

impl ParamOrigin {
pub fn is_item_def(&self) -> bool {
matches!(self, ParamOrigin::Fn | ParamOrigin::EnumVariant)
}

/// Get the name of the `field` that the [ParamOrigin] refers to.
/// In other words, what is the name for the parameters that are
/// associated with the [ParamOrigin].
pub fn field_name(&self) -> &'static str {
match self {
ParamOrigin::Unknown => "field",
ParamOrigin::Tuple => "field",
ParamOrigin::Struct => "field",
ParamOrigin::Fn => "parameter",
ParamOrigin::FnCall => "argument",
ParamOrigin::Tuple |
ParamOrigin::Struct |
ParamOrigin::EnumVariant => "field",
ParamOrigin::ArrayPat => "element",
ParamOrigin::ModulePat => "field",
ParamOrigin::ConstructorPat => "field",
}
}
}

impl Display for ParamOrigin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParamOrigin::Unknown => write!(f, "unknown"),
ParamOrigin::Tuple => write!(f, "tuple"),
ParamOrigin::Struct => write!(f, "struct"),
ParamOrigin::Fn | ParamOrigin::FnCall => write!(f, "function"),
ParamOrigin::EnumVariant => write!(f, "enum variant"),
ParamOrigin::ArrayPat => write!(f, "list pattern"),
ParamOrigin::ModulePat => write!(f, "module pattern"),
ParamOrigin::ConstructorPat => write!(f, "constructor pattern"),
}
}
}
Expand All @@ -1843,10 +1834,6 @@ define_tree! {
/// argument.
pub default: OptionalChild!(Expr),

/// The origin of the parameter, whether it is from a struct field, function
/// def, type function def, etc.
pub origin: ParamOrigin,

/// Any macros are invoked on the parameter.
pub macros: OptionalChild!(MacroInvocations),
}
Expand Down Expand Up @@ -1905,12 +1892,6 @@ define_tree! {

/// The definition is a `mod` block.
Mod,

/// Funtion type.
FnTy,

/// Tuple type.
TupleTy,
}

impl TyParamOrigin {
Expand All @@ -1923,28 +1904,22 @@ define_tree! {
TyParamOrigin::Trait => "trait",
TyParamOrigin::Impl => "impl",
TyParamOrigin::Mod => "mod",
TyParamOrigin::FnTy => "function type",
TyParamOrigin::TupleTy => "tuple type",
}
}

/// Whether the origin is either a [TyParamOrigin::TupleTy] or
/// [TyParamOrigin::FnTy].
pub fn is_fn_or_tuple_ty(&self) -> bool {
matches!(self, TyParamOrigin::TupleTy | TyParamOrigin::FnTy)
}
}

/// A function definition.
#[derive(Debug, PartialEq, Clone)]
#[node]
pub struct FnDef {
/// The parameters of the function definition.
pub params: Children!(Param),
pub params: Child!(Params),

/// The return type of the function definition.
///
/// Will be inferred if [None].
pub return_ty: OptionalChild!(Ty),

/// The body/contents of the function, in the form of an expression.
pub fn_body: Child!(Expr),
}
Expand Down
29 changes: 15 additions & 14 deletions compiler/hash-ast/src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,22 @@ impl AstVisitor for AstTreeGenerator {

Ok(TreeNode::branch(
"function_def",
iter::once(TreeNode::branch("params", params))
iter::once(params)
.chain(return_ty.map(|r| TreeNode::branch("return_type", vec![r])))
.chain(iter::once(fn_body))
.collect(),
))
}

type ParamsRet = TreeNode;
fn visit_params(
&self,
node: ast::AstNodeRef<ast::Params>,
) -> Result<Self::TyParamsRet, Self::Error> {
let walk::Params { params } = walk::walk_params(self, node)?;
Ok(TreeNode::branch(format!("{}s", node.origin.field_name()), params))
}

type ParamRet = TreeNode;
fn visit_param(
&self,
Expand Down Expand Up @@ -860,7 +869,6 @@ impl AstVisitor for AstTreeGenerator {
) -> Result<Self::StructDefRet, Self::Error> {
let walk::StructDef { fields, ty_params } = walk::walk_struct_def(self, node)?;

let fields = TreeNode::branch("fields", fields);
let children = {
if let Some(ty_params) = ty_params && !ty_params.children.is_empty() {
vec![ty_params, fields]
Expand All @@ -879,19 +887,12 @@ impl AstVisitor for AstTreeGenerator {
) -> Result<Self::EnumDefEntryRet, Self::Error> {
let walk::EnumDefEntry { name, fields, ty, macros } =
walk::walk_enum_def_entry(self, node)?;
let mut children = Vec::new();

if !fields.is_empty() {
children.push(TreeNode::branch("fields", fields))
}

if let Some(ty) = ty {
children.push(TreeNode::branch("type", vec![ty]))
}

if let Some(macros) = macros {
children.push(macros)
}
let children = iter::once(TreeNode::leaf("variant"))
.chain(macros)
.chain(fields)
.chain(ty.map(|t| TreeNode::branch("type", vec![t])))
.collect_vec();

Ok(TreeNode::branch(labelled("variant", name.label, "\""), children))
}
Expand Down
54 changes: 31 additions & 23 deletions compiler/hash-fmt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod state;
use collection::CollectionPrintingOptions;
use config::AstPrintingConfig;
use hash_ast::{
ast::{self, walk_mut_self, AstVisitorMutSelf},
ast::{self, walk_mut_self, AstVisitorMutSelf, ParamOrigin},
ast_visitor_mut_self_default_impl,
};
use hash_source::constant::{IntConstant, CONSTANT_MAP};
Expand Down Expand Up @@ -176,9 +176,8 @@ where

self.visit_name(name.ast_ref())?;

if fields.len() > 0 {
let opts = CollectionPrintingOptions::delimited(Delimiter::Paren, ", ");
self.print_separated_collection(fields, opts, |this, field| this.visit_param(field))?;
if let Some(params) = fields {
self.visit_params(params.ast_ref())?;
}

if let Some(ty) = ty {
Expand Down Expand Up @@ -371,12 +370,7 @@ where
self.visit_ty_params(ty_params.ast_ref())?;
}

let mut opts = CollectionPrintingOptions::delimited(Delimiter::Paren, ", ");
opts.indented();

self.print_separated_collection(fields, opts, |this, field| this.visit_param(field))?;

Ok(())
self.visit_params(fields.ast_ref())
}

type PropertyKindRet = ();
Expand All @@ -398,7 +392,7 @@ where
node: ast::AstNodeRef<ast::TupleTy>,
) -> Result<Self::TupleTyRet, Self::Error> {
let ast::TupleTy { entries } = node.body();
self.visit_ty_params(entries.ast_ref())
self.visit_params(entries.ast_ref())
}

type ContinueStatementRet = ();
Expand Down Expand Up @@ -478,6 +472,28 @@ where
self.write(")")
}

type ParamsRet = ();
fn visit_params(
&mut self,
node: ast::AstNodeRef<ast::Params>,
) -> Result<Self::TyParamsRet, Self::Error> {
let ast::Params { params, origin } = node.body();

// Return early if no params are specified.
if params.is_empty() {
return Ok(());
}

let mut opts = CollectionPrintingOptions::delimited(Delimiter::Paren, ", ");

// @@HardCoded: Struct definition fields are indented.
if *origin == ParamOrigin::Struct {
opts.indented();
}

self.print_separated_collection(params, opts, |this, param| this.visit_param(param))
}

type ParamRet = ();

fn visit_param(
Expand Down Expand Up @@ -539,21 +555,15 @@ where
&mut self,
node: ast::AstNodeRef<ast::TyParams>,
) -> Result<Self::TyParamsRet, Self::Error> {
let ast::TyParams { params, origin } = node.body();
let ast::TyParams { params, .. } = node.body();

// Return early if no params are specified.
if params.is_empty() {
return Ok(());
}

let delimiter = if matches!(origin, ast::TyParamOrigin::FnTy | ast::TyParamOrigin::TupleTy)
{
Delimiter::Paren
} else {
Delimiter::Angle
};
let opts = CollectionPrintingOptions::delimited(Delimiter::Angle, ", ");

let opts = CollectionPrintingOptions::delimited(delimiter, ", ");
self.print_separated_collection(params, opts, |this, param| this.visit_ty_param(param))
}

Expand Down Expand Up @@ -802,8 +812,7 @@ where
) -> Result<Self::FnDefRet, Self::Error> {
let ast::FnDef { params, return_ty, fn_body } = node.body();

let opts = CollectionPrintingOptions::delimited(Delimiter::Paren, ", ");
self.print_separated_collection(params, opts, |this, param| this.visit_param(param))?;
self.visit_params(params.ast_ref())?;

if let Some(return_ty) = return_ty {
self.write(" -> ")?;
Expand Down Expand Up @@ -845,8 +854,7 @@ where
) -> Result<Self::FnTyRet, Self::Error> {
let ast::FnTy { params, return_ty } = node.body();

self.visit_ty_params(params.ast_ref())?;

self.visit_params(params.ast_ref())?;
self.write(" -> ")?;
self.visit_ty(return_ty.ast_ref())
}
Expand Down
Loading

0 comments on commit 7ff7ee7

Please sign in to comment.