From 4748720a07f59dbf734cf609248ed9a5306d109b Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:33:24 +0100 Subject: [PATCH] refactor pkg/leakybucket (#3371) * refact pkg/leakybucket - call LoadBuckets with Item instances * extract compileScopeFilter() --- cmd/crowdsec/main.go | 15 ++---- pkg/apiserver/alerts_test.go | 82 +++++++++++++++---------------- pkg/leakybucket/buckets_test.go | 38 +++++++++------ pkg/leakybucket/manager_load.go | 86 ++++++++++++++------------------- pkg/types/ip_test.go | 11 ++++- 5 files changed, 116 insertions(+), 116 deletions(-) diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index e414f59f3e2..518bd8e9c0d 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -86,20 +86,15 @@ func (f *Flags) haveTimeMachine() bool { type labelsMap map[string]string func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { - var ( - err error - files []string - ) - - for _, hubScenarioItem := range hub.GetInstalledByType(cwhub.SCENARIOS, false) { - files = append(files, hubScenarioItem.State.LocalPath) - } + var err error buckets = leakybucket.NewBuckets() - log.Infof("Loading %d scenario files", len(files)) + scenarios := hub.GetInstalledByType(cwhub.SCENARIOS, false) + + log.Infof("Loading %d scenario files", len(scenarios)) - holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, files, &bucketsTomb, buckets, flags.OrderEvent) + holders, outputEventChan, err = leakybucket.LoadBuckets(cConfig.Crowdsec, hub, scenarios, &bucketsTomb, buckets, flags.OrderEvent) if err != nil { return fmt.Errorf("scenario loading failed: %w", err) } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index e51987ba71a..4c5c6ef129c 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -103,13 +103,13 @@ func TestSimulatedAlert(t *testing.T) { // exclude decision in simulation mode w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", alertContent, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) // include decision in simulation mode w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", alertContent, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) } @@ -120,21 +120,21 @@ func TestCreateAlert(t *testing.T) { // Create Alert with invalid format w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") - assert.Equal(t, 400, w.Code) + assert.Equal(t, http.StatusBadRequest, w.Code) assert.JSONEq(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.JSONEq(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String()) // Create Valid Alert w = lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, `["1"]`, w.Body.String()) } @@ -175,13 +175,13 @@ func TestAlertListFilters(t *testing.T) { // bad filter w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // get without filters w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) // check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) @@ -189,149 +189,149 @@ func TestAlertListFilters(t *testing.T) { // test decision_type filter (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test decision_type filter (bad value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test scope (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test scope (bad value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test scenario (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test scenario (bad value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test ip (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test ip (bad value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test ip (invalid value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.JSONEq(t, `{"message":"invalid ip address 'gruueq'"}`, w.Body.String()) // test range (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test range w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test range (invalid value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.JSONEq(t, `{"message":"invalid ip address 'ratata'"}`, w.Body.String()) // test since (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1h", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test since (ok but yields no results) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1ns", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test since (invalid value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) // test until (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1ns", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test until (ok but no return) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1m", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test until (invalid value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) // test simulated (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test simulated (ok) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test has active decision w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) // test has active decision w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) // test has active decision (invalid value) w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.JSONEq(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) } @@ -343,7 +343,7 @@ func TestAlertBulkInsert(t *testing.T) { alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json") w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", alertContent, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) } func TestListAlert(t *testing.T) { @@ -353,13 +353,13 @@ func TestListAlert(t *testing.T) { // List Alert with invalid filter w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password") - assert.Equal(t, 500, w.Code) + assert.Equal(t, http.StatusInternalServerError, w.Code) assert.JSONEq(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } @@ -374,7 +374,7 @@ func TestCreateAlertErrors(t *testing.T) { req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata")) lapi.router.ServeHTTP(w, req) - assert.Equal(t, 401, w.Code) + assert.Equal(t, http.StatusUnauthorized, w.Code) // test invalid bearer w = httptest.NewRecorder() @@ -382,7 +382,7 @@ func TestCreateAlertErrors(t *testing.T) { req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) - assert.Equal(t, 401, w.Code) + assert.Equal(t, http.StatusUnauthorized, w.Code) } func TestDeleteAlert(t *testing.T) { @@ -396,7 +396,7 @@ func TestDeleteAlert(t *testing.T) { AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String()) // Delete Alert @@ -405,7 +405,7 @@ func TestDeleteAlert(t *testing.T) { AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } @@ -420,7 +420,7 @@ func TestDeleteAlertByID(t *testing.T) { AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) assert.JSONEq(t, `{"message":"access forbidden from this IP (127.0.0.2)"}`, w.Body.String()) // Delete Alert @@ -429,7 +429,7 @@ func TestDeleteAlertByID(t *testing.T) { AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } @@ -463,7 +463,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { req.RemoteAddr = ip + ":1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) + assert.Equal(t, http.StatusForbidden, w.Code) assert.Contains(t, w.Body.String(), fmt.Sprintf(`{"message":"access forbidden from this IP (%s)"}`, ip)) } @@ -474,7 +474,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { req.RemoteAddr = ip + ":1234" router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.JSONEq(t, `{"nbDeleted":"1"}`, w.Body.String()) } diff --git a/pkg/leakybucket/buckets_test.go b/pkg/leakybucket/buckets_test.go index 1da906cb555..8bb7a3d4c47 100644 --- a/pkg/leakybucket/buckets_test.go +++ b/pkg/leakybucket/buckets_test.go @@ -139,14 +139,24 @@ func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) er t.Fatalf("failed to parse %s : %s", stagecfg, err) } - files := []string{} + scenarios := []*cwhub.Item{} for _, x := range stages { - files = append(files, x.Filename) + // XXX: LoadBuckets should take an interface, BucketProvider ScenarioProvider or w/e + item := &cwhub.Item{ + Name: x.Filename, + State: cwhub.ItemState{ + LocalVersion: "", + LocalPath: x.Filename, + LocalHash: "", + }, + } + + scenarios = append(scenarios, item) } cscfg := &csconfig.CrowdsecServiceCfg{} - holders, response, err := LoadBuckets(cscfg, hub, files, tomb, buckets, false) + holders, response, err := LoadBuckets(cscfg, hub, scenarios, tomb, buckets, false) if err != nil { t.Fatalf("failed loading bucket : %s", err) } @@ -184,7 +194,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res } dec := json.NewDecoder(yamlFile) dec.DisallowUnknownFields() - //dec.SetStrict(true) + // dec.SetStrict(true) tf := TestFile{} err = dec.Decode(&tf) if err != nil { @@ -196,7 +206,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res } var latest_ts time.Time for _, in := range tf.Lines { - //just to avoid any race during ingestion of funny scenarios + // just to avoid any race during ingestion of funny scenarios time.Sleep(50 * time.Millisecond) var ts time.Time @@ -226,7 +236,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res time.Sleep(1 * time.Second) - //Read results from chan + // Read results from chan POLL_AGAIN: fails := 0 for fails < 2 { @@ -287,37 +297,37 @@ POLL_AGAIN: log.Tracef("Checking next expected result.") - //empty overflow + // empty overflow if out.Overflow.Alert == nil && expected.Overflow.Alert == nil { - //match stuff + // match stuff } else { if out.Overflow.Alert == nil || expected.Overflow.Alert == nil { log.Printf("Here ?") continue } - //Scenario + // Scenario if *out.Overflow.Alert.Scenario != *expected.Overflow.Alert.Scenario { log.Errorf("(scenario) %v != %v", *out.Overflow.Alert.Scenario, *expected.Overflow.Alert.Scenario) continue } log.Infof("(scenario) %v == %v", *out.Overflow.Alert.Scenario, *expected.Overflow.Alert.Scenario) - //EventsCount + // EventsCount if *out.Overflow.Alert.EventsCount != *expected.Overflow.Alert.EventsCount { log.Errorf("(EventsCount) %d != %d", *out.Overflow.Alert.EventsCount, *expected.Overflow.Alert.EventsCount) continue } log.Infof("(EventsCount) %d == %d", *out.Overflow.Alert.EventsCount, *expected.Overflow.Alert.EventsCount) - //Sources + // Sources if !reflect.DeepEqual(out.Overflow.Sources, expected.Overflow.Sources) { log.Errorf("(Sources %s != %s)", spew.Sdump(out.Overflow.Sources), spew.Sdump(expected.Overflow.Sources)) continue } log.Infof("(Sources: %s == %s)", spew.Sdump(out.Overflow.Sources), spew.Sdump(expected.Overflow.Sources)) } - //Events + // Events // if !reflect.DeepEqual(out.Overflow.Alert.Events, expected.Overflow.Alert.Events) { // log.Errorf("(Events %s != %s)", spew.Sdump(out.Overflow.Alert.Events), spew.Sdump(expected.Overflow.Alert.Events)) // valid = false @@ -326,10 +336,10 @@ POLL_AGAIN: // log.Infof("(Events: %s == %s)", spew.Sdump(out.Overflow.Alert.Events), spew.Sdump(expected.Overflow.Alert.Events)) // } - //CheckFailed: + // CheckFailed: log.Warningf("The test is valid, remove entry %d from expects, and %d from t.Results", eidx, ridx) - //don't do this at home : delete current element from list and redo + // don't do this at home : delete current element from list and redo results[eidx] = results[len(results)-1] results = results[:len(results)-1] tf.Results[ridx] = tf.Results[len(tf.Results)-1] diff --git a/pkg/leakybucket/manager_load.go b/pkg/leakybucket/manager_load.go index bc907ac257b..5e8bab8486e 100644 --- a/pkg/leakybucket/manager_load.go +++ b/pkg/leakybucket/manager_load.go @@ -7,7 +7,6 @@ import ( "io" "os" "path/filepath" - "strings" "sync" "time" @@ -201,44 +200,41 @@ func ValidateFactory(bucketFactory *BucketFactory) error { return fmt.Errorf("unknown bucket type '%s'", bucketFactory.Type) } - switch bucketFactory.ScopeType.Scope { - case types.Undefined: + return compileScopeFilter(bucketFactory) +} + +func compileScopeFilter(bucketFactory *BucketFactory) error { + if bucketFactory.ScopeType.Scope == types.Undefined { bucketFactory.ScopeType.Scope = types.Ip - case types.Ip: - case types.Range: - var ( - runTimeFilter *vm.Program - err error - ) + } + if bucketFactory.ScopeType.Scope == types.Ip { if bucketFactory.ScopeType.Filter != "" { - if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("error compiling the scope filter: %w", err) - } - - bucketFactory.ScopeType.RunTimeFilter = runTimeFilter + return errors.New("filter is not allowed for IP scope") } - default: - // Compile the scope filter - var ( - runTimeFilter *vm.Program - err error - ) + return nil + } - if bucketFactory.ScopeType.Filter != "" { - if runTimeFilter, err = expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...); err != nil { - return fmt.Errorf("error compiling the scope filter: %w", err) - } + if bucketFactory.ScopeType.Scope == types.Range && bucketFactory.ScopeType.Filter == "" { + return nil + } - bucketFactory.ScopeType.RunTimeFilter = runTimeFilter - } + if bucketFactory.ScopeType.Filter == "" { + return errors.New("filter is mandatory for non-IP, non-Range scope") } + runTimeFilter, err := expr.Compile(bucketFactory.ScopeType.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + if err != nil { + return fmt.Errorf("error compiling the scope filter: %w", err) + } + + bucketFactory.ScopeType.RunTimeFilter = runTimeFilter + return nil } -func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []string, tomb *tomb.Tomb, buckets *Buckets, orderEvent bool) ([]BucketFactory, chan types.Event, error) { +func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, scenarios []*cwhub.Item, tomb *tomb.Tomb, buckets *Buckets, orderEvent bool) ([]BucketFactory, chan types.Event, error) { var ( ret = []BucketFactory{} response chan types.Event @@ -246,18 +242,15 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str response = make(chan types.Event, 1) - for _, f := range files { - log.Debugf("Loading '%s'", f) + for _, item := range scenarios { + log.Debugf("Loading '%s'", item.State.LocalPath) - if !strings.HasSuffix(f, ".yaml") && !strings.HasSuffix(f, ".yml") { - log.Debugf("Skipping %s : not a yaml file", f) - continue - } + itemPath := item.State.LocalPath // process the yaml - bucketConfigurationFile, err := os.Open(f) + bucketConfigurationFile, err := os.Open(itemPath) if err != nil { - log.Errorf("Can't access leaky configuration file %s", f) + log.Errorf("Can't access leaky configuration file %s", itemPath) return nil, nil, err } @@ -271,8 +264,8 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str err = dec.Decode(&bucketFactory) if err != nil { if !errors.Is(err, io.EOF) { - log.Errorf("Bad yaml in %s: %v", f, err) - return nil, nil, fmt.Errorf("bad yaml in %s: %w", f, err) + log.Errorf("Bad yaml in %s: %v", itemPath, err) + return nil, nil, fmt.Errorf("bad yaml in %s: %w", itemPath, err) } log.Tracef("End of yaml file") @@ -288,7 +281,7 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str } // check compat if bucketFactory.FormatVersion == "" { - log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, f) + log.Tracef("no version in %s : %s, assuming '1.0'", bucketFactory.Name, itemPath) bucketFactory.FormatVersion = "1.0" } @@ -302,22 +295,17 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str continue } - bucketFactory.Filename = filepath.Clean(f) + bucketFactory.Filename = filepath.Clean(itemPath) bucketFactory.BucketName = seed.Generate() bucketFactory.ret = response - hubItem := hub.GetItemByPath(bucketFactory.Filename) - if hubItem == nil { - log.Errorf("scenario %s (%s) could not be found in hub (ignore if in unit tests)", bucketFactory.Name, bucketFactory.Filename) - } else { - if cscfg.SimulationConfig != nil { - bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(hubItem.Name) - } - - bucketFactory.ScenarioVersion = hubItem.State.LocalVersion - bucketFactory.hash = hubItem.State.LocalHash + if cscfg.SimulationConfig != nil { + bucketFactory.Simulated = cscfg.SimulationConfig.IsSimulated(item.Name) } + bucketFactory.ScenarioVersion = item.State.LocalVersion + bucketFactory.hash = item.State.LocalHash + bucketFactory.wgDumpState = buckets.wgDumpState bucketFactory.wgPour = buckets.wgPour diff --git a/pkg/types/ip_test.go b/pkg/types/ip_test.go index ef7253f8a9b..571163761d4 100644 --- a/pkg/types/ip_test.go +++ b/pkg/types/ip_test.go @@ -9,6 +9,7 @@ import ( func TestIP2Int(t *testing.T) { tEmpty := net.IP{} + _, _, _, err := IP2Ints(tEmpty) if !strings.Contains(err.Error(), "unexpected len 0 for ") { t.Fatalf("unexpected: %s", err) @@ -189,31 +190,37 @@ func TestAdd2Int(t *testing.T) { if err != nil && test.exp_error == "" { t.Fatalf("%d unexpected error : %s", idx, err) } + if test.exp_error != "" { if !strings.Contains(err.Error(), test.exp_error) { t.Fatalf("%d unmatched error : %s != %s", idx, err, test.exp_error) } - continue //we can skip this one + + continue // we can skip this one } + if sz != test.exp_sz { t.Fatalf("%d unexpected size %d != %d", idx, sz, test.exp_sz) } + if start_ip != test.exp_start_ip { t.Fatalf("%d unexpected start_ip %d != %d", idx, start_ip, test.exp_start_ip) } + if sz == 16 { if start_sfx != test.exp_start_sfx { t.Fatalf("%d unexpected start sfx %d != %d", idx, start_sfx, test.exp_start_sfx) } } + if end_ip != test.exp_end_ip { t.Fatalf("%d unexpected end ip %d != %d", idx, end_ip, test.exp_end_ip) } + if sz == 16 { if end_sfx != test.exp_end_sfx { t.Fatalf("%d unexpected end sfx %d != %d", idx, end_sfx, test.exp_end_sfx) } } - } }