diff --git a/go/vt/throttler/throttler.go b/go/vt/throttler/throttler.go index 03a20013396..83a1c52225e 100644 --- a/go/vt/throttler/throttler.go +++ b/go/vt/throttler/throttler.go @@ -130,19 +130,31 @@ func NewThrottler(name, unit string, threadCount int, maxRate, maxReplicationLag return newThrottler(GlobalManager, name, unit, threadCount, maxRate, maxReplicationLag, time.Now) } +func NewThrottlerFromConfig(name, unit string, threadCount int, maxRateModuleMaxRate int64, maxReplicationLagModuleConfig MaxReplicationLagModuleConfig, nowFunc func() time.Time) (*Throttler, error) { + return newThrottlerFromConfig(GlobalManager, name, unit, threadCount, maxRateModuleMaxRate, maxReplicationLagModuleConfig, nowFunc) +} + func newThrottler(manager *managerImpl, name, unit string, threadCount int, maxRate, maxReplicationLag int64, nowFunc func() time.Time) (*Throttler, error) { - // Verify input parameters. - if maxRate < 0 { - return nil, fmt.Errorf("maxRate must be >= 0: %v", maxRate) + config := NewMaxReplicationLagModuleConfig(maxReplicationLag) + config.MaxReplicationLagSec = maxReplicationLag + + return newThrottlerFromConfig(manager, name, unit, threadCount, maxRate, config, nowFunc) + +} + +func newThrottlerFromConfig(manager *managerImpl, name, unit string, threadCount int, maxRateModuleMaxRate int64, maxReplicationLagModuleConfig MaxReplicationLagModuleConfig, nowFunc func() time.Time) (*Throttler, error) { + err := maxReplicationLagModuleConfig.Verify() + if err != nil { + return nil, fmt.Errorf("invalid max replication lag config: %w", err) } - if maxReplicationLag < 0 { - return nil, fmt.Errorf("maxReplicationLag must be >= 0: %v", maxReplicationLag) + if maxRateModuleMaxRate < 0 { + return nil, fmt.Errorf("maxRate must be >= 0: %v", maxRateModuleMaxRate) } // Enable the configured modules. - maxRateModule := NewMaxRateModule(maxRate) + maxRateModule := NewMaxRateModule(maxRateModuleMaxRate) actualRateHistory := newAggregatedIntervalHistory(1024, 1*time.Second, threadCount) - maxReplicationLagModule, err := NewMaxReplicationLagModule(NewMaxReplicationLagModuleConfig(maxReplicationLag), actualRateHistory, nowFunc) + maxReplicationLagModule, err := NewMaxReplicationLagModule(maxReplicationLagModuleConfig, actualRateHistory, nowFunc) if err != nil { return nil, err } diff --git a/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go b/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go index d9c2294a808..1208e4a303c 100644 --- a/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go +++ b/go/vt/vttablet/tabletserver/txthrottler/tx_throttler.go @@ -191,7 +191,7 @@ type txThrottlerState struct { // in tests to generate mocks. type healthCheckFactoryFunc func(topoServer *topo.Server, cell string, cellsToWatch []string) discovery.HealthCheck type topologyWatcherFactoryFunc func(topoServer *topo.Server, hc discovery.HealthCheck, cell, keyspace, shard string, refreshInterval time.Duration, topoReadConcurrency int) TopologyWatcherInterface -type throttlerFactoryFunc func(name, unit string, threadCount int, maxRate, maxReplicationLag int64) (ThrottlerInterface, error) +type throttlerFactoryFunc func(name, unit string, threadCount int, maxRate int64, maxReplicationLagConfig throttler.MaxReplicationLagModuleConfig) (ThrottlerInterface, error) var ( healthCheckFactory healthCheckFactoryFunc @@ -210,8 +210,8 @@ func resetTxThrottlerFactories() { topologyWatcherFactory = func(topoServer *topo.Server, hc discovery.HealthCheck, cell, keyspace, shard string, refreshInterval time.Duration, topoReadConcurrency int) TopologyWatcherInterface { return discovery.NewCellTabletsWatcher(context.Background(), topoServer, hc, discovery.NewFilterByKeyspace([]string{keyspace}), cell, refreshInterval, true, topoReadConcurrency) } - throttlerFactory = func(name, unit string, threadCount int, maxRate, maxReplicationLag int64) (ThrottlerInterface, error) { - return throttler.NewThrottler(name, unit, threadCount, maxRate, maxReplicationLag) + throttlerFactory = func(name, unit string, threadCount int, maxRate int64, maxReplicationLagConfig throttler.MaxReplicationLagModuleConfig) (ThrottlerInterface, error) { + return throttler.NewThrottlerFromConfig(name, unit, threadCount, maxRate, maxReplicationLagConfig, time.Now) } } @@ -285,12 +285,15 @@ func (t *TxThrottler) Throttle() (result bool) { } func newTxThrottlerState(config *txThrottlerConfig, keyspace, shard, cell string) (*txThrottlerState, error) { + maxReplicationLagModuleConfig := throttler.MaxReplicationLagModuleConfig{Configuration: config.throttlerConfig} + t, err := throttlerFactory( TxThrottlerName, "TPS", /* unit */ 1, /* threadCount */ throttler.MaxRateModuleDisabled, /* maxRate */ - config.throttlerConfig.MaxReplicationLagSec /* maxReplicationLag */) + maxReplicationLagModuleConfig, + ) if err != nil { return nil, err } diff --git a/go/vt/vttablet/tabletserver/txthrottler/tx_throttler_test.go b/go/vt/vttablet/tabletserver/txthrottler/tx_throttler_test.go index 8eafa64b458..1053068d14a 100644 --- a/go/vt/vttablet/tabletserver/txthrottler/tx_throttler_test.go +++ b/go/vt/vttablet/tabletserver/txthrottler/tx_throttler_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/assert" "vitess.io/vitess/go/vt/discovery" + "vitess.io/vitess/go/vt/throttler" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" @@ -79,7 +80,7 @@ func TestEnabledThrottler(t *testing.T) { } mockThrottler := NewMockThrottlerInterface(mockCtrl) - throttlerFactory = func(name, unit string, threadCount int, maxRate, maxReplicationLag int64) (ThrottlerInterface, error) { + throttlerFactory = func(name, unit string, threadCount int, maxRate int64, maxReplicationLagConfig throttler.MaxReplicationLagModuleConfig) (ThrottlerInterface, error) { assert.Equal(t, 1, threadCount) return mockThrottler, nil }