diff --git a/go/test/endtoend/vtgate/queries/derived/derived_test.go b/go/test/endtoend/vtgate/queries/derived/derived_test.go index 80ae36633e1..c41161d9bcf 100644 --- a/go/test/endtoend/vtgate/queries/derived/derived_test.go +++ b/go/test/endtoend/vtgate/queries/derived/derived_test.go @@ -113,3 +113,15 @@ func TestDerivedTablesWithLimit(t *testing.T) { (SELECT id, user_id FROM music LIMIT 10) as m on u.id = m.user_id`, `[[INT64(1) INT64(1)] [INT64(5) INT64(2)] [INT64(1) INT64(3)] [INT64(2) INT64(4)] [INT64(3) INT64(5)] [INT64(5) INT64(7)] [INT64(4) INT64(6)] [INT64(6) NULL]]`) } + +// TestDerivedTableColumnAliasWithJoin tests the derived table having alias column and using it in the join condition +func TestDerivedTableColumnAliasWithJoin(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec(`SELECT user.id FROM user join (SELECT id as uid FROM user) t on t.uid = user.id`) + mcmp.Exec(`SELECT user.id FROM user left join (SELECT id as uid FROM user) t on t.uid = user.id`) + mcmp.Exec(`SELECT user.id FROM user join (SELECT id FROM user) t(uid) on t.uid = user.id`) + mcmp.Exec(`SELECT user.id FROM user left join (SELECT id FROM user) t(uid) on t.uid = user.id`) +} diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 30cce2617f8..532441d6a34 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -99,7 +99,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser. panic(err) } - newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) + newExpr := ctx.RewriteDerivedTableExpression(expr, tableInfo) if ContainsAggr(ctx, newExpr) { return newFilter(h, expr) } diff --git a/go/vt/vtgate/planbuilder/operators/rewriters.go b/go/vt/vtgate/planbuilder/operators/rewriters.go index 6a329860b4b..7ec8379dfab 100644 --- a/go/vt/vtgate/planbuilder/operators/rewriters.go +++ b/go/vt/vtgate/planbuilder/operators/rewriters.go @@ -218,6 +218,9 @@ func bottomUp( childID = childID.Merge(resolveID(oldInputs[0])) } in, changed := bottomUp(operator, childID, resolveID, rewriter, shouldVisit, false) + if DebugOperatorTree && changed.Changed() { + fmt.Println(ToTree(in)) + } anythingChanged = anythingChanged.Merge(changed) newInputs[i] = in } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 49039ddd347..3c2a1c97434 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -188,3 +188,16 @@ func (ctx *PlanningContext) execOnJoinPredicateEqual(joinPred sqlparser.Expr, fn } return false } + +func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, tableInfo semantics.TableInfo) sqlparser.Expr { + modifiedExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) + for key, exprs := range ctx.joinPredicates { + for _, rhsExpr := range exprs { + if ctx.SemTable.EqualsExpr(expr, rhsExpr) { + ctx.joinPredicates[key] = append(ctx.joinPredicates[key], modifiedExpr) + return modifiedExpr + } + } + } + return modifiedExpr +} diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index d19e07be662..473c231a750 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -5221,5 +5221,27 @@ "main.unsharded_a" ] } + }, + { + "comment": "join with derived table with alias and join condition - merge into route", + "query": "select 1 from user join (select id as uid from user) as t where t.uid = user.id", + "plan": { + "QueryType": "SELECT", + "Original": "select 1 from user join (select id as uid from user) as t where t.uid = user.id", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from (select id as uid from `user` where 1 != 1) as t, `user` where 1 != 1", + "Query": "select 1 from (select id as uid from `user`) as t, `user` where t.uid = `user`.id", + "Table": "`user`" + }, + "TablesUsed": [ + "user.user" + ] + } } ]