From 172fa0361f85fb02b90157e864a3d78415c4a939 Mon Sep 17 00:00:00 2001 From: overlookmotel <557937+overlookmotel@users.noreply.github.com> Date: Wed, 18 Sep 2024 02:23:34 +0000 Subject: [PATCH] fix(transformer): fix stacks in arrow function transform (#5828) Push/pop to stacks in matching `enter_*` / `exit_*` visitors. The reason why there was a bug here is a little bit subtle. Between `enter_expression` and `exit_expression`, another transform can replace the `Expression` with something else. Ditto `enter_declaration` + `exit_declaration`. So pushing+popping from stacks in `enter_expression` + `exit_expression` can make the stack get out of sync. e.g.: ```rs impl<'a> Traverse for TransformerWithStack { fn enter_expression(&mut self, expr: &mut Expression<'a>, ctx: TraverseCtx<'a>) { if let Expression::FunctionExpression(_) = expr { self.stack.push(true); } } fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: TraverseCtx<'a>) { if let Expression::FunctionExpression(_) = expr { self.stack.pop(); } } } // Transformer that replaces `null` with a function expression (!) impl<'a> Traverse for SomeOtherTransformer { fn enter_expression(&mut self, expr: &mut Expression<'a>, ctx: TraverseCtx<'a>) { if let Expression::NullLiteral(_) = expr { *expr = ctx.ast.expression_function( /* ... */ ); } } } ``` `TransformerWithStack` is assuming in `exit_expression` that it previously saw a `FunctionExpression` in `enter_expression`. But `SomeOtherTransformer` has created a *new* `FunctionExpression` after `TransformerWithStack::enter_expression` ran. So in `TransformerWithStack`, `exit_expression` is called 1 more time than `enter_expression`. When `exit_expression` calls `self.stack.pop()`, `self.stack` is empty, and it panics. This example is silly, but real cases of this do exist. e.g. TS transformer creates a new functions when transforming `enum`s. `enter_function` + `exit_function` / `enter_arrow_function_expression` + `exit_arrow_function_expression` not have this problem. As we cannot mutate upwards, there are always the same number of calls to `enter_*` and `exit_*` (same for all `enter_*` / `exit_*` pairs which operate on a struct, *not an enum*). --- .../src/es2015/arrow_functions.rs | 141 +++++++++--------- crates/oxc_transformer/src/es2015/mod.rs | 24 ++- crates/oxc_transformer/src/lib.rs | 12 +- 3 files changed, 93 insertions(+), 84 deletions(-) diff --git a/crates/oxc_transformer/src/es2015/arrow_functions.rs b/crates/oxc_transformer/src/es2015/arrow_functions.rs index 021eb3f438e2d..04ed071e4a09c 100644 --- a/crates/oxc_transformer/src/es2015/arrow_functions.rs +++ b/crates/oxc_transformer/src/es2015/arrow_functions.rs @@ -91,7 +91,7 @@ impl<'a> ArrowFunctions<'a> { Self { ctx, _options: options, - // Reserve for the global scope + // Initial entry for `Program` scope this_var_stack: vec![None], inside_arrow_function_stack: vec![], } @@ -101,12 +101,19 @@ impl<'a> ArrowFunctions<'a> { impl<'a> Traverse<'a> for ArrowFunctions<'a> { /// Insert `var _this = this;` for the global scope. fn exit_program(&mut self, program: &mut Program<'a>, _ctx: &mut TraverseCtx<'a>) { - self.insert_this_var_statement_at_the_top_of_statements(&mut program.body); + debug_assert!(self.inside_arrow_function_stack.is_empty()); + + assert!(self.this_var_stack.len() == 1); + let this_var = self.this_var_stack.pop().unwrap(); + if let Some(this_var) = this_var { + self.insert_this_var_statement_at_the_top_of_statements(&mut program.body, &this_var); + } } fn enter_function(&mut self, func: &mut Function<'a>, _ctx: &mut TraverseCtx<'a>) { if func.body.is_some() { self.this_var_stack.push(None); + self.inside_arrow_function_stack.push(false); } } @@ -126,7 +133,31 @@ impl<'a> Traverse<'a> for ArrowFunctions<'a> { return; }; - self.insert_this_var_statement_at_the_top_of_statements(&mut body.statements); + let this_var = self.this_var_stack.pop().unwrap(); + if let Some(this_var) = this_var { + self.insert_this_var_statement_at_the_top_of_statements( + &mut body.statements, + &this_var, + ); + } + + self.inside_arrow_function_stack.pop().unwrap(); + } + + fn enter_arrow_function_expression( + &mut self, + _arrow: &mut ArrowFunctionExpression<'a>, + _ctx: &mut TraverseCtx<'a>, + ) { + self.inside_arrow_function_stack.push(true); + } + + fn exit_arrow_function_expression( + &mut self, + _arrow: &mut ArrowFunctionExpression<'a>, + _ctx: &mut TraverseCtx<'a>, + ) { + self.inside_arrow_function_stack.pop().unwrap(); } fn enter_jsx_element_name( @@ -160,52 +191,25 @@ impl<'a> Traverse<'a> for ArrowFunctions<'a> { } fn enter_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { - match expr { - Expression::ThisExpression(this_expr) => { - if !self.is_inside_arrow_function() { - return; - } - - let ident = - self.get_this_name(ctx).create_spanned_read_reference(this_expr.span, ctx); - *expr = self.ctx.ast.expression_from_identifier_reference(ident); - } - Expression::ArrowFunctionExpression(_) => { - self.inside_arrow_function_stack.push(true); - } - Expression::FunctionExpression(_) => self.inside_arrow_function_stack.push(false), - _ => {} - } - } - - fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { - match expr { - Expression::ArrowFunctionExpression(_) => { - let Expression::ArrowFunctionExpression(arrow_function_expr) = - ctx.ast.move_expression(expr) - else { - unreachable!() - }; - - *expr = self.transform_arrow_function_expression(arrow_function_expr.unbox(), ctx); - self.inside_arrow_function_stack.pop(); - } - Expression::FunctionExpression(_) => { - self.inside_arrow_function_stack.pop(); + if let Expression::ThisExpression(this_expr) = expr { + if !self.is_inside_arrow_function() { + return; } - _ => {} - } - } - fn enter_declaration(&mut self, decl: &mut Declaration<'a>, _ctx: &mut TraverseCtx<'a>) { - if let Declaration::FunctionDeclaration(_) = decl { - self.inside_arrow_function_stack.push(false); + let ident = self.get_this_name(ctx).create_spanned_read_reference(this_expr.span, ctx); + *expr = self.ctx.ast.expression_from_identifier_reference(ident); } } - fn exit_declaration(&mut self, decl: &mut Declaration<'a>, _ctx: &mut TraverseCtx<'a>) { - if let Declaration::FunctionDeclaration(_) = decl { - self.inside_arrow_function_stack.pop(); + fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { + if let Expression::ArrowFunctionExpression(_) = expr { + let Expression::ArrowFunctionExpression(arrow_function_expr) = + ctx.ast.move_expression(expr) + else { + unreachable!() + }; + + *expr = self.transform_arrow_function_expression(arrow_function_expr.unbox(), ctx); } } @@ -214,7 +218,7 @@ impl<'a> Traverse<'a> for ArrowFunctions<'a> { } fn exit_class(&mut self, _class: &mut Class<'a>, _ctx: &mut TraverseCtx<'a>) { - self.inside_arrow_function_stack.pop(); + self.inside_arrow_function_stack.pop().unwrap(); } fn enter_variable_declarator( @@ -299,34 +303,33 @@ impl<'a> ArrowFunctions<'a> { fn insert_this_var_statement_at_the_top_of_statements( &mut self, statements: &mut Vec<'a, Statement<'a>>, + this_var: &BoundIdentifier<'a>, ) { - if let Some(id) = &self.this_var_stack.pop().unwrap() { - let binding_pattern = self.ctx.ast.binding_pattern( - self.ctx - .ast - .binding_pattern_kind_from_binding_identifier(id.create_binding_identifier()), - NONE, - false, - ); + let binding_pattern = self.ctx.ast.binding_pattern( + self.ctx + .ast + .binding_pattern_kind_from_binding_identifier(this_var.create_binding_identifier()), + NONE, + false, + ); - let variable_declarator = self.ctx.ast.variable_declarator( - SPAN, - VariableDeclarationKind::Var, - binding_pattern, - Some(self.ctx.ast.expression_this(SPAN)), - false, - ); + let variable_declarator = self.ctx.ast.variable_declarator( + SPAN, + VariableDeclarationKind::Var, + binding_pattern, + Some(self.ctx.ast.expression_this(SPAN)), + false, + ); - let stmt = self.ctx.ast.alloc_variable_declaration( - SPAN, - VariableDeclarationKind::Var, - self.ctx.ast.vec1(variable_declarator), - false, - ); + let stmt = self.ctx.ast.alloc_variable_declaration( + SPAN, + VariableDeclarationKind::Var, + self.ctx.ast.vec1(variable_declarator), + false, + ); - let stmt = Statement::VariableDeclaration(stmt); + let stmt = Statement::VariableDeclaration(stmt); - statements.insert(0, stmt); - } + statements.insert(0, stmt); } } diff --git a/crates/oxc_transformer/src/es2015/mod.rs b/crates/oxc_transformer/src/es2015/mod.rs index f6643c0d10bce..6500930fa4e41 100644 --- a/crates/oxc_transformer/src/es2015/mod.rs +++ b/crates/oxc_transformer/src/es2015/mod.rs @@ -51,27 +51,35 @@ impl<'a> Traverse<'a> for ES2015<'a> { } } - fn enter_declaration(&mut self, decl: &mut Declaration<'a>, ctx: &mut TraverseCtx<'a>) { + fn enter_arrow_function_expression( + &mut self, + arrow: &mut ArrowFunctionExpression<'a>, + ctx: &mut TraverseCtx<'a>, + ) { if self.options.arrow_function.is_some() { - self.arrow_functions.enter_declaration(decl, ctx); + self.arrow_functions.enter_arrow_function_expression(arrow, ctx); } } - fn enter_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { + fn exit_arrow_function_expression( + &mut self, + arrow: &mut ArrowFunctionExpression<'a>, + ctx: &mut TraverseCtx<'a>, + ) { if self.options.arrow_function.is_some() { - self.arrow_functions.enter_expression(expr, ctx); + self.arrow_functions.exit_arrow_function_expression(arrow, ctx); } } - fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { + fn enter_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { if self.options.arrow_function.is_some() { - self.arrow_functions.exit_expression(expr, ctx); + self.arrow_functions.enter_expression(expr, ctx); } } - fn exit_declaration(&mut self, decl: &mut Declaration<'a>, ctx: &mut TraverseCtx<'a>) { + fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { if self.options.arrow_function.is_some() { - self.arrow_functions.exit_declaration(decl, ctx); + self.arrow_functions.exit_expression(expr, ctx); } } diff --git a/crates/oxc_transformer/src/lib.rs b/crates/oxc_transformer/src/lib.rs index 2d5b6eabcd22c..39f2ca91408fa 100644 --- a/crates/oxc_transformer/src/lib.rs +++ b/crates/oxc_transformer/src/lib.rs @@ -138,10 +138,11 @@ impl<'a> Traverse<'a> for Transformer<'a> { fn enter_arrow_function_expression( &mut self, - expr: &mut ArrowFunctionExpression<'a>, + arrow: &mut ArrowFunctionExpression<'a>, ctx: &mut TraverseCtx<'a>, ) { - self.x0_typescript.enter_arrow_function_expression(expr, ctx); + self.x0_typescript.enter_arrow_function_expression(arrow, ctx); + self.x3_es2015.enter_arrow_function_expression(arrow, ctx); } fn enter_binding_pattern(&mut self, pat: &mut BindingPattern<'a>, ctx: &mut TraverseCtx<'a>) { @@ -302,6 +303,8 @@ impl<'a> Traverse<'a> for Transformer<'a> { arrow: &mut ArrowFunctionExpression<'a>, ctx: &mut TraverseCtx<'a>, ) { + self.x3_es2015.exit_arrow_function_expression(arrow, ctx); + // Some plugins may add new statements to the ArrowFunctionExpression's body, // which can cause issues with the `() => x;` case, as it only allows a single statement. // To address this, we wrap the last statement in a return statement and set the expression to false. @@ -343,11 +346,6 @@ impl<'a> Traverse<'a> for Transformer<'a> { fn enter_declaration(&mut self, decl: &mut Declaration<'a>, ctx: &mut TraverseCtx<'a>) { self.x0_typescript.enter_declaration(decl, ctx); - self.x3_es2015.enter_declaration(decl, ctx); - } - - fn exit_declaration(&mut self, decl: &mut Declaration<'a>, ctx: &mut TraverseCtx<'a>) { - self.x3_es2015.exit_declaration(decl, ctx); } fn enter_if_statement(&mut self, stmt: &mut IfStatement<'a>, ctx: &mut TraverseCtx<'a>) {