diff --git a/db.go b/db.go index 20c5a8c..f056620 100644 --- a/db.go +++ b/db.go @@ -36,18 +36,18 @@ type Table struct { } func (t *Table) HasColumn(name string) bool { - _, err := t.GetColumn(name) + _, _, err := t.GetColumn(name) return err == nil } -func (t *Table) GetColumn(name string) (*Column, error) { - for _, column := range t.Columns { +func (t *Table) GetColumn(name string) (int, *Column, error) { + for i, column := range t.Columns { if column.Name == name { - return column, nil + return i, column, nil } } - return nil, fmt.Errorf("column '%s' not found", name) + return -1, nil, fmt.Errorf("column '%s' not found", name) } type Column struct { @@ -123,7 +123,7 @@ func BuildSchema(db DB) (*Schema, error) { func GetIncompatibleColumns(src, dst *Table) ([]*Column, error) { var incompatibleColumns []*Column for _, dstColumn := range dst.Columns { - srcColumn, err := src.GetColumn(dstColumn.Name) + _, srcColumn, err := src.GetColumn(dstColumn.Name) if err != nil { return nil, fmt.Errorf("failed to find column '%s/%s' in source schema: %s", dst.Name, dstColumn.Name, err) } diff --git a/pg2mysqlfakes/fake_verifier_watcher.go b/pg2mysqlfakes/fake_verifier_watcher.go index b1b68e3..1bd10bb 100644 --- a/pg2mysqlfakes/fake_verifier_watcher.go +++ b/pg2mysqlfakes/fake_verifier_watcher.go @@ -13,11 +13,12 @@ type FakeVerifierWatcher struct { tableVerificationDidStartArgsForCall []struct { tableName string } - TableVerificationDidFinishStub func(tableName string, missingRows int64) + TableVerificationDidFinishStub func(tableName string, missingRows int64, missingIDs []string) tableVerificationDidFinishMutex sync.RWMutex tableVerificationDidFinishArgsForCall []struct { tableName string missingRows int64 + missingIDs []string } TableVerificationDidFinishWithErrorStub func(tableName string, err error) tableVerificationDidFinishWithErrorMutex sync.RWMutex @@ -53,16 +54,22 @@ func (fake *FakeVerifierWatcher) TableVerificationDidStartArgsForCall(i int) str return fake.tableVerificationDidStartArgsForCall[i].tableName } -func (fake *FakeVerifierWatcher) TableVerificationDidFinish(tableName string, missingRows int64) { +func (fake *FakeVerifierWatcher) TableVerificationDidFinish(tableName string, missingRows int64, missingIDs []string) { + var missingIDsCopy []string + if missingIDs != nil { + missingIDsCopy = make([]string, len(missingIDs)) + copy(missingIDsCopy, missingIDs) + } fake.tableVerificationDidFinishMutex.Lock() fake.tableVerificationDidFinishArgsForCall = append(fake.tableVerificationDidFinishArgsForCall, struct { tableName string missingRows int64 - }{tableName, missingRows}) - fake.recordInvocation("TableVerificationDidFinish", []interface{}{tableName, missingRows}) + missingIDs []string + }{tableName, missingRows, missingIDsCopy}) + fake.recordInvocation("TableVerificationDidFinish", []interface{}{tableName, missingRows, missingIDsCopy}) fake.tableVerificationDidFinishMutex.Unlock() if fake.TableVerificationDidFinishStub != nil { - fake.TableVerificationDidFinishStub(tableName, missingRows) + fake.TableVerificationDidFinishStub(tableName, missingRows, missingIDs) } } @@ -72,10 +79,10 @@ func (fake *FakeVerifierWatcher) TableVerificationDidFinishCallCount() int { return len(fake.tableVerificationDidFinishArgsForCall) } -func (fake *FakeVerifierWatcher) TableVerificationDidFinishArgsForCall(i int) (string, int64) { +func (fake *FakeVerifierWatcher) TableVerificationDidFinishArgsForCall(i int) (string, int64, []string) { fake.tableVerificationDidFinishMutex.RLock() defer fake.tableVerificationDidFinishMutex.RUnlock() - return fake.tableVerificationDidFinishArgsForCall[i].tableName, fake.tableVerificationDidFinishArgsForCall[i].missingRows + return fake.tableVerificationDidFinishArgsForCall[i].tableName, fake.tableVerificationDidFinishArgsForCall[i].missingRows, fake.tableVerificationDidFinishArgsForCall[i].missingIDs } func (fake *FakeVerifierWatcher) TableVerificationDidFinishWithError(tableName string, err error) { diff --git a/verifier.go b/verifier.go index abdbc10..6830b6c 100644 --- a/verifier.go +++ b/verifier.go @@ -29,7 +29,13 @@ func (v *verifier) Verify() error { v.watcher.TableVerificationDidStart(table.Name) var missingRows int64 + var missingIDs []string err = EachMissingRow(v.src, v.dst, table, func(scanArgs []interface{}) { + if colIndex, _, getColErr := table.GetColumn("id"); getColErr == nil { + if colID, ok := scanArgs[colIndex].(*interface{}); ok { + missingIDs = append(missingIDs, fmt.Sprintf("%v", *colID)) + } + } missingRows++ }) if err != nil { @@ -37,7 +43,7 @@ func (v *verifier) Verify() error { continue } - v.watcher.TableVerificationDidFinish(table.Name, missingRows) + v.watcher.TableVerificationDidFinish(table.Name, missingRows, missingIDs) } return nil diff --git a/verifier_test.go b/verifier_test.go index d8f907e..9b676f9 100644 --- a/verifier_test.go +++ b/verifier_test.go @@ -58,18 +58,17 @@ var _ = Describe("Verifier", func() { Expect(err).NotTo(HaveOccurred()) Expect(watcher.TableVerificationDidFinishCallCount()).To(Equal(3)) for i := 0; i < watcher.TableVerificationDidFinishCallCount(); i++ { - _, missingRows := watcher.TableVerificationDidFinishArgsForCall(i) + _, missingRows, missingIDs := watcher.TableVerificationDidFinishArgsForCall(i) Expect(missingRows).To(BeZero()) + Expect(missingIDs).To(BeNil()) } }) Context("when there is data in postgres that is not in mysql", func() { + var lastInsertID int BeforeEach(func() { - result, err := pgRunner.DB().Exec("INSERT INTO table_with_id (id, name, ci_name, created_at, truthiness) VALUES (3, 'some-name', 'some-ci-name', now(), false);") + err := pgRunner.DB().QueryRow("INSERT INTO table_with_id (id, name, ci_name, created_at, truthiness) VALUES (3, 'some-name', 'some-ci-name', now(), false) RETURNING id;").Scan(&lastInsertID) Expect(err).NotTo(HaveOccurred()) - rowsAffected, err := result.RowsAffected() - Expect(err).NotTo(HaveOccurred()) - Expect(rowsAffected).To(BeNumerically("==", 1)) }) It("notifies the watcher", func() { @@ -84,8 +83,13 @@ var _ = Describe("Verifier", func() { } for i := 0; i < len(expected); i++ { - tableName, missingRows := watcher.TableVerificationDidFinishArgsForCall(i) + tableName, missingRows, missingIDs := watcher.TableVerificationDidFinishArgsForCall(i) Expect(missingRows).To(Equal(expected[tableName]), fmt.Sprintf("unexpected result for %s", tableName)) + if tableName == "table_with_id" { + Expect(missingIDs).To(Equal([]string{fmt.Sprintf("%d", lastInsertID)})) + } else { + Expect(missingIDs).To(BeNil()) + } } }) }) @@ -125,8 +129,9 @@ var _ = Describe("Verifier", func() { } for i := 0; i < len(expected); i++ { - tableName, missingRows := watcher.TableVerificationDidFinishArgsForCall(i) + tableName, missingRows, missingIDs := watcher.TableVerificationDidFinishArgsForCall(i) Expect(missingRows).To(Equal(expected[tableName]), fmt.Sprintf("unexpected result for %s", tableName)) + Expect(missingIDs).To(BeNil()) } }) }) diff --git a/watcher.go b/watcher.go index f757518..16ee8ca 100644 --- a/watcher.go +++ b/watcher.go @@ -1,12 +1,15 @@ package pg2mysql -import "fmt" +import ( + "fmt" + "strings" +) //go:generate counterfeiter . VerifierWatcher type VerifierWatcher interface { TableVerificationDidStart(tableName string) - TableVerificationDidFinish(tableName string, missingRows int64) + TableVerificationDidFinish(tableName string, missingRows int64, missingIDs []string) TableVerificationDidFinishWithError(tableName string, err error) } @@ -43,13 +46,16 @@ func (s *StdoutPrinter) TableVerificationDidStart(tableName string) { fmt.Printf("Verifying table %s...", tableName) } -func (s *StdoutPrinter) TableVerificationDidFinish(tableName string, missingRows int64) { +func (s *StdoutPrinter) TableVerificationDidFinish(tableName string, missingRows int64, missingIDs []string) { if missingRows != 0 { if missingRows == 1 { fmt.Println("\n\tFAILED: 1 row missing") } else { fmt.Printf("\n\tFAILED: %d rows missing\n", missingRows) } + if missingIDs != nil { + fmt.Printf("\tMissing IDs: %v\n", strings.Join(missingIDs, ",")) + } } else { s.done() }