Skip to content

Commit

Permalink
Merge pull request #8633 from planetscale/srvtopo-unwatch
Browse files Browse the repository at this point in the history
srvtopo: allow unwatching from watch callbacks
  • Loading branch information
deepthi authored Aug 19, 2021
2 parents 1f5a7ed + ffa5fc8 commit 441d4bd
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 78 deletions.
58 changes: 25 additions & 33 deletions go/test/endtoend/vtgate/vschema/vschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import (
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -72,7 +73,7 @@ func TestMain(m *testing.M) {
Name: keyspaceName,
SchemaSQL: sqlSchema,
}
if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 1, false); err != nil {
if err := clusterInstance.StartUnshardedKeyspace(*keyspace, 0, false); err != nil {
return 1, err
}

Expand All @@ -99,29 +100,19 @@ func TestVSchema(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer conn.Close()

// Test the empty database with no vschema
exec(t, conn, "insert into vt_user (id,name) values(1,'test1'), (2,'test2'), (3,'test3'), (4,'test4')")

qr := exec(t, conn, "select id, name from vt_user order by id")
got := fmt.Sprintf("%v", qr.Rows)
want := `[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]`
assert.Equal(t, want, got)
assertMatches(t, conn, "select id, name from vt_user order by id",
`[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]`)

qr = exec(t, conn, "delete from vt_user")
got = fmt.Sprintf("%v", qr.Rows)
want = `[]`
assert.Equal(t, want, got)
assertMatches(t, conn, "delete from vt_user", `[]`)

// Test empty vschema
qr = exec(t, conn, "SHOW VSCHEMA TABLES")
got = fmt.Sprintf("%v", qr.Rows)
want = `[[VARCHAR("dual")]]`
assert.Equal(t, want, got)
assertMatches(t, conn, "SHOW VSCHEMA TABLES", `[[VARCHAR("dual")]]`)

// Use the DDL to create an unsharded vschema and test again

Expand All @@ -137,28 +128,19 @@ func TestVSchema(t *testing.T) {
exec(t, conn, "commit")

// Test Showing Tables
qr = exec(t, conn, "SHOW VSCHEMA TABLES")
got = fmt.Sprintf("%v", qr.Rows)
want = `[[VARCHAR("dual")] [VARCHAR("main")] [VARCHAR("vt_user")]]`
assert.Equal(t, want, got)
assertMatches(t, conn,
"SHOW VSCHEMA TABLES",
`[[VARCHAR("dual")] [VARCHAR("main")] [VARCHAR("vt_user")]]`)

// Test Showing Vindexes
qr = exec(t, conn, "SHOW VSCHEMA VINDEXES")
got = fmt.Sprintf("%v", qr.Rows)
want = `[]`
assert.Equal(t, want, got)
assertMatches(t, conn, "SHOW VSCHEMA VINDEXES", `[]`)

// Test DML operations
exec(t, conn, "insert into vt_user (id,name) values(1,'test1'), (2,'test2'), (3,'test3'), (4,'test4')")
qr = exec(t, conn, "select id, name from vt_user order by id")
got = fmt.Sprintf("%v", qr.Rows)
want = `[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]`
assert.Equal(t, want, got)
assertMatches(t, conn, "select id, name from vt_user order by id",
`[[INT64(1) VARCHAR("test1")] [INT64(2) VARCHAR("test2")] [INT64(3) VARCHAR("test3")] [INT64(4) VARCHAR("test4")]]`)

qr = exec(t, conn, "delete from vt_user")
got = fmt.Sprintf("%v", qr.Rows)
want = `[]`
assert.Equal(t, want, got)
assertMatches(t, conn, "delete from vt_user", `[]`)

}

Expand All @@ -170,3 +152,13 @@ func exec(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result {
}
return qr
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
t.Helper()
qr := exec(t, conn, query)
got := fmt.Sprintf("%v", qr.Rows)
diff := cmp.Diff(expected, got)
if diff != "" {
t.Errorf("Query: %s (-want +got):\n%s", query, diff)
}
}
9 changes: 4 additions & 5 deletions go/vt/srvtopo/keyspace_filtering_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ limitations under the License.
package srvtopo

import (
"fmt"

"context"
"fmt"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
Expand Down Expand Up @@ -99,9 +98,9 @@ func (ksf keyspaceFilteringServer) GetSrvKeyspace(
func (ksf keyspaceFilteringServer) WatchSrvVSchema(
ctx context.Context,
cell string,
callback func(*vschemapb.SrvVSchema, error),
callback func(*vschemapb.SrvVSchema, error) bool,
) {
filteringCallback := func(schema *vschemapb.SrvVSchema, err error) {
filteringCallback := func(schema *vschemapb.SrvVSchema, err error) bool {
if schema != nil {
for ks := range schema.Keyspaces {
if !ksf.selectKeyspaces[ks] {
Expand All @@ -110,7 +109,7 @@ func (ksf keyspaceFilteringServer) WatchSrvVSchema(
}
}

callback(schema, err)
return callback(schema, err)
}

ksf.server.WatchSrvVSchema(ctx, cell, filteringCallback)
Expand Down
9 changes: 5 additions & 4 deletions go/vt/srvtopo/keyspace_filtering_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ limitations under the License.
package srvtopo

import (
"context"
"fmt"
"reflect"
"sync"
"testing"

"context"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
"vitess.io/vitess/go/vt/srvtopo/srvtopotest"
Expand Down Expand Up @@ -182,7 +181,7 @@ func TestFilteringServerWatchSrvVSchemaFiltersPassthroughSrvVSchema(t *testing.T
wg := sync.WaitGroup{}
wg.Add(1)

cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) {
cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) bool {
// ensure that only selected keyspaces made it into the callback
for name, ks := range gotSchema.Keyspaces {
if !allowed[name] {
Expand All @@ -198,6 +197,7 @@ func TestFilteringServerWatchSrvVSchemaFiltersPassthroughSrvVSchema(t *testing.T
}
}
wg.Done()
return true
}

f.WatchSrvVSchema(stockCtx, stockCell, cb)
Expand All @@ -214,14 +214,15 @@ func TestFilteringServerWatchSrvVSchemaHandlesNilSchema(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)

cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) {
cb := func(gotSchema *vschemapb.SrvVSchema, gotErr error) bool {
if gotSchema != nil {
t.Errorf("Expected nil gotSchema: got %#v", gotSchema)
}
if gotErr != wantErr {
t.Errorf("Unexpected error: want %v got %v", wantErr, gotErr)
}
wg.Done()
return true
}

f.WatchSrvVSchema(stockCtx, "other-cell", cb)
Expand Down
59 changes: 57 additions & 2 deletions go/vt/srvtopo/resilient_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"testing"
"time"

"vitess.io/vitess/go/sync2"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -414,11 +416,12 @@ func TestWatchSrvVSchema(t *testing.T) {
mu := sync.Mutex{}
var watchValue *vschemapb.SrvVSchema
var watchErr error
rs.WatchSrvVSchema(ctx, "test_cell", func(v *vschemapb.SrvVSchema, e error) {
rs.WatchSrvVSchema(ctx, "test_cell", func(v *vschemapb.SrvVSchema, e error) bool {
mu.Lock()
defer mu.Unlock()
watchValue = v
watchErr = e
return true
})
get := func() (*vschemapb.SrvVSchema, error) {
mu.Lock()
Expand Down Expand Up @@ -684,10 +687,11 @@ func TestSrvKeyspaceWatcher(t *testing.T) {
return nil
}

rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(keyspace *topodatapb.SrvKeyspace, err error) {
rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(keyspace *topodatapb.SrvKeyspace, err error) bool {
wmu.Lock()
defer wmu.Unlock()
wseen = append(wseen, watched{keyspace: keyspace, err: err})
return true
})

seen1 := allSeen()
Expand Down Expand Up @@ -754,3 +758,54 @@ func TestSrvKeyspaceWatcher(t *testing.T) {
assert.NotNil(t, seen6[9].keyspace)
assert.Equal(t, seen6[9].keyspace.ShardingColumnName, "updated4")
}

func TestSrvKeyspaceListener(t *testing.T) {
ts, _ := memorytopo.NewServerAndFactory("test_cell")
*srvTopoCacheTTL = time.Duration(100 * time.Millisecond)
*srvTopoCacheRefresh = time.Duration(40 * time.Millisecond)
defer func() {
*srvTopoCacheTTL = 1 * time.Second
*srvTopoCacheRefresh = 1 * time.Second
}()

rs := NewResilientServer(ts, "TestGetSrvKeyspaceWatcher")

ctx, cancel := context.WithCancel(context.Background())
var callbackCount sync2.AtomicInt32

// adding listener will perform callback.
rs.WatchSrvKeyspace(context.Background(), "test_cell", "test_ks", func(srvKs *topodatapb.SrvKeyspace, err error) bool {
callbackCount.Add(1)
select {
case <-ctx.Done():
return false
default:
return true
}
})

// First update (callback - 2)
want := &topodatapb.SrvKeyspace{
ShardingColumnName: "id",
ShardingColumnType: topodatapb.KeyspaceIdType_UINT64,
}
err := ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
require.NoError(t, err)

// Next callback to remove from listener
cancel()

// multi updates thereafter
for i := 0; i < 5; i++ {
want = &topodatapb.SrvKeyspace{
ShardingColumnName: fmt.Sprintf("updated%d", i),
ShardingColumnType: topodatapb.KeyspaceIdType_UINT64,
}
err = ts.UpdateSrvKeyspace(context.Background(), "test_cell", "test_ks", want)
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
}

// only 3 times the callback called for the listener
assert.EqualValues(t, 3, callbackCount.Get())
}
2 changes: 1 addition & 1 deletion go/vt/srvtopo/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ type Server interface {
// WatchSrvVSchema starts watching the SrvVSchema object for
// the provided cell. It will call the callback when
// a new value or an error occurs.
WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error))
WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool)
}
2 changes: 1 addition & 1 deletion go/vt/srvtopo/srvtopotest/passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ func (srv *PassthroughSrvTopoServer) GetSrvKeyspace(ctx context.Context, cell, k
}

// WatchSrvVSchema implements srvtopo.Server
func (srv *PassthroughSrvTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) {
func (srv *PassthroughSrvTopoServer) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) {
callback(srv.WatchedSrvVSchema, srv.WatchedSrvVSchemaError)
}
13 changes: 9 additions & 4 deletions go/vt/srvtopo/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type watchEntry struct {
lastErrorCtx context.Context
lastErrorTime time.Time

listeners []func(interface{}, error)
listeners []func(interface{}, error) bool
}

type resilientWatcher struct {
Expand Down Expand Up @@ -92,7 +92,7 @@ func (w *resilientWatcher) getValue(ctx context.Context, wkey fmt.Stringer) (int
return entry.currentValueLocked(ctx)
}

func (entry *watchEntry) addListener(ctx context.Context, callback func(interface{}, error)) {
func (entry *watchEntry) addListener(ctx context.Context, callback func(interface{}, error) bool) {
entry.mutex.Lock()
defer entry.mutex.Unlock()

Expand Down Expand Up @@ -157,8 +157,13 @@ func (entry *watchEntry) update(ctx context.Context, value interface{}, err erro
entry.onValueLocked(value)
}

for _, callback := range entry.listeners {
callback(entry.value, entry.lastError)
listeners := entry.listeners
entry.listeners = entry.listeners[:0]

for _, callback := range listeners {
if callback(entry.value, entry.lastError) {
entry.listeners = append(entry.listeners, callback)
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions go/vt/srvtopo/watch_srvkeyspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ func (w *SrvKeyspaceWatcher) GetSrvKeyspace(ctx context.Context, cell, keyspace
return ks, err
}

func (w *SrvKeyspaceWatcher) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodata.SrvKeyspace, error)) {
func (w *SrvKeyspaceWatcher) WatchSrvKeyspace(ctx context.Context, cell, keyspace string, callback func(*topodata.SrvKeyspace, error) bool) {
entry := w.rw.getEntry(&srvKeyspaceKey{cell, keyspace})
entry.addListener(ctx, func(v interface{}, err error) {
entry.addListener(ctx, func(v interface{}, err error) bool {
srvkeyspace, _ := v.(*topodata.SrvKeyspace)
callback(srvkeyspace, err)
return callback(srvkeyspace, err)
})
}

Expand Down
6 changes: 3 additions & 3 deletions go/vt/srvtopo/watch_srvvschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ func (w *SrvVSchemaWatcher) GetSrvVSchema(ctx context.Context, cell string) (*vs
return vschema, err
}

func (w *SrvVSchemaWatcher) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) {
func (w *SrvVSchemaWatcher) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) {
entry := w.rw.getEntry(cellName(cell))
entry.addListener(ctx, func(v interface{}, err error) {
entry.addListener(ctx, func(v interface{}, err error) bool {
vschema, _ := v.(*vschemapb.SrvVSchema)
callback(vschema, err)
return callback(vschema, err)
})
}
5 changes: 2 additions & 3 deletions go/vt/vtexplain/vtexplain_topo.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ limitations under the License.
package vtexplain

import (
"context"
"fmt"
"sync"

"context"

"vitess.io/vitess/go/vt/topo"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"
Expand Down Expand Up @@ -115,6 +114,6 @@ func (et *ExplainTopo) GetSrvKeyspace(ctx context.Context, cell, keyspace string
}

// WatchSrvVSchema is part of the srvtopo.Server interface.
func (et *ExplainTopo) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error)) {
func (et *ExplainTopo) WatchSrvVSchema(ctx context.Context, cell string, callback func(*vschemapb.SrvVSchema, error) bool) {
callback(et.getSrvVSchema(), nil)
}
Loading

0 comments on commit 441d4bd

Please sign in to comment.