From 452aa257712487f4f24079d93b4b5b5db6dd5878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=82=E5=B7=9D=E6=81=AD=E4=BD=91?= Date: Wed, 9 Oct 2024 13:52:30 +0900 Subject: [PATCH] Structure the return type --- README.md | 24 ++++++------- breaql.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++ cmd/breaql/main.go | 15 ++++---- errors.go | 20 +++++++++++ go.mod | 3 +- go.sum | 4 +++ mysql.go | 27 ++++++++++----- 7 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 breaql.go create mode 100644 errors.go diff --git a/README.md b/README.md index 2c99f03..4d71e7d 100644 --- a/README.md +++ b/README.md @@ -24,20 +24,22 @@ You can pass the DDL statements via stdin or as a file. ```shell echo ' - CREATE TABLE products (id INT PRIMARY KEY, name VARCHAR(100)); - ALTER TABLE users DROP COLUMN age; - DROP TABLE users; - ' | breaql --driver mysql + CREATE TABLE products (id INT PRIMARY KEY, name VARCHAR(100)); + ALTER TABLE users DROP COLUMN age; + DROP TABLE users; + DROP DATABASE foo; + ' | go run breaql --driver mysql ``` And then you will see the output like this: ```sql -- Detected destructive changes: --- No.1 - ALTER TABLE users DROP COLUMN age; --- No.2 - DROP TABLE users; +-- Table: users + ALTER TABLE users DROP COLUMN age; + DROP TABLE users; +-- Database: foo + DROP DATABASE foo; ``` ### via Go application @@ -63,13 +65,11 @@ func main() { log.Fatal(err) } - if len(changes) == 0 { + if changes.Exist() { fmt.Println("No breaking changes detected") } else { fmt.Println("-- Detected destructive changes:") - for i, change := range changes { - fmt.Printf("-- No.%d\n %s\n", i+1, change) - } + fmt.Printf(changes.FormatSQL()) } } diff --git a/breaql.go b/breaql.go new file mode 100644 index 0000000..fa4fa9a --- /dev/null +++ b/breaql.go @@ -0,0 +1,85 @@ +package breaql + +import ( + "strings" + + "github.com/samber/lo" +) + +type BreakingChanges struct { + Tables TableChanges `json:"tables"` + Databases DatabaseChanges `json:"databases"` +} + +func NewBreakingChanges() BreakingChanges { + return BreakingChanges{ + Tables: make(TableChanges), + Databases: make(DatabaseChanges), + } +} + +// Exist return if any changes exist. +func (bc BreakingChanges) Exist() bool { + return bc.Tables.Exist() || bc.Databases.Exist() +} + +// FormatSQL returns the breaking changes in SQL format. +func (bc BreakingChanges) FormatSQL() string { + builder := strings.Builder{} + for _, table := range bc.Tables.Tables() { + builder.WriteString("-- Table: " + table + "\n") + for _, stmt := range bc.Tables.Statements(table) { + builder.WriteString(" " + stmt + "\n") + } + } + for _, database := range bc.Databases.Databases() { + builder.WriteString("-- Database: " + database + "\n") + for _, stmt := range bc.Databases.Statements(database) { + builder.WriteString(" " + stmt + "\n") + } + } + + return builder.String() +} + +type TableChanges map[string][]string + +func (tc TableChanges) add(table string, statements ...string) { + tc[table] = append(tc[table], statements...) +} + +// Tables returns the affected table names. +func (tc TableChanges) Tables() []string { + return lo.Keys(tc) +} + +// Statements returns the breaking statements for the given table. +func (tc TableChanges) Statements(table string) []string { + return tc[table] +} + +// Exist return if any changes exist. +func (tc TableChanges) Exist() bool { + return len(tc) > 0 +} + +type DatabaseChanges map[string][]string + +func (dc DatabaseChanges) add(database string, statements ...string) { + dc[database] = append(dc[database], statements...) +} + +// Databases returns the affected database names. +func (dc DatabaseChanges) Databases() []string { + return lo.Keys(dc) +} + +// Statements returns the breaking statements for the given database. +func (dc DatabaseChanges) Statements(database string) []string { + return dc[database] +} + +// Exist return if any changes exist. +func (dc DatabaseChanges) Exist() bool { + return len(dc) > 0 +} diff --git a/cmd/breaql/main.go b/cmd/breaql/main.go index 33a0f9d..71f68e7 100644 --- a/cmd/breaql/main.go +++ b/cmd/breaql/main.go @@ -55,7 +55,7 @@ func main_() error { } // Detect destructive changes - var changes []string + var changes breaql.BreakingChanges switch input.Driver { case "mysql": changes, err = breaql.RunMySQL(string(ddl)) @@ -65,11 +65,9 @@ func main_() error { default: return errors.Errorf("unsupported driver: %s", input.Driver) } - if len(changes) > 0 { + if changes.Exist() { fmt.Println("-- Detected destructive changes:") - for i, change := range changes { - fmt.Printf("-- No.%d\n %s\n", i+1, change) - } + fmt.Printf(changes.FormatSQL()) } else { fmt.Println("-- No destructive changes detected. --") } @@ -79,7 +77,12 @@ func main_() error { func main() { if err := main_(); err != nil { - slog.Error(fmt.Sprintf("error: %v", err)) + switch err := errors.Cause(err).(type) { + case *breaql.ParseError: + slog.Error("Parse Error!", slog.String("message", err.Message)) + default: + slog.Error(fmt.Sprintf("error: %v", err)) + } os.Exit(1) } } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..768a7e0 --- /dev/null +++ b/errors.go @@ -0,0 +1,20 @@ +package breaql + +import ( + "fmt" +) + +type ParseError struct { + Message string // simple and human-readable error message + + funcName string + original error +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("error %s: %s", e.funcName, e.Message) +} + +func (e *ParseError) Unwrap() error { + return e.original +} diff --git a/go.mod b/go.mod index 08a2ddf..dabdf31 100644 --- a/go.mod +++ b/go.mod @@ -10,10 +10,11 @@ require ( github.com/pingcap/log v1.1.0 // indirect github.com/pingcap/tidb/pkg/parser v0.0.0-20240820100743-1a0c3ac3292f // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/samber/lo v1.47.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.16.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index c35c28b..9887ff8 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= +github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -53,6 +55,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/mysql.go b/mysql.go index 86920c0..cb537eb 100644 --- a/mysql.go +++ b/mysql.go @@ -4,9 +4,9 @@ import ( "log/slog" "strings" - "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/samber/lo" // Importing the following parser driver causes a build error. //_ "github.com/pingcap/tidb/pkg/types/parser_driver" @@ -15,33 +15,44 @@ import ( ) // RunMySQL parses the given (possibly composite) DDL statements and returns the breaking ones. -func RunMySQL(sql string) ([]string, error) { +func RunMySQL(sql string) (BreakingChanges, error) { p := parser.New() stmtNodes, _, err := p.Parse(sql, "", "") if err != nil { - return nil, errors.Wrap(err, "error p.Parse") + return BreakingChanges{}, &ParseError{original: err, Message: err.Error(), funcName: "parser.Parse"} } - var breakingStmt []string + changes := NewBreakingChanges() for _, stmtNode := range stmtNodes { stmtText := strings.TrimSpace(stmtNode.Text()) slog.Debug("processing stmt", slog.String("stmt", stmtText)) switch stmt := stmtNode.(type) { - case *ast.DropTableStmt, *ast.TruncateTableStmt, *ast.DropDatabaseStmt, *ast.RenameTableStmt: - breakingStmt = append(breakingStmt, stmtText) + case *ast.DropDatabaseStmt: + changes.Databases.add(stmt.Name.String(), stmtText) + + case *ast.DropTableStmt: + lo.ForEach(stmt.Tables, func(stmt *ast.TableName, _ int) { changes.Tables.add(stmt.Name.String(), stmtText) }) + + case *ast.TruncateTableStmt: + changes.Tables.add(stmt.Table.Name.String(), stmtText) + + case *ast.RenameTableStmt: + lo.ForEach(stmt.TableToTables, func(ttt *ast.TableToTable, _ int) { changes.Tables.add(ttt.OldTable.Name.String(), stmtText) }) + case *ast.AlterTableStmt: for _, spec := range stmt.Specs { if isBreakingAlterTableSpec(spec) { - breakingStmt = append(breakingStmt, stmtText) + changes.Tables.add(stmt.Table.Name.String(), stmtText) break } } } + } - return breakingStmt, nil + return changes, nil } func isBreakingAlterTableSpec(spec *ast.AlterTableSpec) bool {