Skip to content

Commit

Permalink
Pulling the GenerateSqlPatchSchemaStatements into the sqlfmt package …
Browse files Browse the repository at this point in the history
…(and cleaning up pkg import cycles)
  • Loading branch information
fulghum committed May 17, 2024
1 parent 02b3834 commit 9b0176b
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 255 deletions.
13 changes: 10 additions & 3 deletions go/libraries/doltcore/diff/async_differ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

dtu "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/constants"
Expand Down Expand Up @@ -332,7 +331,15 @@ func getKeylessRow(ctx context.Context, vals []types.Value) ([]types.Value, erro
vals = append(prefix, vals...)

return []types.Value{
dtu.MustTuple(rowIdTag, id1),
dtu.MustTuple(vals...),
mustTuple(rowIdTag, id1),
mustTuple(vals...),
}, nil
}

func mustTuple(vals ...types.Value) types.Tuple {
tup, err := types.NewTuple(types.Format_Default, vals...)
if err != nil {
panic(err)
}
return tup
}
31 changes: 0 additions & 31 deletions go/libraries/doltcore/diff/table_deltas.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
"github.com/dolthub/dolt/go/libraries/utils/set"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/dolt/go/store/types"
Expand Down Expand Up @@ -620,36 +619,6 @@ func (td TableDelta) GetRowData(ctx context.Context) (from, to durable.Index, er
return from, to, nil
}

// GetDataDiffStatement returns any data diff in SQL statements for given table including INSERT, UPDATE and DELETE row statements.
func GetDataDiffStatement(tableName string, sch schema.Schema, row sql.Row, rowDiffType ChangeType, colDiffTypes []ChangeType) (string, error) {
if len(row) != len(colDiffTypes) {
return "", fmt.Errorf("expected the same size for columns and diff types, got %d and %d", len(row), len(colDiffTypes))
}

switch rowDiffType {
case Added:
return sqlfmt.SqlRowAsInsertStmt(row, tableName, sch)
case Removed:
return sqlfmt.SqlRowAsDeleteStmt(row, tableName, sch, 0)
case ModifiedNew:
updatedCols := set.NewEmptyStrSet()
for i, diffType := range colDiffTypes {
if diffType != None {
updatedCols.Add(sch.GetAllCols().GetByIndex(i).Name)
}
}
if updatedCols.Size() == 0 {
return "", nil
}
return sqlfmt.SqlRowAsUpdateStmt(row, tableName, sch, updatedCols)
case ModifiedOld:
// do nothing, we only issue UPDATE for ModifiedNew
return "", nil
default:
return "", fmt.Errorf("unexpected row diff type: %v", rowDiffType)
}
}

// WorkingSetContainsOnlyIgnoredTables returns true if all changes in working set are ignored tables.
// Otherwise, if there are any non-ignored changes, returns false.
// Note that only unstaged tables are subject to dolt_ignore (this is consistent with what git does.)
Expand Down
5 changes: 3 additions & 2 deletions go/libraries/doltcore/merge/merge_prolly_rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/expranalysis"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/pool"
Expand Down Expand Up @@ -374,7 +375,7 @@ func newCheckValidator(ctx *sql.Context, tm *TableMerger, vm *valueMerger, sch s
continue
}

expr, err := index.ResolveCheckExpression(ctx, tm.name, sch, check.Expression())
expr, err := expranalysis.ResolveCheckExpression(ctx, tm.name, sch, check.Expression())
if err != nil {
return checkValidator{}, err
}
Expand Down Expand Up @@ -1192,7 +1193,7 @@ func resolveDefaults(ctx *sql.Context, tableName string, mergedSchema schema.Sch
}

if col.Default != "" || col.Generated != "" || col.OnUpdate != "" {
expr, err := index.ResolveDefaultExpression(ctx, tableName, mergedSchema, col)
expr, err := expranalysis.ResolveDefaultExpression(ctx, tableName, mergedSchema, col)
if err != nil {
return true, err
}
Expand Down
133 changes: 3 additions & 130 deletions go/libraries/doltcore/sqle/dolt_patch_table_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ import (
"github.com/dolthub/vitess/go/mysql"
"golang.org/x/exp/slices"

"github.com/dolthub/dolt/go/cmd/dolt/errhand"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
Expand Down Expand Up @@ -514,7 +512,7 @@ func getPatchNodes(ctx *sql.Context, dbData env.DbData, tableDeltas []diff.Table
// Get SCHEMA DIFF
var schemaStmts []string
if includeSchemaDiff {
schemaStmts, err = GenerateSqlPatchSchemaStatements(ctx, toRefDetails.root, td)
schemaStmts, err = sqlfmt.GenerateSqlPatchSchemaStatements(ctx, toRefDetails.root, td)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -595,14 +593,14 @@ func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema

var stmt string
if oldRow.Row != nil {
stmt, err = diff.GetDataDiffStatement(tn, tsch, oldRow.Row, oldRow.RowDiff, oldRow.ColDiffs)
stmt, err = sqlfmt.GenerateDataDiffStatement(tn, tsch, oldRow.Row, oldRow.RowDiff, oldRow.ColDiffs)
if err != nil {
return nil, err
}
}

if newRow.Row != nil {
stmt, err = diff.GetDataDiffStatement(tn, tsch, newRow.Row, newRow.RowDiff, newRow.ColDiffs)
stmt, err = sqlfmt.GenerateDataDiffStatement(tn, tsch, newRow.Row, newRow.RowDiff, newRow.ColDiffs)
if err != nil {
return nil, err
}
Expand All @@ -614,131 +612,6 @@ func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema
}
}

// GenerateSqlPatchSchemaStatements examines the table schema changes in the specified TableDelta |td| and returns
// a slice of SQL path statements that represent the equivalent SQL DDL statements for those schema changes. The
// specified RootValue, |toRoot|, must be the RootValue that was used as the "To" root when computing the specified
// TableDelta.
func GenerateSqlPatchSchemaStatements(ctx *sql.Context, toRoot doltdb.RootValue, td diff.TableDelta) ([]string, error) {
toSchemas, err := doltdb.GetAllSchemas(ctx, toRoot)
if err != nil {
return nil, fmt.Errorf("could not read schemas from toRoot, cause: %s", err.Error())
}

fromSch, toSch, err := td.GetSchemas(ctx)
if err != nil {
return nil, fmt.Errorf("cannot retrieve schema for table %s, cause: %s", td.ToName, err.Error())
}

var ddlStatements []string
if td.IsDrop() {
ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(td.FromName))
} else if td.IsAdd() {
stmt, err := sqlfmt.GenerateCreateTableStatement(td.ToName, td.ToSch, td.ToFks, td.ToFksParentSch)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
}
ddlStatements = append(ddlStatements, stmt)
} else {
stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch)
if err != nil {
return nil, err
}
ddlStatements = append(ddlStatements, stmts...)
}

return ddlStatements, nil
}

// GetNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements.
func GetNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) {
if td.IsAdd() || td.IsDrop() {
// use add and drop specific methods
return nil, nil
}

var ddlStatements []string
if td.FromName != td.ToName {
ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName))
}

eq := schema.SchemasAreEqual(fromSch, toSch)
if eq && !td.HasFKChanges() {
return ddlStatements, nil
}

colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch)
for _, tag := range unionTags {
cd := colDiffs[tag]
switch cd.DiffType {
case diff.SchDiffNone:
case diff.SchDiffAdded:
ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation()))))
case diff.SchDiffRemoved:
ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name))
case diff.SchDiffModified:
// Ignore any primary key set changes here
if cd.Old.IsPartOfPK != cd.New.IsPartOfPK {
continue
}
if cd.Old.Name != cd.New.Name {
ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name))
}
if !cd.Old.TypeInfo.Equals(cd.New.TypeInfo) {
ddlStatements = append(ddlStatements, sqlfmt.AlterTableModifyColStmt(td.ToName,
sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation()))))
}
}
}

// Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD
if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) {
ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName))
if toSch.GetPKCols().Size() > 0 {
ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols().GetColumnNames()))
}
}

for _, idxDiff := range diff.DiffSchIndexes(fromSch, toSch) {
switch idxDiff.DiffType {
case diff.SchDiffNone:
case diff.SchDiffAdded:
ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To))
case diff.SchDiffRemoved:
ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From))
case diff.SchDiffModified:
ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From))
ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To))
}
}

for _, fkDiff := range diff.DiffForeignKeys(td.FromFks, td.ToFks) {
switch fkDiff.DiffType {
case diff.SchDiffNone:
case diff.SchDiffAdded:
parentSch := toSchemas[fkDiff.To.ReferencedTableName]
ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch))
case diff.SchDiffRemoved:
from := fkDiff.From
ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name))
case diff.SchDiffModified:
from := fkDiff.From
ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name))

parentSch := toSchemas[fkDiff.To.ReferencedTableName]
ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch))
}
}

// Handle charset/collation changes
toCollation := toSch.GetCollation()
fromCollation := fromSch.GetCollation()
if toCollation != fromCollation {
ddlStatements = append(ddlStatements, sqlfmt.AlterTableCollateStmt(td.ToName, fromCollation, toCollation))
}

return ddlStatements, nil
}

// getDiffQuery returns diff schema for specified columns and array of sql.Expression as projection to be used
// on diff table function row iter. This function attempts to imitate running a query
// fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columnsWithDiff, "diff_type", fromRef, toRef, tableName)
Expand Down
109 changes: 109 additions & 0 deletions go/libraries/doltcore/sqle/expranalysis/expranalysis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expranalysis

import (
"fmt"

"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/planbuilder"
"github.com/dolthub/go-mysql-server/sql/transform"

"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
)

// ResolveDefaultExpression returns a sql.Expression for the column default or generated expression for the
// column provided
func ResolveDefaultExpression(ctx *sql.Context, tableName string, sch schema.Schema, col schema.Column) (sql.Expression, error) {
ct, err := parseCreateTable(ctx, tableName, sch)
if err != nil {
return nil, err
}

colIdx := ct.PkSchema().Schema.IndexOfColName(col.Name)
if colIdx < 0 {
return nil, fmt.Errorf("unable to find column %s in analyzed query", col.Name)
}

sqlCol := ct.PkSchema().Schema[colIdx]
expr := sqlCol.Default
if expr == nil || expr.Expr == nil {
expr = sqlCol.Generated
}

if expr == nil || expr.Expr == nil {
return nil, fmt.Errorf("unable to find default or generated expression")
}

return expr.Expr, nil
}

// ResolveCheckExpression returns a sql.Expression for the check provided
func ResolveCheckExpression(ctx *sql.Context, tableName string, sch schema.Schema, checkExpr string) (sql.Expression, error) {
ct, err := parseCreateTable(ctx, tableName, sch)
if err != nil {
return nil, err
}

for _, check := range ct.Checks() {
if stripTableNamesFromExpression(check.Expr).String() == checkExpr {
return check.Expr, nil
}
}

return nil, fmt.Errorf("unable to find check expression")
}

func stripTableNamesFromExpression(expr sql.Expression) sql.Expression {
e, _, _ := transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
if col, ok := e.(*expression.GetField); ok {
return col.WithTable(""), transform.NewTree, nil
}
return e, transform.SameTree, nil
})
return e
}

func parseCreateTable(ctx *sql.Context, tableName string, sch schema.Schema) (*plan.CreateTable, error) {
createTable, err := sqlfmt.GenerateCreateTableStatement(tableName, sch, nil, nil)
if err != nil {
return nil, err
}

query := createTable

mockDatabase := memory.NewDatabase("mydb")
mockProvider := memory.NewDBProvider(mockDatabase)
catalog := analyzer.NewCatalog(mockProvider)
parseCtx := sql.NewEmptyContext()
parseCtx.SetCurrentDatabase("mydb")

b := planbuilder.New(parseCtx, catalog, sql.NewMysqlParser())
pseudoAnalyzedQuery, _, _, err := b.Parse(query, false)
if err != nil {
return nil, err
}

ct, ok := pseudoAnalyzedQuery.(*plan.CreateTable)
if !ok {
return nil, fmt.Errorf("expected a *plan.CreateTable node, but got %T", pseudoAnalyzedQuery)
}
return ct, nil
}
Loading

0 comments on commit 9b0176b

Please sign in to comment.