From fc17c0c61341bff22a18d6e0bfeb567b3cad672b Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 27 Dec 2024 10:40:50 +0100 Subject: [PATCH] lint: replace type assertions and type switch on errors (#3376) * errorlint: replace type assertions on errors * errorlint: replace type switch on errors * lint --- .golangci.yml | 10 +- cmd/crowdsec/win_service.go | 2 +- pkg/acquisition/acquisition.go | 3 +- pkg/acquisition/modules/kinesis/kinesis.go | 128 ++++++++++++++++----- pkg/apiclient/client_http.go | 9 +- pkg/apiserver/apiserver.go | 42 ++++--- pkg/database/alerts.go | 3 +- 7 files changed, 132 insertions(+), 65 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index b51f17df489..deb073c2eea 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -118,7 +118,7 @@ linters-settings: arguments: [6] - name: function-length # lower this after refactoring - arguments: [110, 237] + arguments: [111, 238] - name: get-return disabled: true - name: increment-decrement @@ -333,14 +333,6 @@ issues: - errorlint text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors" - - linters: - - errorlint - text: "type assertion on error will fail on wrapped errors. Use errors.As to check for specific errors" - - - linters: - - errorlint - text: "type switch on error will fail on wrapped errors. Use errors.As to check for specific errors" - - linters: - nosprintfhostport text: "host:port in url should be constructed with net.JoinHostPort and not directly with fmt.Sprintf" diff --git a/cmd/crowdsec/win_service.go b/cmd/crowdsec/win_service.go index 6aa363ca3a7..ae48e77447c 100644 --- a/cmd/crowdsec/win_service.go +++ b/cmd/crowdsec/win_service.go @@ -67,7 +67,7 @@ func runService(name string) error { // All the calls to logging before the logger is configured are pretty much useless, but we keep them for clarity err := eventlog.InstallAsEventCreate("CrowdSec", eventlog.Error|eventlog.Warning|eventlog.Info) if err != nil { - if errno, ok := err.(syscall.Errno); ok { + if errno, ok := err.(syscall.Errno); ok { //nolint:errorlint if errno == windows.ERROR_ACCESS_DENIED { log.Warnf("Access denied when installing event source, running as non-admin ?") } else { diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 4e233aad616..291bc369c3e 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -328,7 +328,8 @@ func GetMetrics(sources []DataSource, aggregated bool) error { for _, metric := range metrics { if err := prometheus.Register(metric); err != nil { - if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { + var alreadyRegisteredErr prometheus.AlreadyRegisteredError + if !errors.As(err, &alreadyRegisteredErr) { return fmt.Errorf("could not register metrics for datasource %s: %w", sources[i].GetName(), err) } // ignore the error diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index 3744e43f38d..b166a706ca9 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -99,17 +99,22 @@ func (k *KinesisSource) newClient() error { if sess == nil { return errors.New("failed to create aws session") } + config := aws.NewConfig() + if k.Config.AwsRegion != "" { config = config.WithRegion(k.Config.AwsRegion) } + if k.Config.AwsEndpoint != "" { config = config.WithEndpoint(k.Config.AwsEndpoint) } + k.kClient = kinesis.New(sess, config) if k.kClient == nil { return errors.New("failed to create kinesis client") } + return nil } @@ -136,15 +141,19 @@ func (k *KinesisSource) UnmarshalConfig(yamlConfig []byte) error { if k.Config.StreamName == "" && !k.Config.UseEnhancedFanOut { return errors.New("stream_name is mandatory when use_enhanced_fanout is false") } + if k.Config.StreamARN == "" && k.Config.UseEnhancedFanOut { return errors.New("stream_arn is mandatory when use_enhanced_fanout is true") } + if k.Config.ConsumerName == "" && k.Config.UseEnhancedFanOut { return errors.New("consumer_name is mandatory when use_enhanced_fanout is true") } + if k.Config.StreamARN != "" && k.Config.StreamName != "" { return errors.New("stream_arn and stream_name are mutually exclusive") } + if k.Config.MaxRetries <= 0 { k.Config.MaxRetries = 10 } @@ -167,6 +176,7 @@ func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsL } k.shardReaderTomb = &tomb.Tomb{} + return nil } @@ -188,22 +198,27 @@ func (k *KinesisSource) OneShotAcquisition(_ context.Context, _ chan types.Event func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubscriptionLogEvent, error) { b := bytes.NewBuffer(record) + r, err := gzip.NewReader(b) if err != nil { k.logger.Error(err) return nil, err } + decompressed, err := io.ReadAll(r) if err != nil { k.logger.Error(err) return nil, err } + var subscriptionRecord CloudWatchSubscriptionRecord + err = json.Unmarshal(decompressed, &subscriptionRecord) if err != nil { k.logger.Error(err) return nil, err } + return subscriptionRecord.LogEvents, nil } @@ -214,17 +229,20 @@ func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, strea ConsumerName: aws.String(consumerName), StreamARN: aws.String(streamARN), }) + + var resourceNotFoundErr *kinesis.ResourceNotFoundException + if errors.As(err, &resourceNotFoundErr) { + return nil + } + if err != nil { - switch err.(type) { - case *kinesis.ResourceNotFoundException: - return nil - default: - k.logger.Errorf("Error while waiting for consumer deregistration: %s", err) - return fmt.Errorf("cannot describe stream consumer: %w", err) - } + k.logger.Errorf("Error while waiting for consumer deregistration: %s", err) + return fmt.Errorf("cannot describe stream consumer: %w", err) } + time.Sleep(time.Millisecond * 200 * time.Duration(i+1)) } + return fmt.Errorf("consumer %s is not deregistered after %d tries", consumerName, maxTries) } @@ -234,17 +252,21 @@ func (k *KinesisSource) DeregisterConsumer() error { ConsumerName: aws.String(k.Config.ConsumerName), StreamARN: aws.String(k.Config.StreamARN), }) + + var resourceNotFoundErr *kinesis.ResourceNotFoundException + if errors.As(err, &resourceNotFoundErr) { + return nil + } + if err != nil { - switch err.(type) { - case *kinesis.ResourceNotFoundException: - default: - return fmt.Errorf("cannot deregister stream consumer: %w", err) - } + return fmt.Errorf("cannot deregister stream consumer: %w", err) } + err = k.WaitForConsumerDeregistration(k.Config.ConsumerName, k.Config.StreamARN) if err != nil { return fmt.Errorf("cannot wait for consumer deregistration: %w", err) } + return nil } @@ -257,18 +279,22 @@ func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error { if err != nil { return fmt.Errorf("cannot describe stream consumer: %w", err) } + if *describeOutput.ConsumerDescription.ConsumerStatus == "ACTIVE" { k.logger.Debugf("Consumer %s is active", consumerARN) return nil } + time.Sleep(time.Millisecond * 200 * time.Duration(i+1)) k.logger.Debugf("Waiting for consumer registration %d", i) } + return fmt.Errorf("consumer %s is not active after %d tries", consumerARN, maxTries) } func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutput, error) { k.logger.Debugf("Registering consumer %s", k.Config.ConsumerName) + streamConsumer, err := k.kClient.RegisterStreamConsumer(&kinesis.RegisterStreamConsumerInput{ ConsumerName: aws.String(k.Config.ConsumerName), StreamARN: aws.String(k.Config.StreamARN), @@ -276,10 +302,12 @@ func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutpu if err != nil { return nil, fmt.Errorf("cannot register stream consumer: %w", err) } + err = k.WaitForConsumerRegistration(*streamConsumer.Consumer.ConsumerARN) if err != nil { return nil, fmt.Errorf("timeout while waiting for consumer to be active: %w", err) } + return streamConsumer, nil } @@ -296,8 +324,12 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan linesRead.With(prometheus.Labels{"stream": k.Config.StreamName}).Inc() } } - var data []CloudwatchSubscriptionLogEvent - var err error + + var ( + data []CloudwatchSubscriptionLogEvent + err error + ) + if k.Config.FromSubscription { // The AWS docs says that the data is base64 encoded // but apparently GetRecords decodes it for us ? @@ -309,19 +341,22 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan } else { data = []CloudwatchSubscriptionLogEvent{{Message: string(record.Data)}} } + for _, event := range data { logger.Tracef("got record %s", event.Message) + l := types.Line{} l.Raw = event.Message l.Labels = k.Config.Labels l.Time = time.Now().UTC() l.Process = true l.Module = k.GetName() - if k.Config.StreamARN != "" { - l.Src = k.Config.StreamARN - } else { + + l.Src = k.Config.StreamARN + if l.Src == "" { l.Src = k.Config.StreamName } + evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true) evt.Line = l out <- evt @@ -335,20 +370,23 @@ func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEven // and we won't be able to start a new one if this is the first one started by the tomb // TODO: look into parent shards to see if a shard is closed before starting to read it ? time.Sleep(time.Second) + for { select { case <-k.shardReaderTomb.Dying(): logger.Infof("Subscribed shard reader is dying") - err := reader.Close() - if err != nil { + + if err := reader.Close(); err != nil { return fmt.Errorf("cannot close kinesis subscribed shard reader: %w", err) } + return nil case event, ok := <-reader.Events(): if !ok { logger.Infof("Event chan has been closed") return nil } + switch event := event.(type) { case *kinesis.SubscribeToShardEvent: k.ParseAndPushRecords(event.Records, out, logger, shardId) @@ -369,6 +407,7 @@ func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.R for _, shard := range shards.Shards { shardId := *shard.ShardId + r, err := k.kClient.SubscribeToShard(&kinesis.SubscribeToShardInput{ ShardId: aws.String(shardId), StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -377,10 +416,12 @@ func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.R if err != nil { return fmt.Errorf("cannot subscribe to shard: %w", err) } + k.shardReaderTomb.Go(func() error { return k.ReadFromSubscription(r.GetEventStream().Reader, out, shardId, arn.Resource[7:]) }) } + return nil } @@ -389,12 +430,14 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { if err != nil { return fmt.Errorf("cannot parse stream ARN: %w", err) } + if !strings.HasPrefix(parsedARN.Resource, "stream/") { return fmt.Errorf("resource part of stream ARN %s does not start with stream/", k.Config.StreamARN) } k.logger = k.logger.WithField("stream", parsedARN.Resource[7:]) k.logger.Info("starting kinesis acquisition with enhanced fan-out") + err = k.DeregisterConsumer() if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) @@ -417,18 +460,22 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { k.logger.Infof("Kinesis source is dying") k.shardReaderTomb.Kill(nil) _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves + err = k.DeregisterConsumer() if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) } + return nil case <-k.shardReaderTomb.Dying(): k.logger.Debugf("Kinesis subscribed shard reader is dying") + if k.shardReaderTomb.Err() != nil { return k.shardReaderTomb.Err() } // All goroutines have exited without error, so a resharding event, start again k.logger.Debugf("All reader goroutines have exited, resharding event or periodic resubscribe") + continue } } @@ -437,6 +484,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) error { logger := k.logger.WithField("shard", shardId) logger.Debugf("Starting to read shard") + sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ ShardId: aws.String(shardId), StreamName: &k.Config.StreamName, @@ -446,28 +494,35 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro logger.Errorf("Cannot get shard iterator: %s", err) return fmt.Errorf("cannot get shard iterator: %w", err) } + it := sharIt.ShardIterator // AWS recommends to wait for a second between calls to GetRecords for a given shard ticker := time.NewTicker(time.Second) + for { select { case <-ticker.C: records, err := k.kClient.GetRecords(&kinesis.GetRecordsInput{ShardIterator: it}) it = records.NextShardIterator + + var throughputErr *kinesis.ProvisionedThroughputExceededException + if errors.As(err, &throughputErr) { + logger.Warn("Provisioned throughput exceeded") + // TODO: implement exponential backoff + continue + } + + var expiredIteratorErr *kinesis.ExpiredIteratorException + if errors.As(err, &expiredIteratorErr) { + logger.Warn("Expired iterator") + continue + } + if err != nil { - switch err.(type) { - case *kinesis.ProvisionedThroughputExceededException: - logger.Warn("Provisioned throughput exceeded") - // TODO: implement exponential backoff - continue - case *kinesis.ExpiredIteratorException: - logger.Warn("Expired iterator") - continue - default: - logger.Error("Cannot get records") - return fmt.Errorf("cannot get records: %w", err) - } + logger.Error("Cannot get records") + return fmt.Errorf("cannot get records: %w", err) } + k.ParseAndPushRecords(records.Records, out, logger, shardId) if it == nil { @@ -477,6 +532,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro case <-k.shardReaderTomb.Dying(): logger.Infof("shardReaderTomb is dying, exiting ReadFromShard") ticker.Stop() + return nil } } @@ -485,6 +541,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error { k.logger = k.logger.WithField("stream", k.Config.StreamName) k.logger.Info("starting kinesis acquisition from shards") + for { shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ StreamName: aws.String(k.Config.StreamName), @@ -492,9 +549,12 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error if err != nil { return fmt.Errorf("cannot list shards: %w", err) } + k.shardReaderTomb = &tomb.Tomb{} + for _, shard := range shards.Shards { shardId := *shard.ShardId + k.shardReaderTomb.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming/shard") return k.ReadFromShard(out, shardId) @@ -505,6 +565,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error k.logger.Info("kinesis source is dying") k.shardReaderTomb.Kill(nil) _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves + return nil case <-k.shardReaderTomb.Dying(): reason := k.shardReaderTomb.Err() @@ -512,7 +573,9 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error k.logger.Errorf("Unexpected error from shard reader : %s", reason) return reason } + k.logger.Infof("All shards have been closed, probably a resharding event, restarting acquisition") + continue } } @@ -521,11 +584,14 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") + if k.Config.UseEnhancedFanOut { return k.EnhancedRead(out, t) } + return k.ReadFromStream(out, t) }) + return nil } diff --git a/pkg/apiclient/client_http.go b/pkg/apiclient/client_http.go index eeca929ea6e..c64404dc7ee 100644 --- a/pkg/apiclient/client_http.go +++ b/pkg/apiclient/client_http.go @@ -78,10 +78,11 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* } // If the error type is *url.Error, sanitize its URL before returning. - if e, ok := err.(*url.Error); ok { - if url, err := url.Parse(e.URL); err == nil { - e.URL = url.String() - return newResponse(resp), e + var urlErr *url.Error + if errors.As(err, &urlErr) { + if parsedURL, parseErr := url.Parse(urlErr.URL); parseErr == nil { + urlErr.URL = parsedURL.String() + return newResponse(resp), urlErr } return newResponse(resp), err diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index e1d9ce95349..88f1bd21dc4 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -46,10 +46,18 @@ type APIServer struct { consoleConfig *csconfig.ConsoleConfig } -func isBrokenConnection(err any) bool { - if ne, ok := err.(*net.OpError); ok { - if se, ok := ne.Err.(*os.SyscallError); ok { - if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { +func isBrokenConnection(maybeError any) bool { + err, ok := maybeError.(error) + if !ok { + return false + } + + var netOpError *net.OpError + if errors.As(err, &netOpError) { + var syscallError *os.SyscallError + if errors.As(netOpError.Err, &syscallError) { + if strings.Contains(strings.ToLower(syscallError.Error()), "broken pipe") || + strings.Contains(strings.ToLower(syscallError.Error()), "connection reset by peer") { return true } } @@ -57,21 +65,19 @@ func isBrokenConnection(err any) bool { // because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go // and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them - if strErr, ok := err.(error); ok { - // stolen from http2/server.go in x/net - var ( - errClientDisconnected = errors.New("client disconnected") - errClosedBody = errors.New("body closed by handler") - errHandlerComplete = errors.New("http2: request body closed due to handler exiting") - errStreamClosed = errors.New("http2: stream closed") - ) + // stolen from http2/server.go in x/net + var ( + errClientDisconnected = errors.New("client disconnected") + errClosedBody = errors.New("body closed by handler") + errHandlerComplete = errors.New("http2: request body closed due to handler exiting") + errStreamClosed = errors.New("http2: stream closed") + ) - if errors.Is(strErr, errClientDisconnected) || - errors.Is(strErr, errClosedBody) || - errors.Is(strErr, errHandlerComplete) || - errors.Is(strErr, errStreamClosed) { - return true - } + if errors.Is(err, errClientDisconnected) || + errors.Is(err, errClosedBody) || + errors.Is(err, errHandlerComplete) || + errors.Is(err, errStreamClosed) { + return true } return false diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 4e3f209b012..107abcbb1d0 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -642,7 +642,8 @@ func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner * break } - if sqliteErr, ok := err.(sqlite3.Error); ok { + var sqliteErr sqlite3.Error + if errors.As(err, &sqliteErr) { if sqliteErr.Code == sqlite3.ErrBusy { // sqlite3.Error{ // Code: 5,