diff --git a/go/test/endtoend/vtgate/vschema/vschema_test.go b/go/test/endtoend/vtgate/vschema/vschema_test.go index 341e47037f7..99993583f02 100644 --- a/go/test/endtoend/vtgate/vschema/vschema_test.go +++ b/go/test/endtoend/vtgate/vschema/vschema_test.go @@ -23,7 +23,8 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" @@ -72,7 +73,7 @@ func TestMain(m *testing.M) { Name: keyspaceName, SchemaSQL: sqlSchema, } - if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil { + if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 0, false); err != nil { return 1, err } @@ -99,29 +100,19 @@ func TestVSchema(t *testing.T) { defer cluster.PanicHandler(t) ctx := context.Background() conn, err := mysql.Connect(ctx, &vtParams) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer conn.Close() // Test the empty database with no vschema exec(t, conn, "insert into vt_user (id,name) values(1,'test1'), (2,'test2'), (3,'test3'), (4,'test4')") - qr := exec(t, conn, "select id, name from vt_user order by id") - got := fmt.Sprintf("%v", qr.Rows) - want := `[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]` - assert.Equal(t, want, got) + assertMatches(t, conn, "select id, name from vt_user order by id", + `[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]`) - qr = exec(t, conn, "delete from vt_user") - got = fmt.Sprintf("%v", qr.Rows) - want = `[]` - assert.Equal(t, want, got) + assertMatches(t, conn, "delete from vt_user", `[]`) // Test empty vschema - qr = exec(t, conn, "SHOW VSCHEMA TABLES") - got = fmt.Sprintf("%v", qr.Rows) - want = `[[VARCHAR("dual")]]` - assert.Equal(t, want, got) + assertMatches(t, conn, "SHOW VSCHEMA TABLES", `[[VARCHAR("dual")]]`) // Use the DDL to create an unsharded vschema and test again @@ -137,28 +128,19 @@ func TestVSchema(t *testing.T) { exec(t, conn, "commit") // Test Showing Tables - qr = exec(t, conn, "SHOW VSCHEMA TABLES") - got = fmt.Sprintf("%v", qr.Rows) - want = `[[VARCHAR("dual")] [VARCHAR("main")] [VARCHAR("vt_user")]]` - assert.Equal(t, want, got) + assertMatches(t, conn, + "SHOW VSCHEMA TABLES", + `[[VARCHAR("dual")] [VARCHAR("main")] [VARCHAR("vt_user")]]`) // Test Showing Vindexes - qr = exec(t, conn, "SHOW VSCHEMA VINDEXES") - got = fmt.Sprintf("%v", qr.Rows) - want = `[]` - assert.Equal(t, want, got) + assertMatches(t, conn, "SHOW VSCHEMA VINDEXES", `[]`) // Test DML operations exec(t, conn, "insert into vt_user (id,name) values(1,'test1'), (2,'test2'), (3,'test3'), (4,'test4')") - qr = exec(t, conn, "select id, name from vt_user order by id") - got = fmt.Sprintf("%v", qr.Rows) - want = `[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]` - assert.Equal(t, want, got) + assertMatches(t, conn, "select id, name from vt_user order by id", + `[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]`) - qr = exec(t, conn, "delete from vt_user") - got = fmt.Sprintf("%v", qr.Rows) - want = `[]` - assert.Equal(t, want, got) + assertMatches(t, conn, "delete from vt_user", `[]`) } @@ -170,3 +152,13 @@ func exec(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result { } return qr } + +func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { + t.Helper() + qr := exec(t, conn, query) + got := fmt.Sprintf("%v", qr.Rows) + diff := cmp.Diff(expected, got) + if diff != "" { + t.Errorf("Query: %s (-want +got):\n%s", query, diff) + } +} diff --git a/go/vt/srvtopo/keyspace_filtering_server.go b/go/vt/srvtopo/keyspace_filtering_server.go index 8b4f27faf0e..3762ed64476 100644 --- a/go/vt/srvtopo/keyspace_filtering_server.go +++ b/go/vt/srvtopo/keyspace_filtering_server.go @@ -17,9 +17,8 @@ limitations under the License. package srvtopo import ( - "fmt" - "context" + "fmt" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vschemapb "vitess.io/vitess/go/vt/proto/vschema" @@ -99,9 +98,9 @@ func (ksf keyspaceFilteringServer) GetSrvKeyspace( func (ksf keyspaceFilteringServer) WatchSrvVSchema( ctx context.Context, cell string, - callback func(*vschemapb.SrvVSchema, error), + callback func(*vschemapb.SrvVSchema, error) bool, ) { - filteringCallback := func(schema *vschemapb.SrvVSchema, err error) { + filteringCallback := func(schema *vschemapb.SrvVSchema, err error) bool { if schema != nil { for ks := range schema.Keyspaces { if !ksf.selectKeyspaces[ks] { @@ -110,7 +109,7 @@ func (ksf keyspaceFilteringServer) WatchSrvVSchema( } } - callback(schema, err) + return callback(schema, err) } ksf.server.WatchSrvVSchema(ctx, cell, filteringCallback) diff --git a/go/vt/srvtopo/keyspace_filtering_server_test.go b/go/vt/srvtopo/keyspace_filtering_server_test.go index f06a18b73ef..a0986e17f1a 100644 --- a/go/vt/srvtopo/keyspace_filtering_server_test.go +++ b/go/vt/srvtopo/keyspace_filtering_server_test.go @@ -17,13 +17,12 @@ limitations under the License. package srvtopo import ( + "context" "fmt" "reflect" "sync" "testing" - "context" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" vschemapb "vitess.io/vitess/go/vt/proto/vschema" "vitess.io/vitess/go/vt/srvtopo/srvtopotest" @@ -182,7 +181,7 @@ func TestFilteringServerWatchSrvVSchemaFiltersPassthroughSrvVSchema(t *testing.T wg := sync.WaitGroup{} wg.Add(1) - cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) { + cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) bool { // ensure that only selected keyspaces made it into the callback for name, ks := range gotSchema.Keyspaces { if !allowed[name] { @@ -198,6 +197,7 @@ func TestFilteringServerWatchSrvVSchemaFiltersPassthroughSrvVSchema(t *testing.T } } wg.Done() + return true } f.WatchSrvVSchema(stockCtx, stockCell, cb) @@ -214,7 +214,7 @@ func TestFilteringServerWatchSrvVSchemaHandlesNilSchema(t *testing.T) { wg := sync.WaitGroup{} wg.Add(1) - cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) { + cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) bool { if gotSchema != nil { t.Errorf("Expected nil gotSchema: got %#v", gotSchema) } @@ -222,6 +222,7 @@ func TestFilteringServerWatchSrvVSchemaHandlesNilSchema(t *testing.T) { t.Errorf("Unexpected error: want %v got %v", wantErr, gotErr) } wg.Done() + return true } f.WatchSrvVSchema(stockCtx, "other-cell", cb) diff --git a/go/vt/srvtopo/resilient_server_test.go b/go/vt/srvtopo/resilient_server_test.go index 842c8e92426..caffc75a52a 100644 --- a/go/vt/srvtopo/resilient_server_test.go +++ b/go/vt/srvtopo/resilient_server_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "vitess.io/vitess/go/sync2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -414,11 +416,12 @@ func TestWatchSrvVSchema(t *testing.T) { mu := sync.Mutex{} var watchValue *vschemapb.SrvVSchema var watchErr error - rs.WatchSrvVSchema(ctx, "test_cell", func(v *vschemapb.SrvVSchema, e error) { + rs.WatchSrvVSchema(ctx, "test_cell", func(v *vschemapb.SrvVSchema, e error) bool { mu.Lock() defer mu.Unlock() watchValue = v watchErr = e + return true }) get := func() (*vschemapb.SrvVSchema, error) { mu.Lock() @@ -684,10 +687,11 @@ func TestSrvKeyspaceWatcher(t *testing.T) { return nil } - rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(keyspace *topodatapb.SrvKeyspace, err error) { + rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(keyspace *topodatapb.SrvKeyspace, err error) bool { wmu.Lock() defer wmu.Unlock() wseen = append(wseen, watched{keyspace: keyspace, err: err}) + return true }) seen1 := allSeen() @@ -754,3 +758,54 @@ func TestSrvKeyspaceWatcher(t *testing.T) { assert.NotNil(t, seen6[9].keyspace) assert.Equal(t, seen6[9].keyspace.ShardingColumnName, "updated4") } + +func TestSrvKeyspaceListener(t *testing.T) { + ts, _ := memorytopo.NewServerAndFactory("test_cell") + *srvTopoCacheTTL = time.Duration(100 * time.Millisecond) + *srvTopoCacheRefresh = time.Duration(40 * time.Millisecond) + defer func() { + *srvTopoCacheTTL = 1 * time.Second + *srvTopoCacheRefresh = 1 * time.Second + }() + + rs := NewResilientServer(ts, "TestGetSrvKeyspaceWatcher") + + ctx, cancel := context.WithCancel(context.Background()) + var callbackCount sync2.AtomicInt32 + + // adding listener will perform callback. + rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(srvKs *topodatapb.SrvKeyspace, err error) bool { + callbackCount.Add(1) + select { + case <-ctx.Done(): + return false + default: + return true + } + }) + + // First update (callback - 2) + want := &topodatapb.SrvKeyspace{ + ShardingColumnName: "id", + ShardingColumnType: topodatapb.KeyspaceIdType_UINT64, + } + err := ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want) + require.NoError(t, err) + + // Next callback to remove from listener + cancel() + + // multi updates thereafter + for i := 0; i < 5; i++ { + want = &topodatapb.SrvKeyspace{ + ShardingColumnName: fmt.Sprintf("updated%d", i), + ShardingColumnType: topodatapb.KeyspaceIdType_UINT64, + } + err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + } + + // only 3 times the callback called for the listener + assert.EqualValues(t, 3, callbackCount.Get()) +} diff --git a/go/vt/srvtopo/server.go b/go/vt/srvtopo/server.go index 4ead206ce13..3c5842480a9 100644 --- a/go/vt/srvtopo/server.go +++ b/go/vt/srvtopo/server.go @@ -45,5 +45,5 @@ type Server interface { // WatchSrvVSchema starts watching the SrvVSchema object for // the provided cell. It will call the callback when // a new value or an error occurs. - WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) + WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) } diff --git a/go/vt/srvtopo/srvtopotest/passthrough.go b/go/vt/srvtopo/srvtopotest/passthrough.go index 1b0b2508443..7a49c599e1a 100644 --- a/go/vt/srvtopo/srvtopotest/passthrough.go +++ b/go/vt/srvtopo/srvtopotest/passthrough.go @@ -60,6 +60,6 @@ func (srv *PassthroughSrvTopoServer) GetSrvKeyspace(ctx context.Context, cell, k } // WatchSrvVSchema implements srvtopo.Server -func (srv *PassthroughSrvTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) { +func (srv *PassthroughSrvTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { callback(srv.WatchedSrvVSchema, srv.WatchedSrvVSchemaError) } diff --git a/go/vt/srvtopo/watch.go b/go/vt/srvtopo/watch.go index 5b7b2bad7a6..fa734ac6493 100644 --- a/go/vt/srvtopo/watch.go +++ b/go/vt/srvtopo/watch.go @@ -52,7 +52,7 @@ type watchEntry struct { lastErrorCtx context.Context lastErrorTime time.Time - listeners []func(interface{}, error) + listeners []func(interface{}, error) bool } type resilientWatcher struct { @@ -92,7 +92,7 @@ func (w *resilientWatcher) getValue(ctx context.Context, wkey fmt.Stringer) (int return entry.currentValueLocked(ctx) } -func (entry *watchEntry) addListener(ctx context.Context, callback func(interface{}, error)) { +func (entry *watchEntry) addListener(ctx context.Context, callback func(interface{}, error) bool) { entry.mutex.Lock() defer entry.mutex.Unlock() @@ -157,8 +157,13 @@ func (entry *watchEntry) update(ctx context.Context, value interface{}, err erro entry.onValueLocked(value) } - for _, callback := range entry.listeners { - callback(entry.value, entry.lastError) + listeners := entry.listeners + entry.listeners = entry.listeners[:0] + + for _, callback := range listeners { + if callback(entry.value, entry.lastError) { + entry.listeners = append(entry.listeners, callback) + } } } diff --git a/go/vt/srvtopo/watch_srvkeyspace.go b/go/vt/srvtopo/watch_srvkeyspace.go index 5ed275f75e2..9e87cc10ca0 100644 --- a/go/vt/srvtopo/watch_srvkeyspace.go +++ b/go/vt/srvtopo/watch_srvkeyspace.go @@ -74,11 +74,11 @@ func (w *SrvKeyspaceWatcher) GetSrvKeyspace(ctx context.Context, cell, keyspace return ks, err } -func (w *SrvKeyspaceWatcher) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodata.SrvKeyspace, error)) { +func (w *SrvKeyspaceWatcher) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodata.SrvKeyspace, error) bool) { entry := w.rw.getEntry(&srvKeyspaceKey{cell, keyspace}) - entry.addListener(ctx, func(v interface{}, err error) { + entry.addListener(ctx, func(v interface{}, err error) bool { srvkeyspace, _ := v.(*topodata.SrvKeyspace) - callback(srvkeyspace, err) + return callback(srvkeyspace, err) }) } diff --git a/go/vt/srvtopo/watch_srvvschema.go b/go/vt/srvtopo/watch_srvvschema.go index 4b52f409996..5a814b1c676 100644 --- a/go/vt/srvtopo/watch_srvvschema.go +++ b/go/vt/srvtopo/watch_srvvschema.go @@ -71,10 +71,10 @@ func (w *SrvVSchemaWatcher) GetSrvVSchema(ctx context.Context, cell string) (*vs return vschema, err } -func (w *SrvVSchemaWatcher) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) { +func (w *SrvVSchemaWatcher) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { entry := w.rw.getEntry(cellName(cell)) - entry.addListener(ctx, func(v interface{}, err error) { + entry.addListener(ctx, func(v interface{}, err error) bool { vschema, _ := v.(*vschemapb.SrvVSchema) - callback(vschema, err) + return callback(vschema, err) }) } diff --git a/go/vt/vtexplain/vtexplain_topo.go b/go/vt/vtexplain/vtexplain_topo.go index de00bf9fcf7..3302b68d5ce 100644 --- a/go/vt/vtexplain/vtexplain_topo.go +++ b/go/vt/vtexplain/vtexplain_topo.go @@ -17,11 +17,10 @@ limitations under the License. package vtexplain import ( + "context" "fmt" "sync" - "context" - "vitess.io/vitess/go/vt/topo" topodatapb "vitess.io/vitess/go/vt/proto/topodata" @@ -115,6 +114,6 @@ func (et *ExplainTopo) GetSrvKeyspace(ctx context.Context, cell, keyspace string } // WatchSrvVSchema is part of the srvtopo.Server interface. -func (et *ExplainTopo) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) { +func (et *ExplainTopo) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { callback(et.getSrvVSchema(), nil) } diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index c0fabf69e50..c7eaa253cc9 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -1224,8 +1224,9 @@ func TestExecutorAlterVSchemaKeyspace(t *testing.T) { session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) vschemaUpdates := make(chan *vschemapb.SrvVSchema, 2) - executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates @@ -1251,8 +1252,9 @@ func TestExecutorCreateVindexDDL(t *testing.T) { ks := "TestExecutor" vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) - executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates @@ -1322,8 +1324,9 @@ func TestExecutorAddDropVschemaTableDDL(t *testing.T) { ks := KsTestUnsharded vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) - executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(ctx, "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates diff --git a/go/vt/vtgate/executor_vschema_ddl_test.go b/go/vt/vtgate/executor_vschema_ddl_test.go index cbeefa300a2..1ce5a87a62a 100644 --- a/go/vt/vtgate/executor_vschema_ddl_test.go +++ b/go/vt/vtgate/executor_vschema_ddl_test.go @@ -138,8 +138,9 @@ func TestPlanExecutorAlterVSchemaKeyspace(t *testing.T) { session := NewSafeSession(&vtgatepb.Session{TargetString: "@primary", Autocommit: true}) vschemaUpdates := make(chan *vschemapb.SrvVSchema, 2) - executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates @@ -165,8 +166,9 @@ func TestPlanExecutorCreateVindexDDL(t *testing.T) { ks := "TestExecutor" vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) - executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates @@ -206,8 +208,9 @@ func TestPlanExecutorDropVindexDDL(t *testing.T) { ks := "TestExecutor" vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) - executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates @@ -274,8 +277,9 @@ func TestPlanExecutorAddDropVschemaTableDDL(t *testing.T) { ks := KsTestUnsharded vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) - executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates @@ -390,8 +394,9 @@ func TestExecutorAddDropVindexDDL(t *testing.T) { ks := "TestExecutor" session := NewSafeSession(&vtgatepb.Session{TargetString: ks}) vschemaUpdates := make(chan *vschemapb.SrvVSchema, 4) - executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) { + executor.serv.WatchSrvVSchema(context.Background(), "aa", func(vschema *vschemapb.SrvVSchema, err error) bool { vschemaUpdates <- vschema + return true }) vschema := <-vschemaUpdates diff --git a/go/vt/vtgate/sandbox_test.go b/go/vt/vtgate/sandbox_test.go index 7d72cfc9899..96bac8d6f74 100644 --- a/go/vt/vtgate/sandbox_test.go +++ b/go/vt/vtgate/sandbox_test.go @@ -17,12 +17,11 @@ limitations under the License. package vtgate import ( + "context" "flag" "fmt" "sync" - "context" - "vitess.io/vitess/go/json2" "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/key" @@ -289,7 +288,7 @@ func (sct *sandboxTopo) GetSrvKeyspace(ctx context.Context, cell, keyspace strin // If the sandbox was created with a backing topo service, piggy back on it // to properly simulate watches, otherwise just immediately call back the // caller. -func (sct *sandboxTopo) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) { +func (sct *sandboxTopo) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { srvVSchema := getSandboxSrvVSchema() if sct.topoServer == nil { @@ -299,11 +298,15 @@ func (sct *sandboxTopo) WatchSrvVSchema(ctx context.Context, cell string, callba sct.topoServer.UpdateSrvVSchema(ctx, cell, srvVSchema) current, updateChan, _ := sct.topoServer.WatchSrvVSchema(ctx, cell) - callback(current.Value, nil) + if !callback(current.Value, nil) { + panic("sandboxTopo callback returned false") + } go func() { for { update := <-updateChan - callback(update.Value, update.Err) + if !callback(update.Value, update.Err) { + panic("sandboxTopo callback returned false") + } } }() } diff --git a/go/vt/vtgate/vcursor_impl_test.go b/go/vt/vtgate/vcursor_impl_test.go index e283468db95..0df824f43d3 100644 --- a/go/vt/vtgate/vcursor_impl_test.go +++ b/go/vt/vtgate/vcursor_impl_test.go @@ -71,7 +71,7 @@ func (f *fakeTopoServer) GetSrvKeyspace(ctx context.Context, cell, keyspace stri // WatchSrvVSchema starts watching the SrvVSchema object for // the provided cell. It will call the callback when // a new value or an error occurs. -func (f *fakeTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) { +func (f *fakeTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) { } diff --git a/go/vt/vtgate/vschema_manager.go b/go/vt/vtgate/vschema_manager.go index ccc7b8ae661..7097b33052a 100644 --- a/go/vt/vtgate/vschema_manager.go +++ b/go/vt/vtgate/vschema_manager.go @@ -87,12 +87,18 @@ func (vm *VSchemaManager) UpdateVSchema(ctx context.Context, ksName string, vsch log.Errorf("error updating vschema in cell %s: %v", cell, cellErr) } } + if err != nil { + return err + } + + // Update all the local copy of VSchema if the topo update is successful. + vm.VSchemaUpdate(vschema, err) - return err + return nil } // VSchemaUpdate builds the VSchema from SrvVschema and call subscribers. -func (vm *VSchemaManager) VSchemaUpdate(v *vschemapb.SrvVSchema, err error) { +func (vm *VSchemaManager) VSchemaUpdate(v *vschemapb.SrvVSchema, err error) bool { log.Infof("Received vschema update") switch { case err == nil: @@ -129,6 +135,7 @@ func (vm *VSchemaManager) VSchemaUpdate(v *vschemapb.SrvVSchema, err error) { if vm.subscriber != nil { vm.subscriber(vschema, vSchemaStats(err, vschema)) } + return true } func vSchemaStats(err error, vschema *vindexes.VSchema) *VSchemaStats { diff --git a/go/vt/vttablet/tabletserver/vstreamer/engine.go b/go/vt/vttablet/tabletserver/vstreamer/engine.go index dbf68707a9c..0a6c92d53ee 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/engine.go +++ b/go/vt/vttablet/tabletserver/vstreamer/engine.go @@ -335,7 +335,7 @@ func (vse *Engine) setWatch() { } // WatchSrvVSchema does not return until the inner func has been called at least once. - vse.ts.WatchSrvVSchema(context.TODO(), vse.cell, func(v *vschemapb.SrvVSchema, err error) { + vse.ts.WatchSrvVSchema(context.TODO(), vse.cell, func(v *vschemapb.SrvVSchema, err error) bool { switch { case err == nil: // Build vschema down below. @@ -344,7 +344,7 @@ func (vse *Engine) setWatch() { default: log.Errorf("Error fetching vschema: %v", err) vse.vschemaErrors.Add(1) - return + return true } var vschema *vindexes.VSchema if v != nil { @@ -352,7 +352,7 @@ func (vse *Engine) setWatch() { if err != nil { log.Errorf("Error building vschema: %v", err) vse.vschemaErrors.Add(1) - return + return true } } else { vschema = &vindexes.VSchema{} @@ -371,6 +371,7 @@ func (vse *Engine) setWatch() { s.SetVSchema(vse.lvschema) } vse.vschemaUpdates.Add(1) + return true }) }