diff --git a/command/command.go b/command/command.go index a0026fcf..2e8c0c55 100644 --- a/command/command.go +++ b/command/command.go @@ -2,6 +2,7 @@ package command import ( "fmt" + log "github.com/sirupsen/logrus" "github.com/squareup/pranadb/conf" "github.com/squareup/pranadb/protos/squareup/cash/pranadb/v1/clustermsgs" "github.com/squareup/pranadb/remoting" @@ -109,28 +110,32 @@ func (e *Executor) ExecuteSQLStatement(execCtx *execctx.ExecutionContext, sql st } return exec.Empty, nil case ast.Create != nil && ast.Create.MaterializedView != nil: - sequences, err := e.generateTableIDSequences(3) - if err != nil { - return nil, errors.WithStack(err) - } - command, err := NewOriginatingCreateMVCommand(e, execCtx.Planner(), execCtx.Schema, sql, sequences, ast.Create.MaterializedView) - if err != nil { - return nil, errors.WithStack(err) - } - if err := e.executeCommandWithRetry(command); err != nil { + if err := e.executeCommandWithRetry(func() (DDLCommand, error) { + sequences, err := e.generateTableIDSequences(3) + if err != nil { + return nil, errors.WithStack(err) + } + command, err := NewOriginatingCreateMVCommand(e, execCtx.Planner(), execCtx.Schema, sql, sequences, ast.Create.MaterializedView) + if err != nil { + return nil, errors.WithStack(err) + } + return command, nil + }); err != nil { return nil, errors.WithStack(err) } return exec.Empty, nil case ast.Create != nil && ast.Create.Index != nil: - sequences, err := e.generateTableIDSequences(1) - if err != nil { - return nil, errors.WithStack(err) - } - command, err := NewOriginatingCreateIndexCommand(e, execCtx.Planner(), execCtx.Schema, sql, sequences, ast.Create.Index) - if err != nil { - return nil, errors.WithStack(err) - } - if err := e.executeCommandWithRetry(command); err != nil { + if err := e.executeCommandWithRetry(func() (DDLCommand, error) { + sequences, err := e.generateTableIDSequences(1) + if err != nil { + return nil, errors.WithStack(err) + } + command, err := NewOriginatingCreateIndexCommand(e, execCtx.Planner(), execCtx.Schema, sql, sequences, ast.Create.Index) + if err != nil { + return nil, errors.WithStack(err) + } + return command, nil + }); err != nil { return nil, errors.WithStack(err) } return exec.Empty, nil @@ -184,8 +189,8 @@ func (e *Executor) ExecuteSQLStatement(execCtx *execctx.ExecutionContext, sql st return nil, errors.WithStack(err) } return exec.Empty, nil - case ast.ConsumerRate != nil: - if err := e.execConsumerRate(execCtx, ast.ConsumerRate.SourceName, ast.ConsumerRate.Rate); err != nil { + case ast.SourceSetMaxRate != nil: + if err := e.execSetMaxSourceIngestRate(execCtx, ast.SourceSetMaxRate.SourceName, ast.SourceSetMaxRate.Rate); err != nil { return nil, errors.WithStack(err) } return exec.Empty, nil @@ -208,13 +213,20 @@ func (e *Executor) GetPullEngine() *pull.Engine { return e.pullEngine } -func (e *Executor) executeCommandWithRetry(command DDLCommand) error { +func (e *Executor) executeCommandWithRetry(commandFactory func() (DDLCommand, error)) error { start := time.Now() for { - err := e.ddlRunner.RunCommand(command) + command, err := commandFactory() + if err != nil { + return err + } + log.Debugf("executing command %s with potential retry", command.SQL()) + err = e.ddlRunner.RunCommand(command) if err != nil { + log.Errorf("failed to run command %s %v", command.SQL(), err) var perr errors.PranaError if errors.As(err, &perr) && perr.Code == errors.DdlRetry { + log.Debugf("It is a ddl retry - will retry it after a short delay %s", command.SQL()) // Some DDL commands like create MV or index can return DdlRetry if they fail because Raft // leadership changed - in this case we retry rather than returning an error as this can be transient // e.g. cluster is starting up or node is being rolled @@ -362,8 +374,8 @@ func (e *Executor) execDescribe(execCtx *execctx.ExecutionContext, tableName str return describeRows(tableInfo) } -func (e *Executor) execConsumerRate(execCtx *execctx.ExecutionContext, sourceName string, rate int64) error { - return e.ddlClient.Broadcast(&clustermsgs.ConsumerSetRate{ +func (e *Executor) execSetMaxSourceIngestRate(execCtx *execctx.ExecutionContext, sourceName string, rate int64) error { + return e.ddlClient.Broadcast(&clustermsgs.SourceSetMaxIngestRate{ SchemaName: execCtx.Schema.Name, SourceName: sourceName, Rate: rate, diff --git a/command/parser/ast.go b/command/parser/ast.go index 7bad899e..3dd8f801 100644 --- a/command/parser/ast.go +++ b/command/parser/ast.go @@ -188,19 +188,19 @@ type Show struct { TableName string `("ON" @Ident)?` } -type ConsumerRate struct { +type SourceSetMaxRate struct { SourceName string `@Ident` Rate int64 `@Number` } // AST root. type AST struct { - Select string // Unaltered SELECT statement, if any. - Use string `( "USE" @Ident` - Drop *Drop ` | "DROP" @@ ` - Create *Create ` | "CREATE" @@ ` - Show *Show ` | "SHOW" @@ ` - Describe string ` | "DESCRIBE" @Ident ` - ConsumerRate *ConsumerRate ` | "CONSUMER" "RATE" @@ ` - ResetDdl string ` | "RESET" "DDL" @Ident ) ';'?` + Select string // Unaltered SELECT statement, if any. + Use string `( "USE" @Ident` + Drop *Drop ` | "DROP" @@ ` + Create *Create ` | "CREATE" @@ ` + Show *Show ` | "SHOW" @@ ` + Describe string ` | "DESCRIBE" @Ident ` + SourceSetMaxRate *SourceSetMaxRate ` | "SOURCE" "SET" "MAX" "RATE" @@ ` + ResetDdl string ` | "RESET" "DDL" @Ident ) ';'?` } diff --git a/kafka/cflt_client.go b/kafka/cflt_client.go index b5252b64..46d7466c 100644 --- a/kafka/cflt_client.go +++ b/kafka/cflt_client.go @@ -158,6 +158,3 @@ func (cmp *ConfluentMessageProvider) Start() error { cmp.consumer = consumer return nil } - -func (cmp *ConfluentMessageProvider) SetMaxRate(rate int) { -} diff --git a/kafka/kafka.go b/kafka/kafka.go index 9d82bf1a..ad08ae48 100644 --- a/kafka/kafka.go +++ b/kafka/kafka.go @@ -15,7 +15,6 @@ type MessageProvider interface { Start() error Close() error SetRebalanceCallback(callback RebalanceCallback) - SetMaxRate(rate int) } type Message struct { diff --git a/kafka/load/load_client.go b/kafka/load/load_client.go index a209867f..38d47eba 100644 --- a/kafka/load/load_client.go +++ b/kafka/load/load_client.go @@ -7,7 +7,6 @@ import ( "github.com/squareup/pranadb/errors" "github.com/squareup/pranadb/kafka" "github.com/squareup/pranadb/msggen" - "go.uber.org/ratelimit" "math" "math/rand" "strings" @@ -24,7 +23,6 @@ type LoadClientMessageProviderFactory struct { partitionsStart int nextPartition int properties map[string]string - maxRate int maxMessagesPerConsumer int64 uniqueIDsPerPartition int64 messageGeneratorName string @@ -34,7 +32,6 @@ type LoadClientMessageProviderFactory struct { const ( produceTimeout = 100 * time.Millisecond - maxRatePropName = "prana.loadclient.maxrateperconsumer" partitionsPerConsumerPropName = "prana.loadclient.partitionsperconsumer" uniqueIDsPerPartitionPropName = "prana.loadclient.uniqueidsperpartition" maxMessagesPerConsumerPropName = "prana.loadclient.maxmessagesperconsumer" @@ -51,10 +48,6 @@ func NewMessageProviderFactory(bufferSize int, numConsumersPerSource int, nodeID } partitionsPerNode := numConsumersPerSource * partitionsPerConsumer partitionsStart := nodeID * partitionsPerNode - maxRate, err := common.GetOrDefaultIntProperty(maxRatePropName, properties, -1) - if err != nil { - return nil, err - } uniqueIDsPerPartition, err := common.GetOrDefaultIntProperty(uniqueIDsPerPartitionPropName, properties, math.MaxInt64) if err != nil { return nil, err @@ -74,7 +67,6 @@ func NewMessageProviderFactory(bufferSize int, numConsumersPerSource int, nodeID partitionsStart: partitionsStart, nextPartition: partitionsStart, properties: properties, - maxRate: maxRate, uniqueIDsPerPartition: int64(uniqueIDsPerPartition), maxMessagesPerConsumer: int64(maxMessagesPerConsumer), messageGeneratorName: msgGeneratorName, @@ -104,10 +96,6 @@ func (l *LoadClientMessageProviderFactory) NewMessageProvider() (kafka.MessagePr for i, partitionID := range partitions { offsets[i] = l.committedOffsets[partitionID] + 1 } - var rl ratelimit.Limiter - if l.maxRate > 0 { - rl = ratelimit.New(l.maxRate) - } rnd := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) msgGen, err := l.getMessageGenerator(l.messageGeneratorName) if err != nil { @@ -119,7 +107,6 @@ func (l *LoadClientMessageProviderFactory) NewMessageProvider() (kafka.MessagePr partitions: partitions, numPartitions: len(partitions), offsets: offsets, - rateLimiter: rl, uniqueIDsPerPartition: l.uniqueIDsPerPartition, maxMessages: l.maxMessagesPerConsumer, rnd: rnd, @@ -153,12 +140,10 @@ type LoadClientMessageProvider struct { partitions []int32 offsets []int64 sequence int64 - rateLimiter ratelimit.Limiter uniqueIDsPerPartition int64 maxMessages int64 msgGenerator msggen.MessageGenerator rnd *rand.Rand - limiterLock sync.Mutex msgLock sync.Mutex committedOffsets map[int32]int64 } @@ -204,35 +189,12 @@ func (l *LoadClientMessageProvider) Close() error { func (l *LoadClientMessageProvider) SetRebalanceCallback(callback kafka.RebalanceCallback) { } -func (l *LoadClientMessageProvider) SetMaxRate(rate int) { - l.limiterLock.Lock() - defer l.limiterLock.Unlock() - if rate == 0 { - return - } else if rate == -1 { - l.rateLimiter = nil - } else { - l.rateLimiter = ratelimit.New(rate) - } -} - -func (l *LoadClientMessageProvider) getLimiter() ratelimit.Limiter { - // This lock should almost always be uncontended so perf should be ok - l.limiterLock.Lock() - defer l.limiterLock.Unlock() - return l.rateLimiter -} - func (l *LoadClientMessageProvider) genLoop() { var msgCount int64 var msg *kafka.Message for l.running.Get() && msgCount < l.maxMessages { if msg == nil { var err error - limiter := l.getLimiter() - if limiter != nil { - limiter.Take() - } msg, err = l.genMessage() if err != nil { log.Errorf("failed to generate message %+v", err) diff --git a/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin b/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin index d9db6d0f..25018dd6 100644 --- a/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin +++ b/protos/descriptors/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.bin @@ -1,5 +1,5 @@ -– + 6squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto$squareup.cash.pranadb.clustermsgs.v1"• DDLStatementInfo. @@ -32,8 +32,8 @@ schemaName" shard_id (RshardId! request_body ( R requestBody": ClusterReadResponse# - response_body ( R responseBody"g -ConsumerSetRate + response_body ( R responseBody"n +SourceSetMaxIngestRate schema_name ( R schemaName source_name ( R diff --git a/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto b/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto index 3f2c8a8d..0f09a216 100644 --- a/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto +++ b/protos/squareup/cash/pranadb/clustermsgs/v1/clustermsgs.proto @@ -49,7 +49,7 @@ message ClusterReadResponse { bytes response_body = 1; } -message ConsumerSetRate { +message SourceSetMaxIngestRate { string schema_name = 1; string source_name = 2; int64 rate = 3; diff --git a/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go b/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go index a72f71a9..9698338b 100644 --- a/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go +++ b/protos/squareup/cash/pranadb/v1/clustermsgs/clustermsgs.pb.go @@ -513,7 +513,7 @@ func (x *ClusterReadResponse) GetResponseBody() []byte { return nil } -type ConsumerSetRate struct { +type SourceSetMaxIngestRate struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -523,8 +523,8 @@ type ConsumerSetRate struct { Rate int64 `protobuf:"varint,3,opt,name=rate,proto3" json:"rate,omitempty"` } -func (x *ConsumerSetRate) Reset() { - *x = ConsumerSetRate{} +func (x *SourceSetMaxIngestRate) Reset() { + *x = SourceSetMaxIngestRate{} if protoimpl.UnsafeEnabled { mi := &file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -532,13 +532,13 @@ func (x *ConsumerSetRate) Reset() { } } -func (x *ConsumerSetRate) String() string { +func (x *SourceSetMaxIngestRate) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ConsumerSetRate) ProtoMessage() {} +func (*SourceSetMaxIngestRate) ProtoMessage() {} -func (x *ConsumerSetRate) ProtoReflect() protoreflect.Message { +func (x *SourceSetMaxIngestRate) ProtoReflect() protoreflect.Message { mi := &file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -550,26 +550,26 @@ func (x *ConsumerSetRate) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ConsumerSetRate.ProtoReflect.Descriptor instead. -func (*ConsumerSetRate) Descriptor() ([]byte, []int) { +// Deprecated: Use SourceSetMaxIngestRate.ProtoReflect.Descriptor instead. +func (*SourceSetMaxIngestRate) Descriptor() ([]byte, []int) { return file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_rawDescGZIP(), []int{9} } -func (x *ConsumerSetRate) GetSchemaName() string { +func (x *SourceSetMaxIngestRate) GetSchemaName() string { if x != nil { return x.SchemaName } return "" } -func (x *ConsumerSetRate) GetSourceName() string { +func (x *SourceSetMaxIngestRate) GetSourceName() string { if x != nil { return x.SourceName } return "" } -func (x *ConsumerSetRate) GetRate() int64 { +func (x *SourceSetMaxIngestRate) GetRate() int64 { if x != nil { return x.Rate } @@ -783,33 +783,33 @@ var file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_rawDesc = []byte 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x52, 0x65, 0x61, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x5f, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x42, 0x6f, 0x64, 0x79, 0x22, 0x67, 0x0a, 0x0f, 0x43, 0x6f, 0x6e, 0x73, 0x75, - 0x6d, 0x65, 0x72, 0x53, 0x65, 0x74, 0x52, 0x61, 0x74, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x63, - 0x68, 0x65, 0x6d, 0x61, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, - 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x72, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x72, 0x61, 0x74, 0x65, - 0x22, 0x69, 0x0a, 0x12, 0x4c, 0x65, 0x61, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x73, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x53, 0x0a, 0x0c, 0x6c, 0x65, 0x61, 0x64, 0x65, 0x72, - 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x73, - 0x71, 0x75, 0x61, 0x72, 0x65, 0x75, 0x70, 0x2e, 0x63, 0x61, 0x73, 0x68, 0x2e, 0x70, 0x72, 0x61, - 0x6e, 0x61, 0x64, 0x62, 0x2e, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x6d, 0x73, 0x67, 0x73, - 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x65, 0x61, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0b, - 0x6c, 0x65, 0x61, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x73, 0x22, 0x40, 0x0a, 0x0a, 0x4c, - 0x65, 0x61, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x19, 0x0a, 0x08, 0x73, 0x68, 0x61, - 0x72, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x73, 0x68, 0x61, - 0x72, 0x64, 0x49, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x22, 0x34, 0x0a, - 0x13, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x69, 0x6e, 0x67, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x66, 0x69, 0x65, - 0x6c, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x6f, 0x6d, 0x65, 0x46, 0x69, - 0x65, 0x6c, 0x64, 0x42, 0x49, 0x5a, 0x47, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, 0x75, 0x70, 0x2f, 0x70, 0x72, 0x61, 0x6e, 0x61, - 0x64, 0x62, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, - 0x75, 0x70, 0x2f, 0x63, 0x61, 0x73, 0x68, 0x2f, 0x70, 0x72, 0x61, 0x6e, 0x61, 0x64, 0x62, 0x2f, - 0x76, 0x31, 0x2f, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, 0x72, 0x6d, 0x73, 0x67, 0x73, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6e, 0x73, 0x65, 0x42, 0x6f, 0x64, 0x79, 0x22, 0x6e, 0x0a, 0x16, 0x53, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x53, 0x65, 0x74, 0x4d, 0x61, 0x78, 0x49, 0x6e, 0x67, 0x65, 0x73, 0x74, 0x52, 0x61, 0x74, + 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x5f, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x72, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x04, 0x72, 0x61, 0x74, 0x65, 0x22, 0x69, 0x0a, 0x12, 0x4c, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x73, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x53, 0x0a, + 0x0c, 0x6c, 0x65, 0x61, 0x64, 0x65, 0x72, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x30, 0x2e, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, 0x75, 0x70, 0x2e, 0x63, + 0x61, 0x73, 0x68, 0x2e, 0x70, 0x72, 0x61, 0x6e, 0x61, 0x64, 0x62, 0x2e, 0x63, 0x6c, 0x75, 0x73, + 0x74, 0x65, 0x72, 0x6d, 0x73, 0x67, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0b, 0x6c, 0x65, 0x61, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x66, + 0x6f, 0x73, 0x22, 0x40, 0x0a, 0x0a, 0x4c, 0x65, 0x61, 0x64, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, + 0x12, 0x19, 0x0a, 0x08, 0x73, 0x68, 0x61, 0x72, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x07, 0x73, 0x68, 0x61, 0x72, 0x64, 0x49, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x6e, + 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x6e, 0x6f, + 0x64, 0x65, 0x49, 0x64, 0x22, 0x34, 0x0a, 0x13, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x69, 0x6e, 0x67, + 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, + 0x6f, 0x6d, 0x65, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x73, 0x6f, 0x6d, 0x65, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x42, 0x49, 0x5a, 0x47, 0x67, 0x69, + 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, 0x75, + 0x70, 0x2f, 0x70, 0x72, 0x61, 0x6e, 0x61, 0x64, 0x62, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, + 0x2f, 0x73, 0x71, 0x75, 0x61, 0x72, 0x65, 0x75, 0x70, 0x2f, 0x63, 0x61, 0x73, 0x68, 0x2f, 0x70, + 0x72, 0x61, 0x6e, 0x61, 0x64, 0x62, 0x2f, 0x76, 0x31, 0x2f, 0x63, 0x6c, 0x75, 0x73, 0x74, 0x65, + 0x72, 0x6d, 0x73, 0x67, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -835,7 +835,7 @@ var file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_goTypes = []inte (*ClusterForwardWriteResponse)(nil), // 6: squareup.cash.pranadb.clustermsgs.v1.ClusterForwardWriteResponse (*ClusterReadRequest)(nil), // 7: squareup.cash.pranadb.clustermsgs.v1.ClusterReadRequest (*ClusterReadResponse)(nil), // 8: squareup.cash.pranadb.clustermsgs.v1.ClusterReadResponse - (*ConsumerSetRate)(nil), // 9: squareup.cash.pranadb.clustermsgs.v1.ConsumerSetRate + (*SourceSetMaxIngestRate)(nil), // 9: squareup.cash.pranadb.clustermsgs.v1.SourceSetMaxIngestRate (*LeaderInfosMessage)(nil), // 10: squareup.cash.pranadb.clustermsgs.v1.LeaderInfosMessage (*LeaderInfo)(nil), // 11: squareup.cash.pranadb.clustermsgs.v1.LeaderInfo (*RemotingTestMessage)(nil), // 12: squareup.cash.pranadb.clustermsgs.v1.RemotingTestMessage @@ -964,7 +964,7 @@ func file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_init() { } } file_squareup_cash_pranadb_clustermsgs_v1_clustermsgs_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ConsumerSetRate); i { + switch v := v.(*SourceSetMaxIngestRate); i { case 0: return &v.state case 1: diff --git a/push/engine.go b/push/engine.go index 00ea5836..335c5085 100644 --- a/push/engine.go +++ b/push/engine.go @@ -764,7 +764,7 @@ type loadClientSetRateHandler struct { } func (l *loadClientSetRateHandler) HandleMessage(clusterMsg remoting.ClusterMessage) (remoting.ClusterMessage, error) { - setRate, ok := clusterMsg.(*clustermsgs.ConsumerSetRate) + setRate, ok := clusterMsg.(*clustermsgs.SourceSetMaxIngestRate) if !ok { panic("not a ConsumerSetRate") } @@ -779,7 +779,7 @@ func (l *loadClientSetRateHandler) HandleMessage(clusterMsg remoting.ClusterMess // Internal error return nil, errors.Errorf("can't find source %s.%s", setRate.SchemaName, setRate.SourceName) } - source.SetMaxConsumerRate(int(setRate.Rate)) + source.SetMaxIngestRate(int(setRate.Rate)) return nil, nil } diff --git a/push/source/source.go b/push/source/source.go index 110925f0..57c51edb 100644 --- a/push/source/source.go +++ b/push/source/source.go @@ -2,8 +2,8 @@ package source import ( "github.com/squareup/pranadb/kafka/load" - "github.com/squareup/pranadb/push/util" + "go.uber.org/ratelimit" "sync" "sync/atomic" "time" @@ -33,6 +33,7 @@ const ( numConsumersPerSourcePropName = "prana.source.numconsumers" pollTimeoutPropName = "prana.source.polltimeoutms" maxPollMessagesPropName = "prana.source.maxpollmessages" + maxRatePropName = "prana.source.maxingestrate" ) type RowProcessor interface { @@ -53,6 +54,7 @@ type Source struct { numConsumersPerSource int pollTimeoutMs int maxPollMessages int + rateLimiter atomic.Value committedCount int64 enableStats bool commitOffsets common.AtomicBool @@ -104,6 +106,10 @@ func NewSource(sourceInfo *common.SourceInfo, tableExec *exec.TableExecutor, ing if err != nil { return nil, errors.WithStack(err) } + maxIngestRate, err := common.GetOrDefaultIntProperty(maxRatePropName, sourceInfo.OriginInfo.Properties, -1) + if err != nil { + return nil, err + } ti := sourceInfo.OriginInfo var brokerConf conf.BrokerConfig @@ -160,15 +166,18 @@ func NewSource(sourceInfo *common.SourceInfo, tableExec *exec.TableExecutor, ing ingestExpressions: ingestExpressions, cfg: cfg, } + var rl ratelimit.Limiter + if maxIngestRate > 0 { + rl = ratelimit.New(maxIngestRate) + source.rateLimiter.Store(rl) + } source.commitOffsets.Set(true) return source, nil } func (s *Source) Start() error { - s.lock.Lock() defer s.lock.Unlock() - return s.start() } @@ -367,6 +376,8 @@ func (s *Source) ingestMessages(messages []*kafka.Message, mp *MessageParser) er } } + s.maybeLimit() + key := make([]byte, 0, 8) key, err := common.EncodeKeyCols(&row, pkCols, colTypes, key) if err != nil { @@ -447,12 +458,28 @@ func (s *Source) SetCommitOffsets(enable bool) { s.commitOffsets.Set(enable) } -func (s *Source) SetMaxConsumerRate(rate int) { - s.lock.Lock() - defer s.lock.Unlock() +func (s *Source) SetMaxIngestRate(rate int) { + if rate == -1 { + s.setRateLimiter(nil) + } else { + s.setRateLimiter(ratelimit.New(rate)) + } +} - for _, consumer := range s.msgConsumers { - provider := consumer.msgProvider - provider.SetMaxRate(rate) +func (s *Source) maybeLimit() { + if rl := s.getRateLimiter(); rl != nil { + rl.Take() } } + +func (s *Source) getRateLimiter() ratelimit.Limiter { + v := s.rateLimiter.Load() + if v == nil { + return nil + } + return v.(ratelimit.Limiter) //nolint:forcetypeassert +} + +func (s *Source) setRateLimiter(limiter ratelimit.Limiter) { + s.rateLimiter.Store(limiter) +} diff --git a/remoting/cluster_message.go b/remoting/cluster_message.go index 03597227..6b2157c6 100644 --- a/remoting/cluster_message.go +++ b/remoting/cluster_message.go @@ -23,7 +23,7 @@ const ( ClusterMessageClusterReadResponse ClusterMessageForwardWriteRequest ClusterMessageForwardWriteResponse - ClusterMessageConsumerSetRate + ClusterMessageSourceSetMaxRate ClusterMessageLeaderInfos ClusterMessageRemotingTestMessage ) @@ -48,8 +48,8 @@ func TypeForClusterMessage(clusterMessage ClusterMessage) ClusterMessageType { return ClusterMessageForwardWriteRequest case *clustermsgs.ClusterForwardWriteResponse: return ClusterMessageForwardWriteResponse - case *clustermsgs.ConsumerSetRate: - return ClusterMessageConsumerSetRate + case *clustermsgs.SourceSetMaxIngestRate: + return ClusterMessageSourceSetMaxRate case *clustermsgs.LeaderInfosMessage: return ClusterMessageLeaderInfos case *clustermsgs.RemotingTestMessage: @@ -95,8 +95,8 @@ func DeserializeClusterMessage(data []byte) (ClusterMessage, error) { msg = &clustermsgs.DDLCancelMessage{} case ClusterMessageReloadProtobuf: msg = &clustermsgs.ReloadProtobuf{} - case ClusterMessageConsumerSetRate: - msg = &clustermsgs.ConsumerSetRate{} + case ClusterMessageSourceSetMaxRate: + msg = &clustermsgs.SourceSetMaxIngestRate{} case ClusterMessageLeaderInfos: msg = &clustermsgs.LeaderInfosMessage{} case ClusterMessageRemotingTestMessage: diff --git a/server/server.go b/server/server.go index e39241e2..f0c626f3 100644 --- a/server/server.go +++ b/server/server.go @@ -86,7 +86,7 @@ func NewServer(config conf.Config) (*Server, error) { if drag != nil { drag.SetForwardWriteHandler(pushEngine) } - remotingServer.RegisterMessageHandler(remoting.ClusterMessageConsumerSetRate, pushEngine.GetLoadClientSetRateHandler()) + remotingServer.RegisterMessageHandler(remoting.ClusterMessageSourceSetMaxRate, pushEngine.GetLoadClientSetRateHandler()) remotingServer.RegisterMessageHandler(remoting.ClusterMessageForwardWriteRequest, pushEngine.GetForwardWriteHandler()) commandExecutor := command.NewCommandExecutor(metaController, pushEngine, pullEngine, clus, ddlClient, ddlResetClient, protoRegistry, failureInjector, &config) diff --git a/sqltest/testdata/gen_source_test_out.txt b/sqltest/testdata/gen_source_test_out.txt index eba1905f..df8986ae 100644 --- a/sqltest/testdata/gen_source_test_out.txt +++ b/sqltest/testdata/gen_source_test_out.txt @@ -80,7 +80,7 @@ create source test_source_2( 0 rows returned -- now change rate; -consumer rate test_source_2 1000000; +source set max rate test_source_2 1000000; 0 rows returned --wait for rows test_source_2 20; diff --git a/sqltest/testdata/gen_source_test_script.txt b/sqltest/testdata/gen_source_test_script.txt index 5226e654..5c6a57ef 100644 --- a/sqltest/testdata/gen_source_test_script.txt +++ b/sqltest/testdata/gen_source_test_script.txt @@ -51,7 +51,7 @@ create source test_source_2( ); -- now change rate; -consumer rate test_source_2 1000000; +source set max rate test_source_2 1000000; --wait for rows test_source_2 20;