Skip to content

Commit

Permalink
add arthurscreiber's review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
meiji163 committed Oct 24, 2024
1 parent 1bd2b0b commit c1b6000
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 28 deletions.
35 changes: 10 additions & 25 deletions go/logic/applier.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier {

func (this *Applier) InitDBConnections() (err error) {
applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName)
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil {
uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri)
if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil {
return err
}
singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri)
Expand Down Expand Up @@ -1210,7 +1211,7 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB
// ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table
func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error {
var totalDelta int64
ctx := context.TODO()
ctx := context.Background()

err := func() error {
conn, err := this.db.Conn(ctx)
Expand All @@ -1236,31 +1237,23 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
}

buildResults := make([]*dmlBuildResult, 0, len(dmlEvents))
nArgs := 0
for _, dmlEvent := range dmlEvents {
for _, buildResult := range this.buildDMLEventQuery(dmlEvent) {
if buildResult.err != nil {
return rollback(buildResult.err)
}

nArgs += len(buildResult.args)
buildResults = append(buildResults, buildResult)
}
}

execErr := conn.Raw(func(driverConn any) error {
ex, ok := driverConn.(driver.ExecerContext)
if !ok {
return fmt.Errorf("could not cast driverConn to ExecerContext")
}

nvc, ok := driverConn.(driver.NamedValueChecker)
if !ok {
return fmt.Errorf("could not cast driverConn to NamedValueChecker")
}
ex := driverConn.(driver.ExecerContext)
nvc := driverConn.(driver.NamedValueChecker)

var multiArgs []driver.NamedValue
multiArgs := make([]driver.NamedValue, 0, nArgs)
multiQueryBuilder := strings.Builder{}
var rowDeltas []int64

for _, buildResult := range buildResults {
for _, arg := range buildResult.args {
nv := driver.NamedValue{Value: driver.Value(arg)}
Expand All @@ -1270,29 +1263,21 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))

multiQueryBuilder.WriteString(buildResult.query)
multiQueryBuilder.WriteString(";\n")

rowDeltas = append(rowDeltas, buildResult.rowsDelta)
}

// this.migrationContext.Log.Infof("Executing query: %s, args: %+v", multiQueryBuilder.String(), multiArgs)
res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs)
if err != nil {
err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs)
this.migrationContext.Log.Errorf("Error exec: %+v", err)
return err
}

mysqlRes, ok := res.(drivermysql.Result)
if !ok {
return fmt.Errorf("Could not cast %+v to mysql.Result", res)
}
mysqlRes := res.(drivermysql.Result)

// each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1).
// multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event
for i, rowsAffected := range mysqlRes.AllRowsAffected() {
totalDelta += rowDeltas[i] * rowsAffected
totalDelta += buildResults[i].rowsDelta * rowsAffected
}

return nil
})

Expand Down
1 change: 0 additions & 1 deletion go/mysql/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ func (this *ConnectionConfig) GetDBUri(databaseName string) string {
connectionParams := []string{
"autocommit=true",
"interpolateParams=true",
"multiStatements=true",
fmt.Sprintf("charset=%s", this.Charset),
fmt.Sprintf("tls=%s", tlsOption),
fmt.Sprintf("transaction_isolation=%q", this.TransactionIsolation),
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestGetDBUri(t *testing.T) {
c.Charset = "utf8mb4,utf8,latin1"

uri := c.GetDBUri("test")
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&multiStatements=true&charset=utf8mb4,utf8,latin1&tls=false&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4,utf8,latin1&tls=false&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
}

func TestGetDBUriWithTLSSetup(t *testing.T) {
Expand All @@ -100,5 +100,5 @@ func TestGetDBUriWithTLSSetup(t *testing.T) {
c.Charset = "utf8mb4_general_ci,utf8_general_ci,latin1"

uri := c.GetDBUri("test")
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&multiStatements=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
require.Equal(t, `gromit:penguin@tcp(myhost:3306)/test?autocommit=true&interpolateParams=true&charset=utf8mb4_general_ci,utf8_general_ci,latin1&tls=ghost&transaction_isolation="REPEATABLE-READ"&timeout=1.234500s&readTimeout=1.234500s&writeTimeout=1.234500s`, uri)
}

0 comments on commit c1b6000

Please sign in to comment.