From a2dcc0ef9a534fd9998651fa0b31338acd24e109 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 6 May 2024 12:33:54 +0200 Subject: [PATCH] cscli: remove global dbClient (#2985) * cscli: remove global dbClient * lint (whitespace, errors) --- cmd/crowdsec-cli/main.go | 2 - cmd/crowdsec-cli/papi.go | 14 +++---- cmd/crowdsec-cli/support.go | 2 +- pkg/database/database.go | 46 ++++++++++++++--------- pkg/leakybucket/manager_run.go | 6 +-- pkg/leakybucket/overflows.go | 68 +++++++++++++++++++++++++++------- 6 files changed, 94 insertions(+), 44 deletions(-) diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index e3c45390a18..95c528f20b5 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -15,14 +15,12 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/fflag" ) var ( ConfigFilePath string csConfig *csconfig.Config - dbClient *database.Client ) type configGetter func() *csconfig.Config diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go index 5808fcce5f6..558409b2d4d 100644 --- a/cmd/crowdsec-cli/papi.go +++ b/cmd/crowdsec-cli/papi.go @@ -62,17 +62,17 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command { RunE: func(_ *cobra.Command, _ []string) error { var err error cfg := cli.cfg() - dbClient, err = database.NewClient(cfg.DbConfig) + db, err := database.NewClient(cfg.DbConfig) if err != nil { return fmt.Errorf("unable to initialize database client: %w", err) } - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, dbClient, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } - papi, err := apiserver.NewPAPI(apic, dbClient, cfg.API.Server.ConsoleConfig, log.GetLevel()) + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) if err != nil { return fmt.Errorf("unable to initialize PAPI client: %w", err) } @@ -82,7 +82,7 @@ func (cli *cliPapi) NewStatusCmd() *cobra.Command { return fmt.Errorf("unable to get PAPI permissions: %w", err) } var lastTimestampStr *string - lastTimestampStr, err = dbClient.GetConfigItem(apiserver.PapiPullKey) + lastTimestampStr, err = db.GetConfigItem(apiserver.PapiPullKey) if err != nil { lastTimestampStr = ptr.Of("never") } @@ -113,19 +113,19 @@ func (cli *cliPapi) NewSyncCmd() *cobra.Command { cfg := cli.cfg() t := tomb.Tomb{} - dbClient, err = database.NewClient(cfg.DbConfig) + db, err := database.NewClient(cfg.DbConfig) if err != nil { return fmt.Errorf("unable to initialize database client: %w", err) } - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, dbClient, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) if err != nil { return fmt.Errorf("unable to initialize API client: %w", err) } t.Go(apic.Push) - papi, err := apiserver.NewPAPI(apic, dbClient, cfg.API.Server.ConsoleConfig, log.GetLevel()) + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) if err != nil { return fmt.Errorf("unable to initialize PAPI client: %w", err) } diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/support.go index 418a981adee..5890061f502 100644 --- a/cmd/crowdsec-cli/support.go +++ b/cmd/crowdsec-cli/support.go @@ -331,7 +331,7 @@ cscli support dump -f /tmp/crowdsec-support.zip outFile = "/tmp/crowdsec-support.zip" } - dbClient, err = database.NewClient(csConfig.DbConfig) + dbClient, err := database.NewClient(csConfig.DbConfig) if err != nil { log.Warnf("Could not connect to database: %s", err) skipDB = true diff --git a/pkg/database/database.go b/pkg/database/database.go index d984aefb170..96a495f6731 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -35,72 +35,84 @@ func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig. if err != nil { return nil, err } + if config.MaxOpenConns == nil { log.Warningf("MaxOpenConns is 0, defaulting to %d", csconfig.DEFAULT_MAX_OPEN_CONNS) config.MaxOpenConns = ptr.Of(csconfig.DEFAULT_MAX_OPEN_CONNS) } + db.SetMaxOpenConns(*config.MaxOpenConns) drv := entsql.OpenDB(dbdialect, db) + return drv, nil } func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { var client *ent.Client - var err error + if config == nil { - return &Client{}, errors.New("DB config is empty") + return nil, errors.New("DB config is empty") } /*The logger that will be used by db operations*/ clog := log.New() if err := types.ConfigureLogger(clog); err != nil { return nil, fmt.Errorf("while configuring db logger: %w", err) } + if config.LogLevel != nil { clog.SetLevel(*config.LogLevel) } - entLogger := clog.WithField("context", "ent") + entLogger := clog.WithField("context", "ent") entOpt := ent.Log(entLogger.Debug) + typ, dia, err := config.ConnectionDialect() if err != nil { - return &Client{}, err //unsupported database caught here + return nil, err // unsupported database caught here } + if config.Type == "sqlite" { /*if it's the first startup, we want to touch and chmod file*/ if _, err := os.Stat(config.DbPath); os.IsNotExist(err) { - f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600) + f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0o600) if err != nil { - return &Client{}, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) + return nil, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) } + if err := f.Close(); err != nil { - return &Client{}, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) + return nil, fmt.Errorf("failed to create SQLite database file %q: %w", config.DbPath, err) } } - //Always try to set permissions to simplify a bit the code for windows (as the permissions set by OpenFile will be garbage) - if err := setFilePerm(config.DbPath, 0640); err != nil { - return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) + // Always try to set permissions to simplify a bit the code for windows (as the permissions set by OpenFile will be garbage) + if err := setFilePerm(config.DbPath, 0o640); err != nil { + return nil, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) } } + drv, err := getEntDriver(typ, dia, config.ConnectionString(), config) if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to %s: %v", config.Type, err) + return nil, fmt.Errorf("failed opening connection to %s: %v", config.Type, err) } + client = ent.NewClient(ent.Driver(drv), entOpt) + if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel { clog.Debugf("Enabling request debug") + client = client.Debug() } + if err = client.Schema.Create(context.Background()); err != nil { return nil, fmt.Errorf("failed creating schema resources: %v", err) } return &Client{ - Ent: client, - CTX: context.Background(), - Log: clog, - CanFlush: true, - Type: config.Type, - WalMode: config.UseWal, + Ent: client, + CTX: context.Background(), + Log: clog, + CanFlush: true, + Type: config.Type, + WalMode: config.UseWal, decisionBulkSize: config.DecisionBulkSize, }, nil } diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go index ae7a86a4e4e..1d34c238ea5 100644 --- a/pkg/leakybucket/manager_run.go +++ b/pkg/leakybucket/manager_run.go @@ -85,7 +85,7 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) defer buckets.wgDumpState.Done() if outputdir == "" { - return "", fmt.Errorf("empty output dir for dump bucket state") + return "", errors.New("empty output dir for dump bucket state") } tmpFd, err := os.CreateTemp(os.TempDir(), "crowdsec-buckets-dump-") if err != nil { @@ -132,11 +132,11 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) }) bbuckets, err := json.MarshalIndent(serialized, "", " ") if err != nil { - return "", fmt.Errorf("Failed to unmarshal buckets : %s", err) + return "", fmt.Errorf("failed to unmarshal buckets: %s", err) } size, err := tmpFd.Write(bbuckets) if err != nil { - return "", fmt.Errorf("failed to write temp file : %s", err) + return "", fmt.Errorf("failed to write temp file: %s", err) } log.Infof("Serialized %d live buckets (+%d expired) in %d bytes to %s", len(serialized), discard, size, tmpFd.Name()) serialized = nil diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index 80226aafb2a..8092ef35e77 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -1,6 +1,7 @@ package leakybucket import ( + "errors" "fmt" "net" "sort" @@ -22,9 +23,7 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e /*if it's already an overflow, we have properly formatted sources. we can just twitch them to reflect the requested scope*/ if evt.Type == types.OVFLW { - for k, v := range evt.Overflow.Sources { - /*the scopes are already similar, nothing to do*/ if leaky.scopeType.Scope == *v.Scope { srcs[k] = v @@ -46,20 +45,25 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e src.Scope = new(string) *src.Scope = leaky.scopeType.Scope *src.Value = "" + if v.Range != "" { *src.Value = v.Range } + if leaky.scopeType.RunTimeFilter != nil { retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { return srcs, fmt.Errorf("while running scope filter: %w", err) } + value, ok := retValue.(string) if !ok { value = "" } + src.Value = &value } + if *src.Value != "" { srcs[*src.Value] = src } else { @@ -71,50 +75,64 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e } } } + return srcs, nil } + src := models.Source{} + switch leaky.scopeType.Scope { case types.Range, types.Ip: v, ok := evt.Meta["source_ip"] if !ok { return srcs, fmt.Errorf("scope is %s but Meta[source_ip] doesn't exist", leaky.scopeType.Scope) } + if net.ParseIP(v) == nil { return srcs, fmt.Errorf("scope is %s but '%s' isn't a valid ip", leaky.scopeType.Scope, v) } + src.IP = v src.Scope = &leaky.scopeType.Scope + if v, ok := evt.Enriched["ASNumber"]; ok { src.AsNumber = v } else if v, ok := evt.Enriched["ASNNumber"]; ok { src.AsNumber = v } + if v, ok := evt.Enriched["IsoCode"]; ok { src.Cn = v } + if v, ok := evt.Enriched["ASNOrg"]; ok { src.AsName = v } + if v, ok := evt.Enriched["Latitude"]; ok { l, err := strconv.ParseFloat(v, 32) if err != nil { log.Warningf("bad latitude %s : %s", v, err) } + src.Latitude = float32(l) } + if v, ok := evt.Enriched["Longitude"]; ok { l, err := strconv.ParseFloat(v, 32) if err != nil { log.Warningf("bad longitude %s : %s", v, err) } + src.Longitude = float32(l) } + if v, ok := evt.Meta["SourceRange"]; ok && v != "" { _, ipNet, err := net.ParseCIDR(v) if err != nil { - return srcs, fmt.Errorf("Declared range %s of %s can't be parsed", v, src.IP) + return srcs, fmt.Errorf("declared range %s of %s can't be parsed", v, src.IP) } + if ipNet != nil { src.Range = ipNet.String() leaky.logger.Tracef("Valid range from %s : %s", src.IP, src.Range) @@ -124,6 +142,7 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e src.Value = &src.IP } else if leaky.scopeType.Scope == types.Range { src.Value = &src.Range + if leaky.scopeType.RunTimeFilter != nil { retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { @@ -134,14 +153,17 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if !ok { value = "" } + src.Value = &value } } + srcs[*src.Value] = src default: if leaky.scopeType.RunTimeFilter == nil { - return srcs, fmt.Errorf("empty scope information") + return srcs, errors.New("empty scope information") } + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) if err != nil { return srcs, fmt.Errorf("while running scope filter: %w", err) @@ -151,30 +173,34 @@ func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, e if !ok { value = "" } + src.Value = &value src.Scope = new(string) *src.Scope = leaky.scopeType.Scope srcs[*src.Value] = src } + return srcs, nil } // EventsFromQueue iterates the queue to collect & prepare meta-datas from alert func EventsFromQueue(queue *types.Queue) []*models.Event { - events := []*models.Event{} for _, evt := range queue.Queue { if evt.Meta == nil { continue } + meta := models.Meta{} - //we want consistence + // we want consistence skeys := make([]string, 0, len(evt.Meta)) for k := range evt.Meta { skeys = append(skeys, k) } + sort.Strings(skeys) + for _, k := range skeys { v := evt.Meta[k] subMeta := models.MetaItems0{Key: k, Value: v} @@ -185,12 +211,13 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { ovflwEvent := models.Event{ Meta: meta, } - //either MarshaledTime is present and is extracted from log + // either MarshaledTime is present and is extracted from log if evt.MarshaledTime != "" { tmpTimeStamp := evt.MarshaledTime ovflwEvent.Timestamp = &tmpTimeStamp - } else if !evt.Time.IsZero() { //or .Time has been set during parse as time.Now().UTC() + } else if !evt.Time.IsZero() { // or .Time has been set during parse as time.Now().UTC() ovflwEvent.Timestamp = new(string) + raw, err := evt.Time.MarshalText() if err != nil { log.Warningf("while marshaling time '%s' : %s", evt.Time.String(), err) @@ -203,6 +230,7 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { events = append(events, &ovflwEvent) } + return events } @@ -218,17 +246,21 @@ func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Sour if err != nil { return nil, "", fmt.Errorf("while extracting scope from bucket %s: %w", leaky.Name, err) } + for key, src := range srcs { if source_type == types.Undefined { source_type = *src.Scope } + if *src.Scope != source_type { return nil, "", fmt.Errorf("event has multiple source types : %s != %s", *src.Scope, source_type) } + sources[key] = src } } + return sources, source_type, nil } @@ -244,10 +276,12 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { if err != nil { log.Warningf("failed to marshal start ts %s : %s", leaky.First_ts.String(), err) } + stop_at, err := leaky.Ovflw_ts.MarshalText() if err != nil { log.Warningf("failed to marshal ovflw ts %s : %s", leaky.First_ts.String(), err) } + capacity := int32(leaky.Capacity) EventsCount := int32(leaky.Total_count) leakSpeed := leaky.Leakspeed.String() @@ -266,19 +300,20 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { Simulated: &leaky.Simulated, } if leaky.BucketConfig == nil { - return runtimeAlert, fmt.Errorf("leaky.BucketConfig is nil") + return runtimeAlert, errors.New("leaky.BucketConfig is nil") } - //give information about the bucket + // give information about the bucket runtimeAlert.Mapkey = leaky.Mapkey - //Get the sources from Leaky/Queue + // Get the sources from Leaky/Queue sources, source_scope, err := alertFormatSource(leaky, queue) if err != nil { return runtimeAlert, fmt.Errorf("unable to collect sources from bucket: %w", err) } + runtimeAlert.Sources = sources - //Include source info in format string + // Include source info in format string sourceStr := "UNKNOWN" if len(sources) > 1 { sourceStr = fmt.Sprintf("%d sources", len(sources)) @@ -290,19 +325,22 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { } *apiAlert.Message = fmt.Sprintf("%s %s performed '%s' (%d events over %s) at %s", source_scope, sourceStr, leaky.Name, leaky.Total_count, leaky.Ovflw_ts.Sub(leaky.First_ts), leaky.Last_ts) - //Get the events from Leaky/Queue + // Get the events from Leaky/Queue apiAlert.Events = EventsFromQueue(queue) + var warnings []error + apiAlert.Meta, warnings = alertcontext.EventToContext(leaky.Queue.GetQueue()) for _, w := range warnings { log.Warningf("while extracting context from bucket %s : %s", leaky.Name, w) } - //Loop over the Sources and generate appropriate number of ApiAlerts + // Loop over the Sources and generate appropriate number of ApiAlerts for _, srcValue := range sources { newApiAlert := apiAlert srcCopy := srcValue newApiAlert.Source = &srcCopy + if v, ok := leaky.BucketConfig.Labels["remediation"]; ok && v == true { newApiAlert.Remediation = true } @@ -312,6 +350,7 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { log.Errorf("->%s", spew.Sdump(newApiAlert)) log.Fatalf("error : %s", err) } + runtimeAlert.APIAlerts = append(runtimeAlert.APIAlerts, newApiAlert) } @@ -322,5 +361,6 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { if leaky.Reprocess { runtimeAlert.Reprocess = true } + return runtimeAlert, nil }