Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-19.0] Handle Nullability for Columns from Outer Tables (#16174) #16185

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,17 @@ func TestAlterTableWithView(t *testing.T) {

mcmp.AssertMatches("select * from v1", `[[INT64(1) INT64(1)]]`)
}

func TestHandleNullableColumn(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate")
require.NoError(t,
utils.WaitForAuthoritative(t, keyspaceName, "tbl", clusterInstance.VtgateProcess.ReadVSchema))
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into t1(id1, id2) values (0,0), (1,1), (2,2)")
mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (0,0,0), (1,1,6)")
// This query tests that we handle nullable columns correctly
// tbl.nonunq_col is not nullable according to the schema, but because of the left join, it can be NULL
mcmp.ExecWithColumnCompare(`select * from t1 left join tbl on t1.id2 = tbl.id where t1.id1 = 6 or tbl.nonunq_col = 6`)
}
20 changes: 15 additions & 5 deletions go/test/endtoend/vtgate/queries/misc/schema.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
create table if not exists t1(
id1 bigint,
id2 bigint,
primary key(id1)
) Engine=InnoDB;
create table t1
(
id1 bigint,
id2 bigint,
primary key (id1)
) Engine=InnoDB;

create table tbl
(
id bigint,
unq_col bigint,
nonunq_col bigint,
primary key (id),
unique (unq_col)
) Engine = InnoDB;
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func (t *Type) Nullable() bool {
return true // nullable by default for unknown types
}

func (t *Type) SetNullability(n bool) {
t.nullable = n
}

func (t *Type) Valid() bool {
return t.init
}
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega
oa.aggregates = append(oa.aggregates, aggrParam)
}
for _, groupBy := range op.Grouping {
typ, _ := ctx.SemTable.TypeForExpr(groupBy.Inner)
typ, _ := ctx.TypeForExpr(groupBy.Inner)
oa.groupByKeys = append(oa.groupByKeys, &engine.GroupByParams{
KeyCol: groupBy.ColOffset,
WeightStringCol: groupBy.WSOffset,
Expand Down Expand Up @@ -332,7 +332,7 @@ func createMemorySort(ctx *plancontext.PlanningContext, src logicalPlan, orderin
}

for idx, order := range ordering.Order {
typ, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr)
typ, _ := ctx.TypeForExpr(order.SimplifiedExpr)
ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, evalengine.OrderByParams{
Col: ordering.Offset[idx],
WeightStringCol: ordering.WOffset[idx],
Expand Down Expand Up @@ -389,7 +389,7 @@ func getEvalEngingeExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr
case *operators.EvalEngine:
return e.EExpr, nil
case operators.Offset:
typ, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr)
typ, _ := ctx.TypeForExpr(pe.EvalExpr)
return evalengine.NewColumn(int(e), typ, pe.EvalExpr), nil
default:
return nil, vterrors.VT13001("project not planned for: %s", pe.String())
Expand Down Expand Up @@ -560,7 +560,7 @@ func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route

eroute, err := routeToEngineRoute(ctx, op, hints)
for _, order := range op.Ordering {
typ, _ := ctx.SemTable.TypeForExpr(order.AST)
typ, _ := ctx.TypeForExpr(order.AST)
eroute.OrderBy = append(eroute.OrderBy, evalengine.OrderByParams{
Col: order.Offset,
WeightStringCol: order.WOffset,
Expand Down Expand Up @@ -877,11 +877,11 @@ func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin)

var missingTypes []string

ltyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].LHS)
ltyp, found := ctx.TypeForExpr(op.JoinComparisons[0].LHS)
if !found {
missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].LHS))
}
rtyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].RHS)
rtyp, found := ctx.TypeForExpr(op.JoinComparisons[0].RHS)
if !found {
missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].RHS))
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func getOperatorFromJoinTableExpr(ctx *plancontext.PlanningContext, tableExpr *s
case sqlparser.NormalJoinType:
return createInnerJoin(ctx, tableExpr, lhs, rhs)
case sqlparser.LeftJoinType, sqlparser.RightJoinType:
return createOuterJoin(tableExpr, lhs, rhs)
return createOuterJoin(ctx, tableExpr, lhs, rhs)
default:
panic(vterrors.VT13001("unsupported: %s", tableExpr.Join.ToString()))
}
Expand Down
1 change: 0 additions & 1 deletion go/vt/vtgate/planbuilder/operators/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator {
offset := d.Source.AddColumn(ctx, true, false, aeWrap(weightStringFor(e)))
wsCol = &offset
}

d.Columns = append(d.Columns, engine.CheckCol{
Col: idx,
WsCol: wsCol,
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (Operator, *ApplyResult)

func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) Operator {
cfg := &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp

rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr)
cfg := &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
}
Expand Down Expand Up @@ -432,7 +432,7 @@ func (hj *HashJoin) addSingleSidedColumn(

rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr)
cfg := &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ func insertRowsPlan(ctx *plancontext.PlanningContext, insOp *Insert, ins *sqlpar
colNum, _ := findOrAddColumn(ins, col)
for rowNum, row := range rows {
innerpv, err := evalengine.Translate(row[colNum], &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down Expand Up @@ -637,7 +637,7 @@ func modifyForAutoinc(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, v
}
var err error
gen.Values, err = evalengine.Translate(autoIncValues, &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (j *Join) Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult
return newOp, Rewrote("merge querygraphs into a single one")
}

func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
func createOuterJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Operator {
if tableExpr.Join == sqlparser.RightJoinType {
lhs, rhs = rhs, lhs
}
Expand All @@ -93,6 +93,8 @@ func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Oper
}
predicate := tableExpr.Condition.On
sqlparser.RemoveKeyspaceInCol(predicate)
// mark the RHS as outer tables so we know which columns are nullable
ctx.OuterTables = ctx.OuterTables.Merge(TableID(rhs))
return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate}
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) Operator {

// for everything else, we'll turn to the evalengine
eexpr, err := evalengine.Translate(rewritten, &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.T
}
switch aggr.OpCode {
case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
typ, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg())
typ, _ := ctx.TypeForExpr(aggr.Func.GetArg())
return typ

}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/sharded_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func (tr *ShardedRouting) planCompositeInOpArg(
Key: right.String(),
Index: idx,
}
if typ, found := ctx.SemTable.TypeForExpr(col); found {
if typ, found := ctx.TypeForExpr(col); found {
value.Type = typ.Type()
value.Collation = typ.Collation()
}
Expand Down Expand Up @@ -654,7 +654,7 @@ func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) eval
for _, expr := range ctx.SemTable.GetExprAndEqualities(n) {
ee, _ := evalengine.Translate(expr, &evalengine.Config{
Collation: ctx.SemTable.Collation,
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Environment: ctx.VSchema.Environment(),
})
if ee != nil {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ func createMergedUnion(
continue
}
deps = deps.Merge(ctx.SemTable.RecursiveDeps(rae.Expr))
rt, foundR := ctx.SemTable.TypeForExpr(rae.Expr)
lt, foundL := ctx.SemTable.TypeForExpr(lae.Expr)
rt, foundR := ctx.TypeForExpr(rae.Expr)
lt, foundL := ctx.TypeForExpr(lae.Expr)
if foundR && foundL {
types := []sqltypes.Type{rt.Type(), lt.Type()}
t := evalengine.AggregateTypes(types)
Expand Down
21 changes: 21 additions & 0 deletions go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

Expand Down Expand Up @@ -57,6 +58,10 @@ type PlanningContext struct {

// Statement contains the originally parsed statement
Statement sqlparser.Statement

// OuterTables contains the tables that are outer to the current query
// Used to set the nullable flag on the columns
OuterTables semantics.TableSet
}

// CreatePlanningContext initializes a new PlanningContext with the given parameters.
Expand Down Expand Up @@ -201,3 +206,19 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t
}
return modifiedExpr
}

// TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table.
func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
t, found := ctx.SemTable.TypeForExpr(e)
if !found {
return t, found
}
deps := ctx.SemTable.RecursiveDeps(e)
// If the expression is from an outer table, it should be nullable
// There are some exceptions to this, where an expression depending on the outer side
// will never return NULL, but it's better to be conservative here.
if deps.IsOverlapping(ctx.OuterTables) {
t.SetNullability(true)
}
return t, true
}
108 changes: 108 additions & 0 deletions go/vt/vtgate/planbuilder/plancontext/planning_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2024 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package plancontext

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vtgate/evalengine"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

func TestOuterTableNullability(t *testing.T) {
// Tests that columns from outer tables are nullable,
// even though the semantic state says that they are not nullable.
// This is because the outer table may not have a matching row.
// All columns are marked as NOT NULL in the schema.
query := "select * from t1 left join t2 on t1.a = t2.a where t1.a+t2.a/abs(t2.boing)"
ctx, columns := prepareContextAndFindColumns(t, query)

// Check if the columns are correctly marked as nullable.
for _, col := range columns {
colName := "column: " + sqlparser.String(col)
t.Run(colName, func(t *testing.T) {
// Extract the column type from the context and the semantic state.
// The context should mark the column as nullable.
ctxType, found := ctx.TypeForExpr(col)
require.True(t, found, colName)
stType, found := ctx.SemTable.TypeForExpr(col)
require.True(t, found, colName)
ctxNullable := ctxType.Nullable()
stNullable := stType.Nullable()

switch col.Qualifier.Name.String() {
case "t1":
assert.False(t, ctxNullable, colName)
assert.False(t, stNullable, colName)
case "t2":
assert.True(t, ctxNullable, colName)

// The semantic state says that the column is not nullable. Don't trust it.
assert.False(t, stNullable, colName)
}
})
}
}

func prepareContextAndFindColumns(t *testing.T, query string) (ctx *PlanningContext, columns []*sqlparser.ColName) {
parser := sqlparser.NewTestParser()
ast, err := parser.Parse(query)
require.NoError(t, err)
semTable := semantics.EmptySemTable()
t1 := semantics.SingleTableSet(0)
t2 := semantics.SingleTableSet(1)
stmt := ast.(*sqlparser.Select)
expr := stmt.Where.Expr

// Instead of using the semantic analysis, we manually set the types for the columns.
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
col, ok := node.(*sqlparser.ColName)
if !ok {
return true, nil
}

switch col.Qualifier.Name.String() {
case "t1":
semTable.Recursive[col] = t1
case "t2":
semTable.Recursive[col] = t2
}

intNotNull := evalengine.NewType(sqltypes.Int64, collations.Unknown)
intNotNull.SetNullability(false)
semTable.ExprTypes[col] = intNotNull
columns = append(columns, col)
return false, nil
}, nil, expr)

ctx = &PlanningContext{
SemTable: semTable,
joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{},
skipPredicates: map[sqlparser.Expr]any{},
ReservedArguments: map[sqlparser.Expr]string{},
Statement: stmt,
OuterTables: t2, // t2 is the outer table.
}
return
}
1 change: 1 addition & 0 deletions go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel
}

// TypeForExpr returns the type of expressions in the query
// Note that PlanningContext has the same method, and you should use that if you have a PlanningContext
func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
if typ, found := st.ExprTypes[e]; found {
return typ, true
Expand Down
Loading