diff --git a/.github/codecov-ignore-generated.sh b/.github/codecov-ignore-generated.sh deleted file mode 100755 index 3c896d47be7..00000000000 --- a/.github/codecov-ignore-generated.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -# Run this from the repository root: -# -# .github/codecov-ignore-generated.sh >> .github/codecov.yml - -find . -name "*.go" | while read -r file; do - if head -n 1 "$file" | grep -q "Code generated by"; then - echo " - \"$file\"" - fi -done diff --git a/.github/codecov.yml b/.github/codecov.yml index 82598c15511..e3a81070324 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -1,5 +1,8 @@ # we measure coverage but don't enforce it # https://docs.codecov.com/docs/codecov-yaml +codecov: + require_ci_to_pass: false + coverage: status: patch: @@ -10,7 +13,7 @@ coverage: target: 0% # if a directory is ignored, there is no way to un-ignore files like pkg/models/helpers.go -# so we make a full list, manually updated - but it could be generated right before running codecov +# so we make a full list ignore: - "./pkg/modelscapi/success_response.go" - "./pkg/modelscapi/get_decisions_stream_response_deleted.go" @@ -41,17 +44,26 @@ ignore: - "./pkg/modelscapi/enroll_request.go" - "./pkg/modelscapi/register_request.go" - "./pkg/modelscapi/add_signals_request_item_source.go" + - "./pkg/models/success_response.go" + - "./pkg/models/hub_items.go" - "./pkg/models/alert.go" - "./pkg/models/metrics_bouncer_info.go" - "./pkg/models/add_signals_request_item.go" + - "./pkg/models/metrics_meta.go" + - "./pkg/models/metrics_detail_item.go" - "./pkg/models/add_signals_request_item_decisions_item.go" + - "./pkg/models/hub_item.go" - "./pkg/models/get_alerts_response.go" + - "./pkg/models/metrics_labels.go" - "./pkg/models/watcher_auth_request.go" - "./pkg/models/add_alerts_request.go" - "./pkg/models/event.go" - "./pkg/models/decisions_delete_request_item.go" - "./pkg/models/meta.go" + - "./pkg/models/detailed_metrics.go" - "./pkg/models/delete_alerts_response.go" + - "./pkg/models/remediation_components_metrics.go" + - "./pkg/models/console_options.go" - "./pkg/models/topx_response.go" - "./pkg/models/add_signals_request.go" - "./pkg/models/delete_decision_response.go" @@ -60,24 +72,34 @@ ignore: - "./pkg/models/source.go" - "./pkg/models/decisions_stream_response.go" - "./pkg/models/error_response.go" + - "./pkg/models/all_metrics.go" + - "./pkg/models/o_sversion.go" - "./pkg/models/decision.go" - "./pkg/models/decisions_delete_request.go" - "./pkg/models/flush_decision_response.go" - "./pkg/models/watcher_auth_response.go" + - "./pkg/models/lapi_metrics.go" - "./pkg/models/watcher_registration_request.go" - "./pkg/models/metrics_agent_info.go" + - "./pkg/models/log_processors_metrics.go" - "./pkg/models/add_signals_request_item_source.go" + - "./pkg/models/base_metrics.go" - "./pkg/models/add_alerts_response.go" - "./pkg/models/metrics.go" - "./pkg/protobufs/notifier.pb.go" + - "./pkg/protobufs/notifier_grpc.pb.go" + - "./pkg/database/ent/metric_update.go" - "./pkg/database/ent/machine_delete.go" - "./pkg/database/ent/decision_query.go" - "./pkg/database/ent/meta_query.go" + - "./pkg/database/ent/metric/where.go" + - "./pkg/database/ent/metric/metric.go" - "./pkg/database/ent/machine_create.go" - "./pkg/database/ent/alert.go" - "./pkg/database/ent/event_update.go" - "./pkg/database/ent/alert_create.go" - "./pkg/database/ent/alert_query.go" + - "./pkg/database/ent/metric_delete.go" - "./pkg/database/ent/lock_create.go" - "./pkg/database/ent/bouncer_update.go" - "./pkg/database/ent/meta_update.go" @@ -92,6 +114,7 @@ ignore: - "./pkg/database/ent/migrate/migrate.go" - "./pkg/database/ent/migrate/schema.go" - "./pkg/database/ent/configitem.go" + - "./pkg/database/ent/metric_query.go" - "./pkg/database/ent/event.go" - "./pkg/database/ent/event_query.go" - "./pkg/database/ent/lock_update.go" @@ -111,6 +134,7 @@ ignore: - "./pkg/database/ent/bouncer/bouncer.go" - "./pkg/database/ent/bouncer/where.go" - "./pkg/database/ent/hook/hook.go" + - "./pkg/database/ent/metric.go" - "./pkg/database/ent/configitem_create.go" - "./pkg/database/ent/configitem_delete.go" - "./pkg/database/ent/tx.go" @@ -120,6 +144,7 @@ ignore: - "./pkg/database/ent/machine/where.go" - "./pkg/database/ent/machine/machine.go" - "./pkg/database/ent/event_create.go" + - "./pkg/database/ent/metric_create.go" - "./pkg/database/ent/decision/where.go" - "./pkg/database/ent/decision/decision.go" - "./pkg/database/ent/enttest/enttest.go" diff --git a/.github/generate-codecov-yml.sh b/.github/generate-codecov-yml.sh new file mode 100755 index 00000000000..ddb60d0ce80 --- /dev/null +++ b/.github/generate-codecov-yml.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Run this from the repository root: +# +# .github/generate-codecov-yml.sh >> .github/codecov.yml + +cat <> .github/codecov.yml + - name: "Run tests" run: ./test/run-tests ./test/bats --formatter $(pwd)/test/lib/color-formatter diff --git a/.github/workflows/ci-windows-build-msi.yml b/.github/workflows/ci-windows-build-msi.yml index 03cdb4bd871..07e29071e05 100644 --- a/.github/workflows/ci-windows-build-msi.yml +++ b/.github/workflows/ci-windows-build-msi.yml @@ -35,7 +35,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.22.6" + go-version: "1.23" - name: Build run: make windows_installer BUILD_RE2_WASM=1 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 42b52490ea8..4128cb435f9 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -52,7 +52,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.22.6" + go-version: "1.23" cache-dependency-path: "**/go.sum" # Initializes the CodeQL tools for scanning. diff --git a/.github/workflows/docker-tests.yml b/.github/workflows/docker-tests.yml index 228a0829984..918f3bcaf1d 100644 --- a/.github/workflows/docker-tests.yml +++ b/.github/workflows/docker-tests.yml @@ -53,23 +53,12 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.x" - - - name: "Install pipenv" - run: | - cd docker/test - python -m pip install --upgrade pipenv wheel - - - name: "Cache virtualenvs" - id: cache-pipenv - uses: actions/cache@v4 - with: - path: ~/.local/share/virtualenvs - key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }} + cache: 'pipenv' - name: "Install dependencies" - if: steps.cache-pipenv.outputs.cache-hit != 'true' run: | cd docker/test + python -m pip install --upgrade pipenv wheel pipenv install --deploy - name: "Create Docker network" diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml index 5a463bab99c..2966b999a4a 100644 --- a/.github/workflows/go-tests-windows.yml +++ b/.github/workflows/go-tests-windows.yml @@ -34,12 +34,16 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.22.6" + go-version: "1.23" - name: Build run: | make build BUILD_RE2_WASM=1 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: Run tests run: | go install github.com/kyoh86/richgo@v0.3.10 @@ -57,6 +61,6 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.59 + version: v1.61 args: --issues-exit-code=1 --timeout 10m only-new-issues: false diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml index 58b8dc61a0d..3f4aa67e139 100644 --- a/.github/workflows/go-tests.yml +++ b/.github/workflows/go-tests.yml @@ -126,13 +126,40 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.22.6" + go-version: "1.23" + + - name: Run "go generate" and check for changes + run: | + set -e + # ensure the version of 'protoc' matches the one that generated the files + PROTOBUF_VERSION="21.12" + # don't pollute the repo + pushd $HOME + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip + unzip protoc-${PROTOBUF_VERSION}-linux-x86_64.zip -d $HOME/.protoc + popd + export PATH="$HOME/.protoc/bin:$PATH" + go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 + go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + go generate ./... + protoc --version + if [[ $(git status --porcelain) ]]; then + echo "Error: Uncommitted changes found after running 'make generate'. Please commit all generated code." + git diff + exit 1 + else + echo "No changes detected after running 'make generate'." + fi - name: Create localstack streams run: | aws --endpoint-url=http://127.0.0.1:4566 --region us-east-1 kinesis create-stream --stream-name stream-1-shard --shard-count 1 aws --endpoint-url=http://127.0.0.1:4566 --region us-east-1 kinesis create-stream --stream-name stream-2-shards --shard-count 2 + - name: Generate codecov configuration + run: | + .github/generate-codecov-yml.sh >> .github/codecov.yml + - name: Build and run tests, static run: | sudo apt -qq -y -o=Dpkg::Use-Pty=0 install build-essential libre2-dev @@ -142,6 +169,11 @@ jobs: make build BUILD_STATIC=1 make go-acc | sed 's/ *coverage:.*of statements in.*//' | richgo testfilter + # check if some component stubs are missing + - name: "Build profile: minimal" + run: | + make build BUILD_PROFILE=minimal + - name: Run tests again, dynamic run: | make clean build @@ -158,6 +190,6 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.59 + version: v1.61 args: --issues-exit-code=1 --timeout 10m only-new-issues: false diff --git a/.github/workflows/publish-tarball-release.yml b/.github/workflows/publish-tarball-release.yml index 2f809a29a9b..6a41c3fba53 100644 --- a/.github/workflows/publish-tarball-release.yml +++ b/.github/workflows/publish-tarball-release.yml @@ -25,7 +25,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v5 with: - go-version: "1.22.6" + go-version: "1.23" - name: Build the binaries run: | diff --git a/.golangci.yml b/.golangci.yml index fb1dab623c1..acde901dbe6 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,14 +20,14 @@ linters-settings: maintidx: # raise this after refactoring - under: 16 + under: 15 misspell: locale: US nestif: # lower this after refactoring - min-complexity: 24 + min-complexity: 16 nlreturn: block-size: 5 @@ -103,7 +103,7 @@ linters-settings: disabled: true - name: cyclomatic # lower this after refactoring - arguments: [42] + arguments: [39] - name: defer disabled: true - name: empty-block @@ -118,7 +118,7 @@ linters-settings: arguments: [6] - name: function-length # lower this after refactoring - arguments: [110, 235] + arguments: [110, 237] - name: get-return disabled: true - name: increment-decrement @@ -135,14 +135,10 @@ linters-settings: arguments: [7] - name: max-public-structs disabled: true - - name: optimize-operands-order - disabled: true - name: nested-structs disabled: true - name: package-comments disabled: true - - name: struct-tag - disabled: true - name: redundant-import-alias disabled: true - name: time-equal @@ -178,6 +174,37 @@ linters-settings: # Allow blocks to end with comments allow-trailing-comment: true + gocritic: + enable-all: true + disabled-checks: + - typeDefFirst + - paramTypeCombine + - httpNoBody + - ifElseChain + - importShadow + - hugeParam + - rangeValCopy + - commentedOutCode + - commentedOutImport + - unnamedResult + - sloppyReassign + - appendCombine + - captLocal + - typeUnparen + - commentFormatting + - deferInLoop # + - sprintfQuotedString # + - whyNoLint + - equalFold # + - unnecessaryBlock # + - ptrToRefParam # + - stringXbytes # + - appendAssign # + - tooManyResultsChecker + - unnecessaryDefer + - docStub + - preferFprint + linters: enable-all: true disable: @@ -185,6 +212,8 @@ linters: # DEPRECATED by golangi-lint # - execinquery + - exportloopref + - gomnd # # Redundant @@ -196,75 +225,9 @@ linters: - funlen # revive - gocognit # revive - # - # Disabled until fixed for go 1.22 - # - - - copyloopvar # copyloopvar is a linter detects places where loop variables are copied - - intrange # intrange is a linter to find places where for loops could make use of an integer range. + # Disabled atm - # - # Enabled - # - - # - asasalint # check for pass []any as any in variadic func(...any) - # - asciicheck # checks that all code identifiers does not have non-ASCII symbols in the name - # - bidichk # Checks for dangerous unicode character sequences - # - bodyclose # checks whether HTTP response body is closed successfully - # - decorder # check declaration order and count of types, constants, variables and functions - # - depguard # Go linter that checks if package imports are in a list of acceptable packages - # - dupword # checks for duplicate words in the source code - # - durationcheck # check for two durations multiplied together - # - errcheck # errcheck is a program for checking for unchecked errors in Go code. These unchecked errors can be critical bugs in some cases - # - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - # - exportloopref # checks for pointers to enclosing loop variables - # - ginkgolinter # enforces standards of using ginkgo and gomega - # - gocheckcompilerdirectives # Checks that go compiler directive comments (//go:) are valid. - # - gochecknoinits # Checks that no init functions are present in Go code - # - gochecksumtype # Run exhaustiveness checks on Go "sum types" - # - gocritic # Provides diagnostics that check for bugs, performance and style issues. - # - goheader # Checks is file header matches to pattern - # - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - # - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - # - goprintffuncname # Checks that printf-like functions are named with `f` at the end - # - gosimple # (megacheck): Linter for Go source code that specializes in simplifying code - # - gosmopolitan # Report certain i18n/l10n anti-patterns in your Go codebase - # - govet # (vet, vetshadow): Vet examines Go source code and reports suspicious constructs. It is roughly the same as 'go vet' and uses its passes. - # - grouper # Analyze expression groups. - # - importas # Enforces consistent import aliases - # - ineffassign # Detects when assignments to existing variables are not used - # - interfacebloat # A linter that checks the number of methods inside an interface. - # - loggercheck # (logrlint): Checks key value pairs for common logger libraries (kitlog,klog,logr,zap). - # - logrlint # Check logr arguments. - # - maintidx # maintidx measures the maintainability index of each function. - # - makezero # Finds slice declarations with non-zero initial length - # - mirror # reports wrong mirror patterns of bytes/strings usage - # - misspell # Finds commonly misspelled English words - # - nakedret # Checks that functions with naked returns are not longer than a maximum size (can be zero). - # - nestif # Reports deeply nested if statements - # - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - # - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - # - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - # - perfsprint # Checks that fmt.Sprintf can be replaced with a faster alternative. - # - predeclared # find code that shadows one of Go's predeclared identifiers - # - reassign # Checks that package variables are not reassigned - # - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - # - rowserrcheck # checks whether Rows.Err of rows is checked successfully - # - sloglint # ensure consistent code style when using log/slog - # - spancheck # Checks for mistakes with OpenTelemetry/Census spans. - # - sqlclosecheck # Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed. - # - staticcheck # (megacheck): It's a set of rules from staticcheck. It's not the same thing as the staticcheck binary. The author of staticcheck doesn't support or approve the use of staticcheck as a library inside golangci-lint. - # - stylecheck # Stylecheck is a replacement for golint - # - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - # - testableexamples # linter checks if examples are testable (have an expected output) - # - testifylint # Checks usage of github.com/stretchr/testify. - # - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - # - unconvert # Remove unnecessary type conversions - # - unused # (megacheck): Checks Go code for unused constants, variables, functions and types - # - usestdlibvars # A linter that detect the possibility to use variables/constants from the Go standard library. - # - wastedassign # Finds wasted assignment statements - # - zerologlint # Detects the wrong usage of `zerolog` that a user forgets to dispatch with `Send` or `Msg` + - intrange # intrange is a linter to find places where for loops could make use of an integer range. # # Recommended? (easy) @@ -291,9 +254,7 @@ linters: # - containedctx # containedctx is a linter that detects struct contained context.Context field - - contextcheck # check whether the function uses a non-inherited context - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - - gomnd # An analyzer to detect magic numbers. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. @@ -359,28 +320,12 @@ issues: # `err` is often shadowed, we may continue to do it - linters: - govet - text: "shadow: declaration of \"err\" shadows declaration" + text: "shadow: declaration of \"(err|ctx)\" shadows declaration" - linters: - errcheck text: "Error return value of `.*` is not checked" - - linters: - - gocritic - text: "ifElseChain: rewrite if-else to switch statement" - - - linters: - - gocritic - text: "captLocal: `.*' should not be capitalized" - - - linters: - - gocritic - text: "appendAssign: append result not assigned to the same slice" - - - linters: - - gocritic - text: "commentFormatting: put a space between `//` and comment text" - # Will fix, trivial - just beware of merge conflicts - linters: @@ -403,10 +348,6 @@ issues: - errorlint text: "type switch on error will fail on wrapped errors. Use errors.As to check for specific errors" - - linters: - - errorlint - text: "comparing with .* will fail on wrapped errors. Use errors.Is to check for a specific error" - - linters: - nosprintfhostport text: "host:port in url should be constructed with net.JoinHostPort and not directly with fmt.Sprintf" @@ -474,25 +415,26 @@ issues: path: "pkg/(.+)_test.go" text: "deep-exit: .*" - # tolerate deep exit in cobra's OnInitialize, for now + # we use t,ctx instead of ctx,t in tests - linters: - revive - path: "cmd/crowdsec-cli/main.go" - text: "deep-exit: .*" + path: "pkg/(.+)_test.go" + text: "context-as-argument: context.Context should be the first parameter of a function" + # tolerate deep exit in cobra's OnInitialize, for now - linters: - revive - path: "cmd/crowdsec-cli/item_metrics.go" + path: "cmd/crowdsec-cli/main.go" text: "deep-exit: .*" - linters: - revive - path: "cmd/crowdsec-cli/machines.go" + path: "cmd/crowdsec-cli/clihub/item_metrics.go" text: "deep-exit: .*" - linters: - revive - path: "cmd/crowdsec-cli/utils.go" + path: "cmd/crowdsec-cli/idgen/password.go" text: "deep-exit: .*" - linters: diff --git a/Dockerfile b/Dockerfile index 731e08fb1a6..880df88dc02 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # vim: set ft=dockerfile: -FROM golang:1.22.6-alpine3.20 AS build +FROM golang:1.23-alpine3.20 AS build ARG BUILD_VERSION @@ -16,7 +16,7 @@ RUN apk add --no-cache git g++ gcc libc-dev make bash gettext binutils-gold core cd re2-${RE2_VERSION} && \ make install && \ echo "githubciXXXXXXXXXXXXXXXXXXXXXXXX" > /etc/machine-id && \ - go install github.com/mikefarah/yq/v4@v4.43.1 + go install github.com/mikefarah/yq/v4@v4.44.3 COPY . . diff --git a/Dockerfile.debian b/Dockerfile.debian index ec961a4a1ec..5d47f167e99 100644 --- a/Dockerfile.debian +++ b/Dockerfile.debian @@ -1,5 +1,5 @@ # vim: set ft=dockerfile: -FROM golang:1.22.6-bookworm AS build +FROM golang:1.23-bookworm AS build ARG BUILD_VERSION @@ -21,7 +21,7 @@ RUN apt-get update && \ make && \ make install && \ echo "githubciXXXXXXXXXXXXXXXXXXXXXXXX" > /etc/machine-id && \ - go install github.com/mikefarah/yq/v4@v4.43.1 + go install github.com/mikefarah/yq/v4@v4.44.3 COPY . . diff --git a/Makefile b/Makefile index 207b5d610f0..f8ae66e1cb6 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ BUILD_RE2_WASM ?= 0 # for your distribution (look for libre2.a). See the Dockerfile for an example of how to build it. BUILD_STATIC ?= 0 -# List of plugins to build +# List of notification plugins to build PLUGINS ?= $(patsubst ./cmd/notification-%,%,$(wildcard ./cmd/notification-*)) #-------------------------------------- @@ -80,9 +80,17 @@ endif #expr_debug tag is required to enable the debug mode in expr GO_TAGS := netgo,osusergo,sqlite_omit_load_extension,expr_debug +# Allow building on ubuntu 24.10, see https://github.com/golang/go/issues/70023 +export CGO_LDFLAGS_ALLOW=-Wl,--(push|pop)-state.* + # this will be used by Go in the make target, some distributions require it export PKG_CONFIG_PATH:=/usr/local/lib/pkgconfig:$(PKG_CONFIG_PATH) +#-------------------------------------- +# +# Choose the re2 backend. +# + ifeq ($(call bool,$(BUILD_RE2_WASM)),0) ifeq ($(PKG_CONFIG),) $(error "pkg-config is not available. Please install pkg-config.") @@ -90,14 +98,88 @@ endif ifeq ($(RE2_CHECK),) RE2_FAIL := "libre2-dev is not installed, please install it or set BUILD_RE2_WASM=1 to use the WebAssembly version" +# if you prefer to build WASM instead of a critical error, comment out RE2_FAIL and uncomment RE2_MSG. +# RE2_MSG := Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config. else # += adds a space that we don't want GO_TAGS := $(GO_TAGS),re2_cgo LD_OPTS_VARS += -X '$(GO_MODULE_NAME)/pkg/cwversion.Libre2=C++' +RE2_MSG := Using C++ regexp library +endif +else +RE2_MSG := Using WebAssembly regexp library +endif + +ifeq ($(call bool,$(BUILD_RE2_WASM)),1) +else +ifneq (,$(RE2_CHECK)) endif endif -# Build static to avoid the runtime dependency on libre2.so +#-------------------------------------- +# +# Handle optional components and build profiles, to save space on the final binaries. +# +# Keep it safe for now until we decide how to expand on the idea. Either choose a profile or exclude components manually. +# For example if we want to disable some component by default, or have opt-in components (INCLUDE?). + +ifeq ($(and $(BUILD_PROFILE),$(EXCLUDE)),1) +$(error "Cannot specify both BUILD_PROFILE and EXCLUDE") +endif + +COMPONENTS := \ + datasource_appsec \ + datasource_cloudwatch \ + datasource_docker \ + datasource_file \ + datasource_http \ + datasource_k8saudit \ + datasource_kafka \ + datasource_journalctl \ + datasource_kinesis \ + datasource_loki \ + datasource_s3 \ + datasource_syslog \ + datasource_wineventlog \ + cscli_setup + +comma := , +space := $(empty) $(empty) + +# Predefined profiles + +# keep only datasource-file +EXCLUDE_MINIMAL := $(subst $(space),$(comma),$(filter-out datasource_file,,$(COMPONENTS))) + +# example +# EXCLUDE_MEDIUM := datasource_kafka,datasource_kinesis,datasource_s3 + +BUILD_PROFILE ?= default + +# Set the EXCLUDE_LIST based on the chosen profile, unless EXCLUDE is already set +ifeq ($(BUILD_PROFILE),minimal) +EXCLUDE ?= $(EXCLUDE_MINIMAL) +else ifneq ($(BUILD_PROFILE),default) +$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default) +endif + +# Create list of excluded components from the EXCLUDE variable +EXCLUDE_LIST := $(subst $(comma),$(space),$(EXCLUDE)) + +INVALID_COMPONENTS := $(filter-out $(COMPONENTS),$(EXCLUDE_LIST)) +ifneq ($(INVALID_COMPONENTS),) +$(error Invalid optional components specified in EXCLUDE: $(INVALID_COMPONENTS). Valid components are: $(COMPONENTS)) +endif + +# Convert the excluded components to "no_" form +COMPONENT_TAGS := $(foreach component,$(EXCLUDE_LIST),no_$(component)) + +ifneq ($(COMPONENT_TAGS),) +GO_TAGS := $(GO_TAGS),$(subst $(space),$(comma),$(COMPONENT_TAGS)) +endif + +#-------------------------------------- + ifeq ($(call bool,$(BUILD_STATIC)),1) BUILD_TYPE = static EXTLDFLAGS := -extldflags '-static' @@ -111,7 +193,7 @@ ifeq ($(call bool,$(DEBUG)),1) STRIP_SYMBOLS := DISABLE_OPTIMIZATION := -gcflags "-N -l" else -STRIP_SYMBOLS := -s -w +STRIP_SYMBOLS := -s DISABLE_OPTIMIZATION := endif @@ -130,16 +212,13 @@ build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins .PHONY: build-info build-info: ## Print build information $(info Building $(BUILD_VERSION) ($(BUILD_TAG)) $(BUILD_TYPE) for $(GOOS)/$(GOARCH)) + $(info Excluded components: $(if $(EXCLUDE_LIST),$(EXCLUDE_LIST),none)) ifneq (,$(RE2_FAIL)) $(error $(RE2_FAIL)) endif -ifneq (,$(RE2_CHECK)) - $(info Using C++ regexp library) -else - $(info Fallback to WebAssembly regexp library. To use the C++ version, make sure you have installed libre2-dev and pkg-config.) -endif + $(info $(RE2_MSG)) ifeq ($(call bool,$(DEBUG)),1) $(info Building with debug symbols and disabled optimizations) @@ -199,11 +278,6 @@ cscli: ## Build cscli crowdsec: ## Build crowdsec @$(MAKE) -C $(CROWDSEC_FOLDER) build $(MAKE_FLAGS) -.PHONY: generate -generate: ## Generate code for the database and APIs - $(GO) generate ./pkg/database/ent - $(GO) generate ./pkg/models - .PHONY: testclean testclean: bats-clean ## Remove test artifacts @$(RM) pkg/apiserver/ent $(WIN_IGNORE_ERR) diff --git a/README.md b/README.md index a900f0ee514..1e57d4e91c4 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ The architecture is as follows : CrowdSec

-Once an unwanted behavior is detected, deal with it through a [bouncer](https://hub.crowdsec.net/browse/#bouncers). The aggressive IP, scenario triggered and timestamp are sent for curation, to avoid poisoning & false positives. (This can be disabled). If verified, this IP is then redistributed to all CrowdSec users running the same scenario. +Once an unwanted behavior is detected, deal with it through a [bouncer](https://app.crowdsec.net/hub/remediation-components). The aggressive IP, scenario triggered and timestamp are sent for curation, to avoid poisoning & false positives. (This can be disabled). If verified, this IP is then redistributed to all CrowdSec users running the same scenario. ## Outnumbering hackers all together diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0ceb9e5cffc..bcf327bdf38 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -21,7 +21,7 @@ stages: - task: GoTool@0 displayName: "Install Go" inputs: - version: '1.22.6' + version: '1.23.3' - pwsh: | choco install -y make diff --git a/cmd/crowdsec-cli/ask/ask.go b/cmd/crowdsec-cli/ask/ask.go new file mode 100644 index 00000000000..484ccb30c8a --- /dev/null +++ b/cmd/crowdsec-cli/ask/ask.go @@ -0,0 +1,20 @@ +package ask + +import ( + "github.com/AlecAivazis/survey/v2" +) + +func YesNo(message string, defaultAnswer bool) (bool, error) { + var answer bool + + prompt := &survey.Confirm{ + Message: message, + Default: defaultAnswer, + } + + if err := survey.AskOne(prompt, &answer); err != nil { + return defaultAnswer, err + } + + return answer, nil +} diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go deleted file mode 100644 index d3edcea0db9..00000000000 --- a/cmd/crowdsec-cli/bouncers.go +++ /dev/null @@ -1,537 +0,0 @@ -package main - -import ( - "encoding/csv" - "encoding/json" - "errors" - "fmt" - "io" - "os" - "slices" - "strings" - "time" - - "github.com/AlecAivazis/survey/v2" - "github.com/fatih/color" - "github.com/jedib0t/go-pretty/v6/table" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/database/ent" - "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" - "github.com/crowdsecurity/crowdsec/pkg/emoji" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -type featureflagProvider interface { - GetFeatureflags() string -} - -type osProvider interface { - GetOsname() string - GetOsversion() string -} - -func getOSNameAndVersion(o osProvider) string { - ret := o.GetOsname() - if o.GetOsversion() != "" { - if ret != "" { - ret += "/" - } - - ret += o.GetOsversion() - } - - if ret == "" { - return "?" - } - - return ret -} - -func getFeatureFlagList(o featureflagProvider) []string { - if o.GetFeatureflags() == "" { - return nil - } - - return strings.Split(o.GetFeatureflags(), ",") -} - -func askYesNo(message string, defaultAnswer bool) (bool, error) { - var answer bool - - prompt := &survey.Confirm{ - Message: message, - Default: defaultAnswer, - } - - if err := survey.AskOne(prompt, &answer); err != nil { - return defaultAnswer, err - } - - return answer, nil -} - -type cliBouncers struct { - db *database.Client - cfg configGetter -} - -func NewCLIBouncers(cfg configGetter) *cliBouncers { - return &cliBouncers{ - cfg: cfg, - } -} - -func (cli *cliBouncers) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "bouncers [action]", - Short: "Manage bouncers [requires local API]", - Long: `To list/add/delete/prune bouncers. -Note: This command requires database direct access, so is intended to be run on Local API/master. -`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"bouncer"}, - DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { - var err error - - cfg := cli.cfg() - - if err = require.LAPI(cfg); err != nil { - return err - } - - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) - if err != nil { - return err - } - - return nil - }, - } - - cmd.AddCommand(cli.newListCmd()) - cmd.AddCommand(cli.newAddCmd()) - cmd.AddCommand(cli.newDeleteCmd()) - cmd.AddCommand(cli.newPruneCmd()) - cmd.AddCommand(cli.newInspectCmd()) - - return cmd -} - -func (cli *cliBouncers) listHuman(out io.Writer, bouncers ent.Bouncers) { - t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer - t.AppendHeader(table.Row{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type"}) - - for _, b := range bouncers { - revoked := emoji.CheckMark - if b.Revoked { - revoked = emoji.Prohibited - } - - lastPull := "" - if b.LastPull != nil { - lastPull = b.LastPull.Format(time.RFC3339) - } - - t.AppendRow(table.Row{b.Name, b.IPAddress, revoked, lastPull, b.Type, b.Version, b.AuthType}) - } - - io.WriteString(out, t.Render() + "\n") -} - -// bouncerInfo contains only the data we want for inspect/list -type bouncerInfo struct { - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - Name string `json:"name"` - Revoked bool `json:"revoked"` - IPAddress string `json:"ip_address"` - Type string `json:"type"` - Version string `json:"version"` - LastPull *time.Time `json:"last_pull"` - AuthType string `json:"auth_type"` - OS string `json:"os,omitempty"` - Featureflags []string `json:"featureflags,omitempty"` -} - -func newBouncerInfo(b *ent.Bouncer) bouncerInfo { - return bouncerInfo{ - CreatedAt: b.CreatedAt, - UpdatedAt: b.UpdatedAt, - Name: b.Name, - Revoked: b.Revoked, - IPAddress: b.IPAddress, - Type: b.Type, - Version: b.Version, - LastPull: b.LastPull, - AuthType: b.AuthType, - OS: getOSNameAndVersion(b), - Featureflags: getFeatureFlagList(b), - } -} - -func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { - csvwriter := csv.NewWriter(out) - - if err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}); err != nil { - return fmt.Errorf("failed to write raw header: %w", err) - } - - for _, b := range bouncers { - valid := "validated" - if b.Revoked { - valid = "pending" - } - - lastPull := "" - if b.LastPull != nil { - lastPull = b.LastPull.Format(time.RFC3339) - } - - if err := csvwriter.Write([]string{b.Name, b.IPAddress, valid, lastPull, b.Type, b.Version, b.AuthType}); err != nil { - return fmt.Errorf("failed to write raw: %w", err) - } - } - - csvwriter.Flush() - - return nil -} - -func (cli *cliBouncers) list(out io.Writer) error { - bouncers, err := cli.db.ListBouncers() - if err != nil { - return fmt.Errorf("unable to list bouncers: %w", err) - } - - switch cli.cfg().Cscli.Output { - case "human": - cli.listHuman(out, bouncers) - case "json": - info := make([]bouncerInfo, 0, len(bouncers)) - for _, b := range bouncers { - info = append(info, newBouncerInfo(b)) - } - - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - - if err := enc.Encode(info); err != nil { - return errors.New("failed to marshal") - } - - return nil - case "raw": - return cli.listCSV(out, bouncers) - } - - return nil -} - -func (cli *cliBouncers) newListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "list", - Short: "list all bouncers within the database", - Example: `cscli bouncers list`, - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.list(color.Output) - }, - } - - return cmd -} - -func (cli *cliBouncers) add(bouncerName string, key string) error { - var err error - - keyLength := 32 - - if key == "" { - key, err = middlewares.GenerateAPIKey(keyLength) - if err != nil { - return fmt.Errorf("unable to generate api key: %w", err) - } - } - - _, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType) - if err != nil { - return fmt.Errorf("unable to create bouncer: %w", err) - } - - switch cli.cfg().Cscli.Output { - case "human": - fmt.Printf("API key for '%s':\n\n", bouncerName) - fmt.Printf(" %s\n\n", key) - fmt.Print("Please keep this key since you will not be able to retrieve it!\n") - case "raw": - fmt.Print(key) - case "json": - j, err := json.Marshal(key) - if err != nil { - return errors.New("unable to marshal api key") - } - - fmt.Print(string(j)) - } - - return nil -} - -func (cli *cliBouncers) newAddCmd() *cobra.Command { - var key string - - cmd := &cobra.Command{ - Use: "add MyBouncerName", - Short: "add a single bouncer to the database", - Example: `cscli bouncers add MyBouncerName -cscli bouncers add MyBouncerName --key `, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args[0], key) - }, - } - - flags := cmd.Flags() - flags.StringP("length", "l", "", "length of the api key") - _ = flags.MarkDeprecated("length", "use --key instead") - flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") - - return cmd -} - -// validBouncerID returns a list of bouncer IDs for command completion -func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - var err error - - cfg := cli.cfg() - - // need to load config and db because PersistentPreRunE is not called for completions - - if err = require.LAPI(cfg); err != nil { - cobra.CompError("unable to list bouncers " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) - if err != nil { - cobra.CompError("unable to list bouncers " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - - bouncers, err := cli.db.ListBouncers() - if err != nil { - cobra.CompError("unable to list bouncers " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - - ret := []string{} - - for _, bouncer := range bouncers { - if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { - ret = append(ret, bouncer.Name) - } - } - - return ret, cobra.ShellCompDirectiveNoFileComp -} - -func (cli *cliBouncers) delete(bouncers []string, ignoreMissing bool) error { - for _, bouncerID := range bouncers { - if err := cli.db.DeleteBouncer(bouncerID); err != nil { - var notFoundErr *database.BouncerNotFoundError - if ignoreMissing && errors.As(err, ¬FoundErr) { - return nil - } - - return fmt.Errorf("unable to delete bouncer: %w", err) - } - - log.Infof("bouncer '%s' deleted successfully", bouncerID) - } - - return nil -} - -func (cli *cliBouncers) newDeleteCmd() *cobra.Command { - var ignoreMissing bool - - cmd := &cobra.Command{ - Use: "delete MyBouncerName", - Short: "delete bouncer(s) from the database", - Example: `cscli bouncers delete "bouncer1" "bouncer2"`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"remove"}, - DisableAutoGenTag: true, - ValidArgsFunction: cli.validBouncerID, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args, ignoreMissing) - }, - } - - flags := cmd.Flags() - flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more bouncers don't exist") - - return cmd -} - -func (cli *cliBouncers) prune(duration time.Duration, force bool) error { - if duration < 2*time.Minute { - if yes, err := askYesNo( - "The duration you provided is less than 2 minutes. "+ - "This may remove active bouncers. Continue?", false); err != nil { - return err - } else if !yes { - fmt.Println("User aborted prune. No changes were made.") - return nil - } - } - - bouncers, err := cli.db.QueryBouncersInactiveSince(time.Now().UTC().Add(-duration)) - if err != nil { - return fmt.Errorf("unable to query bouncers: %w", err) - } - - if len(bouncers) == 0 { - fmt.Println("No bouncers to prune.") - return nil - } - - cli.listHuman(color.Output, bouncers) - - if !force { - if yes, err := askYesNo( - "You are about to PERMANENTLY remove the above bouncers from the database. "+ - "These will NOT be recoverable. Continue?", false); err != nil { - return err - } else if !yes { - fmt.Println("User aborted prune. No changes were made.") - return nil - } - } - - deleted, err := cli.db.BulkDeleteBouncers(bouncers) - if err != nil { - return fmt.Errorf("unable to prune bouncers: %w", err) - } - - fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) - - return nil -} - -func (cli *cliBouncers) newPruneCmd() *cobra.Command { - var ( - duration time.Duration - force bool - ) - - const defaultDuration = 60 * time.Minute - - cmd := &cobra.Command{ - Use: "prune", - Short: "prune multiple bouncers from the database", - Args: cobra.NoArgs, - DisableAutoGenTag: true, - Example: `cscli bouncers prune -d 45m -cscli bouncers prune -d 45m --force`, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, force) - }, - } - - flags := cmd.Flags() - flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since last pull") - flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") - - return cmd -} - -func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) { - t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer - - t.SetTitle("Bouncer: " + bouncer.Name) - - t.SetColumnConfigs([]table.ColumnConfig{ - {Number: 1, AutoMerge: true}, - }) - - lastPull := "" - if bouncer.LastPull != nil { - lastPull = bouncer.LastPull.String() - } - - t.AppendRows([]table.Row{ - {"Created At", bouncer.CreatedAt}, - {"Last Update", bouncer.UpdatedAt}, - {"Revoked?", bouncer.Revoked}, - {"IP Address", bouncer.IPAddress}, - {"Type", bouncer.Type}, - {"Version", bouncer.Version}, - {"Last Pull", lastPull}, - {"Auth type", bouncer.AuthType}, - {"OS", getOSNameAndVersion(bouncer)}, - }) - - for _, ff := range getFeatureFlagList(bouncer) { - t.AppendRow(table.Row{"Feature Flags", ff}) - } - - io.WriteString(out, t.Render() + "\n") -} - -func (cli *cliBouncers) inspect(bouncer *ent.Bouncer) error { - out := color.Output - outputFormat := cli.cfg().Cscli.Output - - switch outputFormat { - case "human": - cli.inspectHuman(out, bouncer) - case "json": - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - - if err := enc.Encode(newBouncerInfo(bouncer)); err != nil { - return errors.New("failed to marshal") - } - - return nil - default: - return fmt.Errorf("output format '%s' not supported for this command", outputFormat) - } - - return nil -} - -func (cli *cliBouncers) newInspectCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "inspect [bouncer_name]", - Short: "inspect a bouncer by name", - Example: `cscli bouncers inspect "bouncer1"`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: cli.validBouncerID, - RunE: func(cmd *cobra.Command, args []string) error { - bouncerName := args[0] - - b, err := cli.db.Ent.Bouncer.Query(). - Where(bouncer.Name(bouncerName)). - Only(cmd.Context()) - if err != nil { - return fmt.Errorf("unable to read bouncer data '%s': %w", bouncerName, err) - } - - return cli.inspect(b) - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/alerts.go b/cmd/crowdsec-cli/clialert/alerts.go similarity index 77% rename from cmd/crowdsec-cli/alerts.go rename to cmd/crowdsec-cli/clialert/alerts.go index 37f9ab435c7..5907d4a0fa8 100644 --- a/cmd/crowdsec-cli/alerts.go +++ b/cmd/crowdsec-cli/clialert/alerts.go @@ -1,4 +1,4 @@ -package main +package clialert import ( "context" @@ -24,19 +24,19 @@ import ( "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) -func DecisionsFromAlert(alert *models.Alert) string { +func decisionsFromAlert(alert *models.Alert) string { ret := "" decMap := make(map[string]int) for _, decision := range alert.Decisions { k := *decision.Type if *decision.Simulated { - k = fmt.Sprintf("(simul)%s", k) + k = "(simul)" + k } v := decMap[k] @@ -44,7 +44,7 @@ func DecisionsFromAlert(alert *models.Alert) string { } for _, key := range maptools.SortedKeys(decMap) { - if len(ret) > 0 { + if ret != "" { ret += " " } @@ -77,7 +77,7 @@ func (cli *cliAlerts) alertsToTable(alerts *models.GetAlertsResponse, printMachi *alertItem.Scenario, alertItem.Source.Cn, alertItem.Source.GetAsNumberName(), - DecisionsFromAlert(alertItem), + decisionsFromAlert(alertItem), *alertItem.StartAt, } if printMachine { @@ -183,12 +183,14 @@ func (cli *cliAlerts) displayOneAlert(alert *models.Alert, withDetail bool) erro return nil } +type configGetter func() *csconfig.Config + type cliAlerts struct { client *apiclient.ApiClient cfg configGetter } -func NewCLIAlerts(getconfig configGetter) *cliAlerts { +func New(getconfig configGetter) *cliAlerts { return &cliAlerts{ cfg: getconfig, } @@ -214,7 +216,6 @@ func (cli *cliAlerts) NewCommand() *cobra.Command { cli.client, err = apiclient.NewClient(&apiclient.Config{ MachineID: cfg.API.Client.Credentials.Login, Password: strfmt.Password(cfg.API.Client.Credentials.Password), - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -226,17 +227,19 @@ func (cli *cliAlerts) NewCommand() *cobra.Command { }, } - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewInspectCmd()) - cmd.AddCommand(cli.NewFlushCmd()) - cmd.AddCommand(cli.NewDeleteCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newFlushCmd()) + cmd.AddCommand(cli.newDeleteCmd()) return cmd } -func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { - if err := manageCliDecisionAlerts(alertListFilter.IPEquals, alertListFilter.RangeEquals, - alertListFilter.ScopeEquals, alertListFilter.ValueEquals); err != nil { +func (cli *cliAlerts) list(ctx context.Context, alertListFilter apiclient.AlertsListOpts, limit *int, contained *bool, printMachine bool) error { + var err error + + *alertListFilter.ScopeEquals, err = SanitizeScope(*alertListFilter.ScopeEquals, *alertListFilter.IPEquals, *alertListFilter.RangeEquals) + if err != nil { return err } @@ -308,7 +311,7 @@ func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, alertListFilter.Contains = new(bool) } - alerts, _, err := cli.client.Alerts.List(context.Background(), alertListFilter) + alerts, _, err := cli.client.Alerts.List(ctx, alertListFilter) if err != nil { return fmt.Errorf("unable to list alerts: %w", err) } @@ -320,7 +323,7 @@ func (cli *cliAlerts) list(alertListFilter apiclient.AlertsListOpts, limit *int, return nil } -func (cli *cliAlerts) NewListCmd() *cobra.Command { +func (cli *cliAlerts) newListCmd() *cobra.Command { alertListFilter := apiclient.AlertsListOpts{ ScopeEquals: new(string), ValueEquals: new(string), @@ -351,7 +354,7 @@ cscli alerts list --type ban`, Long: `List alerts with optional filters`, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.list(alertListFilter, limit, contained, printMachine) + return cli.list(cmd.Context(), alertListFilter, limit, contained, printMachine) }, } @@ -374,58 +377,58 @@ cscli alerts list --type ban`, return cmd } -func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, ActiveDecision *bool, AlertDeleteAll bool, delAlertByID string, contained *bool) error { +func (cli *cliAlerts) delete(ctx context.Context, delFilter apiclient.AlertsDeleteOpts, activeDecision *bool, deleteAll bool, delAlertByID string, contained *bool) error { var err error - if !AlertDeleteAll { - if err = manageCliDecisionAlerts(alertDeleteFilter.IPEquals, alertDeleteFilter.RangeEquals, - alertDeleteFilter.ScopeEquals, alertDeleteFilter.ValueEquals); err != nil { + if !deleteAll { + *delFilter.ScopeEquals, err = SanitizeScope(*delFilter.ScopeEquals, *delFilter.IPEquals, *delFilter.RangeEquals) + if err != nil { return err } - if ActiveDecision != nil { - alertDeleteFilter.ActiveDecisionEquals = ActiveDecision + if activeDecision != nil { + delFilter.ActiveDecisionEquals = activeDecision } - if *alertDeleteFilter.ScopeEquals == "" { - alertDeleteFilter.ScopeEquals = nil + if *delFilter.ScopeEquals == "" { + delFilter.ScopeEquals = nil } - if *alertDeleteFilter.ValueEquals == "" { - alertDeleteFilter.ValueEquals = nil + if *delFilter.ValueEquals == "" { + delFilter.ValueEquals = nil } - if *alertDeleteFilter.ScenarioEquals == "" { - alertDeleteFilter.ScenarioEquals = nil + if *delFilter.ScenarioEquals == "" { + delFilter.ScenarioEquals = nil } - if *alertDeleteFilter.IPEquals == "" { - alertDeleteFilter.IPEquals = nil + if *delFilter.IPEquals == "" { + delFilter.IPEquals = nil } - if *alertDeleteFilter.RangeEquals == "" { - alertDeleteFilter.RangeEquals = nil + if *delFilter.RangeEquals == "" { + delFilter.RangeEquals = nil } if contained != nil && *contained { - alertDeleteFilter.Contains = new(bool) + delFilter.Contains = new(bool) } limit := 0 - alertDeleteFilter.Limit = &limit + delFilter.Limit = &limit } else { limit := 0 - alertDeleteFilter = apiclient.AlertsDeleteOpts{Limit: &limit} + delFilter = apiclient.AlertsDeleteOpts{Limit: &limit} } var alerts *models.DeleteAlertsResponse if delAlertByID == "" { - alerts, _, err = cli.client.Alerts.Delete(context.Background(), alertDeleteFilter) + alerts, _, err = cli.client.Alerts.Delete(ctx, delFilter) if err != nil { return fmt.Errorf("unable to delete alerts: %w", err) } } else { - alerts, _, err = cli.client.Alerts.DeleteOne(context.Background(), delAlertByID) + alerts, _, err = cli.client.Alerts.DeleteOne(ctx, delAlertByID) if err != nil { return fmt.Errorf("unable to delete alert: %w", err) } @@ -436,14 +439,14 @@ func (cli *cliAlerts) delete(alertDeleteFilter apiclient.AlertsDeleteOpts, Activ return nil } -func (cli *cliAlerts) NewDeleteCmd() *cobra.Command { +func (cli *cliAlerts) newDeleteCmd() *cobra.Command { var ( - ActiveDecision *bool - AlertDeleteAll bool + activeDecision *bool + deleteAll bool delAlertByID string ) - alertDeleteFilter := apiclient.AlertsDeleteOpts{ + delFilter := apiclient.AlertsDeleteOpts{ ScopeEquals: new(string), ValueEquals: new(string), ScenarioEquals: new(string), @@ -462,14 +465,14 @@ cscli alerts delete --range 1.2.3.0/24 cscli alerts delete -s crowdsecurity/ssh-bf"`, DisableAutoGenTag: true, Aliases: []string{"remove"}, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, PreRunE: func(cmd *cobra.Command, _ []string) error { - if AlertDeleteAll { + if deleteAll { return nil } - if *alertDeleteFilter.ScopeEquals == "" && *alertDeleteFilter.ValueEquals == "" && - *alertDeleteFilter.ScenarioEquals == "" && *alertDeleteFilter.IPEquals == "" && - *alertDeleteFilter.RangeEquals == "" && delAlertByID == "" { + if *delFilter.ScopeEquals == "" && *delFilter.ValueEquals == "" && + *delFilter.ScenarioEquals == "" && *delFilter.IPEquals == "" && + *delFilter.RangeEquals == "" && delAlertByID == "" { _ = cmd.Usage() return errors.New("at least one filter or --all must be specified") } @@ -477,25 +480,25 @@ cscli alerts delete -s crowdsecurity/ssh-bf"`, return nil }, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.delete(alertDeleteFilter, ActiveDecision, AlertDeleteAll, delAlertByID, contained) + return cli.delete(cmd.Context(), delFilter, activeDecision, deleteAll, delAlertByID, contained) }, } flags := cmd.Flags() flags.SortFlags = false - flags.StringVar(alertDeleteFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") - flags.StringVarP(alertDeleteFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") - flags.StringVarP(alertDeleteFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") - flags.StringVarP(alertDeleteFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") - flags.StringVarP(alertDeleteFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") + flags.StringVar(delFilter.ScopeEquals, "scope", "", "the scope (ie. ip,range)") + flags.StringVarP(delFilter.ValueEquals, "value", "v", "", "the value to match for in the specified scope") + flags.StringVarP(delFilter.ScenarioEquals, "scenario", "s", "", "the scenario (ie. crowdsecurity/ssh-bf)") + flags.StringVarP(delFilter.IPEquals, "ip", "i", "", "Source ip (shorthand for --scope ip --value )") + flags.StringVarP(delFilter.RangeEquals, "range", "r", "", "Range source ip (shorthand for --scope range --value )") flags.StringVar(&delAlertByID, "id", "", "alert ID") - flags.BoolVarP(&AlertDeleteAll, "all", "a", false, "delete all alerts") + flags.BoolVarP(&deleteAll, "all", "a", false, "delete all alerts") flags.BoolVar(contained, "contained", false, "query decisions contained by range") return cmd } -func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { +func (cli *cliAlerts) inspect(ctx context.Context, details bool, alertIDs ...string) error { cfg := cli.cfg() for _, alertID := range alertIDs { @@ -504,7 +507,7 @@ func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { return fmt.Errorf("bad alert id %s", alertID) } - alert, _, err := cli.client.Alerts.GetByID(context.Background(), id) + alert, _, err := cli.client.Alerts.GetByID(ctx, id) if err != nil { return fmt.Errorf("can't find alert with id %s: %w", alertID, err) } @@ -518,14 +521,14 @@ func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { case "json": data, err := json.MarshalIndent(alert, "", " ") if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %w", alertID, err) + return fmt.Errorf("unable to serialize alert with id %s: %w", alertID, err) } fmt.Printf("%s\n", string(data)) case "raw": data, err := yaml.Marshal(alert) if err != nil { - return fmt.Errorf("unable to marshal alert with id %s: %w", alertID, err) + return fmt.Errorf("unable to serialize alert with id %s: %w", alertID, err) } fmt.Println(string(data)) @@ -535,7 +538,7 @@ func (cli *cliAlerts) inspect(details bool, alertIDs ...string) error { return nil } -func (cli *cliAlerts) NewInspectCmd() *cobra.Command { +func (cli *cliAlerts) newInspectCmd() *cobra.Command { var details bool cmd := &cobra.Command{ @@ -548,7 +551,7 @@ func (cli *cliAlerts) NewInspectCmd() *cobra.Command { _ = cmd.Help() return errors.New("missing alert_id") } - return cli.inspect(details, args...) + return cli.inspect(cmd.Context(), details, args...) }, } @@ -558,7 +561,7 @@ func (cli *cliAlerts) NewInspectCmd() *cobra.Command { return cmd } -func (cli *cliAlerts) NewFlushCmd() *cobra.Command { +func (cli *cliAlerts) newFlushCmd() *cobra.Command { var ( maxItems int maxAge string @@ -572,15 +575,17 @@ func (cli *cliAlerts) NewFlushCmd() *cobra.Command { DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { cfg := cli.cfg() + ctx := cmd.Context() + if err := require.LAPI(cfg); err != nil { return err } - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) + db, err := require.DBClient(ctx, cfg.DbConfig) if err != nil { return err } log.Info("Flushing alerts. !! This may take a long time !!") - err = db.FlushAlerts(maxAge, maxItems) + err = db.FlushAlerts(ctx, maxAge, maxItems) if err != nil { return fmt.Errorf("unable to flush alerts: %w", err) } diff --git a/cmd/crowdsec-cli/clialert/sanitize.go b/cmd/crowdsec-cli/clialert/sanitize.go new file mode 100644 index 00000000000..87b110649da --- /dev/null +++ b/cmd/crowdsec-cli/clialert/sanitize.go @@ -0,0 +1,26 @@ +package clialert + +import ( + "fmt" + "net" + + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +// SanitizeScope validates ip and range and sets the scope accordingly to our case convention. +func SanitizeScope(scope, ip, ipRange string) (string, error) { + if ipRange != "" { + _, _, err := net.ParseCIDR(ipRange) + if err != nil { + return "", fmt.Errorf("%s is not a valid range", ipRange) + } + } + + if ip != "" { + if net.ParseIP(ip) == nil { + return "", fmt.Errorf("%s is not a valid ip", ip) + } + } + + return types.NormalizeScope(scope), nil +} diff --git a/cmd/crowdsec-cli/alerts_table.go b/cmd/crowdsec-cli/clialert/table.go similarity index 97% rename from cmd/crowdsec-cli/alerts_table.go rename to cmd/crowdsec-cli/clialert/table.go index 29383457ced..1416e1e435c 100644 --- a/cmd/crowdsec-cli/alerts_table.go +++ b/cmd/crowdsec-cli/clialert/table.go @@ -1,4 +1,4 @@ -package main +package clialert import ( "fmt" @@ -38,7 +38,7 @@ func alertsTable(out io.Writer, wantColor string, alerts *models.GetAlertsRespon *alertItem.Scenario, alertItem.Source.Cn, alertItem.Source.GetAsNumberName(), - DecisionsFromAlert(alertItem), + decisionsFromAlert(alertItem), *alertItem.StartAt, } diff --git a/cmd/crowdsec-cli/clibouncer/add.go b/cmd/crowdsec-cli/clibouncer/add.go new file mode 100644 index 00000000000..7cc74e45fba --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/add.go @@ -0,0 +1,72 @@ +package clibouncer + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/spf13/cobra" + + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error { + var err error + + keyLength := 32 + + if key == "" { + key, err = middlewares.GenerateAPIKey(keyLength) + if err != nil { + return fmt.Errorf("unable to generate api key: %w", err) + } + } + + _, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType, false) + if err != nil { + return fmt.Errorf("unable to create bouncer: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + fmt.Printf("API key for '%s':\n\n", bouncerName) + fmt.Printf(" %s\n\n", key) + fmt.Print("Please keep this key since you will not be able to retrieve it!\n") + case "raw": + fmt.Print(key) + case "json": + j, err := json.Marshal(key) + if err != nil { + return errors.New("unable to serialize api key") + } + + fmt.Print(string(j)) + } + + return nil +} + +func (cli *cliBouncers) newAddCmd() *cobra.Command { + var key string + + cmd := &cobra.Command{ + Use: "add MyBouncerName", + Short: "add a single bouncer to the database", + Example: `cscli bouncers add MyBouncerName +cscli bouncers add MyBouncerName --key `, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args[0], key) + }, + } + + flags := cmd.Flags() + flags.StringP("length", "l", "", "length of the api key") + _ = flags.MarkDeprecated("length", "use --key instead") + flags.StringVarP(&key, "key", "k", "", "api key for the bouncer") + + return cmd +} diff --git a/cmd/crowdsec-cli/clibouncer/bouncers.go b/cmd/crowdsec-cli/clibouncer/bouncers.go new file mode 100644 index 00000000000..2b0a3556873 --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/bouncers.go @@ -0,0 +1,135 @@ +package clibouncer + +import ( + "slices" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" +) + +type configGetter = func() *csconfig.Config + +type cliBouncers struct { + db *database.Client + cfg configGetter +} + +func New(cfg configGetter) *cliBouncers { + return &cliBouncers{ + cfg: cfg, + } +} + +func (cli *cliBouncers) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "bouncers [action]", + Short: "Manage bouncers [requires local API]", + Long: `To list/add/delete/prune bouncers. +Note: This command requires database direct access, so is intended to be run on Local API/master. +`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"bouncer"}, + DisableAutoGenTag: true, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + + cfg := cli.cfg() + + if err = require.LAPI(cfg); err != nil { + return err + } + + cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) + if err != nil { + return err + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newPruneCmd()) + cmd.AddCommand(cli.newInspectCmd()) + + return cmd +} + +// bouncerInfo contains only the data we want for inspect/list +type bouncerInfo struct { + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Name string `json:"name"` + Revoked bool `json:"revoked"` + IPAddress string `json:"ip_address"` + Type string `json:"type"` + Version string `json:"version"` + LastPull *time.Time `json:"last_pull"` + AuthType string `json:"auth_type"` + OS string `json:"os,omitempty"` + Featureflags []string `json:"featureflags,omitempty"` + AutoCreated bool `json:"auto_created"` +} + +func newBouncerInfo(b *ent.Bouncer) bouncerInfo { + return bouncerInfo{ + CreatedAt: b.CreatedAt, + UpdatedAt: b.UpdatedAt, + Name: b.Name, + Revoked: b.Revoked, + IPAddress: b.IPAddress, + Type: b.Type, + Version: b.Version, + LastPull: b.LastPull, + AuthType: b.AuthType, + OS: clientinfo.GetOSNameAndVersion(b), + Featureflags: clientinfo.GetFeatureFlagList(b), + AutoCreated: b.AutoCreated, + } +} + +// validBouncerID returns a list of bouncer IDs for command completion +func (cli *cliBouncers) validBouncerID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + bouncers, err := cli.db.ListBouncers(ctx) + if err != nil { + cobra.CompError("unable to list bouncers " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, bouncer := range bouncers { + if strings.Contains(bouncer.Name, toComplete) && !slices.Contains(args, bouncer.Name) { + ret = append(ret, bouncer.Name) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} diff --git a/cmd/crowdsec-cli/clibouncer/delete.go b/cmd/crowdsec-cli/clibouncer/delete.go new file mode 100644 index 00000000000..33419f483b6 --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/delete.go @@ -0,0 +1,99 @@ +package clibouncer + +import ( + "context" + "errors" + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func (cli *cliBouncers) findParentBouncer(bouncerName string, bouncers []*ent.Bouncer) (string, error) { + bouncerPrefix := strings.Split(bouncerName, "@")[0] + for _, bouncer := range bouncers { + if strings.HasPrefix(bouncer.Name, bouncerPrefix) && !bouncer.AutoCreated { + return bouncer.Name, nil + } + } + + return "", errors.New("no parent bouncer found") +} + +func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error { + allBouncers, err := cli.db.ListBouncers(ctx) + if err != nil { + return fmt.Errorf("unable to list bouncers: %w", err) + } + for _, bouncerName := range bouncers { + bouncer, err := cli.db.SelectBouncerByName(ctx, bouncerName) + if err != nil { + var notFoundErr *ent.NotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + continue + } + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + + // For TLS bouncers, always delete them, they have no parents + if bouncer.AuthType == types.TlsAuthType { + if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + continue + } + + if bouncer.AutoCreated { + parentBouncer, err := cli.findParentBouncer(bouncerName, allBouncers) + if err != nil { + log.Errorf("bouncer '%s' is auto-created, but couldn't find a parent bouncer", err) + continue + } + log.Warnf("bouncer '%s' is auto-created and cannot be deleted, delete parent bouncer %s instead", bouncerName, parentBouncer) + continue + } + //Try to find all child bouncers and delete them + for _, childBouncer := range allBouncers { + if strings.HasPrefix(childBouncer.Name, bouncerName+"@") && childBouncer.AutoCreated { + if err := cli.db.DeleteBouncer(ctx, childBouncer.Name); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", childBouncer.Name, err) + } + log.Infof("bouncer '%s' deleted successfully", childBouncer.Name) + } + } + + if err := cli.db.DeleteBouncer(ctx, bouncerName); err != nil { + return fmt.Errorf("unable to delete bouncer %s: %w", bouncerName, err) + } + + log.Infof("bouncer '%s' deleted successfully", bouncerName) + } + + return nil +} + +func (cli *cliBouncers) newDeleteCmd() *cobra.Command { + var ignoreMissing bool + + cmd := &cobra.Command{ + Use: "delete MyBouncerName", + Short: "delete bouncer(s) from the database", + Example: `cscli bouncers delete "bouncer1" "bouncer2"`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"remove"}, + DisableAutoGenTag: true, + ValidArgsFunction: cli.validBouncerID, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more bouncers don't exist") + + return cmd +} diff --git a/cmd/crowdsec-cli/clibouncer/inspect.go b/cmd/crowdsec-cli/clibouncer/inspect.go new file mode 100644 index 00000000000..b62344baa9b --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/inspect.go @@ -0,0 +1,99 @@ +package clibouncer + +import ( + "encoding/json" + "errors" + "fmt" + "io" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" +) + +func (cli *cliBouncers) inspectHuman(out io.Writer, bouncer *ent.Bouncer) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + + t.SetTitle("Bouncer: " + bouncer.Name) + + t.SetColumnConfigs([]table.ColumnConfig{ + {Number: 1, AutoMerge: true}, + }) + + lastPull := "" + if bouncer.LastPull != nil { + lastPull = bouncer.LastPull.String() + } + + t.AppendRows([]table.Row{ + {"Created At", bouncer.CreatedAt}, + {"Last Update", bouncer.UpdatedAt}, + {"Revoked?", bouncer.Revoked}, + {"IP Address", bouncer.IPAddress}, + {"Type", bouncer.Type}, + {"Version", bouncer.Version}, + {"Last Pull", lastPull}, + {"Auth type", bouncer.AuthType}, + {"OS", clientinfo.GetOSNameAndVersion(bouncer)}, + {"Auto Created", bouncer.AutoCreated}, + }) + + for _, ff := range clientinfo.GetFeatureFlagList(bouncer) { + t.AppendRow(table.Row{"Feature Flags", ff}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliBouncers) inspect(bouncer *ent.Bouncer) error { + out := color.Output + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + cli.inspectHuman(out, bouncer) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(newBouncerInfo(bouncer)); err != nil { + return errors.New("failed to serialize") + } + + return nil + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliBouncers) newInspectCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inspect [bouncer_name]", + Short: "inspect a bouncer by name", + Example: `cscli bouncers inspect "bouncer1"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.validBouncerID, + RunE: func(cmd *cobra.Command, args []string) error { + bouncerName := args[0] + + b, err := cli.db.Ent.Bouncer.Query(). + Where(bouncer.Name(bouncerName)). + Only(cmd.Context()) + if err != nil { + return fmt.Errorf("unable to read bouncer data '%s': %w", bouncerName, err) + } + + return cli.inspect(b) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clibouncer/list.go b/cmd/crowdsec-cli/clibouncer/list.go new file mode 100644 index 00000000000..a13ca994e1e --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/list.go @@ -0,0 +1,117 @@ +package clibouncer + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "time" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +func (cli *cliBouncers) listHuman(out io.Writer, bouncers ent.Bouncers) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type"}) + + for _, b := range bouncers { + revoked := emoji.CheckMark + if b.Revoked { + revoked = emoji.Prohibited + } + + lastPull := "" + if b.LastPull != nil { + lastPull = b.LastPull.Format(time.RFC3339) + } + + t.AppendRow(table.Row{b.Name, b.IPAddress, revoked, lastPull, b.Type, b.Version, b.AuthType}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliBouncers) listCSV(out io.Writer, bouncers ent.Bouncers) error { + csvwriter := csv.NewWriter(out) + + if err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}); err != nil { + return fmt.Errorf("failed to write raw header: %w", err) + } + + for _, b := range bouncers { + valid := "validated" + if b.Revoked { + valid = "pending" + } + + lastPull := "" + if b.LastPull != nil { + lastPull = b.LastPull.Format(time.RFC3339) + } + + if err := csvwriter.Write([]string{b.Name, b.IPAddress, valid, lastPull, b.Type, b.Version, b.AuthType}); err != nil { + return fmt.Errorf("failed to write raw: %w", err) + } + } + + csvwriter.Flush() + + return nil +} + +func (cli *cliBouncers) List(ctx context.Context, out io.Writer, db *database.Client) error { + // XXX: must use the provided db object, the one in the struct might be nil + // (calling List directly skips the PersistentPreRunE) + + bouncers, err := db.ListBouncers(ctx) + if err != nil { + return fmt.Errorf("unable to list bouncers: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, bouncers) + case "json": + info := make([]bouncerInfo, 0, len(bouncers)) + for _, b := range bouncers { + info = append(info, newBouncerInfo(b)) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(info); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + return cli.listCSV(out, bouncers) + } + + return nil +} + +func (cli *cliBouncers) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list all bouncers within the database", + Example: `cscli bouncers list`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clibouncer/prune.go b/cmd/crowdsec-cli/clibouncer/prune.go new file mode 100644 index 00000000000..754e0898a3b --- /dev/null +++ b/cmd/crowdsec-cli/clibouncer/prune.go @@ -0,0 +1,85 @@ +package clibouncer + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/ask" +) + +func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error { + if duration < 2*time.Minute { + if yes, err := ask.YesNo( + "The duration you provided is less than 2 minutes. "+ + "This may remove active bouncers. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration)) + if err != nil { + return fmt.Errorf("unable to query bouncers: %w", err) + } + + if len(bouncers) == 0 { + fmt.Println("No bouncers to prune.") + return nil + } + + cli.listHuman(color.Output, bouncers) + + if !force { + if yes, err := ask.YesNo( + "You are about to PERMANENTLY remove the above bouncers from the database. "+ + "These will NOT be recoverable. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers) + if err != nil { + return fmt.Errorf("unable to prune bouncers: %w", err) + } + + fmt.Fprintf(os.Stderr, "Successfully deleted %d bouncers\n", deleted) + + return nil +} + +func (cli *cliBouncers) newPruneCmd() *cobra.Command { + var ( + duration time.Duration + force bool + ) + + const defaultDuration = 60 * time.Minute + + cmd := &cobra.Command{ + Use: "prune", + Short: "prune multiple bouncers from the database", + Args: cobra.NoArgs, + DisableAutoGenTag: true, + Example: `cscli bouncers prune -d 45m +cscli bouncers prune -d 45m --force`, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, force) + }, + } + + flags := cmd.Flags() + flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since last pull") + flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") + + return cmd +} diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/clicapi/capi.go similarity index 57% rename from cmd/crowdsec-cli/capi.go rename to cmd/crowdsec-cli/clicapi/capi.go index 1888aa3545a..61d59836fdd 100644 --- a/cmd/crowdsec-cli/capi.go +++ b/cmd/crowdsec-cli/clicapi/capi.go @@ -1,36 +1,36 @@ -package main +package clicapi import ( "context" "errors" "fmt" + "io" "net/url" "os" + "github.com/fatih/color" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) -const ( - CAPIBaseURL = "https://api.crowdsec.net/" - CAPIURLPrefix = "v3" -) +type configGetter = func() *csconfig.Config type cliCapi struct { cfg configGetter } -func NewCLICapi(cfg configGetter) *cliCapi { +func New(cfg configGetter) *cliCapi { return &cliCapi{ cfg: cfg, } @@ -58,27 +58,26 @@ func (cli *cliCapi) NewCommand() *cobra.Command { return cmd } -func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error { +func (cli *cliCapi) register(ctx context.Context, capiUserPrefix string, outputFile string) error { cfg := cli.cfg() - capiUser, err := generateID(capiUserPrefix) + capiUser, err := idgen.GenerateMachineID(capiUserPrefix) if err != nil { return fmt.Errorf("unable to generate machine id: %w", err) } - password := strfmt.Password(generatePassword(passwordLength)) + password := strfmt.Password(idgen.GeneratePassword(idgen.PasswordLength)) apiurl, err := url.Parse(types.CAPIBaseURL) if err != nil { return fmt.Errorf("unable to parse api url %s: %w", types.CAPIBaseURL, err) } - _, err = apiclient.RegisterClient(&apiclient.Config{ + _, err = apiclient.RegisterClient(ctx, &apiclient.Config{ MachineID: capiUser, Password: password, - UserAgent: cwversion.UserAgent(), URL: apiurl, - VersionPrefix: CAPIURLPrefix, + VersionPrefix: "v3", }, nil) if err != nil { return fmt.Errorf("api client register ('%s'): %w", types.CAPIBaseURL, err) @@ -105,7 +104,7 @@ func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error { apiConfigDump, err := yaml.Marshal(apiCfg) if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) + return fmt.Errorf("unable to serialize api credentials: %w", err) } if dumpFile != "" { @@ -119,7 +118,7 @@ func (cli *cliCapi) register(capiUserPrefix string, outputFile string) error { fmt.Println(string(apiConfigDump)) } - log.Warning(ReloadMessage()) + log.Warning(reload.Message) return nil } @@ -135,8 +134,8 @@ func (cli *cliCapi) newRegisterCmd() *cobra.Command { Short: "Register to Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.register(capiUserPrefix, outputFile) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.register(cmd.Context(), capiUserPrefix, outputFile) }, } @@ -148,21 +147,17 @@ func (cli *cliCapi) newRegisterCmd() *cobra.Command { return cmd } -// QueryCAPIStatus checks if the Local API is reachable, and if the credentials are correct. It then checks if the instance is enrolle in the console. -func QueryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) { - +// queryCAPIStatus checks if the Central API is reachable, and if the credentials are correct. It then checks if the instance is enrolle in the console. +func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, bool, error) { apiURL, err := url.Parse(credURL) if err != nil { - return false, false, fmt.Errorf("parsing api url: %w", err) + return false, false, err } - scenarios, err := hub.GetInstalledNamesByType(cwhub.SCENARIOS) - if err != nil { - return false, false, fmt.Errorf("failed to get scenarios: %w", err) - } + itemsForAPI := hub.GetInstalledListForAPI() - if len(scenarios) == 0 { - return false, false, errors.New("no scenarios installed, abort") + if len(itemsForAPI) == 0 { + return false, false, errors.New("no scenarios or appsec-rules installed, abort") } passwd := strfmt.Password(password) @@ -170,31 +165,17 @@ func QueryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password stri client, err := apiclient.NewClient(&apiclient.Config{ MachineID: login, Password: passwd, - Scenarios: scenarios, - UserAgent: cwversion.UserAgent(), + Scenarios: itemsForAPI, URL: apiURL, - //I don't believe papi is neede to check enrollement - //PapiURL: papiURL, + // I don't believe papi is neede to check enrollement + // PapiURL: papiURL, VersionPrefix: "v3", - UpdateScenario: func() ([]string, error) { - l_scenarios, err := hub.GetInstalledNamesByType(cwhub.SCENARIOS) - if err != nil { - return nil, err - } - appsecRules, err := hub.GetInstalledNamesByType(cwhub.APPSEC_RULES) - if err != nil { - return nil, err - } - ret := make([]string, 0, len(l_scenarios)+len(appsecRules)) - ret = append(ret, l_scenarios...) - ret = append(ret, appsecRules...) - - return ret, nil + UpdateScenario: func(_ context.Context) ([]string, error) { + return itemsForAPI, nil }, }) - if err != nil { - return false, false, fmt.Errorf("new client api: %w", err) + return false, false, err } pw := strfmt.Password(password) @@ -202,10 +183,10 @@ func QueryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password stri t := models.WatcherAuthRequest{ MachineID: &login, Password: &pw, - Scenarios: scenarios, + Scenarios: itemsForAPI, } - authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), t) + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, t) if err != nil { return false, false, err } @@ -215,11 +196,11 @@ func QueryCAPIStatus(hub *cwhub.Hub, credURL string, login string, password stri if client.IsEnrolled() { return true, true, nil } - return true, false, nil + return true, false, nil } -func (cli *cliCapi) status() error { +func (cli *cliCapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error { cfg := cli.cfg() if err := require.CAPIRegistered(cfg); err != nil { @@ -228,25 +209,43 @@ func (cli *cliCapi) status() error { cred := cfg.API.Server.OnlineClient.Credentials - hub, err := require.Hub(cfg, nil, nil) + fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Server.OnlineClient.CredentialsFilePath) + fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL) + + auth, enrolled, err := queryCAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password) if err != nil { - return err + return fmt.Errorf("failed to authenticate to Central API (CAPI): %w", err) } - log.Infof("Loaded credentials from %s", cfg.API.Server.OnlineClient.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", cred.Login, cred.URL) + if auth { + fmt.Fprint(out, "You can successfully interact with Central API (CAPI)\n") + } - auth, enrolled, err := QueryCAPIStatus(hub, cred.URL, cred.Login, cred.Password) + if enrolled { + fmt.Fprint(out, "Your instance is enrolled in the console\n") + } - if err != nil { - return fmt.Errorf("CAPI: failed to authenticate to Central API (CAPI): %s", err) + switch *cfg.API.Server.OnlineClient.Sharing { + case true: + fmt.Fprint(out, "Sharing signals is enabled\n") + case false: + fmt.Fprint(out, "Sharing signals is disabled\n") } - if auth { - log.Info("You can successfully interact with Central API (CAPI)") + + switch *cfg.API.Server.OnlineClient.PullConfig.Community { + case true: + fmt.Fprint(out, "Pulling community blocklist is enabled\n") + case false: + fmt.Fprint(out, "Pulling community blocklist is disabled\n") } - if enrolled { - log.Info("Your instance is enrolled in the console") + + switch *cfg.API.Server.OnlineClient.PullConfig.Blocklists { + case true: + fmt.Fprint(out, "Pulling blocklists from the console is enabled\n") + case false: + fmt.Fprint(out, "Pulling blocklists from the console is disabled\n") } + return nil } @@ -256,8 +255,13 @@ func (cli *cliCapi) newStatusCmd() *cobra.Command { Short: "Check status with the Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.status() + RunE: func(cmd *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + return cli.Status(cmd.Context(), color.Output, hub) }, } diff --git a/cmd/crowdsec-cli/console.go b/cmd/crowdsec-cli/cliconsole/console.go similarity index 79% rename from cmd/crowdsec-cli/console.go rename to cmd/crowdsec-cli/cliconsole/console.go index 979c9f0ea60..448ddcee7fa 100644 --- a/cmd/crowdsec-cli/console.go +++ b/cmd/crowdsec-cli/cliconsole/console.go @@ -1,4 +1,4 @@ -package main +package cliconsole import ( "context" @@ -20,19 +20,20 @@ import ( "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/types" ) +type configGetter func() *csconfig.Config + type cliConsole struct { cfg configGetter } -func NewCLIConsole(cfg configGetter) *cliConsole { +func New(cfg configGetter) *cliConsole { return &cliConsole{ cfg: cfg, } @@ -65,7 +66,7 @@ func (cli *cliConsole) NewCommand() *cobra.Command { return cmd } -func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []string, opts []string) error { +func (cli *cliConsole) enroll(ctx context.Context, key string, name string, overwrite bool, tags []string, opts []string) error { cfg := cli.cfg() password := strfmt.Password(cfg.API.Server.OnlineClient.Credentials.Password) @@ -74,20 +75,6 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st return fmt.Errorf("could not parse CAPI URL: %w", err) } - hub, err := require.Hub(cfg, nil, nil) - if err != nil { - return err - } - - scenarios, err := hub.GetInstalledNamesByType(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("failed to get installed scenarios: %w", err) - } - - if len(scenarios) == 0 { - scenarios = make([]string, 0) - } - enableOpts := []string{csconfig.SEND_MANUAL_SCENARIOS, csconfig.SEND_TAINTED_SCENARIOS} if len(opts) != 0 { @@ -100,23 +87,25 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st } for _, availableOpt := range csconfig.CONSOLE_CONFIGS { - if opt == availableOpt { - valid = true - enable := true - - for _, enabledOpt := range enableOpts { - if opt == enabledOpt { - enable = false - continue - } - } + if opt != availableOpt { + continue + } + + valid = true + enable := true - if enable { - enableOpts = append(enableOpts, opt) + for _, enabledOpt := range enableOpts { + if opt == enabledOpt { + enable = false + continue } + } - break + if enable { + enableOpts = append(enableOpts, opt) } + + break } if !valid { @@ -125,16 +114,20 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st } } + hub, err := require.Hub(cfg, nil, nil) + if err != nil { + return err + } + c, _ := apiclient.NewClient(&apiclient.Config{ MachineID: cli.cfg().API.Server.OnlineClient.Credentials.Login, Password: password, - Scenarios: scenarios, - UserAgent: cwversion.UserAgent(), + Scenarios: hub.GetInstalledListForAPI(), URL: apiURL, VersionPrefix: "v3", }) - resp, err := c.Auth.EnrollWatcher(context.Background(), key, name, tags, overwrite) + resp, err := c.Auth.EnrollWatcher(ctx, key, name, tags, overwrite) if err != nil { return fmt.Errorf("could not enroll instance: %w", err) } @@ -180,8 +173,8 @@ After running this command your will need to validate the enrollment in the weba valid options are : %s,all (see 'cscli console status' for details)`, strings.Join(csconfig.CONSOLE_CONFIGS, ",")), Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.enroll(args[0], name, overwrite, tags, opts) + RunE: func(cmd *cobra.Command, args []string) error { + return cli.enroll(cmd.Context(), args[0], name, overwrite, tags, opts) }, } @@ -221,7 +214,7 @@ Enable given information push to the central API. Allows to empower the console` log.Infof("%v have been enabled", args) } - log.Infof(ReloadMessage()) + log.Info(reload.Message) return nil }, @@ -255,7 +248,7 @@ Disable given information push to the central API.`, log.Infof("%v have been disabled", args) } - log.Infof(ReloadMessage()) + log.Info(reload.Message) return nil }, @@ -287,7 +280,7 @@ func (cli *cliConsole) newStatusCmd() *cobra.Command { } data, err := json.MarshalIndent(out, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } fmt.Println(string(data)) case "raw": @@ -325,7 +318,7 @@ func (cli *cliConsole) dumpConfig() error { out, err := yaml.Marshal(serverCfg.ConsoleConfig) if err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", serverCfg.ConsoleConfigPath, err) + return fmt.Errorf("while serializing ConsoleConfig (for %s): %w", serverCfg.ConsoleConfigPath, err) } if serverCfg.ConsoleConfigPath == "" { @@ -348,13 +341,8 @@ func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error { switch arg { case csconfig.CONSOLE_MANAGEMENT: /*for each flag check if it's already set before setting it*/ - if consoleCfg.ConsoleManagement != nil { - if *consoleCfg.ConsoleManagement == wanted { - log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - } else { - log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) - *consoleCfg.ConsoleManagement = wanted - } + if consoleCfg.ConsoleManagement != nil && *consoleCfg.ConsoleManagement == wanted { + log.Debugf("%s already set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) } else { log.Infof("%s set to %t", csconfig.CONSOLE_MANAGEMENT, wanted) consoleCfg.ConsoleManagement = ptr.Of(wanted) @@ -373,7 +361,7 @@ func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error { if changed { fileContent, err := yaml.Marshal(cfg.API.Server.OnlineClient.Credentials) if err != nil { - return fmt.Errorf("cannot marshal credentials: %w", err) + return fmt.Errorf("cannot serialize credentials: %w", err) } log.Infof("Updating credentials file: %s", cfg.API.Server.OnlineClient.CredentialsFilePath) @@ -386,52 +374,32 @@ func (cli *cliConsole) setConsoleOpts(args []string, wanted bool) error { } case csconfig.SEND_CUSTOM_SCENARIOS: /*for each flag check if it's already set before setting it*/ - if consoleCfg.ShareCustomScenarios != nil { - if *consoleCfg.ShareCustomScenarios == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) - *consoleCfg.ShareCustomScenarios = wanted - } + if consoleCfg.ShareCustomScenarios != nil && *consoleCfg.ShareCustomScenarios == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_CUSTOM_SCENARIOS, wanted) consoleCfg.ShareCustomScenarios = ptr.Of(wanted) } case csconfig.SEND_TAINTED_SCENARIOS: /*for each flag check if it's already set before setting it*/ - if consoleCfg.ShareTaintedScenarios != nil { - if *consoleCfg.ShareTaintedScenarios == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) - *consoleCfg.ShareTaintedScenarios = wanted - } + if consoleCfg.ShareTaintedScenarios != nil && *consoleCfg.ShareTaintedScenarios == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_TAINTED_SCENARIOS, wanted) consoleCfg.ShareTaintedScenarios = ptr.Of(wanted) } case csconfig.SEND_MANUAL_SCENARIOS: /*for each flag check if it's already set before setting it*/ - if consoleCfg.ShareManualDecisions != nil { - if *consoleCfg.ShareManualDecisions == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) - *consoleCfg.ShareManualDecisions = wanted - } + if consoleCfg.ShareManualDecisions != nil && *consoleCfg.ShareManualDecisions == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_MANUAL_SCENARIOS, wanted) consoleCfg.ShareManualDecisions = ptr.Of(wanted) } case csconfig.SEND_CONTEXT: /*for each flag check if it's already set before setting it*/ - if consoleCfg.ShareContext != nil { - if *consoleCfg.ShareContext == wanted { - log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) - } else { - log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) - *consoleCfg.ShareContext = wanted - } + if consoleCfg.ShareContext != nil && *consoleCfg.ShareContext == wanted { + log.Debugf("%s already set to %t", csconfig.SEND_CONTEXT, wanted) } else { log.Infof("%s set to %t", csconfig.SEND_CONTEXT, wanted) consoleCfg.ShareContext = ptr.Of(wanted) diff --git a/cmd/crowdsec-cli/console_table.go b/cmd/crowdsec-cli/cliconsole/console_table.go similarity index 98% rename from cmd/crowdsec-cli/console_table.go rename to cmd/crowdsec-cli/cliconsole/console_table.go index 94976618573..8f17b97860a 100644 --- a/cmd/crowdsec-cli/console_table.go +++ b/cmd/crowdsec-cli/cliconsole/console_table.go @@ -1,4 +1,4 @@ -package main +package cliconsole import ( "io" diff --git a/cmd/crowdsec-cli/decisions.go b/cmd/crowdsec-cli/clidecision/decisions.go similarity index 90% rename from cmd/crowdsec-cli/decisions.go rename to cmd/crowdsec-cli/clidecision/decisions.go index d485c90254f..307cabffe51 100644 --- a/cmd/crowdsec-cli/decisions.go +++ b/cmd/crowdsec-cli/clidecision/decisions.go @@ -1,4 +1,4 @@ -package main +package clidecision import ( "context" @@ -17,8 +17,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -114,12 +115,14 @@ func (cli *cliDecisions) decisionsToTable(alerts *models.GetAlertsResponse, prin return nil } +type configGetter func() *csconfig.Config + type cliDecisions struct { client *apiclient.ApiClient cfg configGetter } -func NewCLIDecisions(cfg configGetter) *cliDecisions { +func New(cfg configGetter) *cliDecisions { return &cliDecisions{ cfg: cfg, } @@ -148,7 +151,6 @@ func (cli *cliDecisions) NewCommand() *cobra.Command { cli.client, err = apiclient.NewClient(&apiclient.Config{ MachineID: cfg.API.Client.Credentials.Login, Password: strfmt.Password(cfg.API.Client.Credentials.Password), - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -168,10 +170,11 @@ func (cli *cliDecisions) NewCommand() *cobra.Command { return cmd } -func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { +func (cli *cliDecisions) list(ctx context.Context, filter apiclient.AlertsListOpts, NoSimu *bool, contained *bool, printMachine bool) error { var err error - /*take care of shorthand options*/ - if err = manageCliDecisionAlerts(filter.IPEquals, filter.RangeEquals, filter.ScopeEquals, filter.ValueEquals); err != nil { + + *filter.ScopeEquals, err = clialert.SanitizeScope(*filter.ScopeEquals, *filter.IPEquals, *filter.RangeEquals) + if err != nil { return err } @@ -246,7 +249,7 @@ func (cli *cliDecisions) list(filter apiclient.AlertsListOpts, NoSimu *bool, con filter.Contains = new(bool) } - alerts, _, err := cli.client.Alerts.List(context.Background(), filter) + alerts, _, err := cli.client.Alerts.List(ctx, filter) if err != nil { return fmt.Errorf("unable to retrieve decisions: %w", err) } @@ -287,10 +290,10 @@ cscli decisions list -r 1.2.3.0/24 cscli decisions list -s crowdsecurity/ssh-bf cscli decisions list --origin lists --scenario list_name `, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.list(filter, NoSimu, contained, printMachine) + return cli.list(cmd.Context(), filter, NoSimu, contained, printMachine) }, } @@ -314,7 +317,7 @@ cscli decisions list --origin lists --scenario list_name return cmd } -func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { +func (cli *cliDecisions) add(ctx context.Context, addIP, addRange, addDuration, addValue, addScope, addReason, addType string) error { alerts := models.AddAlertsRequest{} origin := types.CscliOrigin capacity := int32(0) @@ -326,8 +329,10 @@ func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, a stopAt := time.Now().UTC().Format(time.RFC3339) createdAt := time.Now().UTC().Format(time.RFC3339) - /*take care of shorthand options*/ - if err := manageCliDecisionAlerts(&addIP, &addRange, &addScope, &addValue); err != nil { + var err error + + addScope, err = clialert.SanitizeScope(addScope, addIP, addRange) + if err != nil { return err } @@ -381,7 +386,7 @@ func (cli *cliDecisions) add(addIP, addRange, addDuration, addValue, addScope, a } alerts = append(alerts, &alert) - _, _, err := cli.client.Alerts.Add(context.Background(), alerts) + _, _, err = cli.client.Alerts.Add(ctx, alerts) if err != nil { return err } @@ -411,10 +416,10 @@ cscli decisions add --ip 1.2.3.4 --duration 24h --type captcha cscli decisions add --scope username --value foobar `, /*TBD : fix long and example*/ - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { - return cli.add(addIP, addRange, addDuration, addValue, addScope, addReason, addType) + return cli.add(cmd.Context(), addIP, addRange, addDuration, addValue, addScope, addReason, addType) }, } @@ -431,11 +436,12 @@ cscli decisions add --scope username --value foobar return cmd } -func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { +func (cli *cliDecisions) delete(ctx context.Context, delFilter apiclient.DecisionsDeleteOpts, delDecisionID string, contained *bool) error { var err error /*take care of shorthand options*/ - if err = manageCliDecisionAlerts(delFilter.IPEquals, delFilter.RangeEquals, delFilter.ScopeEquals, delFilter.ValueEquals); err != nil { + *delFilter.ScopeEquals, err = clialert.SanitizeScope(*delFilter.ScopeEquals, *delFilter.IPEquals, *delFilter.RangeEquals) + if err != nil { return err } @@ -474,7 +480,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci var decisions *models.DeleteDecisionResponse if delDecisionID == "" { - decisions, _, err = cli.client.Decisions.Delete(context.Background(), delFilter) + decisions, _, err = cli.client.Decisions.Delete(ctx, delFilter) if err != nil { return fmt.Errorf("unable to delete decisions: %w", err) } @@ -483,7 +489,7 @@ func (cli *cliDecisions) delete(delFilter apiclient.DecisionsDeleteOpts, delDeci return fmt.Errorf("id '%s' is not an integer: %w", delDecisionID, err) } - decisions, _, err = cli.client.Decisions.DeleteOne(context.Background(), delDecisionID) + decisions, _, err = cli.client.Decisions.DeleteOne(ctx, delDecisionID) if err != nil { return fmt.Errorf("unable to delete decision: %w", err) } @@ -537,8 +543,8 @@ cscli decisions delete --origin lists --scenario list_name return nil }, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.delete(delFilter, delDecisionID, contained) + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.delete(cmd.Context(), delFilter, delDecisionID, contained) }, } diff --git a/cmd/crowdsec-cli/decisions_import.go b/cmd/crowdsec-cli/clidecision/import.go similarity index 70% rename from cmd/crowdsec-cli/decisions_import.go rename to cmd/crowdsec-cli/clidecision/import.go index 338c1b7fb3e..5b34b74a250 100644 --- a/cmd/crowdsec-cli/decisions_import.go +++ b/cmd/crowdsec-cli/clidecision/import.go @@ -1,4 +1,4 @@ -package main +package clidecision import ( "bufio" @@ -67,65 +67,29 @@ func parseDecisionList(content []byte, format string) ([]decisionRaw, error) { return ret, nil } -func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { - flags := cmd.Flags() - - input, err := flags.GetString("input") - if err != nil { - return err - } - - defaultDuration, err := flags.GetString("duration") - if err != nil { - return err - } - - if defaultDuration == "" { - return errors.New("--duration cannot be empty") - } - - defaultScope, err := flags.GetString("scope") - if err != nil { - return err - } - - if defaultScope == "" { - return errors.New("--scope cannot be empty") - } - - defaultReason, err := flags.GetString("reason") - if err != nil { - return err - } - - if defaultReason == "" { - return errors.New("--reason cannot be empty") - } +func (cli *cliDecisions) import_(ctx context.Context, input string, duration string, scope string, reason string, type_ string, batch int, format string) error { + var ( + content []byte + fin *os.File + err error + ) - defaultType, err := flags.GetString("type") - if err != nil { - return err + if duration == "" { + return errors.New("default duration cannot be empty") } - if defaultType == "" { - return errors.New("--type cannot be empty") + if scope == "" { + return errors.New("default scope cannot be empty") } - batchSize, err := flags.GetInt("batch") - if err != nil { - return err + if reason == "" { + return errors.New("default reason cannot be empty") } - format, err := flags.GetString("format") - if err != nil { - return err + if type_ == "" { + return errors.New("default type cannot be empty") } - var ( - content []byte - fin *os.File - ) - // set format if the file has a json or csv extension if format == "" { if strings.HasSuffix(input, ".json") { @@ -167,23 +131,23 @@ func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { } if d.Duration == "" { - d.Duration = defaultDuration - log.Debugf("item %d: missing 'duration', using default '%s'", i, defaultDuration) + d.Duration = duration + log.Debugf("item %d: missing 'duration', using default '%s'", i, duration) } if d.Scenario == "" { - d.Scenario = defaultReason - log.Debugf("item %d: missing 'reason', using default '%s'", i, defaultReason) + d.Scenario = reason + log.Debugf("item %d: missing 'reason', using default '%s'", i, reason) } if d.Type == "" { - d.Type = defaultType - log.Debugf("item %d: missing 'type', using default '%s'", i, defaultType) + d.Type = type_ + log.Debugf("item %d: missing 'type', using default '%s'", i, type_) } if d.Scope == "" { - d.Scope = defaultScope - log.Debugf("item %d: missing 'scope', using default '%s'", i, defaultScope) + d.Scope = scope + log.Debugf("item %d: missing 'scope', using default '%s'", i, scope) } decisions[i] = &models.Decision{ @@ -201,7 +165,7 @@ func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { log.Infof("You are about to add %d decisions, this may take a while", len(decisions)) } - for _, chunk := range slicetools.Chunks(decisions, batchSize) { + for _, chunk := range slicetools.Chunks(decisions, batch) { log.Debugf("Processing chunk of %d decisions", len(chunk)) importAlert := models.Alert{ CreatedAt: time.Now().UTC().Format(time.RFC3339), @@ -224,7 +188,7 @@ func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { Decisions: chunk, } - _, _, err = cli.client.Alerts.Add(context.Background(), models.AddAlertsRequest{&importAlert}) + _, _, err = cli.client.Alerts.Add(ctx, models.AddAlertsRequest{&importAlert}) if err != nil { return err } @@ -236,12 +200,22 @@ func (cli *cliDecisions) runImport(cmd *cobra.Command, args []string) error { } func (cli *cliDecisions) newImportCmd() *cobra.Command { + var ( + input string + duration string + scope string + reason string + decisionType string + batch int + format string + ) + cmd := &cobra.Command{ Use: "import [options]", Short: "Import decisions from a file or pipe", Long: "expected format:\n" + "csv : any of duration,reason,scope,type,value, with a header line\n" + - "json :" + "`{" + `"duration" : "24h", "reason" : "my_scenario", "scope" : "ip", "type" : "ban", "value" : "x.y.z.z"` + "}`", + "json :" + "`{" + `"duration": "24h", "reason": "my_scenario", "scope": "ip", "type": "ban", "value": "x.y.z.z"` + "}`", Args: cobra.NoArgs, DisableAutoGenTag: true, Example: `decisions.csv: @@ -251,7 +225,7 @@ duration,scope,value $ cscli decisions import -i decisions.csv decisions.json: -[{"duration" : "4h", "scope" : "ip", "type" : "ban", "value" : "1.2.3.4"}] +[{"duration": "4h", "scope": "ip", "type": "ban", "value": "1.2.3.4"}] The file format is detected from the extension, but can be forced with the --format option which is required when reading from standard input. @@ -260,18 +234,20 @@ Raw values, standard input: $ echo "1.2.3.4" | cscli decisions import -i - --format values `, - RunE: cli.runImport, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.import_(cmd.Context(), input, duration, scope, reason, decisionType, batch, format) + }, } flags := cmd.Flags() flags.SortFlags = false - flags.StringP("input", "i", "", "Input file") - flags.StringP("duration", "d", "4h", "Decision duration: 1h,4h,30m") - flags.String("scope", types.Ip, "Decision scope: ip,range,username") - flags.StringP("reason", "R", "manual", "Decision reason: ") - flags.StringP("type", "t", "ban", "Decision type: ban,captcha,throttle") - flags.Int("batch", 0, "Split import in batches of N decisions") - flags.String("format", "", "Input format: 'json', 'csv' or 'values' (each line is a value, no headers)") + flags.StringVarP(&input, "input", "i", "", "Input file") + flags.StringVarP(&duration, "duration", "d", "4h", "Decision duration: 1h,4h,30m") + flags.StringVar(&scope, "scope", types.Ip, "Decision scope: ip,range,username") + flags.StringVarP(&reason, "reason", "R", "manual", "Decision reason: ") + flags.StringVarP(&decisionType, "type", "t", "ban", "Decision type: ban,captcha,throttle") + flags.IntVar(&batch, "batch", 0, "Split import in batches of N decisions") + flags.StringVar(&format, "format", "", "Input format: 'json', 'csv' or 'values' (each line is a value, no headers)") _ = cmd.MarkFlagRequired("input") diff --git a/cmd/crowdsec-cli/decisions_table.go b/cmd/crowdsec-cli/clidecision/table.go similarity index 92% rename from cmd/crowdsec-cli/decisions_table.go rename to cmd/crowdsec-cli/clidecision/table.go index 02952f93b85..189eb80b8e5 100644 --- a/cmd/crowdsec-cli/decisions_table.go +++ b/cmd/crowdsec-cli/clidecision/table.go @@ -1,7 +1,6 @@ -package main +package clidecision import ( - "fmt" "io" "strconv" @@ -23,7 +22,7 @@ func (cli *cliDecisions) decisionsTable(out io.Writer, alerts *models.GetAlertsR for _, alertItem := range *alerts { for _, decisionItem := range alertItem.Decisions { if *alertItem.Simulated { - *decisionItem.Type = fmt.Sprintf("(simul)%s", *decisionItem.Type) + *decisionItem.Type = "(simul)" + *decisionItem.Type } row := []string{ diff --git a/cmd/crowdsec-cli/clientinfo/clientinfo.go b/cmd/crowdsec-cli/clientinfo/clientinfo.go new file mode 100644 index 00000000000..0bf1d98804f --- /dev/null +++ b/cmd/crowdsec-cli/clientinfo/clientinfo.go @@ -0,0 +1,39 @@ +package clientinfo + +import ( + "strings" +) + +type featureflagProvider interface { + GetFeatureflags() string +} + +type osProvider interface { + GetOsname() string + GetOsversion() string +} + +func GetOSNameAndVersion(o osProvider) string { + ret := o.GetOsname() + if o.GetOsversion() != "" { + if ret != "" { + ret += "/" + } + + ret += o.GetOsversion() + } + + if ret == "" { + return "?" + } + + return ret +} + +func GetFeatureFlagList(o featureflagProvider) []string { + if o.GetFeatureflags() == "" { + return nil + } + + return strings.Split(o.GetFeatureflags(), ",") +} diff --git a/cmd/crowdsec-cli/explain.go b/cmd/crowdsec-cli/cliexplain/explain.go similarity index 92% rename from cmd/crowdsec-cli/explain.go rename to cmd/crowdsec-cli/cliexplain/explain.go index c322cce47fe..d6e821e4e6c 100644 --- a/cmd/crowdsec-cli/explain.go +++ b/cmd/crowdsec-cli/cliexplain/explain.go @@ -1,4 +1,4 @@ -package main +package cliexplain import ( "bufio" @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) @@ -40,9 +41,12 @@ func getLineCountForFile(filepath string) (int, error) { return lc, nil } +type configGetter func() *csconfig.Config + type cliExplain struct { - cfg configGetter - flags struct { + cfg configGetter + configFilePath string + flags struct { logFile string dsn string logLine string @@ -56,9 +60,10 @@ type cliExplain struct { } } -func NewCLIExplain(cfg configGetter) *cliExplain { +func New(cfg configGetter, configFilePath string) *cliExplain { return &cliExplain{ - cfg: cfg, + cfg: cfg, + configFilePath: configFilePath, } } @@ -75,7 +80,7 @@ cscli explain --log "Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth) cscli explain --dsn "file://myfile.log" --type nginx tail -n 5 myfile.log | cscli explain --type nginx -f - `, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { return cli.run() @@ -103,7 +108,7 @@ tail -n 5 myfile.log | cscli explain --type nginx -f - flags.StringVar(&cli.flags.crowdsec, "crowdsec", "crowdsec", "Path to crowdsec") flags.BoolVar(&cli.flags.noClean, "no-clean", false, "Don't clean runtime environment after tests") - cmd.MarkFlagRequired("type") + _ = cmd.MarkFlagRequired("type") cmd.MarkFlagsOneRequired("log", "file", "dsn") return cmd @@ -192,7 +197,7 @@ func (cli *cliExplain) run() error { return fmt.Errorf("unable to get absolute path of '%s', exiting", logFile) } - dsn = fmt.Sprintf("file://%s", absolutePath) + dsn = "file://" + absolutePath lineCount, err := getLineCountForFile(absolutePath) if err != nil { @@ -214,7 +219,7 @@ func (cli *cliExplain) run() error { return errors.New("no acquisition (--file or --dsn) provided, can't run cscli test") } - cmdArgs := []string{"-c", ConfigFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} + cmdArgs := []string{"-c", cli.configFilePath, "-type", logType, "-dsn", dsn, "-dump-data", dir, "-no-api"} if labels != "" { log.Debugf("adding labels %s", labels) diff --git a/cmd/crowdsec-cli/hub.go b/cmd/crowdsec-cli/clihub/hub.go similarity index 85% rename from cmd/crowdsec-cli/hub.go rename to cmd/crowdsec-cli/clihub/hub.go index 70df30fc410..f189d6a2e13 100644 --- a/cmd/crowdsec-cli/hub.go +++ b/cmd/crowdsec-cli/clihub/hub.go @@ -1,9 +1,10 @@ -package main +package clihub import ( "context" "encoding/json" "fmt" + "io" "github.com/fatih/color" log "github.com/sirupsen/logrus" @@ -11,14 +12,17 @@ import ( "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) +type configGetter = func() *csconfig.Config + type cliHub struct { cfg configGetter } -func NewCLIHub(cfg configGetter) *cliHub { +func New(cfg configGetter) *cliHub { return &cliHub{ cfg: cfg, } @@ -35,7 +39,7 @@ The Hub is managed by cscli, to get the latest hub files from [Crowdsec Hub](htt Example: `cscli hub list cscli hub update cscli hub upgrade`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, } @@ -47,14 +51,9 @@ cscli hub upgrade`, return cmd } -func (cli *cliHub) list(all bool) error { +func (cli *cliHub) List(out io.Writer, hub *cwhub.Hub, all bool) error { cfg := cli.cfg() - hub, err := require.Hub(cfg, nil, log.StandardLogger()) - if err != nil { - return err - } - for _, v := range hub.Warnings { log.Info(v) } @@ -65,14 +64,16 @@ func (cli *cliHub) list(all bool) error { items := make(map[string][]*cwhub.Item) + var err error + for _, itemType := range cwhub.ItemTypes { - items[itemType], err = selectItems(hub, itemType, nil, !all) + items[itemType], err = SelectItems(hub, itemType, nil, !all) if err != nil { return err } } - err = listItems(color.Output, cfg.Cscli.Color, cwhub.ItemTypes, items, true, cfg.Cscli.Output) + err = ListItems(out, cfg.Cscli.Color, cwhub.ItemTypes, items, true, cfg.Cscli.Output) if err != nil { return err } @@ -86,10 +87,15 @@ func (cli *cliHub) newListCmd() *cobra.Command { cmd := &cobra.Command{ Use: "list [-a]", Short: "List all installed configurations", - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { - return cli.list(all) + hub, err := require.Hub(cli.cfg(), nil, log.StandardLogger()) + if err != nil { + return err + } + + return cli.List(color.Output, hub, all) }, } @@ -134,7 +140,7 @@ func (cli *cliHub) newUpdateCmd() *cobra.Command { Long: ` Fetches the .index.json file from the hub, containing the list of available configs. `, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { return cli.update(cmd.Context(), withContent) @@ -154,16 +160,11 @@ func (cli *cliHub) upgrade(ctx context.Context, force bool) error { } for _, itemType := range cwhub.ItemTypes { - items, err := hub.GetInstalledItemsByType(itemType) - if err != nil { - return err - } - updated := 0 log.Infof("Upgrading %s", itemType) - for _, item := range items { + for _, item := range hub.GetInstalledByType(itemType, true) { didUpdate, err := item.Upgrade(ctx, force) if err != nil { return err @@ -189,7 +190,7 @@ func (cli *cliHub) newUpgradeCmd() *cobra.Command { Long: ` Upgrade all configs installed from Crowdsec Hub. Run 'sudo cscli hub update' if you want the latest versions available. `, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { return cli.upgrade(cmd.Context(), force) @@ -234,7 +235,7 @@ func (cli *cliHub) newTypesCmd() *cobra.Command { Long: ` List the types of supported hub items. `, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { return cli.types() diff --git a/cmd/crowdsec-cli/item_metrics.go b/cmd/crowdsec-cli/clihub/item_metrics.go similarity index 89% rename from cmd/crowdsec-cli/item_metrics.go rename to cmd/crowdsec-cli/clihub/item_metrics.go index f00ae08b00b..f4af8f635db 100644 --- a/cmd/crowdsec-cli/item_metrics.go +++ b/cmd/crowdsec-cli/clihub/item_metrics.go @@ -1,4 +1,4 @@ -package main +package clihub import ( "net/http" @@ -16,22 +16,22 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func ShowMetrics(prometheusURL string, hubItem *cwhub.Item, wantColor string) error { +func showMetrics(prometheusURL string, hubItem *cwhub.Item, wantColor string) error { switch hubItem.Type { case cwhub.PARSERS: - metrics := GetParserMetric(prometheusURL, hubItem.Name) + metrics := getParserMetric(prometheusURL, hubItem.Name) parserMetricsTable(color.Output, wantColor, hubItem.Name, metrics) case cwhub.SCENARIOS: - metrics := GetScenarioMetric(prometheusURL, hubItem.Name) + metrics := getScenarioMetric(prometheusURL, hubItem.Name) scenarioMetricsTable(color.Output, wantColor, hubItem.Name, metrics) case cwhub.COLLECTIONS: for _, sub := range hubItem.SubItems() { - if err := ShowMetrics(prometheusURL, sub, wantColor); err != nil { + if err := showMetrics(prometheusURL, sub, wantColor); err != nil { return err } } case cwhub.APPSEC_RULES: - metrics := GetAppsecRuleMetric(prometheusURL, hubItem.Name) + metrics := getAppsecRuleMetric(prometheusURL, hubItem.Name) appsecMetricsTable(color.Output, wantColor, hubItem.Name, metrics) default: // no metrics for this item type } @@ -39,11 +39,11 @@ func ShowMetrics(prometheusURL string, hubItem *cwhub.Item, wantColor string) er return nil } -// GetParserMetric is a complete rip from prom2json -func GetParserMetric(url string, itemName string) map[string]map[string]int { +// getParserMetric is a complete rip from prom2json +func getParserMetric(url string, itemName string) map[string]map[string]int { stats := make(map[string]map[string]int) - result := GetPrometheusMetric(url) + result := getPrometheusMetric(url) for idx, fam := range result { if !strings.HasPrefix(fam.Name, "cs_") { continue @@ -131,7 +131,7 @@ func GetParserMetric(url string, itemName string) map[string]map[string]int { return stats } -func GetScenarioMetric(url string, itemName string) map[string]int { +func getScenarioMetric(url string, itemName string) map[string]int { stats := make(map[string]int) stats["instantiation"] = 0 @@ -140,7 +140,7 @@ func GetScenarioMetric(url string, itemName string) map[string]int { stats["pour"] = 0 stats["underflow"] = 0 - result := GetPrometheusMetric(url) + result := getPrometheusMetric(url) for idx, fam := range result { if !strings.HasPrefix(fam.Name, "cs_") { continue @@ -195,13 +195,13 @@ func GetScenarioMetric(url string, itemName string) map[string]int { return stats } -func GetAppsecRuleMetric(url string, itemName string) map[string]int { +func getAppsecRuleMetric(url string, itemName string) map[string]int { stats := make(map[string]int) stats["inband_hits"] = 0 stats["outband_hits"] = 0 - results := GetPrometheusMetric(url) + results := getPrometheusMetric(url) for idx, fam := range results { if !strings.HasPrefix(fam.Name, "cs_") { continue @@ -260,7 +260,7 @@ func GetAppsecRuleMetric(url string, itemName string) map[string]int { return stats } -func GetPrometheusMetric(url string) []*prom2json.Family { +func getPrometheusMetric(url string) []*prom2json.Family { mfChan := make(chan *dto.MetricFamily, 1024) // Start with the DefaultTransport for sane defaults. diff --git a/cmd/crowdsec-cli/items.go b/cmd/crowdsec-cli/clihub/items.go similarity index 84% rename from cmd/crowdsec-cli/items.go rename to cmd/crowdsec-cli/clihub/items.go index b0c03922166..f86fe65a2a1 100644 --- a/cmd/crowdsec-cli/items.go +++ b/cmd/crowdsec-cli/clihub/items.go @@ -1,4 +1,4 @@ -package main +package clihub import ( "encoding/csv" @@ -16,8 +16,13 @@ import ( ) // selectItems returns a slice of items of a given type, selected by name and sorted by case-insensitive name -func selectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly bool) ([]*cwhub.Item, error) { - itemNames := hub.GetNamesByType(itemType) +func SelectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly bool) ([]*cwhub.Item, error) { + allItems := hub.GetItemsByType(itemType, true) + + itemNames := make([]string, len(allItems)) + for idx, item := range allItems { + itemNames[idx] = item.Name + } notExist := []string{} @@ -38,7 +43,7 @@ func selectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly b installedOnly = false } - items := make([]*cwhub.Item, 0, len(itemNames)) + wantedItems := make([]*cwhub.Item, 0, len(itemNames)) for _, itemName := range itemNames { item := hub.GetItem(itemType, itemName) @@ -46,15 +51,13 @@ func selectItems(hub *cwhub.Hub, itemType string, args []string, installedOnly b continue } - items = append(items, item) + wantedItems = append(wantedItems, item) } - cwhub.SortItemSlice(items) - - return items, nil + return wantedItems, nil } -func listItems(out io.Writer, wantColor string, itemTypes []string, items map[string][]*cwhub.Item, omitIfEmpty bool, output string) error { +func ListItems(out io.Writer, wantColor string, itemTypes []string, items map[string][]*cwhub.Item, omitIfEmpty bool, output string) error { switch output { case "human": nothingToDisplay := true @@ -103,7 +106,7 @@ func listItems(out io.Writer, wantColor string, itemTypes []string, items map[st x, err := json.MarshalIndent(hubStatus, "", " ") if err != nil { - return fmt.Errorf("failed to unmarshal: %w", err) + return fmt.Errorf("failed to parse: %w", err) } out.Write(x) @@ -143,7 +146,7 @@ func listItems(out io.Writer, wantColor string, itemTypes []string, items map[st return nil } -func inspectItem(item *cwhub.Item, showMetrics bool, output string, prometheusURL string, wantColor string) error { +func InspectItem(item *cwhub.Item, wantMetrics bool, output string, prometheusURL string, wantColor string) error { switch output { case "human", "raw": enc := yaml.NewEncoder(os.Stdout) @@ -155,7 +158,7 @@ func inspectItem(item *cwhub.Item, showMetrics bool, output string, prometheusUR case "json": b, err := json.MarshalIndent(*item, "", " ") if err != nil { - return fmt.Errorf("unable to marshal item: %w", err) + return fmt.Errorf("unable to serialize item: %w", err) } fmt.Print(string(b)) @@ -171,10 +174,10 @@ func inspectItem(item *cwhub.Item, showMetrics bool, output string, prometheusUR fmt.Println() } - if showMetrics { + if wantMetrics { fmt.Printf("\nCurrent metrics: \n") - if err := ShowMetrics(prometheusURL, item, wantColor); err != nil { + if err := showMetrics(prometheusURL, item, wantColor); err != nil { return err } } diff --git a/cmd/crowdsec-cli/utils_table.go b/cmd/crowdsec-cli/clihub/utils_table.go similarity index 92% rename from cmd/crowdsec-cli/utils_table.go rename to cmd/crowdsec-cli/clihub/utils_table.go index 6df16cd85f5..98f14341b10 100644 --- a/cmd/crowdsec-cli/utils_table.go +++ b/cmd/crowdsec-cli/clihub/utils_table.go @@ -1,4 +1,4 @@ -package main +package clihub import ( "fmt" @@ -22,7 +22,7 @@ func listHubItemTable(out io.Writer, wantColor string, title string, items []*cw } io.WriteString(out, title+"\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, t.Render()+"\n") } func appsecMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]int) { @@ -35,7 +35,7 @@ func appsecMetricsTable(out io.Writer, wantColor string, itemName string, metric }) io.WriteString(out, fmt.Sprintf("\n - (AppSec Rule) %s:\n", itemName)) - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, t.Render()+"\n") } func scenarioMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]int) { @@ -55,7 +55,7 @@ func scenarioMetricsTable(out io.Writer, wantColor string, itemName string, metr }) io.WriteString(out, fmt.Sprintf("\n - (Scenario) %s:\n", itemName)) - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, t.Render()+"\n") } func parserMetricsTable(out io.Writer, wantColor string, itemName string, metrics map[string]map[string]int) { @@ -80,6 +80,6 @@ func parserMetricsTable(out io.Writer, wantColor string, itemName string, metric if showTable { io.WriteString(out, fmt.Sprintf("\n - (Parser) %s:\n", itemName)) - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, t.Render()+"\n") } } diff --git a/cmd/crowdsec-cli/clihubtest/clean.go b/cmd/crowdsec-cli/clihubtest/clean.go new file mode 100644 index 00000000000..e3b40b6bd57 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/clean.go @@ -0,0 +1,31 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newCleanCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "clean", + Short: "clean [test_name]", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/coverage.go b/cmd/crowdsec-cli/clihubtest/coverage.go new file mode 100644 index 00000000000..5a4f231caf5 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/coverage.go @@ -0,0 +1,166 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + "math" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +// getCoverage returns the coverage and the percentage of tests that passed +func getCoverage(show bool, getCoverageFunc func() ([]hubtest.Coverage, error)) ([]hubtest.Coverage, int, error) { + if !show { + return nil, 0, nil + } + + coverage, err := getCoverageFunc() + if err != nil { + return nil, 0, fmt.Errorf("while getting coverage: %w", err) + } + + tested := 0 + + for _, test := range coverage { + if test.TestsCount > 0 { + tested++ + } + } + + // keep coverage 0 if there's no tests? + percent := 0 + if len(coverage) > 0 { + percent = int(math.Round((float64(tested) / float64(len(coverage)) * 100))) + } + + return coverage, percent, nil +} + +func (cli *cliHubTest) coverage(showScenarioCov bool, showParserCov bool, showAppsecCov bool, showOnlyPercent bool) error { + cfg := cli.cfg() + + // for this one we explicitly don't do for appsec + if err := HubTest.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %+v", err) + } + + var err error + + // if all are false (flag by default), show them + if !showParserCov && !showScenarioCov && !showAppsecCov { + showParserCov = true + showScenarioCov = true + showAppsecCov = true + } + + parserCoverage, parserCoveragePercent, err := getCoverage(showParserCov, HubTest.GetParsersCoverage) + if err != nil { + return err + } + + scenarioCoverage, scenarioCoveragePercent, err := getCoverage(showScenarioCov, HubTest.GetScenariosCoverage) + if err != nil { + return err + } + + appsecRuleCoverage, appsecRuleCoveragePercent, err := getCoverage(showAppsecCov, HubTest.GetAppsecCoverage) + if err != nil { + return err + } + + if showOnlyPercent { + switch { + case showParserCov: + fmt.Printf("parsers=%d%%", parserCoveragePercent) + case showScenarioCov: + fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) + case showAppsecCov: + fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) + } + + return nil + } + + switch cfg.Cscli.Output { + case "human": + if showParserCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Parser", "Status", "Number of tests"}, parserCoverage) + } + + if showScenarioCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Scenario", "Status", "Number of tests"}, parserCoverage) + } + + if showAppsecCov { + hubTestCoverageTable(color.Output, cfg.Cscli.Color, []string{"Appsec Rule", "Status", "Number of tests"}, parserCoverage) + } + + fmt.Println() + + if showParserCov { + fmt.Printf("PARSERS : %d%% of coverage\n", parserCoveragePercent) + } + + if showScenarioCov { + fmt.Printf("SCENARIOS : %d%% of coverage\n", scenarioCoveragePercent) + } + + if showAppsecCov { + fmt.Printf("APPSEC RULES : %d%% of coverage\n", appsecRuleCoveragePercent) + } + case "json": + dump, err := json.MarshalIndent(parserCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + + dump, err = json.MarshalIndent(scenarioCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + + dump, err = json.MarshalIndent(appsecRuleCoverage, "", " ") + if err != nil { + return err + } + + fmt.Printf("%s", dump) + default: + return errors.New("only human/json output modes are supported") + } + + return nil +} + +func (cli *cliHubTest) newCoverageCmd() *cobra.Command { + var ( + showParserCov bool + showScenarioCov bool + showOnlyPercent bool + showAppsecCov bool + ) + + cmd := &cobra.Command{ + Use: "coverage", + Short: "coverage", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + return cli.coverage(showScenarioCov, showParserCov, showAppsecCov, showOnlyPercent) + }, + } + + cmd.PersistentFlags().BoolVar(&showOnlyPercent, "percent", false, "Show only percentages of coverage") + cmd.PersistentFlags().BoolVar(&showParserCov, "parsers", false, "Show only parsers coverage") + cmd.PersistentFlags().BoolVar(&showScenarioCov, "scenarios", false, "Show only scenarios coverage") + cmd.PersistentFlags().BoolVar(&showAppsecCov, "appsec", false, "Show only appsec coverage") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/create.go b/cmd/crowdsec-cli/clihubtest/create.go new file mode 100644 index 00000000000..3822bed8903 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/create.go @@ -0,0 +1,158 @@ +package clihubtest + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "text/template" + + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) newCreateCmd() *cobra.Command { + var ( + ignoreParsers bool + labels map[string]string + logType string + ) + + parsers := []string{} + postoverflows := []string{} + scenarios := []string{} + + cmd := &cobra.Command{ + Use: "create", + Short: "create [test_name]", + Example: `cscli hubtest create my-awesome-test --type syslog +cscli hubtest create my-nginx-custom-test --type nginx +cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios crowdsecurity/http-probing`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + testName := args[0] + testPath := filepath.Join(hubPtr.HubTestPath, testName) + if _, err := os.Stat(testPath); os.IsExist(err) { + return fmt.Errorf("test '%s' already exists in '%s', exiting", testName, testPath) + } + + if isAppsecTest { + logType = "appsec" + } + + if logType == "" { + return errors.New("please provide a type (--type) for the test") + } + + if err := os.MkdirAll(testPath, os.ModePerm); err != nil { + return fmt.Errorf("unable to create folder '%s': %+v", testPath, err) + } + + configFilePath := filepath.Join(testPath, "config.yaml") + + configFileData := &hubtest.HubTestItemConfig{} + if logType == "appsec" { + // create empty nuclei template file + nucleiFileName := testName + ".yaml" + nucleiFilePath := filepath.Join(testPath, nucleiFileName) + + nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0o755) + if err != nil { + return err + } + + ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) + if ntpl == nil { + return errors.New("unable to parse nuclei template") + } + ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) + nucleiFile.Close() + configFileData.AppsecRules = []string{"./appsec-rules//your_rule_here.yaml"} + configFileData.NucleiTemplate = nucleiFileName + fmt.Println() + fmt.Printf(" Test name : %s\n", testName) + fmt.Printf(" Test path : %s\n", testPath) + fmt.Printf(" Config File : %s\n", configFilePath) + fmt.Printf(" Nuclei Template : %s\n", nucleiFilePath) + } else { + // create empty log file + logFileName := testName + ".log" + logFilePath := filepath.Join(testPath, logFileName) + logFile, err := os.Create(logFilePath) + if err != nil { + return err + } + logFile.Close() + + // create empty parser assertion file + parserAssertFilePath := filepath.Join(testPath, hubtest.ParserAssertFileName) + parserAssertFile, err := os.Create(parserAssertFilePath) + if err != nil { + return err + } + parserAssertFile.Close() + // create empty scenario assertion file + scenarioAssertFilePath := filepath.Join(testPath, hubtest.ScenarioAssertFileName) + scenarioAssertFile, err := os.Create(scenarioAssertFilePath) + if err != nil { + return err + } + scenarioAssertFile.Close() + + parsers = append(parsers, "crowdsecurity/syslog-logs") + parsers = append(parsers, "crowdsecurity/dateparse-enrich") + + if len(scenarios) == 0 { + scenarios = append(scenarios, "") + } + + if len(postoverflows) == 0 { + postoverflows = append(postoverflows, "") + } + configFileData.Parsers = parsers + configFileData.Scenarios = scenarios + configFileData.PostOverflows = postoverflows + configFileData.LogFile = logFileName + configFileData.LogType = logType + configFileData.IgnoreParsers = ignoreParsers + configFileData.Labels = labels + fmt.Println() + fmt.Printf(" Test name : %s\n", testName) + fmt.Printf(" Test path : %s\n", testPath) + fmt.Printf(" Log file : %s (please fill it with logs)\n", logFilePath) + fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) + fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) + fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) + } + + fd, err := os.Create(configFilePath) + if err != nil { + return fmt.Errorf("open: %w", err) + } + data, err := yaml.Marshal(configFileData) + if err != nil { + return fmt.Errorf("serialize: %w", err) + } + _, err = fd.Write(data) + if err != nil { + return fmt.Errorf("write: %w", err) + } + if err := fd.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + + return nil + }, + } + + cmd.PersistentFlags().StringVarP(&logType, "type", "t", "", "Log type of the test") + cmd.Flags().StringSliceVarP(&parsers, "parsers", "p", parsers, "Parsers to add to test") + cmd.Flags().StringSliceVar(&postoverflows, "postoverflows", postoverflows, "Postoverflows to add to test") + cmd.Flags().StringSliceVarP(&scenarios, "scenarios", "s", scenarios, "Scenarios to add to test") + cmd.PersistentFlags().BoolVar(&ignoreParsers, "ignore-parsers", false, "Don't run test on parsers") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/eval.go b/cmd/crowdsec-cli/clihubtest/eval.go new file mode 100644 index 00000000000..83e9eae9c15 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/eval.go @@ -0,0 +1,44 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newEvalCmd() *cobra.Command { + var evalExpression string + + cmd := &cobra.Command{ + Use: "eval", + Short: "eval [test_name]", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("can't load test: %+v", err) + } + + err = test.ParserAssert.LoadTest(test.ParserResultFile) + if err != nil { + return fmt.Errorf("can't load test results from '%s': %+v", test.ParserResultFile, err) + } + + output, err := test.ParserAssert.EvalExpression(evalExpression) + if err != nil { + return err + } + + fmt.Print(output) + } + + return nil + }, + } + + cmd.PersistentFlags().StringVarP(&evalExpression, "expr", "e", "", "Expression to eval") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/explain.go b/cmd/crowdsec-cli/clihubtest/explain.go new file mode 100644 index 00000000000..dbe10fa7ec0 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/explain.go @@ -0,0 +1,76 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/dumps" +) + +func (cli *cliHubTest) explain(testName string, details bool, skipOk bool) error { + test, err := HubTest.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("can't load test: %+v", err) + } + + err = test.ParserAssert.LoadTest(test.ParserResultFile) + if err != nil { + if err = test.Run(); err != nil { + return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) + } + + if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { + return fmt.Errorf("unable to load parser result after run: %w", err) + } + } + + err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) + if err != nil { + if err = test.Run(); err != nil { + return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) + } + + if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { + return fmt.Errorf("unable to load scenario result after run: %w", err) + } + } + + opts := dumps.DumpOpts{ + Details: details, + SkipOk: skipOk, + } + + dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) + + return nil +} + +func (cli *cliHubTest) newExplainCmd() *cobra.Command { + var ( + details bool + skipOk bool + ) + + cmd := &cobra.Command{ + Use: "explain", + Short: "explain [test_name]", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + if err := cli.explain(testName, details, skipOk); err != nil { + return err + } + } + + return nil + }, + } + + flags := cmd.Flags() + flags.BoolVarP(&details, "verbose", "v", false, "Display individual changes") + flags.BoolVar(&skipOk, "failures", false, "Only show failed lines") + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/hubtest.go b/cmd/crowdsec-cli/clihubtest/hubtest.go new file mode 100644 index 00000000000..f4cfed2e1cb --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/hubtest.go @@ -0,0 +1,81 @@ +package clihubtest + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +type configGetter func() *csconfig.Config + +var ( + HubTest hubtest.HubTest + HubAppsecTests hubtest.HubTest + hubPtr *hubtest.HubTest + isAppsecTest bool +) + +type cliHubTest struct { + cfg configGetter +} + +func New(cfg configGetter) *cliHubTest { + return &cliHubTest{ + cfg: cfg, + } +} + +func (cli *cliHubTest) NewCommand() *cobra.Command { + var ( + hubPath string + crowdsecPath string + cscliPath string + ) + + cmd := &cobra.Command{ + Use: "hubtest", + Short: "Run functional tests on hub configurations", + Long: "Run functional tests on hub configurations (parsers, scenarios, collections...)", + Args: cobra.NoArgs, + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + var err error + HubTest, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, false) + if err != nil { + return fmt.Errorf("unable to load hubtest: %+v", err) + } + + HubAppsecTests, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, true) + if err != nil { + return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) + } + + // commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests + hubPtr = &HubTest + if isAppsecTest { + hubPtr = &HubAppsecTests + } + + return nil + }, + } + + cmd.PersistentFlags().StringVar(&hubPath, "hub", ".", "Path to hub folder") + cmd.PersistentFlags().StringVar(&crowdsecPath, "crowdsec", "crowdsec", "Path to crowdsec") + cmd.PersistentFlags().StringVar(&cscliPath, "cscli", "cscli", "Path to cscli") + cmd.PersistentFlags().BoolVar(&isAppsecTest, "appsec", false, "Command relates to appsec tests") + + cmd.AddCommand(cli.newCreateCmd()) + cmd.AddCommand(cli.newRunCmd()) + cmd.AddCommand(cli.newCleanCmd()) + cmd.AddCommand(cli.newInfoCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newCoverageCmd()) + cmd.AddCommand(cli.newEvalCmd()) + cmd.AddCommand(cli.newExplainCmd()) + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/info.go b/cmd/crowdsec-cli/clihubtest/info.go new file mode 100644 index 00000000000..a5d760eea01 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/info.go @@ -0,0 +1,44 @@ +package clihubtest + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) newInfoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "info", + Short: "info [test_name]", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + for _, testName := range args { + test, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + fmt.Println() + fmt.Printf(" Test name : %s\n", test.Name) + fmt.Printf(" Test path : %s\n", test.Path) + if isAppsecTest { + fmt.Printf(" Nuclei Template : %s\n", test.Config.NucleiTemplate) + fmt.Printf(" Appsec Rules : %s\n", strings.Join(test.Config.AppsecRules, ", ")) + } else { + fmt.Printf(" Log file : %s\n", filepath.Join(test.Path, test.Config.LogFile)) + fmt.Printf(" Parser assertion file : %s\n", filepath.Join(test.Path, hubtest.ParserAssertFileName)) + fmt.Printf(" Scenario assertion file : %s\n", filepath.Join(test.Path, hubtest.ScenarioAssertFileName)) + } + fmt.Printf(" Configuration File : %s\n", filepath.Join(test.Path, "config.yaml")) + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/list.go b/cmd/crowdsec-cli/clihubtest/list.go new file mode 100644 index 00000000000..3e76824a18e --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/list.go @@ -0,0 +1,42 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/fatih/color" + "github.com/spf13/cobra" +) + +func (cli *cliHubTest) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + + if err := hubPtr.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %w", err) + } + + switch cfg.Cscli.Output { + case "human": + hubTestListTable(color.Output, cfg.Cscli.Color, hubPtr.Tests) + case "json": + j, err := json.MarshalIndent(hubPtr.Tests, " ", " ") + if err != nil { + return err + } + fmt.Println(string(j)) + default: + return errors.New("only human/json output modes are supported") + } + + return nil + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clihubtest/run.go b/cmd/crowdsec-cli/clihubtest/run.go new file mode 100644 index 00000000000..31cceb81884 --- /dev/null +++ b/cmd/crowdsec-cli/clihubtest/run.go @@ -0,0 +1,213 @@ +package clihubtest + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "strings" + + "github.com/AlecAivazis/survey/v2" + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/emoji" + "github.com/crowdsecurity/crowdsec/pkg/hubtest" +) + +func (cli *cliHubTest) run(runAll bool, nucleiTargetHost string, appSecHost string, args []string) error { + cfg := cli.cfg() + + if !runAll && len(args) == 0 { + return errors.New("please provide test to run or --all flag") + } + + hubPtr.NucleiTargetHost = nucleiTargetHost + hubPtr.AppSecHost = appSecHost + + if runAll { + if err := hubPtr.LoadAllTests(); err != nil { + return fmt.Errorf("unable to load all tests: %+v", err) + } + } else { + for _, testName := range args { + _, err := hubPtr.LoadTestItem(testName) + if err != nil { + return fmt.Errorf("unable to load test '%s': %w", testName, err) + } + } + } + + // set timezone to avoid DST issues + os.Setenv("TZ", "UTC") + + for _, test := range hubPtr.Tests { + if cfg.Cscli.Output == "human" { + log.Infof("Running test '%s'", test.Name) + } + + err := test.Run() + if err != nil { + log.Errorf("running test '%s' failed: %+v", test.Name, err) + } + } + + return nil +} + +func printParserFailures(test *hubtest.HubTestItem) { + if len(test.ParserAssert.Fails) == 0 { + return + } + + fmt.Println() + log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) + + for _, fail := range test.ParserAssert.Fails { + fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) + fmt.Printf(" Actual expression values:\n") + + for key, value := range fail.Debug { + fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) + } + + fmt.Println() + } +} + +func printScenarioFailures(test *hubtest.HubTestItem) { + if len(test.ScenarioAssert.Fails) == 0 { + return + } + + fmt.Println() + log.Errorf("Scenario test '%s' failed (%d errors)\n", test.Name, len(test.ScenarioAssert.Fails)) + + for _, fail := range test.ScenarioAssert.Fails { + fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) + fmt.Printf(" Actual expression values:\n") + + for key, value := range fail.Debug { + fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) + } + + fmt.Println() + } +} + +func (cli *cliHubTest) newRunCmd() *cobra.Command { + var ( + noClean bool + runAll bool + forceClean bool + nucleiTargetHost string + appSecHost string + ) + + cmd := &cobra.Command{ + Use: "run", + Short: "run [test_name]", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.run(runAll, nucleiTargetHost, appSecHost, args) + }, + PersistentPostRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + + success := true + testResult := make(map[string]bool) + for _, test := range hubPtr.Tests { + if test.AutoGen && !isAppsecTest { + if test.ParserAssert.AutoGenAssert { + log.Warningf("Assert file '%s' is empty, generating assertion:", test.ParserAssert.File) + fmt.Println() + fmt.Println(test.ParserAssert.AutoGenAssertData) + } + if test.ScenarioAssert.AutoGenAssert { + log.Warningf("Assert file '%s' is empty, generating assertion:", test.ScenarioAssert.File) + fmt.Println() + fmt.Println(test.ScenarioAssert.AutoGenAssertData) + } + if !noClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + + return fmt.Errorf("please fill your assert file(s) for test '%s', exiting", test.Name) + } + testResult[test.Name] = test.Success + if test.Success { + if cfg.Cscli.Output == "human" { + log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) + } + if !noClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + } else { + success = false + cleanTestEnv := false + if cfg.Cscli.Output == "human" { + printParserFailures(test) + printScenarioFailures(test) + if !forceClean && !noClean { + prompt := &survey.Confirm{ + Message: fmt.Sprintf("\nDo you want to remove runtime folder for test '%s'? (default: Yes)", test.Name), + Default: true, + } + if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { + return fmt.Errorf("unable to ask to remove runtime folder: %w", err) + } + } + } + + if cleanTestEnv || forceClean { + if err := test.Clean(); err != nil { + return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) + } + } + } + } + + switch cfg.Cscli.Output { + case "human": + hubTestResultTable(color.Output, cfg.Cscli.Color, testResult) + case "json": + jsonResult := make(map[string][]string, 0) + jsonResult["success"] = make([]string, 0) + jsonResult["fail"] = make([]string, 0) + for testName, success := range testResult { + if success { + jsonResult["success"] = append(jsonResult["success"], testName) + } else { + jsonResult["fail"] = append(jsonResult["fail"], testName) + } + } + jsonStr, err := json.Marshal(jsonResult) + if err != nil { + return fmt.Errorf("unable to json test result: %w", err) + } + fmt.Println(string(jsonStr)) + default: + return errors.New("only human/json output modes are supported") + } + + if !success { + return errors.New("some tests failed") + } + + return nil + }, + } + + cmd.Flags().BoolVar(&noClean, "no-clean", false, "Don't clean runtime environment if test succeed") + cmd.Flags().BoolVar(&forceClean, "clean", false, "Clean runtime environment if test fail") + cmd.Flags().StringVar(&nucleiTargetHost, "target", hubtest.DefaultNucleiTarget, "Target for AppSec Test") + cmd.Flags().StringVar(&appSecHost, "host", hubtest.DefaultAppsecHost, "Address to expose AppSec for hubtest") + cmd.Flags().BoolVar(&runAll, "all", false, "Run all tests") + + return cmd +} diff --git a/cmd/crowdsec-cli/hubtest_table.go b/cmd/crowdsec-cli/clihubtest/table.go similarity index 50% rename from cmd/crowdsec-cli/hubtest_table.go rename to cmd/crowdsec-cli/clihubtest/table.go index 1fa0f990be2..2a105a1f5c1 100644 --- a/cmd/crowdsec-cli/hubtest_table.go +++ b/cmd/crowdsec-cli/clihubtest/table.go @@ -1,4 +1,4 @@ -package main +package clihubtest import ( "fmt" @@ -42,51 +42,9 @@ func hubTestListTable(out io.Writer, wantColor string, tests []*hubtest.HubTestI t.Render() } -func hubTestParserCoverageTable(out io.Writer, wantColor string, coverage []hubtest.Coverage) { +func hubTestCoverageTable(out io.Writer, wantColor string, headers []string, coverage []hubtest.Coverage) { t := cstable.NewLight(out, wantColor) - t.SetHeaders("Parser", "Status", "Number of tests") - t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) - t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) - - parserTested := 0 - - for _, test := range coverage { - status := emoji.RedCircle - if test.TestsCount > 0 { - status = emoji.GreenCircle - parserTested++ - } - - t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} - -func hubTestAppsecRuleCoverageTable(out io.Writer, wantColor string, coverage []hubtest.Coverage) { - t := cstable.NewLight(out, wantColor) - t.SetHeaders("Appsec Rule", "Status", "Number of tests") - t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) - t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) - - parserTested := 0 - - for _, test := range coverage { - status := emoji.RedCircle - if test.TestsCount > 0 { - status = emoji.GreenCircle - parserTested++ - } - - t.AddRow(test.Name, status, fmt.Sprintf("%d times (across %d tests)", test.TestsCount, len(test.PresentIn))) - } - - t.Render() -} - -func hubTestScenarioCoverageTable(out io.Writer, wantColor string, coverage []hubtest.Coverage) { - t := cstable.NewLight(out, wantColor) - t.SetHeaders("Scenario", "Status", "Number of tests") + t.SetHeaders(headers...) t.SetHeaderAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) t.SetAlignment(text.AlignLeft, text.AlignLeft, text.AlignLeft) diff --git a/cmd/crowdsec-cli/hubappsec.go b/cmd/crowdsec-cli/cliitem/appsec.go similarity index 93% rename from cmd/crowdsec-cli/hubappsec.go rename to cmd/crowdsec-cli/cliitem/appsec.go index 1df3212f941..44afa2133bd 100644 --- a/cmd/crowdsec-cli/hubappsec.go +++ b/cmd/crowdsec-cli/cliitem/appsec.go @@ -1,4 +1,4 @@ -package main +package cliitem import ( "fmt" @@ -13,7 +13,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIAppsecConfig(cfg configGetter) *cliItem { +func NewAppsecConfig(cfg configGetter) *cliItem { return &cliItem{ cfg: cfg, name: cwhub.APPSEC_CONFIGS, @@ -47,7 +47,7 @@ cscli appsec-configs list crowdsecurity/vpatch`, } } -func NewCLIAppsecRule(cfg configGetter) *cliItem { +func NewAppsecRule(cfg configGetter) *cliItem { inspectDetail := func(item *cwhub.Item) error { // Only show the converted rules in human mode if cfg().Cscli.Output != "human" { @@ -62,7 +62,7 @@ func NewCLIAppsecRule(cfg configGetter) *cliItem { } if err := yaml.Unmarshal(yamlContent, &appsecRule); err != nil { - return fmt.Errorf("unable to unmarshal yaml file %s: %w", item.State.LocalPath, err) + return fmt.Errorf("unable to parse yaml file %s: %w", item.State.LocalPath, err) } for _, ruleType := range appsec_rule.SupportedTypes() { diff --git a/cmd/crowdsec-cli/hubcollection.go b/cmd/crowdsec-cli/cliitem/collection.go similarity index 95% rename from cmd/crowdsec-cli/hubcollection.go rename to cmd/crowdsec-cli/cliitem/collection.go index 655b36eb1b8..ea91c1e537a 100644 --- a/cmd/crowdsec-cli/hubcollection.go +++ b/cmd/crowdsec-cli/cliitem/collection.go @@ -1,10 +1,10 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLICollection(cfg configGetter) *cliItem { +func NewCollection(cfg configGetter) *cliItem { return &cliItem{ cfg: cfg, name: cwhub.COLLECTIONS, diff --git a/cmd/crowdsec-cli/hubcontext.go b/cmd/crowdsec-cli/cliitem/context.go similarity index 94% rename from cmd/crowdsec-cli/hubcontext.go rename to cmd/crowdsec-cli/cliitem/context.go index 2a777327379..7d110b8203d 100644 --- a/cmd/crowdsec-cli/hubcontext.go +++ b/cmd/crowdsec-cli/cliitem/context.go @@ -1,10 +1,10 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIContext(cfg configGetter) *cliItem { +func NewContext(cfg configGetter) *cliItem { return &cliItem{ cfg: cfg, name: cwhub.CONTEXTS, diff --git a/cmd/crowdsec-cli/hubscenario.go b/cmd/crowdsec-cli/cliitem/hubscenario.go similarity index 95% rename from cmd/crowdsec-cli/hubscenario.go rename to cmd/crowdsec-cli/cliitem/hubscenario.go index 4434b9a2c45..a5e854b3c82 100644 --- a/cmd/crowdsec-cli/hubscenario.go +++ b/cmd/crowdsec-cli/cliitem/hubscenario.go @@ -1,10 +1,10 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIScenario(cfg configGetter) *cliItem { +func NewScenario(cfg configGetter) *cliItem { return &cliItem{ cfg: cfg, name: cwhub.SCENARIOS, diff --git a/cmd/crowdsec-cli/itemcli.go b/cmd/crowdsec-cli/cliitem/item.go similarity index 93% rename from cmd/crowdsec-cli/itemcli.go rename to cmd/crowdsec-cli/cliitem/item.go index 64c18ae89b1..28828eb9c95 100644 --- a/cmd/crowdsec-cli/itemcli.go +++ b/cmd/crowdsec-cli/cliitem/item.go @@ -1,4 +1,4 @@ -package main +package cliitem import ( "cmp" @@ -15,7 +15,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) @@ -28,6 +31,8 @@ type cliHelp struct { example string } +type configGetter func() *csconfig.Config + type cliItem struct { cfg configGetter name string // plural, as used in the hub index @@ -78,7 +83,7 @@ func (cli cliItem) install(ctx context.Context, args []string, downloadOnly bool return errors.New(msg) } - log.Errorf(msg) + log.Error(msg) continue } @@ -92,7 +97,7 @@ func (cli cliItem) install(ctx context.Context, args []string, downloadOnly bool } } - log.Infof(ReloadMessage()) + log.Info(reload.Message) return nil } @@ -147,19 +152,14 @@ func (cli cliItem) remove(args []string, purge bool, force bool, all bool) error } if all { - getter := hub.GetInstalledItemsByType + itemGetter := hub.GetInstalledByType if purge { - getter = hub.GetItemsByType - } - - items, err := getter(cli.name) - if err != nil { - return err + itemGetter = hub.GetItemsByType } removed := 0 - for _, item := range items { + for _, item := range itemGetter(cli.name, true) { didRemove, err := item.Remove(purge, force) if err != nil { return err @@ -175,7 +175,7 @@ func (cli cliItem) remove(args []string, purge bool, force bool, all bool) error log.Infof("Removed %d %s", removed, cli.name) if removed > 0 { - log.Infof(ReloadMessage()) + log.Info(reload.Message) } return nil @@ -217,7 +217,7 @@ func (cli cliItem) remove(args []string, purge bool, force bool, all bool) error log.Infof("Removed %d %s", removed, cli.name) if removed > 0 { - log.Infof(ReloadMessage()) + log.Info(reload.Message) } return nil @@ -262,14 +262,9 @@ func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all b } if all { - items, err := hub.GetInstalledItemsByType(cli.name) - if err != nil { - return err - } - updated := 0 - for _, item := range items { + for _, item := range hub.GetInstalledByType(cli.name, true) { didUpdate, err := item.Upgrade(ctx, force) if err != nil { return err @@ -283,7 +278,7 @@ func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all b log.Infof("Updated %d %s", updated, cli.name) if updated > 0 { - log.Infof(ReloadMessage()) + log.Info(reload.Message) } return nil @@ -314,7 +309,7 @@ func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all b } if updated > 0 { - log.Infof(ReloadMessage()) + log.Info(reload.Message) } return nil @@ -381,7 +376,7 @@ func (cli cliItem) inspect(ctx context.Context, args []string, url string, diff continue } - if err = inspectItem(item, !noMetrics, cfg.Cscli.Output, cfg.Cscli.PrometheusUrl, cfg.Cscli.Color); err != nil { + if err = clihub.InspectItem(item, !noMetrics, cfg.Cscli.Output, cfg.Cscli.PrometheusUrl, cfg.Cscli.Color); err != nil { return err } @@ -437,12 +432,12 @@ func (cli cliItem) list(args []string, all bool) error { items := make(map[string][]*cwhub.Item) - items[cli.name], err = selectItems(hub, cli.name, args, !all) + items[cli.name], err = clihub.SelectItems(hub, cli.name, args, !all) if err != nil { return err } - return listItems(color.Output, cfg.Cscli.Color, []string{cli.name}, items, false, cfg.Cscli.Output) + return clihub.ListItems(color.Output, cfg.Cscli.Color, []string{cli.name}, items, false, cfg.Cscli.Output) } func (cli cliItem) newListCmd() *cobra.Command { diff --git a/cmd/crowdsec-cli/hubparser.go b/cmd/crowdsec-cli/cliitem/parser.go similarity index 95% rename from cmd/crowdsec-cli/hubparser.go rename to cmd/crowdsec-cli/cliitem/parser.go index cc856cbedb9..bc1d96bdaf0 100644 --- a/cmd/crowdsec-cli/hubparser.go +++ b/cmd/crowdsec-cli/cliitem/parser.go @@ -1,10 +1,10 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIParser(cfg configGetter) *cliItem { +func NewParser(cfg configGetter) *cliItem { return &cliItem{ cfg: cfg, name: cwhub.PARSERS, diff --git a/cmd/crowdsec-cli/hubpostoverflow.go b/cmd/crowdsec-cli/cliitem/postoverflow.go similarity index 95% rename from cmd/crowdsec-cli/hubpostoverflow.go rename to cmd/crowdsec-cli/cliitem/postoverflow.go index 3fd45fd113d..ea53aef327d 100644 --- a/cmd/crowdsec-cli/hubpostoverflow.go +++ b/cmd/crowdsec-cli/cliitem/postoverflow.go @@ -1,10 +1,10 @@ -package main +package cliitem import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -func NewCLIPostOverflow(cfg configGetter) *cliItem { +func NewPostOverflow(cfg configGetter) *cliItem { return &cliItem{ cfg: cfg, name: cwhub.POSTOVERFLOWS, diff --git a/cmd/crowdsec-cli/item_suggest.go b/cmd/crowdsec-cli/cliitem/suggest.go similarity index 77% rename from cmd/crowdsec-cli/item_suggest.go rename to cmd/crowdsec-cli/cliitem/suggest.go index 0ea656549ba..5b080722af9 100644 --- a/cmd/crowdsec-cli/item_suggest.go +++ b/cmd/crowdsec-cli/cliitem/suggest.go @@ -1,4 +1,4 @@ -package main +package cliitem import ( "fmt" @@ -19,7 +19,7 @@ func suggestNearestMessage(hub *cwhub.Hub, itemType string, itemName string) str score := 100 nearest := "" - for _, item := range hub.GetItemMap(itemType) { + for _, item := range hub.GetItemsByType(itemType, false) { d := levenshtein.Distance(itemName, item.Name, nil) if d < score { score = d @@ -44,7 +44,7 @@ func compAllItems(itemType string, args []string, toComplete string, cfg configG comp := make([]string, 0) - for _, item := range hub.GetItemMap(itemType) { + for _, item := range hub.GetItemsByType(itemType, false) { if !slices.Contains(args, item.Name) && strings.Contains(item.Name, toComplete) { comp = append(comp, item.Name) } @@ -61,22 +61,14 @@ func compInstalledItems(itemType string, args []string, toComplete string, cfg c return nil, cobra.ShellCompDirectiveDefault } - items, err := hub.GetInstalledNamesByType(itemType) - if err != nil { - cobra.CompDebugln(fmt.Sprintf("list installed %s err: %s", itemType, err), true) - return nil, cobra.ShellCompDirectiveDefault - } + items := hub.GetInstalledByType(itemType, true) comp := make([]string, 0) - if toComplete != "" { - for _, item := range items { - if strings.Contains(item, toComplete) { - comp = append(comp, item) - } + for _, item := range items { + if strings.Contains(item.Name, toComplete) { + comp = append(comp, item.Name) } - } else { - comp = items } cobra.CompDebugln(fmt.Sprintf("%s: %+v", itemType, comp), true) diff --git a/cmd/crowdsec-cli/lapi.go b/cmd/crowdsec-cli/clilapi/context.go similarity index 59% rename from cmd/crowdsec-cli/lapi.go rename to cmd/crowdsec-cli/clilapi/context.go index c0543f98733..20ceb2b9596 100644 --- a/cmd/crowdsec-cli/lapi.go +++ b/cmd/crowdsec-cli/clilapi/context.go @@ -1,261 +1,22 @@ -package main +package clilapi import ( - "context" "errors" "fmt" - "net/url" - "os" "slices" "sort" "strings" - "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "gopkg.in/yaml.v3" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/parser" ) -const LAPIURLPrefix = "v1" - -type cliLapi struct { - cfg configGetter -} - -func NewCLILapi(cfg configGetter) *cliLapi { - return &cliLapi{ - cfg: cfg, - } -} - -// QueryLAPIStatus checks if the Local API is reachable, and if the credentials are correct -func QueryLAPIStatus(hub *cwhub.Hub, credURL string, login string, password string) error { - apiURL, err := url.Parse(credURL) - if err != nil { - return fmt.Errorf("parsing api url: %w", err) - } - - scenarios, err := hub.GetInstalledNamesByType(cwhub.SCENARIOS) - if err != nil { - return fmt.Errorf("failed to get scenarios: %w", err) - } - - client, err := apiclient.NewDefaultClient(apiURL, - LAPIURLPrefix, - cwversion.UserAgent(), - nil) - if err != nil { - return fmt.Errorf("init default client: %w", err) - } - - pw := strfmt.Password(password) - - t := models.WatcherAuthRequest{ - MachineID: &login, - Password: &pw, - Scenarios: scenarios, - } - - _, _, err = client.Auth.AuthenticateWatcher(context.Background(), t) - if err != nil { - return err - } - - return nil -} - -func (cli *cliLapi) status() error { - cfg := cli.cfg() - - cred := cfg.API.Client.Credentials - - hub, err := require.Hub(cfg, nil, nil) - if err != nil { - return err - } - - log.Infof("Loaded credentials from %s", cfg.API.Client.CredentialsFilePath) - log.Infof("Trying to authenticate with username %s on %s", cred.Login, cred.URL) - - if err := QueryLAPIStatus(hub, cred.URL, cred.Login, cred.Password); err != nil { - return fmt.Errorf("failed to authenticate to Local API (LAPI): %w", err) - } - - log.Infof("You can successfully interact with Local API (LAPI)") - - return nil -} - -func (cli *cliLapi) register(apiURL string, outputFile string, machine string, token string) error { - var err error - - lapiUser := machine - cfg := cli.cfg() - - if lapiUser == "" { - lapiUser, err = generateID("") - if err != nil { - return fmt.Errorf("unable to generate machine id: %w", err) - } - } - - password := strfmt.Password(generatePassword(passwordLength)) - - apiurl, err := prepareAPIURL(cfg.API.Client, apiURL) - if err != nil { - return fmt.Errorf("parsing api url: %w", err) - } - - _, err = apiclient.RegisterClient(&apiclient.Config{ - MachineID: lapiUser, - Password: password, - UserAgent: cwversion.UserAgent(), - RegistrationToken: token, - URL: apiurl, - VersionPrefix: LAPIURLPrefix, - }, nil) - if err != nil { - return fmt.Errorf("api client register: %w", err) - } - - log.Printf("Successfully registered to Local API (LAPI)") - - var dumpFile string - - if outputFile != "" { - dumpFile = outputFile - } else if cfg.API.Client.CredentialsFilePath != "" { - dumpFile = cfg.API.Client.CredentialsFilePath - } else { - dumpFile = "" - } - - apiCfg := cfg.API.Client.Credentials - apiCfg.Login = lapiUser - apiCfg.Password = password.String() - - if apiURL != "" { - apiCfg.URL = apiURL - } - - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) - } - - if dumpFile != "" { - err = os.WriteFile(dumpFile, apiConfigDump, 0o600) - if err != nil { - return fmt.Errorf("write api credentials to '%s' failed: %w", dumpFile, err) - } - - log.Printf("Local API credentials written to '%s'", dumpFile) - } else { - fmt.Printf("%s\n", string(apiConfigDump)) - } - - log.Warning(ReloadMessage()) - - return nil -} - -// prepareAPIURL checks/fixes a LAPI connection url (http, https or socket) and returns an URL struct -func prepareAPIURL(clientCfg *csconfig.LocalApiClientCfg, apiURL string) (*url.URL, error) { - if apiURL == "" { - if clientCfg == nil || clientCfg.Credentials == nil || clientCfg.Credentials.URL == "" { - return nil, errors.New("no Local API URL. Please provide it in your configuration or with the -u parameter") - } - - apiURL = clientCfg.Credentials.URL - } - - // URL needs to end with /, but user doesn't care - if !strings.HasSuffix(apiURL, "/") { - apiURL += "/" - } - - // URL needs to start with http://, but user doesn't care - if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "/") { - apiURL = "http://" + apiURL - } - - return url.Parse(apiURL) -} - -func (cli *cliLapi) newStatusCmd() *cobra.Command { - cmdLapiStatus := &cobra.Command{ - Use: "status", - Short: "Check authentication to Local API (LAPI)", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.status() - }, - } - - return cmdLapiStatus -} - -func (cli *cliLapi) newRegisterCmd() *cobra.Command { - var ( - apiURL string - outputFile string - machine string - token string - ) - - cmd := &cobra.Command{ - Use: "register", - Short: "Register a machine to Local API (LAPI)", - Long: `Register your machine to the Local API (LAPI). -Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`, - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.register(apiURL, outputFile, machine, token) - }, - } - - flags := cmd.Flags() - flags.StringVarP(&apiURL, "url", "u", "", "URL of the API (ie. http://127.0.0.1)") - flags.StringVarP(&outputFile, "file", "f", "", "output file destination") - flags.StringVar(&machine, "machine", "", "Name of the machine to register with") - flags.StringVar(&token, "token", "", "Auto registration token to use") - - return cmd -} - -func (cli *cliLapi) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "lapi [action]", - Short: "Manage interaction with Local API (LAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - if err := cli.cfg().LoadAPIClient(); err != nil { - return fmt.Errorf("loading api client: %w", err) - } - return nil - }, - } - - cmd.AddCommand(cli.newRegisterCmd()) - cmd.AddCommand(cli.newStatusCmd()) - cmd.AddCommand(cli.newContextCmd()) - - return cmd -} - func (cli *cliLapi) addContext(key string, values []string) error { cfg := cli.cfg() @@ -511,14 +272,14 @@ func detectStaticField(grokStatics []parser.ExtraField) []string { for _, static := range grokStatics { if static.Parsed != "" { - fieldName := fmt.Sprintf("evt.Parsed.%s", static.Parsed) + fieldName := "evt.Parsed." + static.Parsed if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } } if static.Meta != "" { - fieldName := fmt.Sprintf("evt.Meta.%s", static.Meta) + fieldName := "evt.Meta." + static.Meta if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -544,7 +305,7 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { if node.Grok.RunTimeRegexp != nil { for _, capturedField := range node.Grok.RunTimeRegexp.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -556,7 +317,7 @@ func detectNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { // ignore error (parser does not exist?) if err == nil { for _, capturedField := range grokCompiled.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -591,7 +352,7 @@ func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { for _, subnode := range node.LeavesNodes { if subnode.Grok.RunTimeRegexp != nil { for _, capturedField := range subnode.Grok.RunTimeRegexp.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } @@ -603,7 +364,7 @@ func detectSubNode(node parser.Node, parserCTX parser.UnixParserCtx) []string { if err == nil { // ignore error (parser does not exist?) for _, capturedField := range grokCompiled.Names() { - fieldName := fmt.Sprintf("evt.Parsed.%s", capturedField) + fieldName := "evt.Parsed." + capturedField if !slices.Contains(ret, fieldName) { ret = append(ret, fieldName) } diff --git a/cmd/crowdsec-cli/clilapi/lapi.go b/cmd/crowdsec-cli/clilapi/lapi.go new file mode 100644 index 00000000000..01341330ae8 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/lapi.go @@ -0,0 +1,42 @@ +package clilapi + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/csconfig" +) + +type configGetter = func() *csconfig.Config + +type cliLapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliLapi { + return &cliLapi{ + cfg: cfg, + } +} + +func (cli *cliLapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "lapi [action]", + Short: "Manage interaction with Local API (LAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + if err := cli.cfg().LoadAPIClient(); err != nil { + return fmt.Errorf("loading api client: %w", err) + } + return nil + }, + } + + cmd.AddCommand(cli.newRegisterCmd()) + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newContextCmd()) + + return cmd +} diff --git a/cmd/crowdsec-cli/clilapi/register.go b/cmd/crowdsec-cli/clilapi/register.go new file mode 100644 index 00000000000..4c9b0f39903 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/register.go @@ -0,0 +1,117 @@ +package clilapi + +import ( + "context" + "fmt" + "os" + + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" +) + +func (cli *cliLapi) register(ctx context.Context, apiURL string, outputFile string, machine string, token string) error { + var err error + + lapiUser := machine + cfg := cli.cfg() + + if lapiUser == "" { + lapiUser, err = idgen.GenerateMachineID("") + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + } + + password := strfmt.Password(idgen.GeneratePassword(idgen.PasswordLength)) + + apiurl, err := prepareAPIURL(cfg.API.Client, apiURL) + if err != nil { + return fmt.Errorf("parsing api url: %w", err) + } + + _, err = apiclient.RegisterClient(ctx, &apiclient.Config{ + MachineID: lapiUser, + Password: password, + RegistrationToken: token, + URL: apiurl, + VersionPrefix: LAPIURLPrefix, + }, nil) + if err != nil { + return fmt.Errorf("api client register: %w", err) + } + + log.Printf("Successfully registered to Local API (LAPI)") + + var dumpFile string + + if outputFile != "" { + dumpFile = outputFile + } else if cfg.API.Client.CredentialsFilePath != "" { + dumpFile = cfg.API.Client.CredentialsFilePath + } else { + dumpFile = "" + } + + apiCfg := cfg.API.Client.Credentials + apiCfg.Login = lapiUser + apiCfg.Password = password.String() + + if apiURL != "" { + apiCfg.URL = apiURL + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" { + err = os.WriteFile(dumpFile, apiConfigDump, 0o600) + if err != nil { + return fmt.Errorf("write api credentials to '%s' failed: %w", dumpFile, err) + } + + log.Printf("Local API credentials written to '%s'", dumpFile) + } else { + fmt.Printf("%s\n", string(apiConfigDump)) + } + + log.Warning(reload.Message) + + return nil +} + +func (cli *cliLapi) newRegisterCmd() *cobra.Command { + var ( + apiURL string + outputFile string + machine string + token string + ) + + cmd := &cobra.Command{ + Use: "register", + Short: "Register a machine to Local API (LAPI)", + Long: `Register your machine to the Local API (LAPI). +Keep in mind the machine needs to be validated by an administrator on LAPI side to be effective.`, + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.register(cmd.Context(), apiURL, outputFile, machine, token) + }, + } + + flags := cmd.Flags() + flags.StringVarP(&apiURL, "url", "u", "", "URL of the API (ie. http://127.0.0.1)") + flags.StringVarP(&outputFile, "file", "f", "", "output file destination") + flags.StringVar(&machine, "machine", "", "Name of the machine to register with") + flags.StringVar(&token, "token", "", "Auto registration token to use") + + return cmd +} diff --git a/cmd/crowdsec-cli/clilapi/status.go b/cmd/crowdsec-cli/clilapi/status.go new file mode 100644 index 00000000000..6ff88834602 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/status.go @@ -0,0 +1,115 @@ +package clilapi + +import ( + "context" + "errors" + "fmt" + "io" + "net/url" + "strings" + + "github.com/fatih/color" + "github.com/go-openapi/strfmt" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/models" +) + +const LAPIURLPrefix = "v1" + +// queryLAPIStatus checks if the Local API is reachable, and if the credentials are correct. +func queryLAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login string, password string) (bool, error) { + apiURL, err := url.Parse(credURL) + if err != nil { + return false, err + } + + client, err := apiclient.NewDefaultClient(apiURL, + LAPIURLPrefix, + "", + nil) + if err != nil { + return false, err + } + + pw := strfmt.Password(password) + + itemsForAPI := hub.GetInstalledListForAPI() + + t := models.WatcherAuthRequest{ + MachineID: &login, + Password: &pw, + Scenarios: itemsForAPI, + } + + _, _, err = client.Auth.AuthenticateWatcher(ctx, t) + if err != nil { + return false, err + } + + return true, nil +} + +func (cli *cliLapi) Status(ctx context.Context, out io.Writer, hub *cwhub.Hub) error { + cfg := cli.cfg() + + cred := cfg.API.Client.Credentials + + fmt.Fprintf(out, "Loaded credentials from %s\n", cfg.API.Client.CredentialsFilePath) + fmt.Fprintf(out, "Trying to authenticate with username %s on %s\n", cred.Login, cred.URL) + + _, err := queryLAPIStatus(ctx, hub, cred.URL, cred.Login, cred.Password) + if err != nil { + return fmt.Errorf("failed to authenticate to Local API (LAPI): %w", err) + } + + fmt.Fprintf(out, "You can successfully interact with Local API (LAPI)\n") + + return nil +} + +// prepareAPIURL checks/fixes a LAPI connection url (http, https or socket) and returns an URL struct +func prepareAPIURL(clientCfg *csconfig.LocalApiClientCfg, apiURL string) (*url.URL, error) { + if apiURL == "" { + if clientCfg == nil || clientCfg.Credentials == nil || clientCfg.Credentials.URL == "" { + return nil, errors.New("no Local API URL. Please provide it in your configuration or with the -u parameter") + } + + apiURL = clientCfg.Credentials.URL + } + + // URL needs to end with /, but user doesn't care + if !strings.HasSuffix(apiURL, "/") { + apiURL += "/" + } + + // URL needs to start with http://, but user doesn't care + if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "/") { + apiURL = "http://" + apiURL + } + + return url.Parse(apiURL) +} + +func (cli *cliLapi) newStatusCmd() *cobra.Command { + cmdLapiStatus := &cobra.Command{ + Use: "status", + Short: "Check authentication to Local API (LAPI)", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + hub, err := require.Hub(cli.cfg(), nil, nil) + if err != nil { + return err + } + + return cli.Status(cmd.Context(), color.Output, hub) + }, + } + + return cmdLapiStatus +} diff --git a/cmd/crowdsec-cli/lapi_test.go b/cmd/crowdsec-cli/clilapi/status_test.go similarity index 98% rename from cmd/crowdsec-cli/lapi_test.go rename to cmd/crowdsec-cli/clilapi/status_test.go index 018ecad8118..caf986d847a 100644 --- a/cmd/crowdsec-cli/lapi_test.go +++ b/cmd/crowdsec-cli/clilapi/status_test.go @@ -1,4 +1,4 @@ -package main +package clilapi import ( "testing" diff --git a/cmd/crowdsec-cli/clilapi/utils.go b/cmd/crowdsec-cli/clilapi/utils.go new file mode 100644 index 00000000000..e3ec65f2145 --- /dev/null +++ b/cmd/crowdsec-cli/clilapi/utils.go @@ -0,0 +1,24 @@ +package clilapi + +func removeFromSlice(val string, slice []string) []string { + var i int + var value string + + valueFound := false + + // get the index + for i, value = range slice { + if value == val { + valueFound = true + break + } + } + + if valueFound { + slice[i] = slice[len(slice)-1] + slice[len(slice)-1] = "" + slice = slice[:len(slice)-1] + } + + return slice +} diff --git a/cmd/crowdsec-cli/climachine/add.go b/cmd/crowdsec-cli/climachine/add.go new file mode 100644 index 00000000000..afddb4e4b65 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/add.go @@ -0,0 +1,152 @@ +package climachine + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/AlecAivazis/survey/v2" + "github.com/go-openapi/strfmt" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func (cli *cliMachines) add(ctx context.Context, args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { + var ( + err error + machineID string + ) + + // create machineID if not specified by user + if len(args) == 0 { + if !autoAdd { + return errors.New("please specify a machine name to add, or use --auto") + } + + machineID, err = idgen.GenerateMachineID("") + if err != nil { + return fmt.Errorf("unable to generate machine id: %w", err) + } + } else { + machineID = args[0] + } + + clientCfg := cli.cfg().API.Client + serverCfg := cli.cfg().API.Server + + /*check if file already exists*/ + if dumpFile == "" && clientCfg != nil && clientCfg.CredentialsFilePath != "" { + credFile := clientCfg.CredentialsFilePath + // use the default only if the file does not exist + _, err = os.Stat(credFile) + + switch { + case os.IsNotExist(err) || force: + dumpFile = credFile + case err != nil: + return fmt.Errorf("unable to stat '%s': %w", credFile, err) + default: + return fmt.Errorf(`credentials file '%s' already exists: please remove it, use "--force" or specify a different file with "-f" ("-f -" for standard output)`, credFile) + } + } + + if dumpFile == "" { + return errors.New(`please specify a file to dump credentials to, with -f ("-f -" for standard output)`) + } + + // create a password if it's not specified by user + if machinePassword == "" && !interactive { + if !autoAdd { + return errors.New("please specify a password with --password or use --auto") + } + + machinePassword = idgen.GeneratePassword(idgen.PasswordLength) + } else if machinePassword == "" && interactive { + qs := &survey.Password{ + Message: "Please provide a password for the machine:", + } + survey.AskOne(qs, &machinePassword) + } + + password := strfmt.Password(machinePassword) + + _, err = cli.db.CreateMachine(ctx, &machineID, &password, "", true, force, types.PasswordAuthType) + if err != nil { + return fmt.Errorf("unable to create machine: %w", err) + } + + fmt.Fprintf(os.Stderr, "Machine '%s' successfully added to the local API.\n", machineID) + + if apiURL == "" { + if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" { + apiURL = clientCfg.Credentials.URL + } else if serverCfg.ClientURL() != "" { + apiURL = serverCfg.ClientURL() + } else { + return errors.New("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") + } + } + + apiCfg := csconfig.ApiCredentialsCfg{ + Login: machineID, + Password: password.String(), + URL: apiURL, + } + + apiConfigDump, err := yaml.Marshal(apiCfg) + if err != nil { + return fmt.Errorf("unable to serialize api credentials: %w", err) + } + + if dumpFile != "" && dumpFile != "-" { + if err = os.WriteFile(dumpFile, apiConfigDump, 0o600); err != nil { + return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) + } + + fmt.Fprintf(os.Stderr, "API credentials written to '%s'.\n", dumpFile) + } else { + fmt.Print(string(apiConfigDump)) + } + + return nil +} + +func (cli *cliMachines) newAddCmd() *cobra.Command { + var ( + password MachinePassword + dumpFile string + apiURL string + interactive bool + autoAdd bool + force bool + ) + + cmd := &cobra.Command{ + Use: "add", + Short: "add a single machine to the database", + DisableAutoGenTag: true, + Long: `Register a new machine in the database. cscli should be on the same machine as LAPI.`, + Example: `cscli machines add --auto +cscli machines add MyTestMachine --auto +cscli machines add MyTestMachine --password MyPassword +cscli machines add -f- --auto > /tmp/mycreds.yaml`, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.add(cmd.Context(), args, string(password), dumpFile, apiURL, interactive, autoAdd, force) + }, + } + + flags := cmd.Flags() + flags.VarP(&password, "password", "p", "machine password to login to the API") + flags.StringVarP(&dumpFile, "file", "f", "", "output file destination (defaults to "+csconfig.DefaultConfigPath("local_api_credentials.yaml")+")") + flags.StringVarP(&apiURL, "url", "u", "", "URL of the local API") + flags.BoolVarP(&interactive, "interactive", "i", false, "interfactive mode to enter the password") + flags.BoolVarP(&autoAdd, "auto", "a", false, "automatically generate password (and username if not provided)") + flags.BoolVar(&force, "force", false, "will force add the machine if it already exist") + + return cmd +} diff --git a/cmd/crowdsec-cli/climachine/delete.go b/cmd/crowdsec-cli/climachine/delete.go new file mode 100644 index 00000000000..644ce93c642 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/delete.go @@ -0,0 +1,52 @@ +package climachine + +import ( + "context" + "errors" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +func (cli *cliMachines) delete(ctx context.Context, machines []string, ignoreMissing bool) error { + for _, machineID := range machines { + if err := cli.db.DeleteWatcher(ctx, machineID); err != nil { + var notFoundErr *database.MachineNotFoundError + if ignoreMissing && errors.As(err, ¬FoundErr) { + return nil + } + + log.Errorf("unable to delete machine: %s", err) + + return nil + } + + log.Infof("machine '%s' deleted successfully", machineID) + } + + return nil +} + +func (cli *cliMachines) newDeleteCmd() *cobra.Command { + var ignoreMissing bool + + cmd := &cobra.Command{ + Use: "delete [machine_name]...", + Short: "delete machine(s) by name", + Example: `cscli machines delete "machine1" "machine2"`, + Args: cobra.MinimumNArgs(1), + Aliases: []string{"remove"}, + DisableAutoGenTag: true, + ValidArgsFunction: cli.validMachineID, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.delete(cmd.Context(), args, ignoreMissing) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more machines don't exist") + + return cmd +} diff --git a/cmd/crowdsec-cli/flag.go b/cmd/crowdsec-cli/climachine/flag.go similarity index 96% rename from cmd/crowdsec-cli/flag.go rename to cmd/crowdsec-cli/climachine/flag.go index 1780d08e5f7..c3fefd896e1 100644 --- a/cmd/crowdsec-cli/flag.go +++ b/cmd/crowdsec-cli/climachine/flag.go @@ -1,4 +1,4 @@ -package main +package climachine // Custom types for flag validation and conversion. diff --git a/cmd/crowdsec-cli/climachine/inspect.go b/cmd/crowdsec-cli/climachine/inspect.go new file mode 100644 index 00000000000..b08f2f62794 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/inspect.go @@ -0,0 +1,184 @@ +package climachine + +import ( + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" +) + +func (cli *cliMachines) inspectHubHuman(out io.Writer, machine *ent.Machine) { + state := machine.Hubstate + + if len(state) == 0 { + fmt.Println("No hub items found for this machine") + return + } + + // group state rows by type for multiple tables + rowsByType := make(map[string][]table.Row) + + for itemType, items := range state { + for _, item := range items { + if _, ok := rowsByType[itemType]; !ok { + rowsByType[itemType] = make([]table.Row, 0) + } + + row := table.Row{item.Name, item.Status, item.Version} + rowsByType[itemType] = append(rowsByType[itemType], row) + } + } + + for itemType, rows := range rowsByType { + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "Status", "Version"}) + t.SetTitle(itemType) + t.AppendRows(rows) + io.WriteString(out, t.Render()+"\n") + } +} + +func (cli *cliMachines) inspectHuman(out io.Writer, machine *ent.Machine) { + t := cstable.New(out, cli.cfg().Cscli.Color).Writer + + t.SetTitle("Machine: " + machine.MachineId) + + t.SetColumnConfigs([]table.ColumnConfig{ + {Number: 1, AutoMerge: true}, + }) + + t.AppendRows([]table.Row{ + {"IP Address", machine.IpAddress}, + {"Created At", machine.CreatedAt}, + {"Last Update", machine.UpdatedAt}, + {"Last Heartbeat", machine.LastHeartbeat}, + {"Validated?", machine.IsValidated}, + {"CrowdSec version", machine.Version}, + {"OS", clientinfo.GetOSNameAndVersion(machine)}, + {"Auth type", machine.AuthType}, + }) + + for dsName, dsCount := range machine.Datasources { + t.AppendRow(table.Row{"Datasources", fmt.Sprintf("%s: %d", dsName, dsCount)}) + } + + for _, ff := range clientinfo.GetFeatureFlagList(machine) { + t.AppendRow(table.Row{"Feature Flags", ff}) + } + + for _, coll := range machine.Hubstate[cwhub.COLLECTIONS] { + t.AppendRow(table.Row{"Collections", coll.Name}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliMachines) inspect(machine *ent.Machine) error { + out := color.Output + outputFormat := cli.cfg().Cscli.Output + + switch outputFormat { + case "human": + cli.inspectHuman(out, machine) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(newMachineInfo(machine)); err != nil { + return errors.New("failed to serialize") + } + + return nil + default: + return fmt.Errorf("output format '%s' not supported for this command", outputFormat) + } + + return nil +} + +func (cli *cliMachines) inspectHub(machine *ent.Machine) error { + out := color.Output + + switch cli.cfg().Cscli.Output { + case "human": + cli.inspectHubHuman(out, machine) + case "json": + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(machine.Hubstate); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + csvwriter := csv.NewWriter(out) + + err := csvwriter.Write([]string{"type", "name", "status", "version"}) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + rows := make([][]string, 0) + + for itemType, items := range machine.Hubstate { + for _, item := range items { + rows = append(rows, []string{itemType, item.Name, item.Status, item.Version}) + } + } + + for _, row := range rows { + if err := csvwriter.Write(row); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + + csvwriter.Flush() + } + + return nil +} + +func (cli *cliMachines) newInspectCmd() *cobra.Command { + var showHub bool + + cmd := &cobra.Command{ + Use: "inspect [machine_name]", + Short: "inspect a machine by name", + Example: `cscli machines inspect "machine1"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + ValidArgsFunction: cli.validMachineID, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + machineID := args[0] + + machine, err := cli.db.QueryMachineByID(ctx, machineID) + if err != nil { + return fmt.Errorf("unable to read machine data '%s': %w", machineID, err) + } + + if showHub { + return cli.inspectHub(machine) + } + + return cli.inspect(machine) + }, + } + + flags := cmd.Flags() + + flags.BoolVarP(&showHub, "hub", "H", false, "show hub state") + + return cmd +} diff --git a/cmd/crowdsec-cli/climachine/list.go b/cmd/crowdsec-cli/climachine/list.go new file mode 100644 index 00000000000..6bedb2ad807 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/list.go @@ -0,0 +1,137 @@ +package climachine + +import ( + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "time" + + "github.com/fatih/color" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/emoji" +) + +// getLastHeartbeat returns the last heartbeat timestamp of a machine +// and a boolean indicating if the machine is considered active or not. +func getLastHeartbeat(m *ent.Machine) (string, bool) { + if m.LastHeartbeat == nil { + return "-", false + } + + elapsed := time.Now().UTC().Sub(*m.LastHeartbeat) + + hb := elapsed.Truncate(time.Second).String() + if elapsed > 2*time.Minute { + return hb, false + } + + return hb, true +} + +func (cli *cliMachines) listHuman(out io.Writer, machines ent.Machines) { + t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer + t.AppendHeader(table.Row{"Name", "IP Address", "Last Update", "Status", "Version", "OS", "Auth Type", "Last Heartbeat"}) + + for _, m := range machines { + validated := emoji.Prohibited + if m.IsValidated { + validated = emoji.CheckMark + } + + hb, active := getLastHeartbeat(m) + if !active { + hb = emoji.Warning + " " + hb + } + + t.AppendRow(table.Row{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, clientinfo.GetOSNameAndVersion(m), m.AuthType, hb}) + } + + io.WriteString(out, t.Render()+"\n") +} + +func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error { + csvwriter := csv.NewWriter(out) + + err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat", "os"}) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + for _, m := range machines { + validated := "false" + if m.IsValidated { + validated = "true" + } + + hb := "-" + if m.LastHeartbeat != nil { + hb = m.LastHeartbeat.Format(time.RFC3339) + } + + if err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb, fmt.Sprintf("%s/%s", m.Osname, m.Osversion)}); err != nil { + return fmt.Errorf("failed to write raw output: %w", err) + } + } + + csvwriter.Flush() + + return nil +} + +func (cli *cliMachines) List(ctx context.Context, out io.Writer, db *database.Client) error { + // XXX: must use the provided db object, the one in the struct might be nil + // (calling List directly skips the PersistentPreRunE) + + machines, err := db.ListMachines(ctx) + if err != nil { + return fmt.Errorf("unable to list machines: %w", err) + } + + switch cli.cfg().Cscli.Output { + case "human": + cli.listHuman(out, machines) + case "json": + info := make([]machineInfo, 0, len(machines)) + for _, m := range machines { + info = append(info, newMachineInfo(m)) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + + if err := enc.Encode(info); err != nil { + return errors.New("failed to serialize") + } + + return nil + case "raw": + return cli.listCSV(out, machines) + } + + return nil +} + +func (cli *cliMachines) newListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "list all machines in the database", + Long: `list all machines in the database with their status and last heartbeat`, + Example: `cscli machines list`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.List(cmd.Context(), color.Output, cli.db) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/climachine/machines.go b/cmd/crowdsec-cli/climachine/machines.go new file mode 100644 index 00000000000..ad503c6e936 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/machines.go @@ -0,0 +1,132 @@ +package climachine + +import ( + "slices" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clientinfo" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" +) + +type configGetter = func() *csconfig.Config + +type cliMachines struct { + db *database.Client + cfg configGetter +} + +func New(cfg configGetter) *cliMachines { + return &cliMachines{ + cfg: cfg, + } +} + +func (cli *cliMachines) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "machines [action]", + Short: "Manage local API machines [requires local API]", + Long: `To list/add/delete/validate/prune machines. +Note: This command requires database direct access, so is intended to be run on the local API machine. +`, + Example: `cscli machines [action]`, + DisableAutoGenTag: true, + Aliases: []string{"machine"}, + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + var err error + if err = require.LAPI(cli.cfg()); err != nil { + return err + } + cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig) + if err != nil { + return err + } + + return nil + }, + } + + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newAddCmd()) + cmd.AddCommand(cli.newDeleteCmd()) + cmd.AddCommand(cli.newValidateCmd()) + cmd.AddCommand(cli.newPruneCmd()) + cmd.AddCommand(cli.newInspectCmd()) + + return cmd +} + +// machineInfo contains only the data we want for inspect/list: no hub status, scenarios, edges, etc. +type machineInfo struct { + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + LastPush *time.Time `json:"last_push,omitempty"` + LastHeartbeat *time.Time `json:"last_heartbeat,omitempty"` + MachineId string `json:"machineId,omitempty"` + IpAddress string `json:"ipAddress,omitempty"` + Version string `json:"version,omitempty"` + IsValidated bool `json:"isValidated,omitempty"` + AuthType string `json:"auth_type"` + OS string `json:"os,omitempty"` + Featureflags []string `json:"featureflags,omitempty"` + Datasources map[string]int64 `json:"datasources,omitempty"` +} + +func newMachineInfo(m *ent.Machine) machineInfo { + return machineInfo{ + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + LastPush: m.LastPush, + LastHeartbeat: m.LastHeartbeat, + MachineId: m.MachineId, + IpAddress: m.IpAddress, + Version: m.Version, + IsValidated: m.IsValidated, + AuthType: m.AuthType, + OS: clientinfo.GetOSNameAndVersion(m), + Featureflags: clientinfo.GetFeatureFlagList(m), + Datasources: m.Datasources, + } +} + +// validMachineID returns a list of machine IDs for command completion +func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + var err error + + cfg := cli.cfg() + ctx := cmd.Context() + + // need to load config and db because PersistentPreRunE is not called for completions + + if err = require.LAPI(cfg); err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + cli.db, err = require.DBClient(ctx, cfg.DbConfig) + if err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + machines, err := cli.db.ListMachines(ctx) + if err != nil { + cobra.CompError("unable to list machines " + err.Error()) + return nil, cobra.ShellCompDirectiveNoFileComp + } + + ret := []string{} + + for _, machine := range machines { + if strings.Contains(machine.MachineId, toComplete) && !slices.Contains(args, machine.MachineId) { + ret = append(ret, machine.MachineId) + } + } + + return ret, cobra.ShellCompDirectiveNoFileComp +} diff --git a/cmd/crowdsec-cli/climachine/prune.go b/cmd/crowdsec-cli/climachine/prune.go new file mode 100644 index 00000000000..ed41ef0a736 --- /dev/null +++ b/cmd/crowdsec-cli/climachine/prune.go @@ -0,0 +1,96 @@ +package climachine + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/fatih/color" + "github.com/spf13/cobra" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/ask" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" +) + +func (cli *cliMachines) prune(ctx context.Context, duration time.Duration, notValidOnly bool, force bool) error { + if duration < 2*time.Minute && !notValidOnly { + if yes, err := ask.YesNo( + "The duration you provided is less than 2 minutes. "+ + "This can break installations if the machines are only temporarily disconnected. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + machines := []*ent.Machine{} + if pending, err := cli.db.QueryPendingMachine(ctx); err == nil { + machines = append(machines, pending...) + } + + if !notValidOnly { + if pending, err := cli.db.QueryMachinesInactiveSince(ctx, time.Now().UTC().Add(-duration)); err == nil { + machines = append(machines, pending...) + } + } + + if len(machines) == 0 { + fmt.Println("No machines to prune.") + return nil + } + + cli.listHuman(color.Output, machines) + + if !force { + if yes, err := ask.YesNo( + "You are about to PERMANENTLY remove the above machines from the database. "+ + "These will NOT be recoverable. Continue?", false); err != nil { + return err + } else if !yes { + fmt.Println("User aborted prune. No changes were made.") + return nil + } + } + + deleted, err := cli.db.BulkDeleteWatchers(ctx, machines) + if err != nil { + return fmt.Errorf("unable to prune machines: %w", err) + } + + fmt.Fprintf(os.Stderr, "successfully deleted %d machines\n", deleted) + + return nil +} + +func (cli *cliMachines) newPruneCmd() *cobra.Command { + var ( + duration time.Duration + notValidOnly bool + force bool + ) + + const defaultDuration = 10 * time.Minute + + cmd := &cobra.Command{ + Use: "prune", + Short: "prune multiple machines from the database", + Long: `prune multiple machines that are not validated or have not connected to the local API in a given duration.`, + Example: `cscli machines prune +cscli machines prune --duration 1h +cscli machines prune --not-validated-only --force`, + Args: cobra.NoArgs, + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cli.prune(cmd.Context(), duration, notValidOnly, force) + }, + } + + flags := cmd.Flags() + flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since validated machine last heartbeat") + flags.BoolVar(¬ValidOnly, "not-validated-only", false, "only prune machines that are not validated") + flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") + + return cmd +} diff --git a/cmd/crowdsec-cli/climachine/validate.go b/cmd/crowdsec-cli/climachine/validate.go new file mode 100644 index 00000000000..cba872aa05d --- /dev/null +++ b/cmd/crowdsec-cli/climachine/validate.go @@ -0,0 +1,35 @@ +package climachine + +import ( + "context" + "fmt" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +func (cli *cliMachines) validate(ctx context.Context, machineID string) error { + if err := cli.db.ValidateMachine(ctx, machineID); err != nil { + return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) + } + + log.Infof("machine '%s' validated successfully", machineID) + + return nil +} + +func (cli *cliMachines) newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate", + Short: "validate a machine to access the local API", + Long: `validate a machine to access the local API.`, + Example: `cscli machines validate "machine_name"`, + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(cmd.Context(), args[0]) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/climetrics/list.go b/cmd/crowdsec-cli/climetrics/list.go index d3afbef0669..27fa99710c8 100644 --- a/cmd/crowdsec-cli/climetrics/list.go +++ b/cmd/crowdsec-cli/climetrics/list.go @@ -64,11 +64,11 @@ func (cli *cliMetrics) list() error { t.AppendRow(table.Row{metric.Type, metric.Title, metric.Description}) } - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, t.Render()+"\n") case "json": x, err := json.MarshalIndent(allMetrics, "", " ") if err != nil { - return fmt.Errorf("failed to marshal metric types: %w", err) + return fmt.Errorf("failed to serialize metric types: %w", err) } fmt.Println(string(x)) @@ -84,7 +84,7 @@ func (cli *cliMetrics) newListCmd() *cobra.Command { Use: "list", Short: "List available types of metrics.", Long: `List available types of metrics.`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { return cli.list() diff --git a/cmd/crowdsec-cli/climetrics/metrics.go b/cmd/crowdsec-cli/climetrics/metrics.go index f3bc4874460..67bd7b6ad93 100644 --- a/cmd/crowdsec-cli/climetrics/metrics.go +++ b/cmd/crowdsec-cli/climetrics/metrics.go @@ -36,7 +36,7 @@ cscli metrics --url http://lapi.local:6060/metrics show acquisition parsers # List available metric types cscli metrics list`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(cmd *cobra.Command, _ []string) error { return cli.show(cmd.Context(), nil, url, noUnit) diff --git a/cmd/crowdsec-cli/climetrics/show.go b/cmd/crowdsec-cli/climetrics/show.go index 7559463b66b..045959048f6 100644 --- a/cmd/crowdsec-cli/climetrics/show.go +++ b/cmd/crowdsec-cli/climetrics/show.go @@ -5,9 +5,8 @@ import ( "errors" "fmt" - log "github.com/sirupsen/logrus" - "github.com/fatih/color" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" diff --git a/cmd/crowdsec-cli/climetrics/statacquis.go b/cmd/crowdsec-cli/climetrics/statacquis.go index 827dcf036c3..0af2e796f40 100644 --- a/cmd/crowdsec-cli/climetrics/statacquis.go +++ b/cmd/crowdsec-cli/climetrics/statacquis.go @@ -37,8 +37,8 @@ func (s statAcquis) Table(out io.Writer, wantColor string, noUnit bool, showEmpt log.Warningf("while collecting acquis stats: %s", err) } else if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statalert.go b/cmd/crowdsec-cli/climetrics/statalert.go index e48dd6c924f..942eceaa75c 100644 --- a/cmd/crowdsec-cli/climetrics/statalert.go +++ b/cmd/crowdsec-cli/climetrics/statalert.go @@ -38,8 +38,8 @@ func (s statAlert) Table(out io.Writer, wantColor string, noUnit bool, showEmpty if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statappsecengine.go b/cmd/crowdsec-cli/climetrics/statappsecengine.go index 4a249e11687..d924375247f 100644 --- a/cmd/crowdsec-cli/climetrics/statappsecengine.go +++ b/cmd/crowdsec-cli/climetrics/statappsecengine.go @@ -34,8 +34,8 @@ func (s statAppsecEngine) Table(out io.Writer, wantColor string, noUnit bool, sh log.Warningf("while collecting appsec stats: %s", err) } else if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statappsecrule.go b/cmd/crowdsec-cli/climetrics/statappsecrule.go index 2f859d70cfb..e06a7c2e2b3 100644 --- a/cmd/crowdsec-cli/climetrics/statappsecrule.go +++ b/cmd/crowdsec-cli/climetrics/statappsecrule.go @@ -41,7 +41,7 @@ func (s statAppsecRule) Table(out io.Writer, wantColor string, noUnit bool, show log.Warningf("while collecting appsec rules stats: %s", err) } else if numRows > 0 || showEmpty { io.WriteString(out, fmt.Sprintf("Appsec '%s' Rules Metrics:\n", appsecEngine)) - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statbouncer.go b/cmd/crowdsec-cli/climetrics/statbouncer.go index 62e68b6bc41..bc0da152d6d 100644 --- a/cmd/crowdsec-cli/climetrics/statbouncer.go +++ b/cmd/crowdsec-cli/climetrics/statbouncer.go @@ -129,7 +129,7 @@ func (*statBouncer) Description() (string, string) { func logWarningOnce(warningsLogged map[string]bool, msg string) { if _, ok := warningsLogged[msg]; !ok { - log.Warningf(msg) + log.Warning(msg) warningsLogged[msg] = true } diff --git a/cmd/crowdsec-cli/climetrics/statbucket.go b/cmd/crowdsec-cli/climetrics/statbucket.go index 507d9f3a476..1882fe21df1 100644 --- a/cmd/crowdsec-cli/climetrics/statbucket.go +++ b/cmd/crowdsec-cli/climetrics/statbucket.go @@ -35,8 +35,8 @@ func (s statBucket) Table(out io.Writer, wantColor string, noUnit bool, showEmpt log.Warningf("while collecting scenario stats: %s", err) } else if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statdecision.go b/cmd/crowdsec-cli/climetrics/statdecision.go index 145665cfba2..b862f49ff12 100644 --- a/cmd/crowdsec-cli/climetrics/statdecision.go +++ b/cmd/crowdsec-cli/climetrics/statdecision.go @@ -53,8 +53,8 @@ func (s statDecision) Table(out io.Writer, wantColor string, noUnit bool, showEm if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statlapi.go b/cmd/crowdsec-cli/climetrics/statlapi.go index 45b384708bf..9559eacf0f4 100644 --- a/cmd/crowdsec-cli/climetrics/statlapi.go +++ b/cmd/crowdsec-cli/climetrics/statlapi.go @@ -49,8 +49,8 @@ func (s statLapi) Table(out io.Writer, wantColor string, noUnit bool, showEmpty if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statlapibouncer.go b/cmd/crowdsec-cli/climetrics/statlapibouncer.go index 828ccb33413..5e5f63a79d3 100644 --- a/cmd/crowdsec-cli/climetrics/statlapibouncer.go +++ b/cmd/crowdsec-cli/climetrics/statlapibouncer.go @@ -35,8 +35,8 @@ func (s statLapiBouncer) Table(out io.Writer, wantColor string, noUnit bool, sho if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statlapidecision.go b/cmd/crowdsec-cli/climetrics/statlapidecision.go index ffc999555c1..44f0e8f4b87 100644 --- a/cmd/crowdsec-cli/climetrics/statlapidecision.go +++ b/cmd/crowdsec-cli/climetrics/statlapidecision.go @@ -57,8 +57,8 @@ func (s statLapiDecision) Table(out io.Writer, wantColor string, noUnit bool, sh if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statlapimachine.go b/cmd/crowdsec-cli/climetrics/statlapimachine.go index 09abe2dd44b..0e6693bea82 100644 --- a/cmd/crowdsec-cli/climetrics/statlapimachine.go +++ b/cmd/crowdsec-cli/climetrics/statlapimachine.go @@ -35,8 +35,8 @@ func (s statLapiMachine) Table(out io.Writer, wantColor string, noUnit bool, sho if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statparser.go b/cmd/crowdsec-cli/climetrics/statparser.go index 0b3512052b9..520e68f9adf 100644 --- a/cmd/crowdsec-cli/climetrics/statparser.go +++ b/cmd/crowdsec-cli/climetrics/statparser.go @@ -36,8 +36,8 @@ func (s statParser) Table(out io.Writer, wantColor string, noUnit bool, showEmpt log.Warningf("while collecting parsers stats: %s", err) } else if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statstash.go b/cmd/crowdsec-cli/climetrics/statstash.go index 5938ac05fc8..2729de931a1 100644 --- a/cmd/crowdsec-cli/climetrics/statstash.go +++ b/cmd/crowdsec-cli/climetrics/statstash.go @@ -52,8 +52,8 @@ func (s statStash) Table(out io.Writer, wantColor string, noUnit bool, showEmpty if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/statwhitelist.go b/cmd/crowdsec-cli/climetrics/statwhitelist.go index ccb7e52153b..7f533b45b4b 100644 --- a/cmd/crowdsec-cli/climetrics/statwhitelist.go +++ b/cmd/crowdsec-cli/climetrics/statwhitelist.go @@ -36,8 +36,8 @@ func (s statWhitelist) Table(out io.Writer, wantColor string, noUnit bool, showE log.Warningf("while collecting parsers stats: %s", err) } else if numRows > 0 || showEmpty { title, _ := s.Description() - io.WriteString(out, title + ":\n") - io.WriteString(out, t.Render() + "\n") + io.WriteString(out, title+":\n") + io.WriteString(out, t.Render()+"\n") io.WriteString(out, "\n") } } diff --git a/cmd/crowdsec-cli/climetrics/store.go b/cmd/crowdsec-cli/climetrics/store.go index 5de50558e89..55fab5dbd7f 100644 --- a/cmd/crowdsec-cli/climetrics/store.go +++ b/cmd/crowdsec-cli/climetrics/store.go @@ -260,7 +260,7 @@ func (ms metricStore) Format(out io.Writer, wantColor string, sections []string, case "json": x, err := json.MarshalIndent(want, "", " ") if err != nil { - return fmt.Errorf("failed to marshal metrics: %w", err) + return fmt.Errorf("failed to serialize metrics: %w", err) } out.Write(x) default: diff --git a/cmd/crowdsec-cli/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go similarity index 89% rename from cmd/crowdsec-cli/notifications.go rename to cmd/crowdsec-cli/clinotifications/notifications.go index 8c6b6631b33..baf899c10cf 100644 --- a/cmd/crowdsec-cli/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -1,4 +1,4 @@ -package main +package clinotifications import ( "context" @@ -29,7 +29,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csprofiles" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -40,11 +39,13 @@ type NotificationsCfg struct { ids []uint } +type configGetter func() *csconfig.Config + type cliNotifications struct { cfg configGetter } -func NewCLINotifications(cfg configGetter) *cliNotifications { +func New(cfg configGetter) *cliNotifications { return &cliNotifications{ cfg: cfg, } @@ -71,10 +72,10 @@ func (cli *cliNotifications) NewCommand() *cobra.Command { }, } - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewInspectCmd()) - cmd.AddCommand(cli.NewReinjectCmd()) - cmd.AddCommand(cli.NewTestCmd()) + cmd.AddCommand(cli.newListCmd()) + cmd.AddCommand(cli.newInspectCmd()) + cmd.AddCommand(cli.newReinjectCmd()) + cmd.AddCommand(cli.newTestCmd()) return cmd } @@ -151,13 +152,13 @@ func (cli *cliNotifications) getProfilesConfigs() (map[string]NotificationsCfg, return ncfgs, nil } -func (cli *cliNotifications) NewListCmd() *cobra.Command { +func (cli *cliNotifications) newListCmd() *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "list notifications plugins", Long: `list notifications plugins and their status (active or not)`, Example: `cscli notifications list`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { cfg := cli.cfg() @@ -171,7 +172,7 @@ func (cli *cliNotifications) NewListCmd() *cobra.Command { } else if cfg.Cscli.Output == "json" { x, err := json.MarshalIndent(ncfgs, "", " ") if err != nil { - return fmt.Errorf("failed to marshal notification configuration: %w", err) + return fmt.Errorf("failed to serialize notification configuration: %w", err) } fmt.Printf("%s", string(x)) } else if cfg.Cscli.Output == "raw" { @@ -200,7 +201,7 @@ func (cli *cliNotifications) NewListCmd() *cobra.Command { return cmd } -func (cli *cliNotifications) NewInspectCmd() *cobra.Command { +func (cli *cliNotifications) newInspectCmd() *cobra.Command { cmd := &cobra.Command{ Use: "inspect", Short: "Inspect notifications plugin", @@ -230,7 +231,7 @@ func (cli *cliNotifications) NewInspectCmd() *cobra.Command { } else if cfg.Cscli.Output == "json" { x, err := json.MarshalIndent(cfg, "", " ") if err != nil { - return fmt.Errorf("failed to marshal notification configuration: %w", err) + return fmt.Errorf("failed to serialize notification configuration: %w", err) } fmt.Printf("%s", string(x)) } @@ -259,7 +260,7 @@ func (cli *cliNotifications) notificationConfigFilter(cmd *cobra.Command, args [ return ret, cobra.ShellCompDirectiveNoFileComp } -func (cli cliNotifications) NewTestCmd() *cobra.Command { +func (cli cliNotifications) newTestCmd() *cobra.Command { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb @@ -274,7 +275,8 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { Args: cobra.ExactArgs(1), DisableAutoGenTag: true, ValidArgsFunction: cli.notificationConfigFilter, - PreRunE: func(_ *cobra.Command, args []string) error { + PreRunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() cfg := cli.cfg() pconfigs, err := cli.getPluginConfigs() if err != nil { @@ -285,7 +287,7 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { return fmt.Errorf("plugin name: '%s' does not exist", args[0]) } // Create a single profile with plugin name as notification name - return pluginBroker.Init(cfg.PluginConfig, []*csconfig.ProfileCfg{ + return pluginBroker.Init(ctx, cfg.PluginConfig, []*csconfig.ProfileCfg{ { Notifications: []string{ pcfg.Name, @@ -330,7 +332,7 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { CreatedAt: time.Now().UTC().Format(time.RFC3339), } if err := yaml.Unmarshal([]byte(alertOverride), alert); err != nil { - return fmt.Errorf("failed to unmarshal alert override: %w", err) + return fmt.Errorf("failed to parse alert override: %w", err) } pluginBroker.PluginChannel <- csplugin.ProfileAlert{ @@ -350,7 +352,7 @@ func (cli cliNotifications) NewTestCmd() *cobra.Command { return cmd } -func (cli *cliNotifications) NewReinjectCmd() *cobra.Command { +func (cli *cliNotifications) newReinjectCmd() *cobra.Command { var ( alertOverride string alert *models.Alert @@ -367,30 +369,31 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not `, Args: cobra.ExactArgs(1), DisableAutoGenTag: true, - PreRunE: func(_ *cobra.Command, args []string) error { + PreRunE: func(cmd *cobra.Command, args []string) error { var err error - alert, err = cli.fetchAlertFromArgString(args[0]) + alert, err = cli.fetchAlertFromArgString(cmd.Context(), args[0]) if err != nil { return err } return nil }, - RunE: func(_ *cobra.Command, _ []string) error { + RunE: func(cmd *cobra.Command, _ []string) error { var ( pluginBroker csplugin.PluginBroker pluginTomb tomb.Tomb ) + ctx := cmd.Context() cfg := cli.cfg() if alertOverride != "" { if err := json.Unmarshal([]byte(alertOverride), alert); err != nil { - return fmt.Errorf("can't unmarshal data in the alert flag: %w", err) + return fmt.Errorf("can't parse data in the alert flag: %w", err) } } - err := pluginBroker.Init(cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) + err := pluginBroker.Init(ctx, cfg.PluginConfig, cfg.API.Server.Profiles, cfg.ConfigPaths) if err != nil { return fmt.Errorf("can't initialize plugins: %w", err) } @@ -446,7 +449,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return cmd } -func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Alert, error) { +func (cli *cliNotifications) fetchAlertFromArgString(ctx context.Context, toParse string) (*models.Alert, error) { cfg := cli.cfg() id, err := strconv.Atoi(toParse) @@ -462,7 +465,6 @@ func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Al client, err := apiclient.NewClient(&apiclient.Config{ MachineID: cfg.API.Client.Credentials.Login, Password: strfmt.Password(cfg.API.Client.Credentials.Password), - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -470,7 +472,7 @@ func (cli *cliNotifications) fetchAlertFromArgString(toParse string) (*models.Al return nil, fmt.Errorf("error creating the client for the API: %w", err) } - alert, _, err := client.Alerts.GetByID(context.Background(), id) + alert, _, err := client.Alerts.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("can't find alert with id %d: %w", id, err) } diff --git a/cmd/crowdsec-cli/notifications_table.go b/cmd/crowdsec-cli/clinotifications/notifications_table.go similarity index 97% rename from cmd/crowdsec-cli/notifications_table.go rename to cmd/crowdsec-cli/clinotifications/notifications_table.go index 2976797bd8a..0b6a3f58efc 100644 --- a/cmd/crowdsec-cli/notifications_table.go +++ b/cmd/crowdsec-cli/clinotifications/notifications_table.go @@ -1,4 +1,4 @@ -package main +package clinotifications import ( "io" diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go new file mode 100644 index 00000000000..7ac2455d28f --- /dev/null +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -0,0 +1,174 @@ +package clipapi + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/fatih/color" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/tomb.v2" + + "github.com/crowdsecurity/go-cs-lib/ptr" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/apiserver" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/database" +) + +type configGetter = func() *csconfig.Config + +type cliPapi struct { + cfg configGetter +} + +func New(cfg configGetter) *cliPapi { + return &cliPapi{ + cfg: cfg, + } +} + +func (cli *cliPapi) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "papi [action]", + Short: "Manage interaction with Polling API (PAPI)", + Args: cobra.MinimumNArgs(1), + DisableAutoGenTag: true, + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { + cfg := cli.cfg() + if err := require.LAPI(cfg); err != nil { + return err + } + if err := require.CAPI(cfg); err != nil { + return err + } + + return require.PAPI(cfg) + }, + } + + cmd.AddCommand(cli.newStatusCmd()) + cmd.AddCommand(cli.newSyncCmd()) + + return cmd +} + +func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Client) error { + cfg := cli.cfg() + + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + if err != nil { + return fmt.Errorf("unable to initialize API client: %w", err) + } + + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) + if err != nil { + return fmt.Errorf("unable to initialize PAPI client: %w", err) + } + + perms, err := papi.GetPermissions(ctx) + if err != nil { + return fmt.Errorf("unable to get PAPI permissions: %w", err) + } + + lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) + if err != nil { + lastTimestampStr = ptr.Of("never") + } + + // both can and did happen + if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" { + lastTimestampStr = ptr.Of("never") + } + + fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n") + fmt.Fprintf(out, "Console plan: %s\n", perms.Plan) + fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr) + fmt.Fprint(out, "PAPI subscriptions:\n") + + for _, sub := range perms.Categories { + fmt.Fprintf(out, " - %s\n", sub) + } + + return nil +} + +func (cli *cliPapi) newStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status", + Short: "Get status of the Polling API", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + + return cli.Status(ctx, color.Output, db) + }, + } + + return cmd +} + +func (cli *cliPapi) sync(ctx context.Context, out io.Writer, db *database.Client) error { + cfg := cli.cfg() + t := tomb.Tomb{} + + apic, err := apiserver.NewAPIC(ctx, cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) + if err != nil { + return fmt.Errorf("unable to initialize API client: %w", err) + } + + t.Go(func() error { return apic.Push(ctx) }) + + papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) + if err != nil { + return fmt.Errorf("unable to initialize PAPI client: %w", err) + } + + t.Go(papi.SyncDecisions) + + err = papi.PullOnce(ctx, time.Time{}, true) + if err != nil { + return fmt.Errorf("unable to sync decisions: %w", err) + } + + log.Infof("Sending acknowledgements to CAPI") + + apic.Shutdown() + papi.Shutdown() + t.Wait() + time.Sleep(5 * time.Second) // FIXME: the push done by apic.Push is run inside a sub goroutine, sleep to make sure it's done + + return nil +} + +func (cli *cliPapi) newSyncCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "sync", + Short: "Sync with the Polling API, pulling all non-expired orders for the instance", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg := cli.cfg() + ctx := cmd.Context() + + db, err := require.DBClient(ctx, cfg.DbConfig) + if err != nil { + return err + } + + return cli.sync(ctx, color.Output, db) + }, + } + + return cmd +} diff --git a/cmd/crowdsec-cli/clisetup/setup.go b/cmd/crowdsec-cli/clisetup/setup.go new file mode 100644 index 00000000000..269cdfb78e9 --- /dev/null +++ b/cmd/crowdsec-cli/clisetup/setup.go @@ -0,0 +1,307 @@ +package clisetup + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "os/exec" + + goccyyaml "github.com/goccy/go-yaml" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/setup" +) + +type configGetter func() *csconfig.Config + +type cliSetup struct { + cfg configGetter +} + +func New(cfg configGetter) *cliSetup { + return &cliSetup{ + cfg: cfg, + } +} + +func (cli *cliSetup) NewCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "setup", + Short: "Tools to configure crowdsec", + Long: "Manage hub configuration and service detection", + Args: cobra.MinimumNArgs(0), + DisableAutoGenTag: true, + } + + cmd.AddCommand(cli.newDetectCmd()) + cmd.AddCommand(cli.newInstallHubCmd()) + cmd.AddCommand(cli.newDataSourcesCmd()) + cmd.AddCommand(cli.newValidateCmd()) + + return cmd +} + +type detectFlags struct { + detectConfigFile string + listSupportedServices bool + forcedUnits []string + forcedProcesses []string + forcedOSFamily string + forcedOSID string + forcedOSVersion string + skipServices []string + snubSystemd bool + outYaml bool +} + +func (f *detectFlags) bind(cmd *cobra.Command) { + defaultServiceDetect := csconfig.DefaultConfigPath("hub", "detect.yaml") + + flags := cmd.Flags() + flags.StringVar(&f.detectConfigFile, "detect-config", defaultServiceDetect, "path to service detection configuration") + flags.BoolVar(&f.listSupportedServices, "list-supported-services", false, "do not detect; only print supported services") + flags.StringSliceVar(&f.forcedUnits, "force-unit", nil, "force detection of a systemd unit (can be repeated)") + flags.StringSliceVar(&f.forcedProcesses, "force-process", nil, "force detection of a running process (can be repeated)") + flags.StringSliceVar(&f.skipServices, "skip-service", nil, "ignore a service, don't recommend hub/datasources (can be repeated)") + flags.StringVar(&f.forcedOSFamily, "force-os-family", "", "override OS.Family: one of linux, freebsd, windows or darwin") + flags.StringVar(&f.forcedOSID, "force-os-id", "", "override OS.ID=[debian | ubuntu | , redhat...]") + flags.StringVar(&f.forcedOSVersion, "force-os-version", "", "override OS.RawVersion (of OS or Linux distribution)") + flags.BoolVar(&f.snubSystemd, "snub-systemd", false, "don't use systemd, even if available") + flags.BoolVar(&f.outYaml, "yaml", false, "output yaml, not json") +} + +func (cli *cliSetup) newDetectCmd() *cobra.Command { + f := detectFlags{} + + cmd := &cobra.Command{ + Use: "detect", + Short: "detect running services, generate a setup file", + DisableAutoGenTag: true, + RunE: func(_ *cobra.Command, args []string) error { + return cli.detect(f) + }, + } + + f.bind(cmd) + + return cmd +} + +func (cli *cliSetup) newInstallHubCmd() *cobra.Command { + var dryRun bool + + cmd := &cobra.Command{ + Use: "install-hub [setup_file] [flags]", + Short: "install items from a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.install(cmd.Context(), dryRun, args[0]) + }, + } + + flags := cmd.Flags() + flags.BoolVar(&dryRun, "dry-run", false, "don't install anything; print out what would have been") + + return cmd +} + +func (cli *cliSetup) newDataSourcesCmd() *cobra.Command { + var toDir string + + cmd := &cobra.Command{ + Use: "datasources [setup_file] [flags]", + Short: "generate datasource (acquisition) configuration from a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.dataSources(args[0], toDir) + }, + } + + flags := cmd.Flags() + flags.StringVar(&toDir, "to-dir", "", "write the configuration to a directory, in multiple files") + + return cmd +} + +func (cli *cliSetup) newValidateCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "validate [setup_file]", + Short: "validate a setup file", + Args: cobra.ExactArgs(1), + DisableAutoGenTag: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cli.validate(args[0]) + }, + } + + return cmd +} + +func (cli *cliSetup) detect(f detectFlags) error { + var ( + detectReader *os.File + err error + ) + + switch f.detectConfigFile { + case "-": + log.Tracef("Reading detection rules from stdin") + + detectReader = os.Stdin + default: + log.Tracef("Reading detection rules: %s", f.detectConfigFile) + + detectReader, err = os.Open(f.detectConfigFile) + if err != nil { + return err + } + } + + if !f.snubSystemd { + _, err = exec.LookPath("systemctl") + if err != nil { + log.Debug("systemctl not available: snubbing systemd") + + f.snubSystemd = true + } + } + + if f.forcedOSFamily == "" && f.forcedOSID != "" { + log.Debug("force-os-id is set: force-os-family defaults to 'linux'") + + f.forcedOSFamily = "linux" + } + + if f.listSupportedServices { + supported, err := setup.ListSupported(detectReader) + if err != nil { + return err + } + + for _, svc := range supported { + fmt.Println(svc) + } + + return nil + } + + opts := setup.DetectOptions{ + ForcedUnits: f.forcedUnits, + ForcedProcesses: f.forcedProcesses, + ForcedOS: setup.ExprOS{ + Family: f.forcedOSFamily, + ID: f.forcedOSID, + RawVersion: f.forcedOSVersion, + }, + SkipServices: f.skipServices, + SnubSystemd: f.snubSystemd, + } + + hubSetup, err := setup.Detect(detectReader, opts) + if err != nil { + return fmt.Errorf("detecting services: %w", err) + } + + setup, err := setupAsString(hubSetup, f.outYaml) + if err != nil { + return err + } + + fmt.Println(setup) + + return nil +} + +func setupAsString(cs setup.Setup, outYaml bool) (string, error) { + var ( + ret []byte + err error + ) + + wrap := func(err error) error { + return fmt.Errorf("while serializing setup: %w", err) + } + + indentLevel := 2 + buf := &bytes.Buffer{} + enc := yaml.NewEncoder(buf) + enc.SetIndent(indentLevel) + + if err = enc.Encode(cs); err != nil { + return "", wrap(err) + } + + if err = enc.Close(); err != nil { + return "", wrap(err) + } + + ret = buf.Bytes() + + if !outYaml { + // take a general approach to output json, so we avoid the + // double tags in the structures and can use go-yaml features + // missing from the json package + ret, err = goccyyaml.YAMLToJSON(ret) + if err != nil { + return "", wrap(err) + } + } + + return string(ret), nil +} + +func (cli *cliSetup) dataSources(fromFile string, toDir string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading setup file: %w", err) + } + + output, err := setup.DataSources(input, toDir) + if err != nil { + return err + } + + if toDir == "" { + fmt.Println(output) + } + + return nil +} + +func (cli *cliSetup) install(ctx context.Context, dryRun bool, fromFile string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading file %s: %w", fromFile, err) + } + + cfg := cli.cfg() + + hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) + if err != nil { + return err + } + + return setup.InstallHubItems(ctx, hub, input, dryRun) +} + +func (cli *cliSetup) validate(fromFile string) error { + input, err := os.ReadFile(fromFile) + if err != nil { + return fmt.Errorf("while reading stdin: %w", err) + } + + if err = setup.Validate(input); err != nil { + fmt.Printf("%v\n", err) + return errors.New("invalid setup file") + } + + return nil +} diff --git a/cmd/crowdsec-cli/simulation.go b/cmd/crowdsec-cli/clisimulation/simulation.go similarity index 91% rename from cmd/crowdsec-cli/simulation.go rename to cmd/crowdsec-cli/clisimulation/simulation.go index f8d8a660b8c..8136aa213c3 100644 --- a/cmd/crowdsec-cli/simulation.go +++ b/cmd/crowdsec-cli/clisimulation/simulation.go @@ -1,4 +1,4 @@ -package main +package clisimulation import ( "errors" @@ -10,15 +10,19 @@ import ( "github.com/spf13/cobra" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/reload" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) +type configGetter func() *csconfig.Config + type cliSimulation struct { cfg configGetter } -func NewCLISimulation(cfg configGetter) *cliSimulation { +func New(cfg configGetter) *cliSimulation { return &cliSimulation{ cfg: cfg, } @@ -44,21 +48,21 @@ cscli simulation disable crowdsecurity/ssh-bf`, }, PersistentPostRun: func(cmd *cobra.Command, _ []string) { if cmd.Name() != "status" { - log.Infof(ReloadMessage()) + log.Info(reload.Message) } }, } cmd.Flags().SortFlags = false cmd.PersistentFlags().SortFlags = false - cmd.AddCommand(cli.NewEnableCmd()) - cmd.AddCommand(cli.NewDisableCmd()) - cmd.AddCommand(cli.NewStatusCmd()) + cmd.AddCommand(cli.newEnableCmd()) + cmd.AddCommand(cli.newDisableCmd()) + cmd.AddCommand(cli.newStatusCmd()) return cmd } -func (cli *cliSimulation) NewEnableCmd() *cobra.Command { +func (cli *cliSimulation) newEnableCmd() *cobra.Command { var forceGlobalSimulation bool cmd := &cobra.Command{ @@ -118,7 +122,7 @@ func (cli *cliSimulation) NewEnableCmd() *cobra.Command { return cmd } -func (cli *cliSimulation) NewDisableCmd() *cobra.Command { +func (cli *cliSimulation) newDisableCmd() *cobra.Command { var forceGlobalSimulation bool cmd := &cobra.Command{ @@ -165,7 +169,7 @@ func (cli *cliSimulation) NewDisableCmd() *cobra.Command { return cmd } -func (cli *cliSimulation) NewStatusCmd() *cobra.Command { +func (cli *cliSimulation) newStatusCmd() *cobra.Command { cmd := &cobra.Command{ Use: "status", Short: "Show simulation mode status", @@ -216,7 +220,7 @@ func (cli *cliSimulation) dumpSimulationFile() error { newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig) if err != nil { - return fmt.Errorf("unable to marshal simulation configuration: %w", err) + return fmt.Errorf("unable to serialize simulation configuration: %w", err) } err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) @@ -238,7 +242,7 @@ func (cli *cliSimulation) disableGlobalSimulation() error { newConfigSim, err := yaml.Marshal(cfg.Cscli.SimulationConfig) if err != nil { - return fmt.Errorf("unable to marshal new simulation configuration: %w", err) + return fmt.Errorf("unable to serialize new simulation configuration: %w", err) } err = os.WriteFile(cfg.ConfigPaths.SimulationFilePath, newConfigSim, 0o644) diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/clisupport/support.go similarity index 82% rename from cmd/crowdsec-cli/support.go rename to cmd/crowdsec-cli/clisupport/support.go index ef14f90df17..4474f5c8f11 100644 --- a/cmd/crowdsec-cli/support.go +++ b/cmd/crowdsec-cli/clisupport/support.go @@ -1,4 +1,4 @@ -package main +package clisupport import ( "archive/zip" @@ -22,7 +22,13 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clilapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climachine" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climetrics" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clipapi" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" @@ -36,12 +42,13 @@ const ( SUPPORT_VERSION_PATH = "version.txt" SUPPORT_FEATURES_PATH = "features.txt" SUPPORT_OS_INFO_PATH = "osinfo.txt" - SUPPORT_HUB_DIR = "hub/" + SUPPORT_HUB = "hub.txt" SUPPORT_BOUNCERS_PATH = "lapi/bouncers.txt" SUPPORT_AGENTS_PATH = "lapi/agents.txt" SUPPORT_CROWDSEC_CONFIG_PATH = "config/crowdsec.yaml" SUPPORT_LAPI_STATUS_PATH = "lapi_status.txt" SUPPORT_CAPI_STATUS_PATH = "capi_status.txt" + SUPPORT_PAPI_STATUS_PATH = "papi_status.txt" SUPPORT_ACQUISITION_DIR = "config/acquis/" SUPPORT_CROWDSEC_PROFILE_PATH = "config/profiles.yaml" SUPPORT_CRASH_DIR = "crash/" @@ -161,31 +168,28 @@ func (cli *cliSupport) dumpOSInfo(zw *zip.Writer) error { return nil } -func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub, itemType string) error { - var err error +func (cli *cliSupport) dumpHubItems(zw *zip.Writer, hub *cwhub.Hub) error { + log.Infof("Collecting hub") - out := new(bytes.Buffer) - - log.Infof("Collecting hub: %s", itemType) - - items := make(map[string][]*cwhub.Item) - - if items[itemType], err = selectItems(hub, itemType, nil, true); err != nil { - return fmt.Errorf("could not collect %s list: %w", itemType, err) + if hub == nil { + return errors.New("no hub connection") } - if err := listItems(out, cli.cfg().Cscli.Color, []string{itemType}, items, false, "human"); err != nil { - return fmt.Errorf("could not list %s: %w", itemType, err) + out := new(bytes.Buffer) + ch := clihub.New(cli.cfg) + + if err := ch.List(out, hub, false); err != nil { + return err } stripped := stripAnsiString(out.String()) - cli.writeToZip(zw, SUPPORT_HUB_DIR+itemType+".txt", time.Now(), strings.NewReader(stripped)) + cli.writeToZip(zw, SUPPORT_HUB, time.Now(), strings.NewReader(stripped)) return nil } -func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpBouncers(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting bouncers") if db == nil { @@ -193,10 +197,11 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { } out := new(bytes.Buffer) + cb := clibouncer.New(cli.cfg) - // call the "cscli bouncers list" command directly, skip any preRun - cm := cliBouncers{db: db, cfg: cli.cfg} - cm.list(out) + if err := cb.List(ctx, out, db); err != nil { + return err + } stripped := stripAnsiString(out.String()) @@ -205,7 +210,7 @@ func (cli *cliSupport) dumpBouncers(zw *zip.Writer, db *database.Client) error { return nil } -func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { +func (cli *cliSupport) dumpAgents(ctx context.Context, zw *zip.Writer, db *database.Client) error { log.Info("Collecting agents") if db == nil { @@ -213,10 +218,11 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { } out := new(bytes.Buffer) + cm := climachine.New(cli.cfg) - // call the "cscli machines list" command directly, skip any preRun - cm := cliMachines{db: db, cfg: cli.cfg} - cm.list(out) + if err := cm.List(ctx, out, db); err != nil { + return err + } stripped := stripAnsiString(out.String()) @@ -225,54 +231,56 @@ func (cli *cliSupport) dumpAgents(zw *zip.Writer, db *database.Client) error { return nil } -func (cli *cliSupport) dumpLAPIStatus(zw *zip.Writer, hub *cwhub.Hub) error { +func (cli *cliSupport) dumpLAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error { log.Info("Collecting LAPI status") - cfg := cli.cfg() - cred := cfg.API.Client.Credentials - out := new(bytes.Buffer) + cl := clilapi.New(cli.cfg) - fmt.Fprintf(out, "LAPI credentials file: %s\n", cfg.API.Client.CredentialsFilePath) - fmt.Fprintf(out, "LAPI URL: %s\n", cred.URL) - fmt.Fprintf(out, "LAPI username: %s\n", cred.Login) - - if err := QueryLAPIStatus(hub, cred.URL, cred.Login, cred.Password); err != nil { - return fmt.Errorf("could not authenticate to Local API (LAPI): %w", err) + err := cl.Status(ctx, out, hub) + if err != nil { + fmt.Fprintf(out, "%s\n", err) } - fmt.Fprintln(out, "You can successfully interact with Local API (LAPI)") + stripped := stripAnsiString(out.String()) - cli.writeToZip(zw, SUPPORT_LAPI_STATUS_PATH, time.Now(), out) + cli.writeToZip(zw, SUPPORT_LAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) return nil } -func (cli *cliSupport) dumpCAPIStatus(zw *zip.Writer, hub *cwhub.Hub) error { +func (cli *cliSupport) dumpCAPIStatus(ctx context.Context, zw *zip.Writer, hub *cwhub.Hub) error { log.Info("Collecting CAPI status") - cfg := cli.cfg() - cred := cfg.API.Server.OnlineClient.Credentials - out := new(bytes.Buffer) + cc := clicapi.New(cli.cfg) - fmt.Fprintf(out, "CAPI credentials file: %s\n", cfg.API.Server.OnlineClient.CredentialsFilePath) - fmt.Fprintf(out, "CAPI URL: %s\n", cred.URL) - fmt.Fprintf(out, "CAPI username: %s\n", cred.Login) - - auth, enrolled, err := QueryCAPIStatus(hub, cred.URL, cred.Login, cred.Password) + err := cc.Status(ctx, out, hub) if err != nil { - return fmt.Errorf("could not authenticate to Central API (CAPI): %w", err) - } - if auth { - fmt.Fprintln(out, "You can successfully interact with Central API (CAPI)") + fmt.Fprintf(out, "%s\n", err) } - if enrolled { - fmt.Fprintln(out, "Your instance is enrolled in the console") + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_CAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) + + return nil +} + +func (cli *cliSupport) dumpPAPIStatus(ctx context.Context, zw *zip.Writer, db *database.Client) error { + log.Info("Collecting PAPI status") + + out := new(bytes.Buffer) + cp := clipapi.New(cli.cfg) + + err := cp.Status(ctx, out, db) + if err != nil { + fmt.Fprintf(out, "%s\n", err) } - cli.writeToZip(zw, SUPPORT_CAPI_STATUS_PATH, time.Now(), out) + stripped := stripAnsiString(out.String()) + + cli.writeToZip(zw, SUPPORT_PAPI_STATUS_PATH, time.Now(), strings.NewReader(stripped)) return nil } @@ -385,11 +393,13 @@ func (cli *cliSupport) dumpCrash(zw *zip.Writer) error { return nil } +type configGetter func() *csconfig.Config + type cliSupport struct { cfg configGetter } -func NewCLISupport(cfg configGetter) *cliSupport { +func New(cfg configGetter) *cliSupport { return &cliSupport{ cfg: cfg, } @@ -511,30 +521,30 @@ func (cli *cliSupport) dump(ctx context.Context, outFile string) error { log.Warnf("could not collect main config file: %s", err) } - if hub != nil { - for _, itemType := range cwhub.ItemTypes { - if err = cli.dumpHubItems(zipWriter, hub, itemType); err != nil { - log.Warnf("could not collect %s information: %s", itemType, err) - } - } + if err = cli.dumpHubItems(zipWriter, hub); err != nil { + log.Warnf("could not collect hub information: %s", err) } - if err = cli.dumpBouncers(zipWriter, db); err != nil { + if err = cli.dumpBouncers(ctx, zipWriter, db); err != nil { log.Warnf("could not collect bouncers information: %s", err) } - if err = cli.dumpAgents(zipWriter, db); err != nil { + if err = cli.dumpAgents(ctx, zipWriter, db); err != nil { log.Warnf("could not collect agents information: %s", err) } if !skipCAPI { - if err = cli.dumpCAPIStatus(zipWriter, hub); err != nil { + if err = cli.dumpCAPIStatus(ctx, zipWriter, hub); err != nil { log.Warnf("could not collect CAPI status: %s", err) } + + if err = cli.dumpPAPIStatus(ctx, zipWriter, db); err != nil { + log.Warnf("could not collect PAPI status: %s", err) + } } if !skipLAPI { - if err = cli.dumpLAPIStatus(zipWriter, hub); err != nil { + if err = cli.dumpLAPIStatus(ctx, zipWriter, hub); err != nil { log.Warnf("could not collect LAPI status: %s", err) } diff --git a/cmd/crowdsec-cli/config.go b/cmd/crowdsec-cli/config.go index e88845798e2..4cf8916ad4b 100644 --- a/cmd/crowdsec-cli/config.go +++ b/cmd/crowdsec-cli/config.go @@ -18,7 +18,7 @@ func (cli *cliConfig) NewCommand() *cobra.Command { cmd := &cobra.Command{ Use: "config [command]", Short: "Allows to view current config", - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, } diff --git a/cmd/crowdsec-cli/config_backup.go b/cmd/crowdsec-cli/config_backup.go index e8ac6213530..d23aff80a78 100644 --- a/cmd/crowdsec-cli/config_backup.go +++ b/cmd/crowdsec-cli/config_backup.go @@ -74,7 +74,7 @@ func (cli *cliConfig) backupHub(dirPath string) error { upstreamParsersContent, err := json.MarshalIndent(upstreamParsers, "", " ") if err != nil { - return fmt.Errorf("failed marshaling upstream parsers: %w", err) + return fmt.Errorf("failed to serialize upstream parsers: %w", err) } err = os.WriteFile(upstreamParsersFname, upstreamParsersContent, 0o644) diff --git a/cmd/crowdsec-cli/config_feature_flags.go b/cmd/crowdsec-cli/config_feature_flags.go index d1dbe2b93b7..760e2194bb3 100644 --- a/cmd/crowdsec-cli/config_feature_flags.go +++ b/cmd/crowdsec-cli/config_feature_flags.go @@ -121,7 +121,7 @@ func (cli *cliConfig) newFeatureFlagsCmd() *cobra.Command { Use: "feature-flags", Short: "Displays feature flag status", Long: `Displays the supported feature flags and their current status.`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { return cli.featureFlags(showRetired) diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index fc3670165f8..c32328485ec 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -40,7 +40,7 @@ func (cli *cliConfig) restoreHub(ctx context.Context, dirPath string) error { err = json.Unmarshal(file, &upstreamList) if err != nil { - return fmt.Errorf("error unmarshaling %s: %w", upstreamListFN, err) + return fmt.Errorf("error parsing %s: %w", upstreamListFN, err) } for _, toinstall := range upstreamList { diff --git a/cmd/crowdsec-cli/config_show.go b/cmd/crowdsec-cli/config_show.go index e411f5a322b..3d17d264574 100644 --- a/cmd/crowdsec-cli/config_show.go +++ b/cmd/crowdsec-cli/config_show.go @@ -50,7 +50,7 @@ func (cli *cliConfig) showKey(key string) error { case "json": data, err := json.MarshalIndent(output, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } fmt.Println(string(data)) @@ -212,14 +212,14 @@ func (cli *cliConfig) show() error { case "json": data, err := json.MarshalIndent(cfg, "", " ") if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } fmt.Println(string(data)) case "raw": data, err := yaml.Marshal(cfg) if err != nil { - return fmt.Errorf("failed to marshal configuration: %w", err) + return fmt.Errorf("failed to serialize configuration: %w", err) } fmt.Println(string(data)) @@ -235,7 +235,7 @@ func (cli *cliConfig) newShowCmd() *cobra.Command { Use: "show", Short: "Displays current config", Long: `Displays the current cli configuration.`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { if err := cli.cfg().LoadAPIClient(); err != nil { diff --git a/cmd/crowdsec-cli/config_showyaml.go b/cmd/crowdsec-cli/config_showyaml.go index 52daee6a65e..10549648d09 100644 --- a/cmd/crowdsec-cli/config_showyaml.go +++ b/cmd/crowdsec-cli/config_showyaml.go @@ -15,7 +15,7 @@ func (cli *cliConfig) newShowYAMLCmd() *cobra.Command { cmd := &cobra.Command{ Use: "show-yaml", Short: "Displays merged config.yaml + config.yaml.local", - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { return cli.showYAML() diff --git a/cmd/crowdsec-cli/dashboard.go b/cmd/crowdsec-cli/dashboard.go index c61fc8eeded..53a7dff85a0 100644 --- a/cmd/crowdsec-cli/dashboard.go +++ b/cmd/crowdsec-cli/dashboard.go @@ -20,9 +20,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/idgen" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/metabase" - "github.com/crowdsecurity/go-cs-lib/version" ) var ( @@ -127,7 +129,7 @@ func (cli *cliDashboard) newSetupCmd() *cobra.Command { Use: "setup", Short: "Setup a metabase container.", Long: `Perform a metabase docker setup, download standard dashboards, create a fresh user and start the container`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, Example: ` cscli dashboard setup @@ -142,7 +144,7 @@ cscli dashboard setup -l 0.0.0.0 -p 443 --password if metabasePassword == "" { isValid := passwordIsValid(metabasePassword) for !isValid { - metabasePassword = generatePassword(16) + metabasePassword = idgen.GeneratePassword(16) isValid = passwordIsValid(metabasePassword) } } @@ -196,7 +198,7 @@ func (cli *cliDashboard) newStartCmd() *cobra.Command { Use: "start", Short: "Start the metabase container.", Long: `Stats the metabase container using docker.`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { mb, err := metabase.NewMetabase(metabaseConfigPath, metabaseContainerID) @@ -227,7 +229,7 @@ func (cli *cliDashboard) newStopCmd() *cobra.Command { Use: "stop", Short: "Stops the metabase container.", Long: `Stops the metabase container using docker.`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { if err := metabase.StopContainer(metabaseContainerID); err != nil { @@ -243,7 +245,7 @@ func (cli *cliDashboard) newStopCmd() *cobra.Command { func (cli *cliDashboard) newShowPasswordCmd() *cobra.Command { cmd := &cobra.Command{Use: "show-password", Short: "displays password of metabase.", - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, RunE: func(_ *cobra.Command, _ []string) error { m := metabase.Metabase{} @@ -266,7 +268,7 @@ func (cli *cliDashboard) newRemoveCmd() *cobra.Command { Use: "remove", Short: "removes the metabase container.", Long: `removes the metabase container using docker.`, - Args: cobra.ExactArgs(0), + Args: cobra.NoArgs, DisableAutoGenTag: true, Example: ` cscli dashboard remove diff --git a/cmd/crowdsec-cli/hubtest.go b/cmd/crowdsec-cli/hubtest.go deleted file mode 100644 index 2a4635d39f1..00000000000 --- a/cmd/crowdsec-cli/hubtest.go +++ /dev/null @@ -1,746 +0,0 @@ -package main - -import ( - "encoding/json" - "errors" - "fmt" - "math" - "os" - "path/filepath" - "strings" - "text/template" - - "github.com/AlecAivazis/survey/v2" - "github.com/fatih/color" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - - "github.com/crowdsecurity/crowdsec/pkg/dumps" - "github.com/crowdsecurity/crowdsec/pkg/emoji" - "github.com/crowdsecurity/crowdsec/pkg/hubtest" -) - -var ( - HubTest hubtest.HubTest - HubAppsecTests hubtest.HubTest - hubPtr *hubtest.HubTest - isAppsecTest bool -) - -type cliHubTest struct { - cfg configGetter -} - -func NewCLIHubTest(cfg configGetter) *cliHubTest { - return &cliHubTest{ - cfg: cfg, - } -} - -func (cli *cliHubTest) NewCommand() *cobra.Command { - var ( - hubPath string - crowdsecPath string - cscliPath string - ) - - cmd := &cobra.Command{ - Use: "hubtest", - Short: "Run functional tests on hub configurations", - Long: "Run functional tests on hub configurations (parsers, scenarios, collections...)", - Args: cobra.ExactArgs(0), - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - var err error - HubTest, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, false) - if err != nil { - return fmt.Errorf("unable to load hubtest: %+v", err) - } - - HubAppsecTests, err = hubtest.NewHubTest(hubPath, crowdsecPath, cscliPath, true) - if err != nil { - return fmt.Errorf("unable to load appsec specific hubtest: %+v", err) - } - - // commands will use the hubPtr, will point to the default hubTest object, or the one dedicated to appsec tests - hubPtr = &HubTest - if isAppsecTest { - hubPtr = &HubAppsecTests - } - - return nil - }, - } - - cmd.PersistentFlags().StringVar(&hubPath, "hub", ".", "Path to hub folder") - cmd.PersistentFlags().StringVar(&crowdsecPath, "crowdsec", "crowdsec", "Path to crowdsec") - cmd.PersistentFlags().StringVar(&cscliPath, "cscli", "cscli", "Path to cscli") - cmd.PersistentFlags().BoolVar(&isAppsecTest, "appsec", false, "Command relates to appsec tests") - - cmd.AddCommand(cli.NewCreateCmd()) - cmd.AddCommand(cli.NewRunCmd()) - cmd.AddCommand(cli.NewCleanCmd()) - cmd.AddCommand(cli.NewInfoCmd()) - cmd.AddCommand(cli.NewListCmd()) - cmd.AddCommand(cli.NewCoverageCmd()) - cmd.AddCommand(cli.NewEvalCmd()) - cmd.AddCommand(cli.NewExplainCmd()) - - return cmd -} - -func (cli *cliHubTest) NewCreateCmd() *cobra.Command { - var ( - ignoreParsers bool - labels map[string]string - logType string - ) - - parsers := []string{} - postoverflows := []string{} - scenarios := []string{} - - cmd := &cobra.Command{ - Use: "create", - Short: "create [test_name]", - Example: `cscli hubtest create my-awesome-test --type syslog -cscli hubtest create my-nginx-custom-test --type nginx -cscli hubtest create my-scenario-test --parsers crowdsecurity/nginx --scenarios crowdsecurity/http-probing`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - testName := args[0] - testPath := filepath.Join(hubPtr.HubTestPath, testName) - if _, err := os.Stat(testPath); os.IsExist(err) { - return fmt.Errorf("test '%s' already exists in '%s', exiting", testName, testPath) - } - - if isAppsecTest { - logType = "appsec" - } - - if logType == "" { - return errors.New("please provide a type (--type) for the test") - } - - if err := os.MkdirAll(testPath, os.ModePerm); err != nil { - return fmt.Errorf("unable to create folder '%s': %+v", testPath, err) - } - - configFilePath := filepath.Join(testPath, "config.yaml") - - configFileData := &hubtest.HubTestItemConfig{} - if logType == "appsec" { - // create empty nuclei template file - nucleiFileName := fmt.Sprintf("%s.yaml", testName) - nucleiFilePath := filepath.Join(testPath, nucleiFileName) - - nucleiFile, err := os.OpenFile(nucleiFilePath, os.O_RDWR|os.O_CREATE, 0o755) - if err != nil { - return err - } - - ntpl := template.Must(template.New("nuclei").Parse(hubtest.TemplateNucleiFile)) - if ntpl == nil { - return errors.New("unable to parse nuclei template") - } - ntpl.ExecuteTemplate(nucleiFile, "nuclei", struct{ TestName string }{TestName: testName}) - nucleiFile.Close() - configFileData.AppsecRules = []string{"./appsec-rules//your_rule_here.yaml"} - configFileData.NucleiTemplate = nucleiFileName - fmt.Println() - fmt.Printf(" Test name : %s\n", testName) - fmt.Printf(" Test path : %s\n", testPath) - fmt.Printf(" Config File : %s\n", configFilePath) - fmt.Printf(" Nuclei Template : %s\n", nucleiFilePath) - } else { - // create empty log file - logFileName := fmt.Sprintf("%s.log", testName) - logFilePath := filepath.Join(testPath, logFileName) - logFile, err := os.Create(logFilePath) - if err != nil { - return err - } - logFile.Close() - - // create empty parser assertion file - parserAssertFilePath := filepath.Join(testPath, hubtest.ParserAssertFileName) - parserAssertFile, err := os.Create(parserAssertFilePath) - if err != nil { - return err - } - parserAssertFile.Close() - // create empty scenario assertion file - scenarioAssertFilePath := filepath.Join(testPath, hubtest.ScenarioAssertFileName) - scenarioAssertFile, err := os.Create(scenarioAssertFilePath) - if err != nil { - return err - } - scenarioAssertFile.Close() - - parsers = append(parsers, "crowdsecurity/syslog-logs") - parsers = append(parsers, "crowdsecurity/dateparse-enrich") - - if len(scenarios) == 0 { - scenarios = append(scenarios, "") - } - - if len(postoverflows) == 0 { - postoverflows = append(postoverflows, "") - } - configFileData.Parsers = parsers - configFileData.Scenarios = scenarios - configFileData.PostOverflows = postoverflows - configFileData.LogFile = logFileName - configFileData.LogType = logType - configFileData.IgnoreParsers = ignoreParsers - configFileData.Labels = labels - fmt.Println() - fmt.Printf(" Test name : %s\n", testName) - fmt.Printf(" Test path : %s\n", testPath) - fmt.Printf(" Log file : %s (please fill it with logs)\n", logFilePath) - fmt.Printf(" Parser assertion file : %s (please fill it with assertion)\n", parserAssertFilePath) - fmt.Printf(" Scenario assertion file : %s (please fill it with assertion)\n", scenarioAssertFilePath) - fmt.Printf(" Configuration File : %s (please fill it with parsers, scenarios...)\n", configFilePath) - } - - fd, err := os.Create(configFilePath) - if err != nil { - return fmt.Errorf("open: %w", err) - } - data, err := yaml.Marshal(configFileData) - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - _, err = fd.Write(data) - if err != nil { - return fmt.Errorf("write: %w", err) - } - if err := fd.Close(); err != nil { - return fmt.Errorf("close: %w", err) - } - - return nil - }, - } - - cmd.PersistentFlags().StringVarP(&logType, "type", "t", "", "Log type of the test") - cmd.Flags().StringSliceVarP(&parsers, "parsers", "p", parsers, "Parsers to add to test") - cmd.Flags().StringSliceVar(&postoverflows, "postoverflows", postoverflows, "Postoverflows to add to test") - cmd.Flags().StringSliceVarP(&scenarios, "scenarios", "s", scenarios, "Scenarios to add to test") - cmd.PersistentFlags().BoolVar(&ignoreParsers, "ignore-parsers", false, "Don't run test on parsers") - - return cmd -} - - -func (cli *cliHubTest) run(runAll bool, NucleiTargetHost string, AppSecHost string, args []string) error { - cfg := cli.cfg() - - if !runAll && len(args) == 0 { - return errors.New("please provide test to run or --all flag") - } - hubPtr.NucleiTargetHost = NucleiTargetHost - hubPtr.AppSecHost = AppSecHost - if runAll { - if err := hubPtr.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %+v", err) - } - } else { - for _, testName := range args { - _, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %w", testName, err) - } - } - } - - // set timezone to avoid DST issues - os.Setenv("TZ", "UTC") - for _, test := range hubPtr.Tests { - if cfg.Cscli.Output == "human" { - log.Infof("Running test '%s'", test.Name) - } - err := test.Run() - if err != nil { - log.Errorf("running test '%s' failed: %+v", test.Name, err) - } - } - - return nil -} - - -func (cli *cliHubTest) NewRunCmd() *cobra.Command { - var ( - noClean bool - runAll bool - forceClean bool - NucleiTargetHost string - AppSecHost string - ) - - cmd := &cobra.Command{ - Use: "run", - Short: "run [test_name]", - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.run(runAll, NucleiTargetHost, AppSecHost, args) - }, - PersistentPostRunE: func(_ *cobra.Command, _ []string) error { - cfg := cli.cfg() - - success := true - testResult := make(map[string]bool) - for _, test := range hubPtr.Tests { - if test.AutoGen && !isAppsecTest { - if test.ParserAssert.AutoGenAssert { - log.Warningf("Assert file '%s' is empty, generating assertion:", test.ParserAssert.File) - fmt.Println() - fmt.Println(test.ParserAssert.AutoGenAssertData) - } - if test.ScenarioAssert.AutoGenAssert { - log.Warningf("Assert file '%s' is empty, generating assertion:", test.ScenarioAssert.File) - fmt.Println() - fmt.Println(test.ScenarioAssert.AutoGenAssertData) - } - if !noClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) - } - } - return fmt.Errorf("please fill your assert file(s) for test '%s', exiting", test.Name) - } - testResult[test.Name] = test.Success - if test.Success { - if cfg.Cscli.Output == "human" { - log.Infof("Test '%s' passed successfully (%d assertions)\n", test.Name, test.ParserAssert.NbAssert+test.ScenarioAssert.NbAssert) - } - if !noClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) - } - } - } else { - success = false - cleanTestEnv := false - if cfg.Cscli.Output == "human" { - if len(test.ParserAssert.Fails) > 0 { - fmt.Println() - log.Errorf("Parser test '%s' failed (%d errors)\n", test.Name, len(test.ParserAssert.Fails)) - for _, fail := range test.ParserAssert.Fails { - fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) - fmt.Printf(" Actual expression values:\n") - for key, value := range fail.Debug { - fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) - } - fmt.Println() - } - } - if len(test.ScenarioAssert.Fails) > 0 { - fmt.Println() - log.Errorf("Scenario test '%s' failed (%d errors)\n", test.Name, len(test.ScenarioAssert.Fails)) - for _, fail := range test.ScenarioAssert.Fails { - fmt.Printf("(L.%d) %s => %s\n", fail.Line, emoji.RedCircle, fail.Expression) - fmt.Printf(" Actual expression values:\n") - for key, value := range fail.Debug { - fmt.Printf(" %s = '%s'\n", key, strings.TrimSuffix(value, "\n")) - } - fmt.Println() - } - } - if !forceClean && !noClean { - prompt := &survey.Confirm{ - Message: fmt.Sprintf("\nDo you want to remove runtime folder for test '%s'? (default: Yes)", test.Name), - Default: true, - } - if err := survey.AskOne(prompt, &cleanTestEnv); err != nil { - return fmt.Errorf("unable to ask to remove runtime folder: %w", err) - } - } - } - - if cleanTestEnv || forceClean { - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) - } - } - } - } - - switch cfg.Cscli.Output { - case "human": - hubTestResultTable(color.Output, cfg.Cscli.Color, testResult) - case "json": - jsonResult := make(map[string][]string, 0) - jsonResult["success"] = make([]string, 0) - jsonResult["fail"] = make([]string, 0) - for testName, success := range testResult { - if success { - jsonResult["success"] = append(jsonResult["success"], testName) - } else { - jsonResult["fail"] = append(jsonResult["fail"], testName) - } - } - jsonStr, err := json.Marshal(jsonResult) - if err != nil { - return fmt.Errorf("unable to json test result: %w", err) - } - fmt.Println(string(jsonStr)) - default: - return errors.New("only human/json output modes are supported") - } - - if !success { - return errors.New("some tests failed") - } - - return nil - }, - } - - cmd.Flags().BoolVar(&noClean, "no-clean", false, "Don't clean runtime environment if test succeed") - cmd.Flags().BoolVar(&forceClean, "clean", false, "Clean runtime environment if test fail") - cmd.Flags().StringVar(&NucleiTargetHost, "target", hubtest.DefaultNucleiTarget, "Target for AppSec Test") - cmd.Flags().StringVar(&AppSecHost, "host", hubtest.DefaultAppsecHost, "Address to expose AppSec for hubtest") - cmd.Flags().BoolVar(&runAll, "all", false, "Run all tests") - - return cmd -} - -func (cli *cliHubTest) NewCleanCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "clean", - Short: "clean [test_name]", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %w", testName, err) - } - if err := test.Clean(); err != nil { - return fmt.Errorf("unable to clean test '%s' env: %w", test.Name, err) - } - } - - return nil - }, - } - - return cmd -} - -func (cli *cliHubTest) NewInfoCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "info", - Short: "info [test_name]", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("unable to load test '%s': %w", testName, err) - } - fmt.Println() - fmt.Printf(" Test name : %s\n", test.Name) - fmt.Printf(" Test path : %s\n", test.Path) - if isAppsecTest { - fmt.Printf(" Nuclei Template : %s\n", test.Config.NucleiTemplate) - fmt.Printf(" Appsec Rules : %s\n", strings.Join(test.Config.AppsecRules, ", ")) - } else { - fmt.Printf(" Log file : %s\n", filepath.Join(test.Path, test.Config.LogFile)) - fmt.Printf(" Parser assertion file : %s\n", filepath.Join(test.Path, hubtest.ParserAssertFileName)) - fmt.Printf(" Scenario assertion file : %s\n", filepath.Join(test.Path, hubtest.ScenarioAssertFileName)) - } - fmt.Printf(" Configuration File : %s\n", filepath.Join(test.Path, "config.yaml")) - } - - return nil - }, - } - - return cmd -} - -func (cli *cliHubTest) NewListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "list", - Short: "list", - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - cfg := cli.cfg() - - if err := hubPtr.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %w", err) - } - - switch cfg.Cscli.Output { - case "human": - hubTestListTable(color.Output, cfg.Cscli.Color, hubPtr.Tests) - case "json": - j, err := json.MarshalIndent(hubPtr.Tests, " ", " ") - if err != nil { - return err - } - fmt.Println(string(j)) - default: - return errors.New("only human/json output modes are supported") - } - - return nil - }, - } - - return cmd -} - -func (cli *cliHubTest) coverage(showScenarioCov bool, showParserCov bool, showAppsecCov bool, showOnlyPercent bool) error { - cfg := cli.cfg() - - // for this one we explicitly don't do for appsec - if err := HubTest.LoadAllTests(); err != nil { - return fmt.Errorf("unable to load all tests: %+v", err) - } - - var err error - - scenarioCoverage := []hubtest.Coverage{} - parserCoverage := []hubtest.Coverage{} - appsecRuleCoverage := []hubtest.Coverage{} - scenarioCoveragePercent := 0 - parserCoveragePercent := 0 - appsecRuleCoveragePercent := 0 - - // if both are false (flag by default), show both - showAll := !showScenarioCov && !showParserCov && !showAppsecCov - - if showParserCov || showAll { - parserCoverage, err = HubTest.GetParsersCoverage() - if err != nil { - return fmt.Errorf("while getting parser coverage: %w", err) - } - - parserTested := 0 - - for _, test := range parserCoverage { - if test.TestsCount > 0 { - parserTested++ - } - } - - parserCoveragePercent = int(math.Round((float64(parserTested) / float64(len(parserCoverage)) * 100))) - } - - if showScenarioCov || showAll { - scenarioCoverage, err = HubTest.GetScenariosCoverage() - if err != nil { - return fmt.Errorf("while getting scenario coverage: %w", err) - } - - scenarioTested := 0 - - for _, test := range scenarioCoverage { - if test.TestsCount > 0 { - scenarioTested++ - } - } - - scenarioCoveragePercent = int(math.Round((float64(scenarioTested) / float64(len(scenarioCoverage)) * 100))) - } - - if showAppsecCov || showAll { - appsecRuleCoverage, err = HubTest.GetAppsecCoverage() - if err != nil { - return fmt.Errorf("while getting scenario coverage: %w", err) - } - - appsecRuleTested := 0 - - for _, test := range appsecRuleCoverage { - if test.TestsCount > 0 { - appsecRuleTested++ - } - } - - appsecRuleCoveragePercent = int(math.Round((float64(appsecRuleTested) / float64(len(appsecRuleCoverage)) * 100))) - } - - if showOnlyPercent { - switch { - case showAll: - fmt.Printf("parsers=%d%%\nscenarios=%d%%\nappsec_rules=%d%%", parserCoveragePercent, scenarioCoveragePercent, appsecRuleCoveragePercent) - case showParserCov: - fmt.Printf("parsers=%d%%", parserCoveragePercent) - case showScenarioCov: - fmt.Printf("scenarios=%d%%", scenarioCoveragePercent) - case showAppsecCov: - fmt.Printf("appsec_rules=%d%%", appsecRuleCoveragePercent) - } - - return nil - } - - switch cfg.Cscli.Output { - case "human": - if showParserCov || showAll { - hubTestParserCoverageTable(color.Output, cfg.Cscli.Color, parserCoverage) - } - - if showScenarioCov || showAll { - hubTestScenarioCoverageTable(color.Output, cfg.Cscli.Color, scenarioCoverage) - } - - if showAppsecCov || showAll { - hubTestAppsecRuleCoverageTable(color.Output, cfg.Cscli.Color, appsecRuleCoverage) - } - - fmt.Println() - - if showParserCov || showAll { - fmt.Printf("PARSERS : %d%% of coverage\n", parserCoveragePercent) - } - - if showScenarioCov || showAll { - fmt.Printf("SCENARIOS : %d%% of coverage\n", scenarioCoveragePercent) - } - - if showAppsecCov || showAll { - fmt.Printf("APPSEC RULES : %d%% of coverage\n", appsecRuleCoveragePercent) - } - case "json": - dump, err := json.MarshalIndent(parserCoverage, "", " ") - if err != nil { - return err - } - - fmt.Printf("%s", dump) - - dump, err = json.MarshalIndent(scenarioCoverage, "", " ") - if err != nil { - return err - } - - fmt.Printf("%s", dump) - - dump, err = json.MarshalIndent(appsecRuleCoverage, "", " ") - if err != nil { - return err - } - - fmt.Printf("%s", dump) - default: - return errors.New("only human/json output modes are supported") - } - - return nil -} - -func (cli *cliHubTest) NewCoverageCmd() *cobra.Command { - var ( - showParserCov bool - showScenarioCov bool - showOnlyPercent bool - showAppsecCov bool - ) - - cmd := &cobra.Command{ - Use: "coverage", - Short: "coverage", - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.coverage(showScenarioCov, showParserCov, showAppsecCov, showOnlyPercent) - }, - } - - cmd.PersistentFlags().BoolVar(&showOnlyPercent, "percent", false, "Show only percentages of coverage") - cmd.PersistentFlags().BoolVar(&showParserCov, "parsers", false, "Show only parsers coverage") - cmd.PersistentFlags().BoolVar(&showScenarioCov, "scenarios", false, "Show only scenarios coverage") - cmd.PersistentFlags().BoolVar(&showAppsecCov, "appsec", false, "Show only appsec coverage") - - return cmd -} - -func (cli *cliHubTest) NewEvalCmd() *cobra.Command { - var evalExpression string - - cmd := &cobra.Command{ - Use: "eval", - Short: "eval [test_name]", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := hubPtr.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("can't load test: %+v", err) - } - - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - return fmt.Errorf("can't load test results from '%s': %+v", test.ParserResultFile, err) - } - - output, err := test.ParserAssert.EvalExpression(evalExpression) - if err != nil { - return err - } - - fmt.Print(output) - } - - return nil - }, - } - - cmd.PersistentFlags().StringVarP(&evalExpression, "expr", "e", "", "Expression to eval") - - return cmd -} - -func (cli *cliHubTest) NewExplainCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "explain", - Short: "explain [test_name]", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - for _, testName := range args { - test, err := HubTest.LoadTestItem(testName) - if err != nil { - return fmt.Errorf("can't load test: %+v", err) - } - err = test.ParserAssert.LoadTest(test.ParserResultFile) - if err != nil { - if err = test.Run(); err != nil { - return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) - } - - if err = test.ParserAssert.LoadTest(test.ParserResultFile); err != nil { - return fmt.Errorf("unable to load parser result after run: %w", err) - } - } - - err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile) - if err != nil { - if err = test.Run(); err != nil { - return fmt.Errorf("running test '%s' failed: %+v", test.Name, err) - } - - if err = test.ScenarioAssert.LoadTest(test.ScenarioResultFile, test.BucketPourResultFile); err != nil { - return fmt.Errorf("unable to load scenario result after run: %w", err) - } - } - opts := dumps.DumpOpts{} - dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) - } - - return nil - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/idgen/machineid.go b/cmd/crowdsec-cli/idgen/machineid.go new file mode 100644 index 00000000000..4bd356b3abc --- /dev/null +++ b/cmd/crowdsec-cli/idgen/machineid.go @@ -0,0 +1,48 @@ +package idgen + +import ( + "fmt" + "strings" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/machineid" +) + +// Returns a unique identifier for each crowdsec installation, using an +// identifier of the OS installation where available, otherwise a random +// string. +func generateMachineIDPrefix() (string, error) { + prefix, err := machineid.ID() + if err == nil { + return prefix, nil + } + + log.Debugf("failed to get machine-id with usual files: %s", err) + + bID, err := uuid.NewRandom() + if err == nil { + return bID.String(), nil + } + + return "", fmt.Errorf("generating machine id: %w", err) +} + +// Generate a unique identifier, composed by a prefix and a random suffix. +// The prefix can be provided by a parameter to use in test environments. +func GenerateMachineID(prefix string) (string, error) { + var err error + if prefix == "" { + prefix, err = generateMachineIDPrefix() + } + + if err != nil { + return "", err + } + + prefix = strings.ReplaceAll(prefix, "-", "")[:32] + suffix := GeneratePassword(16) + + return prefix + suffix, nil +} diff --git a/cmd/crowdsec-cli/idgen/password.go b/cmd/crowdsec-cli/idgen/password.go new file mode 100644 index 00000000000..e0faa4daacc --- /dev/null +++ b/cmd/crowdsec-cli/idgen/password.go @@ -0,0 +1,32 @@ +package idgen + +import ( + saferand "crypto/rand" + "math/big" + + log "github.com/sirupsen/logrus" +) + +const PasswordLength = 64 + +func GeneratePassword(length int) string { + upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" + lower := "abcdefghijklmnopqrstuvwxyz" + digits := "0123456789" + + charset := upper + lower + digits + charsetLength := len(charset) + + buf := make([]byte, length) + + for i := range length { + rInt, err := saferand.Int(saferand.Reader, big.NewInt(int64(charsetLength))) + if err != nil { + log.Fatalf("failed getting data from prng for password generation : %s", err) + } + + buf[i] = charset[rInt.Int64()] + } + + return string(buf) +} diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go deleted file mode 100644 index dcdb1963b49..00000000000 --- a/cmd/crowdsec-cli/machines.go +++ /dev/null @@ -1,771 +0,0 @@ -package main - -import ( - saferand "crypto/rand" - "encoding/csv" - "encoding/json" - "errors" - "fmt" - "io" - "math/big" - "os" - "slices" - "strings" - "time" - - "github.com/AlecAivazis/survey/v2" - "github.com/fatih/color" - "github.com/go-openapi/strfmt" - "github.com/google/uuid" - "github.com/jedib0t/go-pretty/v6/table" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - - "github.com/crowdsecurity/machineid" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cstable" - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/database" - "github.com/crowdsecurity/crowdsec/pkg/database/ent" - "github.com/crowdsecurity/crowdsec/pkg/emoji" - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -const passwordLength = 64 - -func generatePassword(length int) string { - upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" - lower := "abcdefghijklmnopqrstuvwxyz" - digits := "0123456789" - - charset := upper + lower + digits - charsetLength := len(charset) - - buf := make([]byte, length) - - for i := range length { - rInt, err := saferand.Int(saferand.Reader, big.NewInt(int64(charsetLength))) - if err != nil { - log.Fatalf("failed getting data from prng for password generation : %s", err) - } - - buf[i] = charset[rInt.Int64()] - } - - return string(buf) -} - -// Returns a unique identifier for each crowdsec installation, using an -// identifier of the OS installation where available, otherwise a random -// string. -func generateIDPrefix() (string, error) { - prefix, err := machineid.ID() - if err == nil { - return prefix, nil - } - - log.Debugf("failed to get machine-id with usual files: %s", err) - - bID, err := uuid.NewRandom() - if err == nil { - return bID.String(), nil - } - - return "", fmt.Errorf("generating machine id: %w", err) -} - -// Generate a unique identifier, composed by a prefix and a random suffix. -// The prefix can be provided by a parameter to use in test environments. -func generateID(prefix string) (string, error) { - var err error - if prefix == "" { - prefix, err = generateIDPrefix() - } - - if err != nil { - return "", err - } - - prefix = strings.ReplaceAll(prefix, "-", "")[:32] - suffix := generatePassword(16) - - return prefix + suffix, nil -} - -// getLastHeartbeat returns the last heartbeat timestamp of a machine -// and a boolean indicating if the machine is considered active or not. -func getLastHeartbeat(m *ent.Machine) (string, bool) { - if m.LastHeartbeat == nil { - return "-", false - } - - elapsed := time.Now().UTC().Sub(*m.LastHeartbeat) - - hb := elapsed.Truncate(time.Second).String() - if elapsed > 2*time.Minute { - return hb, false - } - - return hb, true -} - -type cliMachines struct { - db *database.Client - cfg configGetter -} - -func NewCLIMachines(cfg configGetter) *cliMachines { - return &cliMachines{ - cfg: cfg, - } -} - -func (cli *cliMachines) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "machines [action]", - Short: "Manage local API machines [requires local API]", - Long: `To list/add/delete/validate/prune machines. -Note: This command requires database direct access, so is intended to be run on the local API machine. -`, - Example: `cscli machines [action]`, - DisableAutoGenTag: true, - Aliases: []string{"machine"}, - PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { - var err error - if err = require.LAPI(cli.cfg()); err != nil { - return err - } - cli.db, err = require.DBClient(cmd.Context(), cli.cfg().DbConfig) - if err != nil { - return err - } - - return nil - }, - } - - cmd.AddCommand(cli.newListCmd()) - cmd.AddCommand(cli.newAddCmd()) - cmd.AddCommand(cli.newDeleteCmd()) - cmd.AddCommand(cli.newValidateCmd()) - cmd.AddCommand(cli.newPruneCmd()) - cmd.AddCommand(cli.newInspectCmd()) - - return cmd -} - -func (cli *cliMachines) inspectHubHuman(out io.Writer, machine *ent.Machine) { - state := machine.Hubstate - - if len(state) == 0 { - fmt.Println("No hub items found for this machine") - return - } - - // group state rows by type for multiple tables - rowsByType := make(map[string][]table.Row) - - for itemType, items := range state { - for _, item := range items { - if _, ok := rowsByType[itemType]; !ok { - rowsByType[itemType] = make([]table.Row, 0) - } - - row := table.Row{item.Name, item.Status, item.Version} - rowsByType[itemType] = append(rowsByType[itemType], row) - } - } - - for itemType, rows := range rowsByType { - t := cstable.New(out, cli.cfg().Cscli.Color).Writer - t.AppendHeader(table.Row{"Name", "Status", "Version"}) - t.SetTitle(itemType) - t.AppendRows(rows) - io.WriteString(out, t.Render() + "\n") - } -} - -func (cli *cliMachines) listHuman(out io.Writer, machines ent.Machines) { - t := cstable.NewLight(out, cli.cfg().Cscli.Color).Writer - t.AppendHeader(table.Row{"Name", "IP Address", "Last Update", "Status", "Version", "OS", "Auth Type", "Last Heartbeat"}) - - for _, m := range machines { - validated := emoji.Prohibited - if m.IsValidated { - validated = emoji.CheckMark - } - - hb, active := getLastHeartbeat(m) - if !active { - hb = emoji.Warning + " " + hb - } - - t.AppendRow(table.Row{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, getOSNameAndVersion(m), m.AuthType, hb}) - } - - io.WriteString(out, t.Render() + "\n") -} - -// machineInfo contains only the data we want for inspect/list: no hub status, scenarios, edges, etc. -type machineInfo struct { - CreatedAt time.Time `json:"created_at,omitempty"` - UpdatedAt time.Time `json:"updated_at,omitempty"` - LastPush *time.Time `json:"last_push,omitempty"` - LastHeartbeat *time.Time `json:"last_heartbeat,omitempty"` - MachineId string `json:"machineId,omitempty"` - IpAddress string `json:"ipAddress,omitempty"` - Version string `json:"version,omitempty"` - IsValidated bool `json:"isValidated,omitempty"` - AuthType string `json:"auth_type"` - OS string `json:"os,omitempty"` - Featureflags []string `json:"featureflags,omitempty"` - Datasources map[string]int64 `json:"datasources,omitempty"` -} - -func newMachineInfo(m *ent.Machine) machineInfo { - return machineInfo{ - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - LastPush: m.LastPush, - LastHeartbeat: m.LastHeartbeat, - MachineId: m.MachineId, - IpAddress: m.IpAddress, - Version: m.Version, - IsValidated: m.IsValidated, - AuthType: m.AuthType, - OS: getOSNameAndVersion(m), - Featureflags: getFeatureFlagList(m), - Datasources: m.Datasources, - } -} - -func (cli *cliMachines) listCSV(out io.Writer, machines ent.Machines) error { - csvwriter := csv.NewWriter(out) - - err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat", "os"}) - if err != nil { - return fmt.Errorf("failed to write header: %w", err) - } - - for _, m := range machines { - validated := "false" - if m.IsValidated { - validated = "true" - } - - hb := "-" - if m.LastHeartbeat != nil { - hb = m.LastHeartbeat.Format(time.RFC3339) - } - - if err := csvwriter.Write([]string{m.MachineId, m.IpAddress, m.UpdatedAt.Format(time.RFC3339), validated, m.Version, m.AuthType, hb, fmt.Sprintf("%s/%s", m.Osname, m.Osversion)}); err != nil { - return fmt.Errorf("failed to write raw output: %w", err) - } - } - - csvwriter.Flush() - - return nil -} - -func (cli *cliMachines) list(out io.Writer) error { - machines, err := cli.db.ListMachines() - if err != nil { - return fmt.Errorf("unable to list machines: %w", err) - } - - switch cli.cfg().Cscli.Output { - case "human": - cli.listHuman(out, machines) - case "json": - info := make([]machineInfo, 0, len(machines)) - for _, m := range machines { - info = append(info, newMachineInfo(m)) - } - - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - - if err := enc.Encode(info); err != nil { - return errors.New("failed to marshal") - } - - return nil - case "raw": - return cli.listCSV(out, machines) - } - - return nil -} - -func (cli *cliMachines) newListCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "list", - Short: "list all machines in the database", - Long: `list all machines in the database with their status and last heartbeat`, - Example: `cscli machines list`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.list(color.Output) - }, - } - - return cmd -} - -func (cli *cliMachines) newAddCmd() *cobra.Command { - var ( - password MachinePassword - dumpFile string - apiURL string - interactive bool - autoAdd bool - force bool - ) - - cmd := &cobra.Command{ - Use: "add", - Short: "add a single machine to the database", - DisableAutoGenTag: true, - Long: `Register a new machine in the database. cscli should be on the same machine as LAPI.`, - Example: `cscli machines add --auto -cscli machines add MyTestMachine --auto -cscli machines add MyTestMachine --password MyPassword -cscli machines add -f- --auto > /tmp/mycreds.yaml`, - RunE: func(_ *cobra.Command, args []string) error { - return cli.add(args, string(password), dumpFile, apiURL, interactive, autoAdd, force) - }, - } - - flags := cmd.Flags() - flags.VarP(&password, "password", "p", "machine password to login to the API") - flags.StringVarP(&dumpFile, "file", "f", "", "output file destination (defaults to "+csconfig.DefaultConfigPath("local_api_credentials.yaml")+")") - flags.StringVarP(&apiURL, "url", "u", "", "URL of the local API") - flags.BoolVarP(&interactive, "interactive", "i", false, "interfactive mode to enter the password") - flags.BoolVarP(&autoAdd, "auto", "a", false, "automatically generate password (and username if not provided)") - flags.BoolVar(&force, "force", false, "will force add the machine if it already exist") - - return cmd -} - -func (cli *cliMachines) add(args []string, machinePassword string, dumpFile string, apiURL string, interactive bool, autoAdd bool, force bool) error { - var ( - err error - machineID string - ) - - // create machineID if not specified by user - if len(args) == 0 { - if !autoAdd { - return errors.New("please specify a machine name to add, or use --auto") - } - - machineID, err = generateID("") - if err != nil { - return fmt.Errorf("unable to generate machine id: %w", err) - } - } else { - machineID = args[0] - } - - clientCfg := cli.cfg().API.Client - serverCfg := cli.cfg().API.Server - - /*check if file already exists*/ - if dumpFile == "" && clientCfg != nil && clientCfg.CredentialsFilePath != "" { - credFile := clientCfg.CredentialsFilePath - // use the default only if the file does not exist - _, err = os.Stat(credFile) - - switch { - case os.IsNotExist(err) || force: - dumpFile = credFile - case err != nil: - return fmt.Errorf("unable to stat '%s': %w", credFile, err) - default: - return fmt.Errorf(`credentials file '%s' already exists: please remove it, use "--force" or specify a different file with "-f" ("-f -" for standard output)`, credFile) - } - } - - if dumpFile == "" { - return errors.New(`please specify a file to dump credentials to, with -f ("-f -" for standard output)`) - } - - // create a password if it's not specified by user - if machinePassword == "" && !interactive { - if !autoAdd { - return errors.New("please specify a password with --password or use --auto") - } - - machinePassword = generatePassword(passwordLength) - } else if machinePassword == "" && interactive { - qs := &survey.Password{ - Message: "Please provide a password for the machine:", - } - survey.AskOne(qs, &machinePassword) - } - - password := strfmt.Password(machinePassword) - - _, err = cli.db.CreateMachine(&machineID, &password, "", true, force, types.PasswordAuthType) - if err != nil { - return fmt.Errorf("unable to create machine: %w", err) - } - - fmt.Fprintf(os.Stderr, "Machine '%s' successfully added to the local API.\n", machineID) - - if apiURL == "" { - if clientCfg != nil && clientCfg.Credentials != nil && clientCfg.Credentials.URL != "" { - apiURL = clientCfg.Credentials.URL - } else if serverCfg.ClientURL() != "" { - apiURL = serverCfg.ClientURL() - } else { - return errors.New("unable to dump an api URL. Please provide it in your configuration or with the -u parameter") - } - } - - apiCfg := csconfig.ApiCredentialsCfg{ - Login: machineID, - Password: password.String(), - URL: apiURL, - } - - apiConfigDump, err := yaml.Marshal(apiCfg) - if err != nil { - return fmt.Errorf("unable to marshal api credentials: %w", err) - } - - if dumpFile != "" && dumpFile != "-" { - if err = os.WriteFile(dumpFile, apiConfigDump, 0o600); err != nil { - return fmt.Errorf("write api credentials in '%s' failed: %w", dumpFile, err) - } - - fmt.Fprintf(os.Stderr, "API credentials written to '%s'.\n", dumpFile) - } else { - fmt.Print(string(apiConfigDump)) - } - - return nil -} - -// validMachineID returns a list of machine IDs for command completion -func (cli *cliMachines) validMachineID(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - var err error - - cfg := cli.cfg() - - // need to load config and db because PersistentPreRunE is not called for completions - - if err = require.LAPI(cfg); err != nil { - cobra.CompError("unable to list machines " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - - cli.db, err = require.DBClient(cmd.Context(), cfg.DbConfig) - if err != nil { - cobra.CompError("unable to list machines " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - - machines, err := cli.db.ListMachines() - if err != nil { - cobra.CompError("unable to list machines " + err.Error()) - return nil, cobra.ShellCompDirectiveNoFileComp - } - - ret := []string{} - - for _, machine := range machines { - if strings.Contains(machine.MachineId, toComplete) && !slices.Contains(args, machine.MachineId) { - ret = append(ret, machine.MachineId) - } - } - - return ret, cobra.ShellCompDirectiveNoFileComp -} - -func (cli *cliMachines) delete(machines []string, ignoreMissing bool) error { - for _, machineID := range machines { - if err := cli.db.DeleteWatcher(machineID); err != nil { - var notFoundErr *database.MachineNotFoundError - if ignoreMissing && errors.As(err, ¬FoundErr) { - return nil - } - - log.Errorf("unable to delete machine: %s", err) - - return nil - } - - log.Infof("machine '%s' deleted successfully", machineID) - } - - return nil -} - -func (cli *cliMachines) newDeleteCmd() *cobra.Command { - var ignoreMissing bool - - cmd := &cobra.Command{ - Use: "delete [machine_name]...", - Short: "delete machine(s) by name", - Example: `cscli machines delete "machine1" "machine2"`, - Args: cobra.MinimumNArgs(1), - Aliases: []string{"remove"}, - DisableAutoGenTag: true, - ValidArgsFunction: cli.validMachineID, - RunE: func(_ *cobra.Command, args []string) error { - return cli.delete(args, ignoreMissing) - }, - } - - flags := cmd.Flags() - flags.BoolVar(&ignoreMissing, "ignore-missing", false, "don't print errors if one or more machines don't exist") - - return cmd -} - -func (cli *cliMachines) prune(duration time.Duration, notValidOnly bool, force bool) error { - if duration < 2*time.Minute && !notValidOnly { - if yes, err := askYesNo( - "The duration you provided is less than 2 minutes. "+ - "This can break installations if the machines are only temporarily disconnected. Continue?", false); err != nil { - return err - } else if !yes { - fmt.Println("User aborted prune. No changes were made.") - return nil - } - } - - machines := []*ent.Machine{} - if pending, err := cli.db.QueryPendingMachine(); err == nil { - machines = append(machines, pending...) - } - - if !notValidOnly { - if pending, err := cli.db.QueryMachinesInactiveSince(time.Now().UTC().Add(-duration)); err == nil { - machines = append(machines, pending...) - } - } - - if len(machines) == 0 { - fmt.Println("No machines to prune.") - return nil - } - - cli.listHuman(color.Output, machines) - - if !force { - if yes, err := askYesNo( - "You are about to PERMANENTLY remove the above machines from the database. "+ - "These will NOT be recoverable. Continue?", false); err != nil { - return err - } else if !yes { - fmt.Println("User aborted prune. No changes were made.") - return nil - } - } - - deleted, err := cli.db.BulkDeleteWatchers(machines) - if err != nil { - return fmt.Errorf("unable to prune machines: %w", err) - } - - fmt.Fprintf(os.Stderr, "successfully deleted %d machines\n", deleted) - - return nil -} - -func (cli *cliMachines) newPruneCmd() *cobra.Command { - var ( - duration time.Duration - notValidOnly bool - force bool - ) - - const defaultDuration = 10 * time.Minute - - cmd := &cobra.Command{ - Use: "prune", - Short: "prune multiple machines from the database", - Long: `prune multiple machines that are not validated or have not connected to the local API in a given duration.`, - Example: `cscli machines prune -cscli machines prune --duration 1h -cscli machines prune --not-validated-only --force`, - Args: cobra.NoArgs, - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, _ []string) error { - return cli.prune(duration, notValidOnly, force) - }, - } - - flags := cmd.Flags() - flags.DurationVarP(&duration, "duration", "d", defaultDuration, "duration of time since validated machine last heartbeat") - flags.BoolVar(¬ValidOnly, "not-validated-only", false, "only prune machines that are not validated") - flags.BoolVar(&force, "force", false, "force prune without asking for confirmation") - - return cmd -} - -func (cli *cliMachines) validate(machineID string) error { - if err := cli.db.ValidateMachine(machineID); err != nil { - return fmt.Errorf("unable to validate machine '%s': %w", machineID, err) - } - - log.Infof("machine '%s' validated successfully", machineID) - - return nil -} - -func (cli *cliMachines) newValidateCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "validate", - Short: "validate a machine to access the local API", - Long: `validate a machine to access the local API.`, - Example: `cscli machines validate "machine_name"`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.validate(args[0]) - }, - } - - return cmd -} - -func (cli *cliMachines) inspectHuman(out io.Writer, machine *ent.Machine) { - t := cstable.New(out, cli.cfg().Cscli.Color).Writer - - t.SetTitle("Machine: " + machine.MachineId) - - t.SetColumnConfigs([]table.ColumnConfig{ - {Number: 1, AutoMerge: true}, - }) - - t.AppendRows([]table.Row{ - {"IP Address", machine.IpAddress}, - {"Created At", machine.CreatedAt}, - {"Last Update", machine.UpdatedAt}, - {"Last Heartbeat", machine.LastHeartbeat}, - {"Validated?", machine.IsValidated}, - {"CrowdSec version", machine.Version}, - {"OS", getOSNameAndVersion(machine)}, - {"Auth type", machine.AuthType}, - }) - - for dsName, dsCount := range machine.Datasources { - t.AppendRow(table.Row{"Datasources", fmt.Sprintf("%s: %d", dsName, dsCount)}) - } - - for _, ff := range getFeatureFlagList(machine) { - t.AppendRow(table.Row{"Feature Flags", ff}) - } - - for _, coll := range machine.Hubstate[cwhub.COLLECTIONS] { - t.AppendRow(table.Row{"Collections", coll.Name}) - } - - io.WriteString(out, t.Render() + "\n") -} - -func (cli *cliMachines) inspect(machine *ent.Machine) error { - out := color.Output - outputFormat := cli.cfg().Cscli.Output - - switch outputFormat { - case "human": - cli.inspectHuman(out, machine) - case "json": - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - - if err := enc.Encode(newMachineInfo(machine)); err != nil { - return errors.New("failed to marshal") - } - - return nil - default: - return fmt.Errorf("output format '%s' not supported for this command", outputFormat) - } - - return nil -} - -func (cli *cliMachines) inspectHub(machine *ent.Machine) error { - out := color.Output - - switch cli.cfg().Cscli.Output { - case "human": - cli.inspectHubHuman(out, machine) - case "json": - enc := json.NewEncoder(out) - enc.SetIndent("", " ") - - if err := enc.Encode(machine.Hubstate); err != nil { - return errors.New("failed to marshal") - } - - return nil - case "raw": - csvwriter := csv.NewWriter(out) - - err := csvwriter.Write([]string{"type", "name", "status", "version"}) - if err != nil { - return fmt.Errorf("failed to write header: %w", err) - } - - rows := make([][]string, 0) - - for itemType, items := range machine.Hubstate { - for _, item := range items { - rows = append(rows, []string{itemType, item.Name, item.Status, item.Version}) - } - } - - for _, row := range rows { - if err := csvwriter.Write(row); err != nil { - return fmt.Errorf("failed to write raw output: %w", err) - } - } - - csvwriter.Flush() - } - - return nil -} - -func (cli *cliMachines) newInspectCmd() *cobra.Command { - var showHub bool - - cmd := &cobra.Command{ - Use: "inspect [machine_name]", - Short: "inspect a machine by name", - Example: `cscli machines inspect "machine1"`, - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - ValidArgsFunction: cli.validMachineID, - RunE: func(_ *cobra.Command, args []string) error { - machineID := args[0] - machine, err := cli.db.QueryMachineByID(machineID) - if err != nil { - return fmt.Errorf("unable to read machine data '%s': %w", machineID, err) - } - - if showHub { - return cli.inspectHub(machine) - } - - return cli.inspect(machine) - }, - } - - flags := cmd.Flags() - - flags.BoolVarP(&showHub, "hub", "H", false, "show hub state") - - return cmd -} diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index d4046414030..1cca03b1d3d 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -14,8 +14,22 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clialert" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clibouncer" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clicapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliconsole" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clidecision" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliexplain" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihub" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clihubtest" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/cliitem" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clilapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climachine" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/climetrics" - + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clinotifications" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clipapi" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisimulation" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisupport" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/fflag" ) @@ -152,14 +166,6 @@ func (cli *cliRoot) initialize() error { return nil } -// list of valid subcommands for the shell completion -var validArgs = []string{ - "alerts", "appsec-configs", "appsec-rules", "bouncers", "capi", "collections", - "completion", "config", "console", "contexts", "dashboard", "decisions", "explain", - "hub", "hubtest", "lapi", "machines", "metrics", "notifications", "parsers", - "postoverflows", "scenarios", "simulation", "support", "version", -} - func (cli *cliRoot) colorize(cmd *cobra.Command) { cc.Init(&cc.Config{ RootCmd: cmd, @@ -191,6 +197,14 @@ func (cli *cliRoot) NewCommand() (*cobra.Command, error) { return nil, fmt.Errorf("failed to set feature flags from env: %w", err) } + // list of valid subcommands for the shell completion + validArgs := []string{ + "alerts", "appsec-configs", "appsec-rules", "bouncers", "capi", "collections", + "completion", "config", "console", "contexts", "dashboard", "decisions", "explain", + "hub", "hubtest", "lapi", "machines", "metrics", "notifications", "parsers", + "postoverflows", "scenarios", "simulation", "support", "version", + } + cmd := &cobra.Command{ Use: "cscli", Short: "cscli allows you to manage crowdsec", @@ -238,6 +252,36 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall return nil, err } + cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) + cmd.AddCommand(NewCLIVersion().NewCommand()) + cmd.AddCommand(NewCLIConfig(cli.cfg).NewCommand()) + cmd.AddCommand(clihub.New(cli.cfg).NewCommand()) + cmd.AddCommand(climetrics.New(cli.cfg).NewCommand()) + cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) + cmd.AddCommand(clidecision.New(cli.cfg).NewCommand()) + cmd.AddCommand(clialert.New(cli.cfg).NewCommand()) + cmd.AddCommand(clisimulation.New(cli.cfg).NewCommand()) + cmd.AddCommand(clibouncer.New(cli.cfg).NewCommand()) + cmd.AddCommand(climachine.New(cli.cfg).NewCommand()) + cmd.AddCommand(clicapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(clilapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(NewCompletionCmd()) + cmd.AddCommand(cliconsole.New(cli.cfg).NewCommand()) + cmd.AddCommand(cliexplain.New(cli.cfg, ConfigFilePath).NewCommand()) + cmd.AddCommand(clihubtest.New(cli.cfg).NewCommand()) + cmd.AddCommand(clinotifications.New(cli.cfg).NewCommand()) + cmd.AddCommand(clisupport.New(cli.cfg).NewCommand()) + cmd.AddCommand(clipapi.New(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewCollection(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewParser(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewScenario(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewPostOverflow(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewContext(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewAppsecConfig(cli.cfg).NewCommand()) + cmd.AddCommand(cliitem.NewAppsecRule(cli.cfg).NewCommand()) + + cli.addSetup(cmd) + if len(os.Args) > 1 { cobra.OnInitialize( func() { @@ -248,38 +292,6 @@ It is meant to allow you to manage bans, parsers/scenarios/etc, api and generall ) } - cmd.AddCommand(NewCLIDoc().NewCommand(cmd)) - cmd.AddCommand(NewCLIVersion().NewCommand()) - cmd.AddCommand(NewCLIConfig(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIHub(cli.cfg).NewCommand()) - cmd.AddCommand(climetrics.New(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIDashboard(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIDecisions(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIAlerts(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLISimulation(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIBouncers(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIMachines(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLICapi(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLILapi(cli.cfg).NewCommand()) - cmd.AddCommand(NewCompletionCmd()) - cmd.AddCommand(NewCLIConsole(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIExplain(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIHubTest(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLINotifications(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLISupport(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIPapi(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLICollection(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIParser(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIScenario(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIPostOverflow(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIContext(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIAppsecConfig(cli.cfg).NewCommand()) - cmd.AddCommand(NewCLIAppsecRule(cli.cfg).NewCommand()) - - if fflag.CscliSetup.IsEnabled() { - cmd.AddCommand(NewCLISetup(cli.cfg).NewCommand()) - } - return cmd, nil } diff --git a/cmd/crowdsec-cli/messages.go b/cmd/crowdsec-cli/messages.go deleted file mode 100644 index 02f051601e4..00000000000 --- a/cmd/crowdsec-cli/messages.go +++ /dev/null @@ -1,23 +0,0 @@ -package main - -import ( - "fmt" - "runtime" -) - -// ReloadMessage returns a description of the task required to reload -// the crowdsec configuration, according to the operating system. -func ReloadMessage() string { - var msg string - - switch runtime.GOOS { - case "windows": - msg = "Please restart the crowdsec service" - case "freebsd": - msg = `Run 'sudo service crowdsec reload'` - default: - msg = `Run 'sudo systemctl reload crowdsec'` - } - - return fmt.Sprintf("%s for the new configuration to be effective.", msg) -} diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go deleted file mode 100644 index a2fa0a90871..00000000000 --- a/cmd/crowdsec-cli/papi.go +++ /dev/null @@ -1,148 +0,0 @@ -package main - -import ( - "fmt" - "time" - - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/ptr" - - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/apiserver" -) - -type cliPapi struct { - cfg configGetter -} - -func NewCLIPapi(cfg configGetter) *cliPapi { - return &cliPapi{ - cfg: cfg, - } -} - -func (cli *cliPapi) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "papi [action]", - Short: "Manage interaction with Polling API (PAPI)", - Args: cobra.MinimumNArgs(1), - DisableAutoGenTag: true, - PersistentPreRunE: func(_ *cobra.Command, _ []string) error { - cfg := cli.cfg() - if err := require.LAPI(cfg); err != nil { - return err - } - if err := require.CAPI(cfg); err != nil { - return err - } - - return require.PAPI(cfg) - }, - } - - cmd.AddCommand(cli.NewStatusCmd()) - cmd.AddCommand(cli.NewSyncCmd()) - - return cmd -} - -func (cli *cliPapi) NewStatusCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "status", - Short: "Get status of the Polling API", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, _ []string) error { - var err error - cfg := cli.cfg() - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) - if err != nil { - return err - } - - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) - if err != nil { - return fmt.Errorf("unable to initialize API client: %w", err) - } - - papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) - if err != nil { - return fmt.Errorf("unable to initialize PAPI client: %w", err) - } - - perms, err := papi.GetPermissions() - if err != nil { - return fmt.Errorf("unable to get PAPI permissions: %w", err) - } - var lastTimestampStr *string - lastTimestampStr, err = db.GetConfigItem(apiserver.PapiPullKey) - if err != nil { - lastTimestampStr = ptr.Of("never") - } - log.Infof("You can successfully interact with Polling API (PAPI)") - log.Infof("Console plan: %s", perms.Plan) - log.Infof("Last order received: %s", *lastTimestampStr) - - log.Infof("PAPI subscriptions:") - for _, sub := range perms.Categories { - log.Infof(" - %s", sub) - } - - return nil - }, - } - - return cmd -} - -func (cli *cliPapi) NewSyncCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "sync", - Short: "Sync with the Polling API, pulling all non-expired orders for the instance", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, _ []string) error { - var err error - cfg := cli.cfg() - t := tomb.Tomb{} - - db, err := require.DBClient(cmd.Context(), cfg.DbConfig) - if err != nil { - return err - } - - apic, err := apiserver.NewAPIC(cfg.API.Server.OnlineClient, db, cfg.API.Server.ConsoleConfig, cfg.API.Server.CapiWhitelists) - if err != nil { - return fmt.Errorf("unable to initialize API client: %w", err) - } - - t.Go(apic.Push) - - papi, err := apiserver.NewPAPI(apic, db, cfg.API.Server.ConsoleConfig, log.GetLevel()) - if err != nil { - return fmt.Errorf("unable to initialize PAPI client: %w", err) - } - - t.Go(papi.SyncDecisions) - - err = papi.PullOnce(time.Time{}, true) - if err != nil { - return fmt.Errorf("unable to sync decisions: %w", err) - } - - log.Infof("Sending acknowledgements to CAPI") - - apic.Shutdown() - papi.Shutdown() - t.Wait() - time.Sleep(5 * time.Second) // FIXME: the push done by apic.Push is run inside a sub goroutine, sleep to make sure it's done - - return nil - }, - } - - return cmd -} diff --git a/cmd/crowdsec-cli/reload/reload.go b/cmd/crowdsec-cli/reload/reload.go new file mode 100644 index 00000000000..fe03af1ea79 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload.go @@ -0,0 +1,6 @@ +//go:build !windows && !freebsd && !linux + +package reload + +// generic message since we don't know the platform +const Message = "Please reload the crowdsec process for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_freebsd.go b/cmd/crowdsec-cli/reload/reload_freebsd.go new file mode 100644 index 00000000000..0dac99f2315 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_freebsd.go @@ -0,0 +1,4 @@ +package reload + +// actually sudo is not that popular on freebsd, but this will do +const Message = "Run 'sudo service crowdsec reload' for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_linux.go b/cmd/crowdsec-cli/reload/reload_linux.go new file mode 100644 index 00000000000..fbe16e5f168 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_linux.go @@ -0,0 +1,4 @@ +package reload + +// assume systemd, although gentoo and others may differ +const Message = "Run 'sudo systemctl reload crowdsec' for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/reload/reload_windows.go b/cmd/crowdsec-cli/reload/reload_windows.go new file mode 100644 index 00000000000..88642425ae2 --- /dev/null +++ b/cmd/crowdsec-cli/reload/reload_windows.go @@ -0,0 +1,3 @@ +package reload + +const Message = "Please restart the crowdsec service for the new configuration to be effective." diff --git a/cmd/crowdsec-cli/require/require.go b/cmd/crowdsec-cli/require/require.go index 15d8bce682d..191eee55bc5 100644 --- a/cmd/crowdsec-cli/require/require.go +++ b/cmd/crowdsec-cli/require/require.go @@ -34,6 +34,14 @@ func CAPI(c *csconfig.Config) error { } func PAPI(c *csconfig.Config) error { + if err := CAPI(c); err != nil { + return err + } + + if err := CAPIRegistered(c); err != nil { + return err + } + if c.API.Server.OnlineClient.Credentials.PapiURL == "" { return errors.New("no PAPI URL in configuration") } diff --git a/cmd/crowdsec-cli/setup.go b/cmd/crowdsec-cli/setup.go index d747af9225f..66c0d71e777 100644 --- a/cmd/crowdsec-cli/setup.go +++ b/cmd/crowdsec-cli/setup.go @@ -1,304 +1,18 @@ +//go:build !no_cscli_setup package main import ( - "bytes" - "context" - "errors" - "fmt" - "os" - "os/exec" - - goccyyaml "github.com/goccy/go-yaml" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/yaml.v3" - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" - "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/setup" + "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/clisetup" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/fflag" ) -type cliSetup struct { - cfg configGetter -} - -func NewCLISetup(cfg configGetter) *cliSetup { - return &cliSetup{ - cfg: cfg, - } -} - -func (cli *cliSetup) NewCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "setup", - Short: "Tools to configure crowdsec", - Long: "Manage hub configuration and service detection", - Args: cobra.MinimumNArgs(0), - DisableAutoGenTag: true, - } - - cmd.AddCommand(cli.NewDetectCmd()) - cmd.AddCommand(cli.NewInstallHubCmd()) - cmd.AddCommand(cli.NewDataSourcesCmd()) - cmd.AddCommand(cli.NewValidateCmd()) - - return cmd -} - -type detectFlags struct { - detectConfigFile string - listSupportedServices bool - forcedUnits []string - forcedProcesses []string - forcedOSFamily string - forcedOSID string - forcedOSVersion string - skipServices []string - snubSystemd bool - outYaml bool -} - -func (f *detectFlags) bind(cmd *cobra.Command) { - defaultServiceDetect := csconfig.DefaultConfigPath("hub", "detect.yaml") - - flags := cmd.Flags() - flags.StringVar(&f.detectConfigFile, "detect-config", defaultServiceDetect, "path to service detection configuration") - flags.BoolVar(&f.listSupportedServices, "list-supported-services", false, "do not detect; only print supported services") - flags.StringSliceVar(&f.forcedUnits, "force-unit", nil, "force detection of a systemd unit (can be repeated)") - flags.StringSliceVar(&f.forcedProcesses, "force-process", nil, "force detection of a running process (can be repeated)") - flags.StringSliceVar(&f.skipServices, "skip-service", nil, "ignore a service, don't recommend hub/datasources (can be repeated)") - flags.StringVar(&f.forcedOSFamily, "force-os-family", "", "override OS.Family: one of linux, freebsd, windows or darwin") - flags.StringVar(&f.forcedOSID, "force-os-id", "", "override OS.ID=[debian | ubuntu | , redhat...]") - flags.StringVar(&f.forcedOSVersion, "force-os-version", "", "override OS.RawVersion (of OS or Linux distribution)") - flags.BoolVar(&f.snubSystemd, "snub-systemd", false, "don't use systemd, even if available") - flags.BoolVar(&f.outYaml, "yaml", false, "output yaml, not json") -} - -func (cli *cliSetup) NewDetectCmd() *cobra.Command { - f := detectFlags{} - - cmd := &cobra.Command{ - Use: "detect", - Short: "detect running services, generate a setup file", - DisableAutoGenTag: true, - RunE: func(_ *cobra.Command, args []string) error { - return cli.detect(f) - }, - } - - f.bind(cmd) - return cmd -} - -func (cli *cliSetup) NewInstallHubCmd() *cobra.Command { - var dryRun bool - - cmd := &cobra.Command{ - Use: "install-hub [setup_file] [flags]", - Short: "install items from a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - return cli.install(cmd.Context(), dryRun, args[0]) - }, - } - - flags := cmd.Flags() - flags.BoolVar(&dryRun, "dry-run", false, "don't install anything; print out what would have been") - - return cmd -} - -func (cli *cliSetup) NewDataSourcesCmd() *cobra.Command { - var toDir string - - cmd := &cobra.Command{ - Use: "datasources [setup_file] [flags]", - Short: "generate datasource (acquisition) configuration from a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - return cli.dataSources(args[0], toDir) - }, - } - - flags := cmd.Flags() - flags.StringVar(&toDir, "to-dir", "", "write the configuration to a directory, in multiple files") - - return cmd -} - -func (cli *cliSetup) NewValidateCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "validate [setup_file]", - Short: "validate a setup file", - Args: cobra.ExactArgs(1), - DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - return cli.validate(args[0]) - }, - } - - return cmd -} - -func (cli *cliSetup) detect(f detectFlags) error { - var ( - detectReader *os.File - err error - ) - - switch f.detectConfigFile { - case "-": - log.Tracef("Reading detection rules from stdin") - - detectReader = os.Stdin - default: - log.Tracef("Reading detection rules: %s", f.detectConfigFile) - - detectReader, err = os.Open(f.detectConfigFile) - if err != nil { - return err - } - } - - if !f.snubSystemd { - _, err := exec.LookPath("systemctl") - if err != nil { - log.Debug("systemctl not available: snubbing systemd") - - f.snubSystemd = true - } - } - - if f.forcedOSFamily == "" && f.forcedOSID != "" { - log.Debug("force-os-id is set: force-os-family defaults to 'linux'") - - f.forcedOSFamily = "linux" - } - - if f.listSupportedServices { - supported, err := setup.ListSupported(detectReader) - if err != nil { - return err - } - - for _, svc := range supported { - fmt.Println(svc) - } - - return nil - } - - opts := setup.DetectOptions{ - ForcedUnits: f.forcedUnits, - ForcedProcesses: f.forcedProcesses, - ForcedOS: setup.ExprOS{ - Family: f.forcedOSFamily, - ID: f.forcedOSID, - RawVersion: f.forcedOSVersion, - }, - SkipServices: f.skipServices, - SnubSystemd: f.snubSystemd, - } - - hubSetup, err := setup.Detect(detectReader, opts) - if err != nil { - return fmt.Errorf("detecting services: %w", err) - } - - setup, err := setupAsString(hubSetup, f.outYaml) - if err != nil { - return err - } - - fmt.Println(setup) - - return nil -} - -func setupAsString(cs setup.Setup, outYaml bool) (string, error) { - var ( - ret []byte - err error - ) - - wrap := func(err error) error { - return fmt.Errorf("while marshaling setup: %w", err) - } - - indentLevel := 2 - buf := &bytes.Buffer{} - enc := yaml.NewEncoder(buf) - enc.SetIndent(indentLevel) - - if err = enc.Encode(cs); err != nil { - return "", wrap(err) - } - - if err = enc.Close(); err != nil { - return "", wrap(err) - } - - ret = buf.Bytes() - - if !outYaml { - // take a general approach to output json, so we avoid the - // double tags in the structures and can use go-yaml features - // missing from the json package - ret, err = goccyyaml.YAMLToJSON(ret) - if err != nil { - return "", wrap(err) - } - } - - return string(ret), nil -} - -func (cli *cliSetup) dataSources(fromFile string, toDir string) error { - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading setup file: %w", err) - } - - output, err := setup.DataSources(input, toDir) - if err != nil { - return err - } - - if toDir == "" { - fmt.Println(output) - } - - return nil -} - -func (cli *cliSetup) install(ctx context.Context, dryRun bool, fromFile string) error { - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading file %s: %w", fromFile, err) - } - - cfg := cli.cfg() - - hub, err := require.Hub(cfg, require.RemoteHub(ctx, cfg), log.StandardLogger()) - if err != nil { - return err - } - - return setup.InstallHubItems(ctx, hub, input, dryRun) -} - -func (cli *cliSetup) validate(fromFile string) error { - input, err := os.ReadFile(fromFile) - if err != nil { - return fmt.Errorf("while reading stdin: %w", err) - } - - if err = setup.Validate(input); err != nil { - fmt.Printf("%v\n", err) - return errors.New("invalid setup file") +func (cli *cliRoot) addSetup(cmd *cobra.Command) { + if fflag.CscliSetup.IsEnabled() { + cmd.AddCommand(clisetup.New(cli.cfg).NewCommand()) } - return nil + component.Register("cscli_setup") } diff --git a/cmd/crowdsec-cli/setup_stub.go b/cmd/crowdsec-cli/setup_stub.go new file mode 100644 index 00000000000..e001f93c797 --- /dev/null +++ b/cmd/crowdsec-cli/setup_stub.go @@ -0,0 +1,9 @@ +//go:build no_cscli_setup +package main + +import ( + "github.com/spf13/cobra" +) + +func (cli *cliRoot) addSetup(_ *cobra.Command) { +} diff --git a/cmd/crowdsec-cli/utils.go b/cmd/crowdsec-cli/utils.go deleted file mode 100644 index f6c32094958..00000000000 --- a/cmd/crowdsec-cli/utils.go +++ /dev/null @@ -1,63 +0,0 @@ -package main - -import ( - "fmt" - "net" - "strings" - - "github.com/crowdsecurity/crowdsec/pkg/types" -) - -func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value *string) error { - /*if a range is provided, change the scope*/ - if *ipRange != "" { - _, _, err := net.ParseCIDR(*ipRange) - if err != nil { - return fmt.Errorf("%s isn't a valid range", *ipRange) - } - } - - if *ip != "" { - ipRepr := net.ParseIP(*ip) - if ipRepr == nil { - return fmt.Errorf("%s isn't a valid ip", *ip) - } - } - - // avoid confusion on scope (ip vs Ip and range vs Range) - switch strings.ToLower(*scope) { - case "ip": - *scope = types.Ip - case "range": - *scope = types.Range - case "country": - *scope = types.Country - case "as": - *scope = types.AS - } - - return nil -} - -func removeFromSlice(val string, slice []string) []string { - var i int - var value string - - valueFound := false - - // get the index - for i, value = range slice { - if value == val { - valueFound = true - break - } - } - - if valueFound { - slice[i] = slice[len(slice)-1] - slice[len(slice)-1] = "" - slice = slice[:len(slice)-1] - } - - return slice -} diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index c57b8d87cff..ccb0acf0209 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "runtime" @@ -14,12 +15,12 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) -func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { +func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) { if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil { log.Info("push and pull to Central API disabled") } - apiServer, err := apiserver.NewServer(cConfig.API.Server) + apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server) if err != nil { return nil, fmt.Errorf("unable to run local API: %w", err) } @@ -39,7 +40,7 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return nil, errors.New("plugins are enabled, but config_paths.plugin_dir is not defined") } - err = pluginBroker.Init(cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) + err = pluginBroker.Init(ctx, cConfig.PluginConfig, cConfig.API.Server.Profiles, cConfig.ConfigPaths) if err != nil { return nil, fmt.Errorf("unable to run plugin broker: %w", err) } @@ -58,11 +59,14 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { func serveAPIServer(apiServer *apiserver.APIServer) { apiReady := make(chan bool, 1) + apiTomb.Go(func() error { defer trace.CatchPanic("crowdsec/serveAPIServer") + go func() { defer trace.CatchPanic("crowdsec/runAPIServer") log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) + if err := apiServer.Run(apiReady); err != nil { log.Fatal(err) } @@ -76,6 +80,7 @@ func serveAPIServer(apiServer *apiserver.APIServer) { <-apiTomb.Dying() // lock until go routine is dying pluginTomb.Kill(nil) log.Infof("serve: shutting down api server") + return apiServer.Shutdown() }) <-apiReady @@ -87,5 +92,6 @@ func hasPlugins(profiles []*csconfig.ProfileCfg) bool { return true } } + return false } diff --git a/cmd/crowdsec/appsec.go b/cmd/crowdsec/appsec.go new file mode 100644 index 00000000000..cb02b137dcd --- /dev/null +++ b/cmd/crowdsec/appsec.go @@ -0,0 +1,18 @@ +// +build !no_datasource_appsec + +package main + +import ( + "fmt" + + "github.com/crowdsecurity/crowdsec/pkg/appsec" + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func LoadAppsecRules(hub *cwhub.Hub) error { + if err := appsec.LoadAppsecRules(hub); err != nil { + return fmt.Errorf("while loading appsec rules: %w", err) + } + + return nil +} diff --git a/cmd/crowdsec/appsec_stub.go b/cmd/crowdsec/appsec_stub.go new file mode 100644 index 00000000000..4a65b32a9ad --- /dev/null +++ b/cmd/crowdsec/appsec_stub.go @@ -0,0 +1,11 @@ +//go:build no_datasource_appsec + +package main + +import ( + "github.com/crowdsecurity/crowdsec/pkg/cwhub" +) + +func LoadAppsecRules(hub *cwhub.Hub) error { + return nil +} diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 5aafc6b0dfe..db93992605d 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -14,7 +14,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/acquisition" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/alertcontext" - "github.com/crowdsecurity/crowdsec/pkg/appsec" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" @@ -43,12 +42,13 @@ func initCrowdsec(cConfig *csconfig.Config, hub *cwhub.Hub) (*parser.Parsers, [] return nil, nil, fmt.Errorf("while loading parsers: %w", err) } - if err := LoadBuckets(cConfig, hub); err != nil { + if err = LoadBuckets(cConfig, hub); err != nil { return nil, nil, fmt.Errorf("while loading scenarios: %w", err) } - if err := appsec.LoadAppsecRules(hub); err != nil { - return nil, nil, fmt.Errorf("while loading appsec rules: %w", err) + // can be nerfed by a build flag + if err = LoadAppsecRules(hub); err != nil { + return nil, nil, err } datasources, err := LoadAcquisition(cConfig) @@ -82,6 +82,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H return nil }) } + parserWg.Done() return nil @@ -108,13 +109,14 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H return runPour(inputEventChan, holders, buckets, cConfig) }) } + bucketWg.Done() return nil }) bucketWg.Wait() - apiClient, err := AuthenticatedLAPIClient(*cConfig.API.Client.Credentials, hub) + apiClient, err := AuthenticatedLAPIClient(context.TODO(), *cConfig.API.Client.Credentials, hub) if err != nil { return err } @@ -134,6 +136,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H return runOutput(inputEventChan, outputEventChan, buckets, *parsers.Povfwctx, parsers.Povfwnodes, apiClient) }) } + outputWg.Done() return nil @@ -166,7 +169,7 @@ func runCrowdsec(cConfig *csconfig.Config, parsers *parser.Parsers, hub *cwhub.H log.Info("Starting processing data") - if err := acquisition.StartAcquisition(dataSources, inputLineChan, &acquisTomb); err != nil { + if err := acquisition.StartAcquisition(context.TODO(), dataSources, inputLineChan, &acquisTomb); err != nil { return fmt.Errorf("starting acquisition error: %w", err) } diff --git a/cmd/crowdsec/lapiclient.go b/cmd/crowdsec/lapiclient.go index 6cc0fba9515..6656ba6b4c2 100644 --- a/cmd/crowdsec/lapiclient.go +++ b/cmd/crowdsec/lapiclient.go @@ -11,25 +11,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { - scenarios, err := hub.GetInstalledNamesByType(cwhub.SCENARIOS) - if err != nil { - return nil, fmt.Errorf("loading list of installed hub scenarios: %w", err) - } - - appsecRules, err := hub.GetInstalledNamesByType(cwhub.APPSEC_RULES) - if err != nil { - return nil, fmt.Errorf("loading list of installed hub appsec rules: %w", err) - } - - installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules)) - installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...) - installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...) - +func AuthenticatedLAPIClient(ctx context.Context, credentials csconfig.ApiCredentialsCfg, hub *cwhub.Hub) (*apiclient.ApiClient, error) { apiURL, err := url.Parse(credentials.URL) if err != nil { return nil, fmt.Errorf("parsing api url ('%s'): %w", credentials.URL, err) @@ -42,38 +27,27 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub. password := strfmt.Password(credentials.Password) + itemsForAPI := hub.GetInstalledListForAPI() + client, err := apiclient.NewClient(&apiclient.Config{ MachineID: credentials.Login, Password: password, - Scenarios: installedScenariosAndAppsecRules, - UserAgent: cwversion.UserAgent(), + Scenarios: itemsForAPI, URL: apiURL, PapiURL: papiURL, VersionPrefix: "v1", - UpdateScenario: func() ([]string, error) { - scenarios, err := hub.GetInstalledNamesByType(cwhub.SCENARIOS) - if err != nil { - return nil, err - } - appsecRules, err := hub.GetInstalledNamesByType(cwhub.APPSEC_RULES) - if err != nil { - return nil, err - } - ret := make([]string, 0, len(scenarios)+len(appsecRules)) - ret = append(ret, scenarios...) - ret = append(ret, appsecRules...) - - return ret, nil + UpdateScenario: func(_ context.Context) ([]string, error) { + return itemsForAPI, nil }, }) if err != nil { return nil, fmt.Errorf("new client api: %w", err) } - authResp, _, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := client.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &credentials.Login, Password: &password, - Scenarios: installedScenariosAndAppsecRules, + Scenarios: itemsForAPI, }) if err != nil { return nil, fmt.Errorf("authenticate watcher (%s): %w", credentials.Login, err) diff --git a/cmd/crowdsec/lpmetrics.go b/cmd/crowdsec/lpmetrics.go index 0fd27054071..24842851294 100644 --- a/cmd/crowdsec/lpmetrics.go +++ b/cmd/crowdsec/lpmetrics.go @@ -7,7 +7,6 @@ import ( "time" "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -46,10 +45,8 @@ func getHubState(hub *cwhub.Hub) models.HubItems { for _, itemType := range cwhub.ItemTypes { ret[itemType] = []models.HubItem{} - items, _ := hub.GetInstalledItemsByType(itemType) - cwhub.SortItemSlice(items) - for _, item := range items { + for _, item := range hub.GetInstalledByType(itemType, true) { status := "official" if item.State.IsLocal() { status = "custom" @@ -90,7 +87,8 @@ func newStaticMetrics(consoleOptions []string, datasources []acquisition.DataSou } func NewMetricsProvider(apic *apiclient.ApiClient, interval time.Duration, logger *logrus.Entry, - consoleOptions []string, datasources []acquisition.DataSource, hub *cwhub.Hub) *MetricsProvider { + consoleOptions []string, datasources []acquisition.DataSource, hub *cwhub.Hub, +) *MetricsProvider { return &MetricsProvider{ apic: apic, interval: interval, diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 18416e044e7..6d8ca24c335 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -91,10 +91,8 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { files []string ) - for _, hubScenarioItem := range hub.GetItemMap(cwhub.SCENARIOS) { - if hubScenarioItem.State.Installed { - files = append(files, hubScenarioItem.State.LocalPath) - } + for _, hubScenarioItem := range hub.GetInstalledByType(cwhub.SCENARIOS, false) { + files = append(files, hubScenarioItem.State.LocalPath) } buckets = leakybucket.NewBuckets() diff --git a/cmd/crowdsec/metrics.go b/cmd/crowdsec/metrics.go index d3c6e172091..ff280fc3512 100644 --- a/cmd/crowdsec/metrics.go +++ b/cmd/crowdsec/metrics.go @@ -118,7 +118,9 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha return } - decisions, err := dbClient.QueryDecisionCountByScenario() + ctx := r.Context() + + decisions, err := dbClient.QueryDecisionCountByScenario(ctx) if err != nil { log.Errorf("Error querying decisions for metrics: %v", err) next.ServeHTTP(w, r) @@ -138,7 +140,7 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha "include_capi": {"false"}, } - alerts, err := dbClient.AlertsCountPerScenario(alertsFilter) + alerts, err := dbClient.AlertsCountPerScenario(ctx, alertsFilter) if err != nil { log.Errorf("Error querying alerts for metrics: %v", err) next.ServeHTTP(w, r) diff --git a/cmd/crowdsec/pour.go b/cmd/crowdsec/pour.go index 388c7a6c1b3..2fc7d7e42c9 100644 --- a/cmd/crowdsec/pour.go +++ b/cmd/crowdsec/pour.go @@ -32,7 +32,7 @@ func runPour(input chan types.Event, holders []leaky.BucketFactory, buckets *lea if parsed.MarshaledTime != "" { z := &time.Time{} if err := z.UnmarshalText([]byte(parsed.MarshaledTime)); err != nil { - log.Warningf("Failed to unmarshal time from event '%s' : %s", parsed.MarshaledTime, err) + log.Warningf("Failed to parse time from event '%s' : %s", parsed.MarshaledTime, err) } else { log.Warning("Starting buckets garbage collection ...") @@ -59,9 +59,9 @@ func runPour(input chan types.Event, holders []leaky.BucketFactory, buckets *lea globalBucketPourKo.Inc() } - if len(parsed.MarshaledTime) != 0 { + if parsed.MarshaledTime != "" { if err := lastProcessedItem.UnmarshalText([]byte(parsed.MarshaledTime)); err != nil { - log.Warningf("failed to unmarshal time from event : %s", err) + log.Warningf("failed to parse time from event : %s", err) } } } diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index f1a658e9512..14602c425fe 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -52,6 +52,8 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { func reloadHandler(sig os.Signal) (*csconfig.Config, error) { var tmpFile string + ctx := context.TODO() + // re-initialize tombs acquisTomb = tomb.Tomb{} parsersTomb = tomb.Tomb{} @@ -74,7 +76,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return nil, fmt.Errorf("unable to init api server: %w", err) } @@ -88,7 +90,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) { return nil, err } - if err := hub.Load(); err != nil { + if err = hub.Load(); err != nil { return nil, err } @@ -374,7 +376,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error { cConfig.API.Server.OnlineClient = nil } - apiServer, err := initAPIServer(cConfig) + apiServer, err := initAPIServer(ctx, cConfig) if err != nil { return fmt.Errorf("api server init: %w", err) } @@ -390,7 +392,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error { return err } - if err := hub.Load(); err != nil { + if err = hub.Load(); err != nil { return err } diff --git a/cmd/notification-dummy/main.go b/cmd/notification-dummy/main.go index 024a1eb81ba..7fbb10d4fca 100644 --- a/cmd/notification-dummy/main.go +++ b/cmd/notification-dummy/main.go @@ -9,6 +9,7 @@ import ( plugin "github.com/hashicorp/go-plugin" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -19,6 +20,7 @@ type PluginConfig struct { } type DummyPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -84,7 +86,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "dummy": &protobufs.NotifierPlugin{ + "dummy": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-email/main.go b/cmd/notification-email/main.go index 3b535ae7ffa..5fc02cdd1d7 100644 --- a/cmd/notification-email/main.go +++ b/cmd/notification-email/main.go @@ -12,6 +12,7 @@ import ( mail "github.com/xhit/go-simple-mail/v2" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -55,6 +56,7 @@ type PluginConfig struct { } type EmailPlugin struct { + protobufs.UnimplementedNotifierServer ConfigByName map[string]PluginConfig } @@ -81,7 +83,7 @@ func (n *EmailPlugin) Configure(ctx context.Context, config *protobufs.Config) ( return nil, errors.New("SMTP host is not set") } - if d.ReceiverEmails == nil || len(d.ReceiverEmails) == 0 { + if len(d.ReceiverEmails) == 0 { return nil, errors.New("receiver emails are not set") } @@ -170,7 +172,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "email": &protobufs.NotifierPlugin{ + "email": &csplugin.NotifierPlugin{ Impl: &EmailPlugin{ConfigByName: make(map[string]PluginConfig)}, }, }, diff --git a/cmd/notification-file/main.go b/cmd/notification-file/main.go index 7fc529cff41..a4dbb8ee5db 100644 --- a/cmd/notification-file/main.go +++ b/cmd/notification-file/main.go @@ -15,6 +15,7 @@ import ( plugin "github.com/hashicorp/go-plugin" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -52,6 +53,7 @@ type LogRotate struct { } type FilePlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -210,7 +212,7 @@ func (s *FilePlugin) Configure(ctx context.Context, config *protobufs.Config) (* d := PluginConfig{} err := yaml.Unmarshal(config.Config, &d) if err != nil { - logger.Error("Failed to unmarshal config", "error", err) + logger.Error("Failed to parse config", "error", err) return &protobufs.Empty{}, err } FileWriteMutex = &sync.Mutex{} @@ -241,7 +243,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "file": &protobufs.NotifierPlugin{ + "file": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-http/main.go b/cmd/notification-http/main.go index 6b11a78ef86..3f84984315b 100644 --- a/cmd/notification-http/main.go +++ b/cmd/notification-http/main.go @@ -16,6 +16,7 @@ import ( plugin "github.com/hashicorp/go-plugin" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -34,6 +35,7 @@ type PluginConfig struct { } type HTTPPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -190,7 +192,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "http": &protobufs.NotifierPlugin{ + "http": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-sentinel/main.go b/cmd/notification-sentinel/main.go index a29e941f80c..0293d45b0a4 100644 --- a/cmd/notification-sentinel/main.go +++ b/cmd/notification-sentinel/main.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/go-plugin" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -27,6 +28,7 @@ type PluginConfig struct { } type SentinelPlugin struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig } @@ -122,7 +124,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "sentinel": &protobufs.NotifierPlugin{ + "sentinel": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/cmd/notification-slack/main.go b/cmd/notification-slack/main.go index fba1b33e334..34c7c0df361 100644 --- a/cmd/notification-slack/main.go +++ b/cmd/notification-slack/main.go @@ -10,6 +10,7 @@ import ( "github.com/slack-go/slack" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -23,6 +24,7 @@ type PluginConfig struct { LogLevel *string `yaml:"log_level"` } type Notify struct { + protobufs.UnimplementedNotifierServer ConfigByName map[string]PluginConfig } @@ -84,7 +86,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "slack": &protobufs.NotifierPlugin{ + "slack": &csplugin.NotifierPlugin{ Impl: &Notify{ConfigByName: make(map[string]PluginConfig)}, }, }, diff --git a/cmd/notification-splunk/main.go b/cmd/notification-splunk/main.go index 26190c58a89..e18f416c14a 100644 --- a/cmd/notification-splunk/main.go +++ b/cmd/notification-splunk/main.go @@ -14,6 +14,7 @@ import ( plugin "github.com/hashicorp/go-plugin" "gopkg.in/yaml.v3" + "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) @@ -32,6 +33,7 @@ type PluginConfig struct { } type Splunk struct { + protobufs.UnimplementedNotifierServer PluginConfigByName map[string]PluginConfig Client http.Client } @@ -117,7 +119,7 @@ func main() { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshake, Plugins: map[string]plugin.Plugin{ - "splunk": &protobufs.NotifierPlugin{ + "splunk": &csplugin.NotifierPlugin{ Impl: sp, }, }, diff --git a/debian/rules b/debian/rules index c11771282ea..5b8d6fc51f8 100755 --- a/debian/rules +++ b/debian/rules @@ -13,7 +13,7 @@ override_dh_auto_build: override_dh_auto_install: # just use the prebuilt binaries, otherwise: - # make build BUILD_RE_WASM=0 BUILD_STATIC=1 + # make build BUILD_STATIC=1 mkdir -p debian/crowdsec/usr/bin mkdir -p debian/crowdsec/etc/crowdsec diff --git a/docker/test/Pipfile.lock b/docker/test/Pipfile.lock index 2cb587b6b88..99184d9f2a2 100644 --- a/docker/test/Pipfile.lock +++ b/docker/test/Pipfile.lock @@ -18,69 +18,84 @@ "default": { "certifi": { "hashes": [ - "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b", - "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90" + "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", + "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9" ], "markers": "python_version >= '3.6'", - "version": "==2024.7.4" + "version": "==2024.8.30" }, "cffi": { "hashes": [ - "sha256:0c9ef6ff37e974b73c25eecc13952c55bceed9112be2d9d938ded8e856138bcc", - "sha256:131fd094d1065b19540c3d72594260f118b231090295d8c34e19a7bbcf2e860a", - "sha256:1b8ebc27c014c59692bb2664c7d13ce7a6e9a629be20e54e7271fa696ff2b417", - "sha256:2c56b361916f390cd758a57f2e16233eb4f64bcbeee88a4881ea90fca14dc6ab", - "sha256:2d92b25dbf6cae33f65005baf472d2c245c050b1ce709cc4588cdcdd5495b520", - "sha256:31d13b0f99e0836b7ff893d37af07366ebc90b678b6664c955b54561fc36ef36", - "sha256:32c68ef735dbe5857c810328cb2481e24722a59a2003018885514d4c09af9743", - "sha256:3686dffb02459559c74dd3d81748269ffb0eb027c39a6fc99502de37d501faa8", - "sha256:582215a0e9adbe0e379761260553ba11c58943e4bbe9c36430c4ca6ac74b15ed", - "sha256:5b50bf3f55561dac5438f8e70bfcdfd74543fd60df5fa5f62d94e5867deca684", - "sha256:5bf44d66cdf9e893637896c7faa22298baebcd18d1ddb6d2626a6e39793a1d56", - "sha256:6602bc8dc6f3a9e02b6c22c4fc1e47aa50f8f8e6d3f78a5e16ac33ef5fefa324", - "sha256:673739cb539f8cdaa07d92d02efa93c9ccf87e345b9a0b556e3ecc666718468d", - "sha256:68678abf380b42ce21a5f2abde8efee05c114c2fdb2e9eef2efdb0257fba1235", - "sha256:68e7c44931cc171c54ccb702482e9fc723192e88d25a0e133edd7aff8fcd1f6e", - "sha256:6b3d6606d369fc1da4fd8c357d026317fbb9c9b75d36dc16e90e84c26854b088", - "sha256:748dcd1e3d3d7cd5443ef03ce8685043294ad6bd7c02a38d1bd367cfd968e000", - "sha256:7651c50c8c5ef7bdb41108b7b8c5a83013bfaa8a935590c5d74627c047a583c7", - "sha256:7b78010e7b97fef4bee1e896df8a4bbb6712b7f05b7ef630f9d1da00f6444d2e", - "sha256:7e61e3e4fa664a8588aa25c883eab612a188c725755afff6289454d6362b9673", - "sha256:80876338e19c951fdfed6198e70bc88f1c9758b94578d5a7c4c91a87af3cf31c", - "sha256:8895613bcc094d4a1b2dbe179d88d7fb4a15cee43c052e8885783fac397d91fe", - "sha256:88e2b3c14bdb32e440be531ade29d3c50a1a59cd4e51b1dd8b0865c54ea5d2e2", - "sha256:8f8e709127c6c77446a8c0a8c8bf3c8ee706a06cd44b1e827c3e6a2ee6b8c098", - "sha256:9cb4a35b3642fc5c005a6755a5d17c6c8b6bcb6981baf81cea8bfbc8903e8ba8", - "sha256:9f90389693731ff1f659e55c7d1640e2ec43ff725cc61b04b2f9c6d8d017df6a", - "sha256:a09582f178759ee8128d9270cd1344154fd473bb77d94ce0aeb2a93ebf0feaf0", - "sha256:a6a14b17d7e17fa0d207ac08642c8820f84f25ce17a442fd15e27ea18d67c59b", - "sha256:a72e8961a86d19bdb45851d8f1f08b041ea37d2bd8d4fd19903bc3083d80c896", - "sha256:abd808f9c129ba2beda4cfc53bde801e5bcf9d6e0f22f095e45327c038bfe68e", - "sha256:ac0f5edd2360eea2f1daa9e26a41db02dd4b0451b48f7c318e217ee092a213e9", - "sha256:b29ebffcf550f9da55bec9e02ad430c992a87e5f512cd63388abb76f1036d8d2", - "sha256:b2ca4e77f9f47c55c194982e10f058db063937845bb2b7a86c84a6cfe0aefa8b", - "sha256:b7be2d771cdba2942e13215c4e340bfd76398e9227ad10402a8767ab1865d2e6", - "sha256:b84834d0cf97e7d27dd5b7f3aca7b6e9263c56308ab9dc8aae9784abb774d404", - "sha256:b86851a328eedc692acf81fb05444bdf1891747c25af7529e39ddafaf68a4f3f", - "sha256:bcb3ef43e58665bbda2fb198698fcae6776483e0c4a631aa5647806c25e02cc0", - "sha256:c0f31130ebc2d37cdd8e44605fb5fa7ad59049298b3f745c74fa74c62fbfcfc4", - "sha256:c6a164aa47843fb1b01e941d385aab7215563bb8816d80ff3a363a9f8448a8dc", - "sha256:d8a9d3ebe49f084ad71f9269834ceccbf398253c9fac910c4fd7053ff1386936", - "sha256:db8e577c19c0fda0beb7e0d4e09e0ba74b1e4c092e0e40bfa12fe05b6f6d75ba", - "sha256:dc9b18bf40cc75f66f40a7379f6a9513244fe33c0e8aa72e2d56b0196a7ef872", - "sha256:e09f3ff613345df5e8c3667da1d918f9149bd623cd9070c983c013792a9a62eb", - "sha256:e4108df7fe9b707191e55f33efbcb2d81928e10cea45527879a4749cbe472614", - "sha256:e6024675e67af929088fda399b2094574609396b1decb609c55fa58b028a32a1", - "sha256:e70f54f1796669ef691ca07d046cd81a29cb4deb1e5f942003f401c0c4a2695d", - "sha256:e715596e683d2ce000574bae5d07bd522c781a822866c20495e52520564f0969", - "sha256:e760191dd42581e023a68b758769e2da259b5d52e3103c6060ddc02c9edb8d7b", - "sha256:ed86a35631f7bfbb28e108dd96773b9d5a6ce4811cf6ea468bb6a359b256b1e4", - "sha256:ee07e47c12890ef248766a6e55bd38ebfb2bb8edd4142d56db91b21ea68b7627", - "sha256:fa3a0128b152627161ce47201262d3140edb5a5c3da88d73a1b790a959126956", - "sha256:fcc8eb6d5902bb1cf6dc4f187ee3ea80a1eba0a89aba40a5cb20a5087d961357" + "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", + "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2", + "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1", + "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", + "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", + "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", + "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8", + "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36", + "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", + "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", + "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc", + "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", + "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", + "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", + "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", + "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", + "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", + "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", + "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", + "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b", + "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", + "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", + "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c", + "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", + "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", + "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", + "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8", + "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1", + "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", + "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", + "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", + "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595", + "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0", + "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", + "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", + "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", + "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", + "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", + "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", + "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16", + "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", + "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e", + "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", + "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964", + "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", + "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576", + "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", + "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3", + "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662", + "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", + "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", + "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", + "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", + "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", + "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", + "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", + "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d", + "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9", + "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7", + "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", + "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a", + "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e", + "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", + "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", + "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", + "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87", + "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b" ], "markers": "platform_python_implementation != 'PyPy'", - "version": "==1.16.0" + "version": "==1.17.1" }, "charset-normalizer": { "hashes": [ @@ -180,36 +195,36 @@ }, "cryptography": { "hashes": [ - "sha256:0663585d02f76929792470451a5ba64424acc3cd5227b03921dab0e2f27b1709", - "sha256:08a24a7070b2b6804c1940ff0f910ff728932a9d0e80e7814234269f9d46d069", - "sha256:232ce02943a579095a339ac4b390fbbe97f5b5d5d107f8a08260ea2768be8cc2", - "sha256:2905ccf93a8a2a416f3ec01b1a7911c3fe4073ef35640e7ee5296754e30b762b", - "sha256:299d3da8e00b7e2b54bb02ef58d73cd5f55fb31f33ebbf33bd00d9aa6807df7e", - "sha256:2c6d112bf61c5ef44042c253e4859b3cbbb50df2f78fa8fae6747a7814484a70", - "sha256:31e44a986ceccec3d0498e16f3d27b2ee5fdf69ce2ab89b52eaad1d2f33d8778", - "sha256:3d9a1eca329405219b605fac09ecfc09ac09e595d6def650a437523fcd08dd22", - "sha256:3dcdedae5c7710b9f97ac6bba7e1052b95c7083c9d0e9df96e02a1932e777895", - "sha256:47ca71115e545954e6c1d207dd13461ab81f4eccfcb1345eac874828b5e3eaaf", - "sha256:4a997df8c1c2aae1e1e5ac49c2e4f610ad037fc5a3aadc7b64e39dea42249431", - "sha256:51956cf8730665e2bdf8ddb8da0056f699c1a5715648c1b0144670c1ba00b48f", - "sha256:5bcb8a5620008a8034d39bce21dc3e23735dfdb6a33a06974739bfa04f853947", - "sha256:64c3f16e2a4fc51c0d06af28441881f98c5d91009b8caaff40cf3548089e9c74", - "sha256:6e2b11c55d260d03a8cf29ac9b5e0608d35f08077d8c087be96287f43af3ccdc", - "sha256:7b3f5fe74a5ca32d4d0f302ffe6680fcc5c28f8ef0dc0ae8f40c0f3a1b4fca66", - "sha256:844b6d608374e7d08f4f6e6f9f7b951f9256db41421917dfb2d003dde4cd6b66", - "sha256:9a8d6802e0825767476f62aafed40532bd435e8a5f7d23bd8b4f5fd04cc80ecf", - "sha256:aae4d918f6b180a8ab8bf6511a419473d107df4dbb4225c7b48c5c9602c38c7f", - "sha256:ac1955ce000cb29ab40def14fd1bbfa7af2017cca696ee696925615cafd0dce5", - "sha256:b88075ada2d51aa9f18283532c9f60e72170041bba88d7f37e49cbb10275299e", - "sha256:cb013933d4c127349b3948aa8aaf2f12c0353ad0eccd715ca789c8a0f671646f", - "sha256:cc70b4b581f28d0a254d006f26949245e3657d40d8857066c2ae22a61222ef55", - "sha256:e9c5266c432a1e23738d178e51c2c7a5e2ddf790f248be939448c0ba2021f9d1", - "sha256:ea9e57f8ea880eeea38ab5abf9fbe39f923544d7884228ec67d666abd60f5a47", - "sha256:ee0c405832ade84d4de74b9029bedb7b31200600fa524d218fc29bfa371e97f5", - "sha256:fdcb265de28585de5b859ae13e3846a8e805268a823a12a4da2597f1f5afc9f0" + "sha256:014f58110f53237ace6a408b5beb6c427b64e084eb451ef25a28308270086494", + "sha256:1bbcce1a551e262dfbafb6e6252f1ae36a248e615ca44ba302df077a846a8806", + "sha256:203e92a75716d8cfb491dc47c79e17d0d9207ccffcbcb35f598fbe463ae3444d", + "sha256:27e613d7077ac613e399270253259d9d53872aaf657471473ebfc9a52935c062", + "sha256:2bd51274dcd59f09dd952afb696bf9c61a7a49dfc764c04dd33ef7a6b502a1e2", + "sha256:38926c50cff6f533f8a2dae3d7f19541432610d114a70808f0926d5aaa7121e4", + "sha256:511f4273808ab590912a93ddb4e3914dfd8a388fed883361b02dea3791f292e1", + "sha256:58d4e9129985185a06d849aa6df265bdd5a74ca6e1b736a77959b498e0505b85", + "sha256:5b43d1ea6b378b54a1dc99dd8a2b5be47658fe9a7ce0a58ff0b55f4b43ef2b84", + "sha256:61ec41068b7b74268fa86e3e9e12b9f0c21fcf65434571dbb13d954bceb08042", + "sha256:666ae11966643886c2987b3b721899d250855718d6d9ce41b521252a17985f4d", + "sha256:68aaecc4178e90719e95298515979814bda0cbada1256a4485414860bd7ab962", + "sha256:7c05650fe8023c5ed0d46793d4b7d7e6cd9c04e68eabe5b0aeea836e37bdcec2", + "sha256:80eda8b3e173f0f247f711eef62be51b599b5d425c429b5d4ca6a05e9e856baa", + "sha256:8385d98f6a3bf8bb2d65a73e17ed87a3ba84f6991c155691c51112075f9ffc5d", + "sha256:88cce104c36870d70c49c7c8fd22885875d950d9ee6ab54df2745f83ba0dc365", + "sha256:9d3cdb25fa98afdd3d0892d132b8d7139e2c087da1712041f6b762e4f807cc96", + "sha256:a575913fb06e05e6b4b814d7f7468c2c660e8bb16d8d5a1faf9b33ccc569dd47", + "sha256:ac119bb76b9faa00f48128b7f5679e1d8d437365c5d26f1c2c3f0da4ce1b553d", + "sha256:c1332724be35d23a854994ff0b66530119500b6053d0bd3363265f7e5e77288d", + "sha256:d03a475165f3134f773d1388aeb19c2d25ba88b6a9733c5c590b9ff7bbfa2e0c", + "sha256:d75601ad10b059ec832e78823b348bfa1a59f6b8d545db3a24fd44362a1564cb", + "sha256:de41fd81a41e53267cb020bb3a7212861da53a7d39f863585d13ea11049cf277", + "sha256:e710bf40870f4db63c3d7d929aa9e09e4e7ee219e703f949ec4073b4294f6172", + "sha256:ea25acb556320250756e53f9e20a4177515f012c9eaea17eb7587a8c4d8ae034", + "sha256:f98bf604c82c416bc829e490c700ca1553eafdf2912a91e23a79d97d9801372a", + "sha256:fba1007b3ef89946dbbb515aeeb41e30203b004f0b4b00e5e16078b518563289" ], "markers": "python_version >= '3.7'", - "version": "==43.0.0" + "version": "==43.0.1" }, "docker": { "hashes": [ @@ -229,11 +244,11 @@ }, "idna": { "hashes": [ - "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc", - "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0" + "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", + "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3" ], - "markers": "python_version >= '3.5'", - "version": "==3.7" + "markers": "python_version >= '3.6'", + "version": "==3.10" }, "iniconfig": { "hashes": [ @@ -292,11 +307,11 @@ }, "pytest": { "hashes": [ - "sha256:7e8e5c5abd6e93cb1cc151f23e57adc31fcf8cfd2a3ff2da63e23f732de35db6", - "sha256:e9600ccf4f563976e2c99fa02c7624ab938296551f280835ee6516df8bc4ae8c" + "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181", + "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2" ], "markers": "python_version >= '3.8'", - "version": "==8.3.1" + "version": "==8.3.3" }, "pytest-cs": { "git": "https://github.com/crowdsecurity/pytest-cs.git", @@ -337,60 +352,62 @@ }, "pyyaml": { "hashes": [ - "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5", - "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc", - "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df", - "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741", - "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206", - "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27", - "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595", - "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62", - "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98", - "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696", - "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290", - "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9", - "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d", - "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6", - "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867", - "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47", - "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486", - "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6", - "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3", - "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007", - "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938", - "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0", - "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c", - "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735", - "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d", - "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28", - "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4", - "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba", - "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8", - "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef", - "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5", - "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd", - "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3", - "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0", - "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515", - "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c", - "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c", - "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924", - "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34", - "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43", - "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859", - "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673", - "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54", - "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a", - "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b", - "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab", - "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa", - "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c", - "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585", - "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d", - "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f" + "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff", + "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", + "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", + "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e", + "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", + "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", + "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", + "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", + "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", + "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", + "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a", + "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", + "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", + "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8", + "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", + "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19", + "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", + "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a", + "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", + "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", + "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", + "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631", + "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d", + "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", + "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", + "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", + "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", + "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", + "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", + "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706", + "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", + "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", + "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", + "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083", + "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", + "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", + "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", + "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f", + "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725", + "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", + "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab", + "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", + "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725", + "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", + "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5", + "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d", + "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290", + "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", + "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", + "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", + "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", + "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", + "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4" ], - "markers": "python_version >= '3.6'", - "version": "==6.0.1" + "markers": "python_version >= '3.8'", + "version": "==6.0.2" }, "requests": { "hashes": [ @@ -410,11 +427,11 @@ }, "urllib3": { "hashes": [ - "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472", - "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168" + "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", + "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9" ], "markers": "python_version >= '3.8'", - "version": "==2.2.2" + "version": "==2.2.3" } }, "develop": { @@ -435,11 +452,11 @@ }, "executing": { "hashes": [ - "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147", - "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc" + "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf", + "sha256:8ea27ddd260da8150fa5a708269c4a10e76161e2496ec3e587da9e3c0fe4b9ab" ], - "markers": "python_version >= '3.5'", - "version": "==2.0.1" + "markers": "python_version >= '3.8'", + "version": "==2.1.0" }, "gnureadline": { "hashes": [ @@ -485,11 +502,11 @@ }, "ipython": { "hashes": [ - "sha256:1cec0fbba8404af13facebe83d04436a7434c7400e59f47acf467c64abd0956c", - "sha256:e6b347c27bdf9c32ee9d31ae85defc525755a1869f14057e900675b9e8d6e6ff" + "sha256:0d0d15ca1e01faeb868ef56bc7ee5a0de5bd66885735682e8a322ae289a13d1a", + "sha256:530ef1e7bb693724d3cdc37287c80b07ad9b25986c007a53aa1857272dac3f35" ], "markers": "python_version >= '3.11'", - "version": "==8.26.0" + "version": "==8.28.0" }, "jedi": { "hashes": [ @@ -525,11 +542,11 @@ }, "prompt-toolkit": { "hashes": [ - "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10", - "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360" + "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90", + "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e" ], "markers": "python_full_version >= '3.7.0'", - "version": "==3.0.47" + "version": "==3.0.48" }, "ptyprocess": { "hashes": [ diff --git a/go.mod b/go.mod index ec8566db84a..f4bd9379a2d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/crowdsecurity/crowdsec -go 1.22 +go 1.23.3 // Don't use the toolchain directive to avoid uncontrolled downloads during // a build, especially in sandboxed environments (freebsd, gentoo...). @@ -16,12 +16,12 @@ require ( github.com/appleboy/gin-jwt/v2 v2.9.2 github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go v1.52.0 - github.com/beevik/etree v1.3.0 + github.com/beevik/etree v1.4.1 github.com/blackfireio/osinfo v1.0.5 github.com/bluele/gcache v0.0.2 github.com/buger/jsonparser v1.1.1 github.com/c-robinson/iplib v1.0.8 - github.com/cespare/xxhash/v2 v2.2.0 + github.com/cespare/xxhash/v2 v2.3.0 github.com/corazawaf/libinjection-go v0.1.2 github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 @@ -82,12 +82,12 @@ require ( github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 github.com/wasilibs/go-re2 v1.7.0 github.com/xhit/go-simple-mail/v2 v2.16.0 - golang.org/x/crypto v0.22.0 - golang.org/x/mod v0.15.0 + golang.org/x/crypto v0.26.0 + golang.org/x/mod v0.17.0 golang.org/x/sys v0.24.0 - golang.org/x/text v0.14.0 - google.golang.org/grpc v1.56.3 - google.golang.org/protobuf v1.33.0 + golang.org/x/text v0.17.0 + google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.34.2 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 gopkg.in/yaml.v2 v2.4.0 @@ -128,7 +128,7 @@ require ( github.com/go-stack/stack v1.8.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/glog v1.1.0 // indirect + github.com/golang/glog v1.2.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect @@ -201,14 +201,14 @@ require ( go.mongodb.org/mongo-driver v1.9.4 // indirect go.uber.org/atomic v1.10.0 // indirect golang.org/x/arch v0.7.0 // indirect - golang.org/x/net v0.24.0 // indirect - golang.org/x/sync v0.6.0 // indirect - golang.org/x/term v0.19.0 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/term v0.23.0 // indirect golang.org/x/time v0.3.0 // indirect - golang.org/x/tools v0.18.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gotest.tools/v3 v3.5.0 // indirect diff --git a/go.sum b/go.sum index ff73dc56332..b2bd77c9915 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/aws/aws-sdk-go v1.52.0 h1:ptgek/4B2v/ljsjYSEvLQ8LTD+SQyrqhOOWvHc/VGPI github.com/aws/aws-sdk-go v1.52.0/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/beevik/etree v1.3.0 h1:hQTc+pylzIKDb23yYprodCWWTt+ojFfUZyzU09a/hmU= github.com/beevik/etree v1.3.0/go.mod h1:aiPf89g/1k3AShMVAzriilpcE4R/Vuor90y83zVZWFc= +github.com/beevik/etree v1.4.1 h1:PmQJDDYahBGNKDcpdX8uPy1xRCwoCGVUiW669MEirVI= +github.com/beevik/etree v1.4.1/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -74,8 +76,8 @@ github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZF github.com/bytedance/sonic v1.10.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= github.com/c-robinson/iplib v1.0.8 h1:exDRViDyL9UBLcfmlxxkY5odWX5092nPsQIykHXhIn4= github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= @@ -294,8 +296,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/golang/glog v1.1.0 h1:/d3pCKDPWNnvIWe0vVUpNP32qc8U3PDVxySP/y360qE= -github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP3NQ= +github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY= +github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -763,8 +765,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -772,8 +774,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -797,8 +799,8 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -808,8 +810,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -854,8 +856,8 @@ golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q= -golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -868,8 +870,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -893,8 +895,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= -golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -906,14 +908,14 @@ google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCID google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= -google.golang.org/grpc v1.56.3 h1:8I4C0Yq1EjstUzUJzpcRVbuYA2mODtEmpWiQoN/b2nc= -google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 634557021f1..ef5a413b91f 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -1,6 +1,7 @@ package acquisition import ( + "context" "errors" "fmt" "io" @@ -18,19 +19,8 @@ import ( "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" - appsecacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec" - cloudwatchacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/cloudwatch" - dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" - fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" - journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" - kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" - kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" - k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" - lokiacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" - s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" - syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" - wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -50,43 +40,76 @@ func (e *DataSourceUnavailableError) Unwrap() error { // The interface each datasource must implement type DataSource interface { - GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module - GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) - UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime - Configure([]byte, *log.Entry, int) error // Complete the YAML datasource configuration and perform runtime checks. - ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource - GetMode() string // Get the mode (TAIL, CAT or SERVER) - GetName() string // Get the name of the module - OneShotAcquisition(chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) - StreamingAcquisition(chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) - CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) - GetUuid() string // Get the unique identifier of the datasource + GetMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module + GetAggregMetrics() []prometheus.Collector // Returns pointers to metrics that are managed by the module (aggregated mode, limits cardinality) + UnmarshalConfig([]byte) error // Decode and pre-validate the YAML datasource - anything that can be checked before runtime + Configure([]byte, *log.Entry, int) error // Complete the YAML datasource configuration and perform runtime checks. + ConfigureByDSN(string, map[string]string, *log.Entry, string) error // Configure the datasource + GetMode() string // Get the mode (TAIL, CAT or SERVER) + GetName() string // Get the name of the module + OneShotAcquisition(context.Context, chan types.Event, *tomb.Tomb) error // Start one shot acquisition(eg, cat a file) + StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error // Start live acquisition (eg, tail a file) + CanRun() error // Whether the datasource can run or not (eg, journalctl on BSD is a non-sense) + GetUuid() string // Get the unique identifier of the datasource Dump() interface{} } -var AcquisitionSources = map[string]func() DataSource{ - "file": func() DataSource { return &fileacquisition.FileSource{} }, - "journalctl": func() DataSource { return &journalctlacquisition.JournalCtlSource{} }, - "cloudwatch": func() DataSource { return &cloudwatchacquisition.CloudwatchSource{} }, - "syslog": func() DataSource { return &syslogacquisition.SyslogSource{} }, - "docker": func() DataSource { return &dockeracquisition.DockerSource{} }, - "kinesis": func() DataSource { return &kinesisacquisition.KinesisSource{} }, - "wineventlog": func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }, - "kafka": func() DataSource { return &kafkaacquisition.KafkaSource{} }, - "k8s-audit": func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }, - "loki": func() DataSource { return &lokiacquisition.LokiSource{} }, - "s3": func() DataSource { return &s3acquisition.S3Source{} }, - "appsec": func() DataSource { return &appsecacquisition.AppsecSource{} }, +var ( + // We declare everything here so we can tell if they are unsupported, or excluded from the build + AcquisitionSources = map[string]func() DataSource{} + transformRuntimes = map[string]*vm.Program{} +) + +func GetDataSourceIface(dataSourceType string) (DataSource, error) { + source, registered := AcquisitionSources[dataSourceType] + if registered { + return source(), nil + } + + built, known := component.Built["datasource_"+dataSourceType] + + if !known { + return nil, fmt.Errorf("unknown data source %s", dataSourceType) + } + + if built { + panic("datasource " + dataSourceType + " is built but not registered") + } + + return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", dataSourceType) } -var transformRuntimes = map[string]*vm.Program{} +// registerDataSource registers a datasource in the AcquisitionSources map. +// It must be called in the init() function of the datasource package, and the datasource name +// must be declared with a nil value in the map, to allow for conditional compilation. +func registerDataSource(dataSourceType string, dsGetter func() DataSource) { + component.Register("datasource_" + dataSourceType) -func GetDataSourceIface(dataSourceType string) DataSource { - source := AcquisitionSources[dataSourceType] - if source == nil { - return nil + AcquisitionSources[dataSourceType] = dsGetter +} + +// setupLogger creates a logger for the datasource to use at runtime. +func setupLogger(source, name string, level *log.Level) (*log.Entry, error) { + clog := log.New() + if err := types.ConfigureLogger(clog); err != nil { + return nil, fmt.Errorf("while configuring datasource logger: %w", err) + } + + if level != nil { + clog.SetLevel(*level) } - return source() + + fields := log.Fields{ + "type": source, + } + + if name != "" { + fields["name"] = name + } + + subLogger := clog.WithFields(fields) + + return subLogger, nil } // DataSourceConfigure creates and returns a DataSource object from a configuration, @@ -98,35 +121,29 @@ func DataSourceConfigure(commonConfig configuration.DataSourceCommonCfg, metrics // once to DataSourceCommonCfg, and then later to the dedicated type of the datasource yamlConfig, err := yaml.Marshal(commonConfig) if err != nil { - return nil, fmt.Errorf("unable to marshal back interface: %w", err) + return nil, fmt.Errorf("unable to serialize back interface: %w", err) } - if dataSrc := GetDataSourceIface(commonConfig.Source); dataSrc != nil { - /* this logger will then be used by the datasource at runtime */ - clog := log.New() - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring datasource logger: %w", err) - } - if commonConfig.LogLevel != nil { - clog.SetLevel(*commonConfig.LogLevel) - } - customLog := log.Fields{ - "type": commonConfig.Source, - } - if commonConfig.Name != "" { - customLog["name"] = commonConfig.Name - } - subLogger := clog.WithFields(customLog) - /* check eventual dependencies are satisfied (ie. journald will check journalctl availability) */ - if err := dataSrc.CanRun(); err != nil { - return nil, &DataSourceUnavailableError{Name: commonConfig.Source, Err: err} - } - /* configure the actual datasource */ - if err := dataSrc.Configure(yamlConfig, subLogger, metricsLevel); err != nil { - return nil, fmt.Errorf("failed to configure datasource %s: %w", commonConfig.Source, err) - } - return &dataSrc, nil + + dataSrc, err := GetDataSourceIface(commonConfig.Source) + if err != nil { + return nil, err } - return nil, fmt.Errorf("cannot find source %s", commonConfig.Source) + + subLogger, err := setupLogger(commonConfig.Source, commonConfig.Name, commonConfig.LogLevel) + if err != nil { + return nil, err + } + + /* check eventual dependencies are satisfied (ie. journald will check journalctl availability) */ + if err := dataSrc.CanRun(); err != nil { + return nil, &DataSourceUnavailableError{Name: commonConfig.Source, Err: err} + } + /* configure the actual datasource */ + if err := dataSrc.Configure(yamlConfig, subLogger, metricsLevel); err != nil { + return nil, fmt.Errorf("failed to configure datasource %s: %w", commonConfig.Source, err) + } + + return &dataSrc, nil } // detectBackwardCompatAcquis: try to magically detect the type for backward compat (type was not mandatory then) @@ -134,12 +151,15 @@ func detectBackwardCompatAcquis(sub configuration.DataSourceCommonCfg) string { if _, ok := sub.Config["filename"]; ok { return "file" } + if _, ok := sub.Config["filenames"]; ok { return "file" } + if _, ok := sub.Config["journalctl_filter"]; ok { return "journalctl" } + return "" } @@ -150,29 +170,35 @@ func LoadAcquisitionFromDSN(dsn string, labels map[string]string, transformExpr if len(frags) == 1 { return nil, fmt.Errorf("%s isn't valid dsn (no protocol)", dsn) } - dataSrc := GetDataSourceIface(frags[0]) - if dataSrc == nil { - return nil, fmt.Errorf("no acquisition for protocol %s://", frags[0]) + + dataSrc, err := GetDataSourceIface(frags[0]) + if err != nil { + return nil, fmt.Errorf("no acquisition for protocol %s:// - %w", frags[0], err) } - /* this logger will then be used by the datasource at runtime */ - clog := log.New() - if err := types.ConfigureLogger(clog); err != nil { - return nil, fmt.Errorf("while configuring datasource logger: %w", err) + + subLogger, err := setupLogger(dsn, "", nil) + if err != nil { + return nil, err } - subLogger := clog.WithField("type", dsn) + uniqueId := uuid.NewString() + if transformExpr != "" { vm, err := expr.Compile(transformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s': %w", transformExpr, err) } + transformRuntimes[uniqueId] = vm } - err := dataSrc.ConfigureByDSN(dsn, labels, subLogger, uniqueId) + + err = dataSrc.ConfigureByDSN(dsn, labels, subLogger, uniqueId) if err != nil { return nil, fmt.Errorf("while configuration datasource for %s: %w", dsn, err) } + sources = append(sources, dataSrc) + return sources, nil } @@ -180,9 +206,11 @@ func GetMetricsLevelFromPromCfg(prom *csconfig.PrometheusCfg) int { if prom == nil { return configuration.METRICS_FULL } + if !prom.Enabled { return configuration.METRICS_NONE } + if prom.Level == configuration.CFG_METRICS_AGGREGATE { return configuration.METRICS_AGGREGATE } @@ -190,6 +218,7 @@ func GetMetricsLevelFromPromCfg(prom *csconfig.PrometheusCfg) int { if prom.Level == configuration.CFG_METRICS_FULL { return configuration.METRICS_FULL } + return configuration.METRICS_FULL } @@ -198,50 +227,66 @@ func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig var sources []DataSource metrics_level := GetMetricsLevelFromPromCfg(prom) + for _, acquisFile := range config.AcquisitionFiles { log.Infof("loading acquisition file : %s", acquisFile) + yamlFile, err := os.Open(acquisFile) if err != nil { return nil, err } + dec := yaml.NewDecoder(yamlFile) dec.SetStrict(true) + idx := -1 + for { var sub configuration.DataSourceCommonCfg - err = dec.Decode(&sub) + idx += 1 + + err = dec.Decode(&sub) if err != nil { if !errors.Is(err, io.EOF) { return nil, fmt.Errorf("failed to yaml decode %s: %w", acquisFile, err) } + log.Tracef("End of yaml file") + break } - //for backward compat ('type' was not mandatory, detect it) + // for backward compat ('type' was not mandatory, detect it) if guessType := detectBackwardCompatAcquis(sub); guessType != "" { sub.Source = guessType } - //it's an empty item, skip it + // it's an empty item, skip it if len(sub.Labels) == 0 { if sub.Source == "" { log.Debugf("skipping empty item in %s", acquisFile) continue } + if sub.Source != "docker" { - //docker is the only source that can be empty + // docker is the only source that can be empty return nil, fmt.Errorf("missing labels in %s (position: %d)", acquisFile, idx) } } + if sub.Source == "" { return nil, fmt.Errorf("data source type is empty ('source') in %s (position: %d)", acquisFile, idx) } - if GetDataSourceIface(sub.Source) == nil { - return nil, fmt.Errorf("unknown data source %s in %s (position: %d)", sub.Source, acquisFile, idx) + + // pre-check that the source is valid + _, err := GetDataSourceIface(sub.Source) + if err != nil { + return nil, fmt.Errorf("in file %s (position: %d) - %w", acquisFile, idx, err) } + uniqueId := uuid.NewString() sub.UniqueId = uniqueId + src, err := DataSourceConfigure(sub, metrics_level) if err != nil { var dserr *DataSourceUnavailableError @@ -249,29 +294,36 @@ func LoadAcquisitionFromFile(config *csconfig.CrowdsecServiceCfg, prom *csconfig log.Error(err) continue } + return nil, fmt.Errorf("while configuring datasource of type %s from %s (position: %d): %w", sub.Source, acquisFile, idx, err) } + if sub.TransformExpr != "" { vm, err := expr.Compile(sub.TransformExpr, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { return nil, fmt.Errorf("while compiling transform expression '%s' for datasource %s in %s (position: %d): %w", sub.TransformExpr, sub.Source, acquisFile, idx, err) } + transformRuntimes[uniqueId] = vm } + sources = append(sources, *src) } } + return sources, nil } func GetMetrics(sources []DataSource, aggregated bool) error { var metrics []prometheus.Collector - for i := range len(sources) { + + for i := range sources { if aggregated { metrics = sources[i].GetMetrics() } else { metrics = sources[i].GetAggregMetrics() } + for _, metric := range metrics { if err := prometheus.Register(metric); err != nil { if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { @@ -281,12 +333,28 @@ func GetMetrics(sources []DataSource, aggregated bool) error { } } } + return nil } +// There's no need for an actual deep copy +// The event is almost empty, we are mostly interested in allocating new maps for Parsed/Meta/... +func copyEvent(evt types.Event, line string) types.Event { + evtCopy := types.MakeEvent(evt.ExpectMode == types.TIMEMACHINE, evt.Type, evt.Process) + evtCopy.Line = evt.Line + evtCopy.Line.Raw = line + evtCopy.Line.Labels = make(map[string]string) + for k, v := range evt.Line.Labels { + evtCopy.Line.Labels[k] = v + } + + return evtCopy +} + func transform(transformChan chan types.Event, output chan types.Event, AcquisTomb *tomb.Tomb, transformRuntime *vm.Program, logger *log.Entry) { defer trace.CatchPanic("crowdsec/acquis") logger.Infof("transformer started") + for { select { case <-AcquisTomb.Dying(): @@ -294,22 +362,25 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo return case evt := <-transformChan: logger.Tracef("Received event %s", evt.Line.Raw) + out, err := expr.Run(transformRuntime, map[string]interface{}{"evt": &evt}) if err != nil { logger.Errorf("while running transform expression: %s, sending event as-is", err) output <- evt } + if out == nil { logger.Errorf("transform expression returned nil, sending event as-is") output <- evt } + switch v := out.(type) { case string: logger.Tracef("transform expression returned %s", v) - evt.Line.Raw = v - output <- evt + output <- copyEvent(evt, v) case []interface{}: logger.Tracef("transform expression returned %v", v) //nolint:asasalint // We actually want to log the slice content + for _, line := range v { l, ok := line.(string) if !ok { @@ -317,14 +388,14 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo output <- evt continue } - evt.Line.Raw = l - output <- evt + + output <- copyEvent(evt, l) } case []string: logger.Tracef("transform expression returned %v", v) + for _, line := range v { - evt.Line.Raw = line - output <- evt + output <- copyEvent(evt, line) } default: logger.Errorf("transform expression returned an invalid type %T, sending event as-is", out) @@ -334,49 +405,58 @@ func transform(transformChan chan types.Event, output chan types.Event, AcquisTo } } -func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +func StartAcquisition(ctx context.Context, sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { // Don't wait if we have no sources, as it will hang forever if len(sources) == 0 { return nil } - for i := range len(sources) { - subsrc := sources[i] //ensure its a copy + for i := range sources { + subsrc := sources[i] // ensure its a copy log.Debugf("starting one source %d/%d ->> %T", i, len(sources), subsrc) AcquisTomb.Go(func() error { defer trace.CatchPanic("crowdsec/acquis") + var err error outChan := output + log.Debugf("datasource %s UUID: %s", subsrc.GetName(), subsrc.GetUuid()) + if transformRuntime, ok := transformRuntimes[subsrc.GetUuid()]; ok { log.Infof("transform expression found for datasource %s", subsrc.GetName()) + transformChan := make(chan types.Event) outChan = transformChan transformLogger := log.WithFields(log.Fields{ "component": "transform", "datasource": subsrc.GetName(), }) + AcquisTomb.Go(func() error { transform(outChan, output, AcquisTomb, transformRuntime, transformLogger) return nil }) } + if subsrc.GetMode() == configuration.TAIL_MODE { - err = subsrc.StreamingAcquisition(outChan, AcquisTomb) + err = subsrc.StreamingAcquisition(ctx, outChan, AcquisTomb) } else { - err = subsrc.OneShotAcquisition(outChan, AcquisTomb) + err = subsrc.OneShotAcquisition(ctx, outChan, AcquisTomb) } + if err != nil { - //if one of the acqusition returns an error, we kill the others to properly shutdown + // if one of the acqusition returns an error, we kill the others to properly shutdown AcquisTomb.Kill(err) } + return nil }) } /*return only when acquisition is over (cat) or never (tail)*/ err := AcquisTomb.Wait() + return err } diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index a5eecbc20ed..dd70172cf62 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -1,6 +1,7 @@ package acquisition import ( + "context" "errors" "fmt" "strings" @@ -56,14 +57,19 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry, metricsLevel int) return nil } -func (f *MockSource) GetMode() string { return f.Mode } -func (f *MockSource) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSource) CanRun() error { return nil } -func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSource) Dump() interface{} { return f } -func (f *MockSource) GetName() string { return "mock" } +func (f *MockSource) GetMode() string { return f.Mode } +func (f *MockSource) OneShotAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} + +func (f *MockSource) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSource) CanRun() error { return nil } +func (f *MockSource) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSource) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSource) Dump() interface{} { return f } +func (f *MockSource) GetName() string { return "mock" } func (f *MockSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { return errors.New("not supported") } @@ -79,13 +85,8 @@ func (f *MockSourceCantRun) GetName() string { return "mock_cant_run" } // appendMockSource is only used to add mock source for tests func appendMockSource() { - if GetDataSourceIface("mock") == nil { - AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } - } - - if GetDataSourceIface("mock_cant_run") == nil { - AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } - } + AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } + AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } } func TestDataSourceConfigure(t *testing.T) { @@ -150,7 +151,7 @@ labels: log_level: debug source: tutu `, - ExpectedError: "cannot find source tutu", + ExpectedError: "unknown data source tutu", }, { TestName: "mismatch_config", @@ -184,6 +185,7 @@ wowo: ajsajasjas yaml.Unmarshal([]byte(tc.String), &common) ds, err := DataSourceConfigure(common, configuration.METRICS_NONE) cstest.RequireErrorContains(t, err, tc.ExpectedError) + if tc.ExpectedError != "" { return } @@ -270,7 +272,7 @@ func TestLoadAcquisitionFromFile(t *testing.T) { Config: csconfig.CrowdsecServiceCfg{ AcquisitionFiles: []string{"test_files/bad_source.yaml"}, }, - ExpectedError: "unknown data source does_not_exist in test_files/bad_source.yaml", + ExpectedError: "in file test_files/bad_source.yaml (position: 0) - unknown data source does_not_exist", }, { TestName: "invalid_filetype_config", @@ -284,6 +286,7 @@ func TestLoadAcquisitionFromFile(t *testing.T) { t.Run(tc.TestName, func(t *testing.T) { dss, err := LoadAcquisitionFromFile(&tc.Config, nil) cstest.RequireErrorContains(t, err, tc.ExpectedError) + if tc.ExpectedError != "" { return } @@ -320,7 +323,7 @@ func (f *MockCat) Configure(cfg []byte, logger *log.Entry, metricsLevel int) err func (f *MockCat) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockCat) GetName() string { return "mock_cat" } func (f *MockCat) GetMode() string { return "cat" } -func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) error { +func (f *MockCat) OneShotAcquisition(ctx context.Context, out chan types.Event, tomb *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" @@ -329,7 +332,8 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro return nil } -func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { + +func (f *MockCat) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { return errors.New("can't run in tail") } func (f *MockCat) CanRun() error { return nil } @@ -364,15 +368,17 @@ func (f *MockTail) Configure(cfg []byte, logger *log.Entry, metricsLevel int) er func (f *MockTail) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockTail) GetName() string { return "mock_tail" } func (f *MockTail) GetMode() string { return "tail" } -func (f *MockTail) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) error { +func (f *MockTail) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return errors.New("can't run in cat mode") } -func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { + +func (f *MockTail) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } + <-t.Dying() return nil @@ -386,9 +392,10 @@ func (f *MockTail) ConfigureByDSN(string, map[string]string, *log.Entry, string) } func (f *MockTail) GetUuid() string { return "" } -//func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { +// func StartAcquisition(sources []DataSource, output chan types.Event, AcquisTomb *tomb.Tomb) error { func TestStartAcquisitionCat(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockCat{}, } @@ -396,7 +403,7 @@ func TestStartAcquisitionCat(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -416,6 +423,7 @@ READLOOP: } func TestStartAcquisitionTail(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTail{}, } @@ -423,7 +431,7 @@ func TestStartAcquisitionTail(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil { t.Errorf("unexpected error") } }() @@ -450,18 +458,20 @@ type MockTailError struct { MockTail } -func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *MockTailError) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { for range 10 { evt := types.Event{} evt.Line.Src = "test" out <- evt } + t.Kill(errors.New("got error (tomb)")) return errors.New("got error") } func TestStartAcquisitionTailError(t *testing.T) { + ctx := context.Background() sources := []DataSource{ &MockTailError{}, } @@ -469,7 +479,7 @@ func TestStartAcquisitionTailError(t *testing.T) { acquisTomb := tomb.Tomb{} go func() { - if err := StartAcquisition(sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { + if err := StartAcquisition(ctx, sources, out, &acquisTomb); err != nil && err.Error() != "got error (tomb)" { t.Errorf("expected error, got '%s'", err) } }() @@ -485,7 +495,7 @@ READLOOP: } } assert.Equal(t, 10, count) - //acquisTomb.Kill(nil) + // acquisTomb.Kill(nil) time.Sleep(1 * time.Second) cstest.RequireErrorContains(t, acquisTomb.Err(), "got error (tomb)") } @@ -500,14 +510,19 @@ func (f *MockSourceByDSN) UnmarshalConfig(cfg []byte) error { return nil } func (f *MockSourceByDSN) Configure(cfg []byte, logger *log.Entry, metricsLevel int) error { return nil } -func (f *MockSourceByDSN) GetMode() string { return f.Mode } -func (f *MockSourceByDSN) OneShotAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { return nil } -func (f *MockSourceByDSN) CanRun() error { return nil } -func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } -func (f *MockSourceByDSN) Dump() interface{} { return f } -func (f *MockSourceByDSN) GetName() string { return "mockdsn" } +func (f *MockSourceByDSN) GetMode() string { return f.Mode } +func (f *MockSourceByDSN) OneShotAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} + +func (f *MockSourceByDSN) StreamingAcquisition(context.Context, chan types.Event, *tomb.Tomb) error { + return nil +} +func (f *MockSourceByDSN) CanRun() error { return nil } +func (f *MockSourceByDSN) GetMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) GetAggregMetrics() []prometheus.Collector { return nil } +func (f *MockSourceByDSN) Dump() interface{} { return f } +func (f *MockSourceByDSN) GetName() string { return "mockdsn" } func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { dsn = strings.TrimPrefix(dsn, "mockdsn://") if dsn != "test_expect" { @@ -542,9 +557,7 @@ func TestConfigureByDSN(t *testing.T) { }, } - if GetDataSourceIface("mockdsn") == nil { - AcquisitionSources["mockdsn"] = func() DataSource { return &MockSourceByDSN{} } - } + AcquisitionSources["mockdsn"] = func() DataSource { return &MockSourceByDSN{} } for _, tc := range tests { t.Run(tc.dsn, func(t *testing.T) { diff --git a/pkg/acquisition/appsec.go b/pkg/acquisition/appsec.go new file mode 100644 index 00000000000..81616d3d2b8 --- /dev/null +++ b/pkg/acquisition/appsec.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_appsec + +package acquisition + +import ( + appsecacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("appsec", func() DataSource { return &appsecacquisition.AppsecSource{} }) +} diff --git a/pkg/acquisition/cloudwatch.go b/pkg/acquisition/cloudwatch.go new file mode 100644 index 00000000000..e6b3d3e3e53 --- /dev/null +++ b/pkg/acquisition/cloudwatch.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_cloudwatch + +package acquisition + +import ( + cloudwatchacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/cloudwatch" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("cloudwatch", func() DataSource { return &cloudwatchacquisition.CloudwatchSource{} }) +} diff --git a/pkg/acquisition/docker.go b/pkg/acquisition/docker.go new file mode 100644 index 00000000000..3bf792a039a --- /dev/null +++ b/pkg/acquisition/docker.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_docker + +package acquisition + +import ( + dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("docker", func() DataSource { return &dockeracquisition.DockerSource{} }) +} diff --git a/pkg/acquisition/file.go b/pkg/acquisition/file.go new file mode 100644 index 00000000000..1ff2e4a3c0e --- /dev/null +++ b/pkg/acquisition/file.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_file + +package acquisition + +import ( + fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("file", func() DataSource { return &fileacquisition.FileSource{} }) +} diff --git a/pkg/acquisition/http.go b/pkg/acquisition/http.go new file mode 100644 index 00000000000..59745772b62 --- /dev/null +++ b/pkg/acquisition/http.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_http + +package acquisition + +import ( + httpacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/http" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("http", func() DataSource { return &httpacquisition.HTTPSource{} }) +} diff --git a/pkg/acquisition/journalctl.go b/pkg/acquisition/journalctl.go new file mode 100644 index 00000000000..691f961ae77 --- /dev/null +++ b/pkg/acquisition/journalctl.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_journalctl + +package acquisition + +import ( + journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("journalctl", func() DataSource { return &journalctlacquisition.JournalCtlSource{} }) +} diff --git a/pkg/acquisition/k8s.go b/pkg/acquisition/k8s.go new file mode 100644 index 00000000000..cb9446be285 --- /dev/null +++ b/pkg/acquisition/k8s.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_k8saudit + +package acquisition + +import ( + k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("k8s-audit", func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }) +} diff --git a/pkg/acquisition/kafka.go b/pkg/acquisition/kafka.go new file mode 100644 index 00000000000..7d315d87feb --- /dev/null +++ b/pkg/acquisition/kafka.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_kafka + +package acquisition + +import ( + kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("kafka", func() DataSource { return &kafkaacquisition.KafkaSource{} }) +} diff --git a/pkg/acquisition/kinesis.go b/pkg/acquisition/kinesis.go new file mode 100644 index 00000000000..b41372e7fb9 --- /dev/null +++ b/pkg/acquisition/kinesis.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_kinesis + +package acquisition + +import ( + kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("kinesis", func() DataSource { return &kinesisacquisition.KinesisSource{} }) +} diff --git a/pkg/acquisition/loki.go b/pkg/acquisition/loki.go new file mode 100644 index 00000000000..1eed6686591 --- /dev/null +++ b/pkg/acquisition/loki.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_loki + +package acquisition + +import ( + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("loki", func() DataSource { return &loki.LokiSource{} }) +} diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 5b0661a21b7..2f7861b32ff 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -41,6 +41,7 @@ type AppsecSourceConfig struct { Path string `yaml:"path"` Routines int `yaml:"routines"` AppsecConfig string `yaml:"appsec_config"` + AppsecConfigs []string `yaml:"appsec_configs"` AppsecConfigPath string `yaml:"appsec_config_path"` AuthCacheDuration *time.Duration `yaml:"auth_cache_duration"` configuration.DataSourceCommonCfg `yaml:",inline"` @@ -59,7 +60,7 @@ type AppsecSource struct { AppsecConfigs map[string]appsec.AppsecConfig lapiURL string AuthCache AuthCache - AppsecRunners []AppsecRunner //one for each go-routine + AppsecRunners []AppsecRunner // one for each go-routine } // Struct to handle cache of authentication @@ -85,6 +86,7 @@ func (ac *AuthCache) Get(apiKey string) (time.Time, bool) { ac.mu.RLock() expiration, exists := ac.APIKeys[apiKey] ac.mu.RUnlock() + return expiration, exists } @@ -120,14 +122,19 @@ func (w *AppsecSource) UnmarshalConfig(yamlConfig []byte) error { w.config.Routines = 1 } - if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" { + if w.config.AppsecConfig == "" && w.config.AppsecConfigPath == "" && len(w.config.AppsecConfigs) == 0 { return errors.New("appsec_config or appsec_config_path must be set") } + if (w.config.AppsecConfig != "" || w.config.AppsecConfigPath != "") && len(w.config.AppsecConfigs) != 0 { + return errors.New("appsec_config and appsec_config_path are mutually exclusive with appsec_configs") + } + if w.config.Name == "" { if w.config.ListenSocket != "" && w.config.ListenAddr == "" { w.config.Name = w.config.ListenSocket } + if w.config.ListenSocket == "" { w.config.Name = fmt.Sprintf("%s%s", w.config.ListenAddr, w.config.Path) } @@ -153,6 +160,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe if err != nil { return fmt.Errorf("unable to parse appsec configuration: %w", err) } + w.logger = logger w.metricsLevel = MetricsLevel w.logger.Tracef("Appsec configuration: %+v", w.config) @@ -172,7 +180,10 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe w.InChan = make(chan appsec.ParsedRequest) appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")} - //let's load the associated appsec_config: + //we keep the datasource name + appsecCfg.Name = w.config.Name + + // let's load the associated appsec_config: if w.config.AppsecConfigPath != "" { err := appsecCfg.LoadByPath(w.config.AppsecConfigPath) if err != nil { @@ -183,10 +194,20 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe if err != nil { return fmt.Errorf("unable to load appsec_config: %w", err) } + } else if len(w.config.AppsecConfigs) > 0 { + for _, appsecConfig := range w.config.AppsecConfigs { + err := appsecCfg.Load(appsecConfig) + if err != nil { + return fmt.Errorf("unable to load appsec_config: %w", err) + } + } } else { return errors.New("no appsec_config provided") } + // Now we can set up the logger + appsecCfg.SetUpLogger() + w.AppsecRuntime, err = appsecCfg.Build() if err != nil { return fmt.Errorf("unable to build appsec_config: %w", err) @@ -201,7 +222,7 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe for nbRoutine := range w.config.Routines { appsecRunnerUUID := uuid.New().String() - //we copy AppsecRutime for each runner + // we copy AppsecRutime for each runner wrt := *w.AppsecRuntime wrt.Logger = w.logger.Dup().WithField("runner_uuid", appsecRunnerUUID) runner := AppsecRunner{ @@ -211,17 +232,20 @@ func (w *AppsecSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe AppsecRuntime: &wrt, Labels: w.config.Labels, } + err := runner.Init(appsecCfg.GetDataDir()) if err != nil { return fmt.Errorf("unable to initialize runner: %w", err) } + w.AppsecRunners[nbRoutine] = runner } w.logger.Infof("Created %d appsec runners", len(w.AppsecRunners)) - //We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec + // We don´t use the wrapper provided by coraza because we want to fully control what happens when a rule match to send the information in crowdsec w.mux.HandleFunc(w.config.Path, w.appsecHandler) + return nil } @@ -237,16 +261,18 @@ func (w *AppsecSource) GetName() string { return "appsec" } -func (w *AppsecSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *AppsecSource) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return errors.New("AppSec datasource does not support command line acquisition") } -func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *AppsecSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { w.outChan = out + t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/appsec/live") w.logger.Infof("%d appsec runner to start", len(w.AppsecRunners)) + for _, runner := range w.AppsecRunners { runner.outChan = out t.Go(func() error { @@ -254,6 +280,7 @@ func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) return runner.Run(t) }) } + t.Go(func() error { if w.config.ListenSocket != "" { w.logger.Infof("creating unix socket %s", w.config.ListenSocket) @@ -268,10 +295,11 @@ func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } else { err = w.server.Serve(listener) } - if err != nil && err != http.ErrServerClosed { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("appsec server failed: %w", err) } } + return nil }) t.Go(func() error { @@ -288,15 +316,17 @@ func (w *AppsecSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) return fmt.Errorf("appsec server failed: %w", err) } } + return nil }) <-t.Dying() w.logger.Info("Shutting down Appsec server") - //xx let's clean up the appsec runners :) + // xx let's clean up the appsec runners :) appsec.AppsecRulesDetails = make(map[int]appsec.RulesDetails) - w.server.Shutdown(context.TODO()) + w.server.Shutdown(ctx) return nil }) + return nil } @@ -391,9 +421,10 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { logger.Debugf("Response: %+v", appsecResponse) rw.WriteHeader(statusCode) + body, err := json.Marshal(appsecResponse) if err != nil { - logger.Errorf("unable to marshal response: %s", err) + logger.Errorf("unable to serialize response: %s", err) rw.WriteHeader(http.StatusInternalServerError) } else { rw.Write(body) diff --git a/pkg/acquisition/modules/appsec/appsec_lnx_test.go b/pkg/acquisition/modules/appsec/appsec_lnx_test.go index 3e40a1f970c..61dfc536f5e 100644 --- a/pkg/acquisition/modules/appsec/appsec_lnx_test.go +++ b/pkg/acquisition/modules/appsec/appsec_lnx_test.go @@ -1,5 +1,4 @@ //go:build !windows -// +build !windows package appsecacquisition @@ -16,6 +15,7 @@ import ( func TestAppsecRuleTransformsOthers(t *testing.T) { log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ { name: "normalizepath", diff --git a/pkg/acquisition/modules/appsec/appsec_rules_test.go b/pkg/acquisition/modules/appsec/appsec_rules_test.go index 909f16357ed..00093c5a5ad 100644 --- a/pkg/acquisition/modules/appsec/appsec_rules_test.go +++ b/pkg/acquisition/modules/appsec/appsec_rules_test.go @@ -28,7 +28,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -59,7 +60,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"tutu"}}, @@ -84,7 +86,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -110,7 +113,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -136,7 +140,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"toto"}}, @@ -165,7 +170,8 @@ func TestAppsecRuleMatches(t *testing.T) { {Filter: "IsInBand == true", Apply: []string{"SetRemediation('captcha')"}}, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"bla"}}, @@ -192,7 +198,8 @@ func TestAppsecRuleMatches(t *testing.T) { {Filter: "IsInBand == true", Apply: []string{"SetReturnCode(418)"}}, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"bla"}}, @@ -219,7 +226,8 @@ func TestAppsecRuleMatches(t *testing.T) { {Filter: "IsInBand == true", Apply: []string{"SetRemediationByName('rule42', 'captcha')"}}, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Args: url.Values{"foo": []string{"bla"}}, @@ -243,7 +251,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Cookie": []string{"foo=toto"}}, @@ -273,7 +282,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Cookie": []string{"foo=toto; bar=tutu"}}, @@ -303,7 +313,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Cookie": []string{"bar=tutu; tututata=toto"}}, @@ -333,7 +344,8 @@ func TestAppsecRuleMatches(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/urllll", Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}}, @@ -354,6 +366,32 @@ toto require.Len(t, events[1].Appsec.MatchedRules, 1) require.Equal(t, "rule1", events[1].Appsec.MatchedRules[0]["msg"]) + require.Len(t, responses, 1) + require.True(t, responses[0].InBandInterrupt) + }, + }, + { + name: "Basic matching IP address", + expected_load_ok: true, + inband_native_rules: []string{ + "SecRule REMOTE_ADDR \"@ipMatch 1.2.3.4\" \"id:1,phase:1,log,deny,msg: 'block ip'\"", + }, + input_request: appsec.ParsedRequest{ + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", + Method: "GET", + URI: "/urllll", + Headers: http.Header{"Content-Type": []string{"multipart/form-data; boundary=boundary"}}, + }, + output_asserts: func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) { + require.Len(t, events, 2) + require.Equal(t, types.APPSEC, events[0].Type) + + require.Equal(t, types.LOG, events[1].Type) + require.True(t, events[1].Appsec.HasInBandMatches) + require.Len(t, events[1].Appsec.MatchedRules, 1) + require.Equal(t, "block ip", events[1].Appsec.MatchedRules[0]["msg"]) + require.Len(t, responses, 1) require.True(t, responses[0].InBandInterrupt) }, @@ -381,7 +419,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/toto", }, @@ -404,7 +443,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/TOTO", }, @@ -427,7 +467,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/toto", }, @@ -451,7 +492,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=dG90bw", }, @@ -475,7 +517,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=dG90bw===", }, @@ -499,7 +542,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=toto", }, @@ -523,7 +567,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=%42%42%2F%41", }, @@ -547,7 +592,8 @@ func TestAppsecRuleTransforms(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/?foo=%20%20%42%42%2F%41%20%20", }, @@ -585,7 +631,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar?something=toto&foobar=smth", }, @@ -612,7 +659,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar?something=toto&foobar=smth", }, @@ -639,7 +687,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Body: []byte("smth=toto&foobar=other"), @@ -668,7 +717,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Body: []byte("smth=toto&foobar=other"), @@ -697,7 +747,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Headers: http.Header{"foobar": []string{"toto"}}, @@ -725,7 +776,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Headers: http.Header{"foobar": []string{"toto"}}, @@ -748,7 +800,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", }, @@ -770,7 +823,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Proto: "HTTP/3.1", @@ -793,7 +847,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar", }, @@ -815,7 +870,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/foobar?a=b", }, @@ -837,7 +893,8 @@ func TestAppsecRuleZones(t *testing.T) { }, }, input_request: appsec.ParsedRequest{ - RemoteAddr: "1.2.3.4", + ClientIP: "1.2.3.4", + RemoteAddr: "127.0.0.1", Method: "GET", URI: "/", Body: []byte("foobar=42421"), diff --git a/pkg/acquisition/modules/appsec/appsec_runner.go b/pkg/acquisition/modules/appsec/appsec_runner.go index ed49d6a7b41..7ce43779591 100644 --- a/pkg/acquisition/modules/appsec/appsec_runner.go +++ b/pkg/acquisition/modules/appsec/appsec_runner.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "slices" + "strings" "time" "github.com/prometheus/client_golang/prometheus" @@ -31,23 +32,38 @@ type AppsecRunner struct { logger *log.Entry } +func (r *AppsecRunner) MergeDedupRules(collections []appsec.AppsecCollection, logger *log.Entry) string { + var rulesArr []string + dedupRules := make(map[string]struct{}) + + for _, collection := range collections { + for _, rule := range collection.Rules { + if _, ok := dedupRules[rule]; !ok { + rulesArr = append(rulesArr, rule) + dedupRules[rule] = struct{}{} + } else { + logger.Debugf("Discarding duplicate rule : %s", rule) + } + } + } + if len(rulesArr) != len(dedupRules) { + logger.Warningf("%d rules were discarded as they were duplicates", len(rulesArr)-len(dedupRules)) + } + + return strings.Join(rulesArr, "\n") +} + func (r *AppsecRunner) Init(datadir string) error { var err error fs := os.DirFS(datadir) - inBandRules := "" - outOfBandRules := "" - - for _, collection := range r.AppsecRuntime.InBandRules { - inBandRules += collection.String() - } - - for _, collection := range r.AppsecRuntime.OutOfBandRules { - outOfBandRules += collection.String() - } inBandLogger := r.logger.Dup().WithField("band", "inband") outBandLogger := r.logger.Dup().WithField("band", "outband") + //While loading rules, we dedup rules based on their content, while keeping the order + inBandRules := r.MergeDedupRules(r.AppsecRuntime.InBandRules, inBandLogger) + outOfBandRules := r.MergeDedupRules(r.AppsecRuntime.OutOfBandRules, outBandLogger) + //setting up inband engine inbandCfg := coraza.NewWAFConfig().WithDirectives(inBandRules).WithRootFS(fs).WithDebugLogger(appsec.NewCrzLogger(inBandLogger)) if !r.AppsecRuntime.Config.InbandOptions.DisableBodyInspection { @@ -135,7 +151,7 @@ func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *ap //FIXME: should we abort here ? } - request.Tx.ProcessConnection(request.RemoteAddr, 0, "", 0) + request.Tx.ProcessConnection(request.ClientIP, 0, "", 0) for k, v := range request.Args { for _, vv := range v { @@ -167,7 +183,7 @@ func (r *AppsecRunner) processRequest(tx appsec.ExtendedTransaction, request *ap return nil } - if request.Body != nil && len(request.Body) > 0 { + if len(request.Body) > 0 { in, _, err = request.Tx.WriteRequestBody(request.Body) if err != nil { r.logger.Errorf("unable to write request body : %s", err) @@ -249,7 +265,7 @@ func (r *AppsecRunner) handleInBandInterrupt(request *appsec.ParsedRequest) { // Should the in band match trigger an overflow ? if r.AppsecRuntime.Response.SendAlert { - appsecOvlfw, err := AppsecEventGeneration(evt) + appsecOvlfw, err := AppsecEventGeneration(evt, request.HTTPRequest) if err != nil { r.logger.Errorf("unable to generate appsec event : %s", err) return @@ -293,7 +309,7 @@ func (r *AppsecRunner) handleOutBandInterrupt(request *appsec.ParsedRequest) { // Should the match trigger an overflow ? if r.AppsecRuntime.Response.SendAlert { - appsecOvlfw, err := AppsecEventGeneration(evt) + appsecOvlfw, err := AppsecEventGeneration(evt, request.HTTPRequest) if err != nil { r.logger.Errorf("unable to generate appsec event : %s", err) return diff --git a/pkg/acquisition/modules/appsec/appsec_runner_test.go b/pkg/acquisition/modules/appsec/appsec_runner_test.go new file mode 100644 index 00000000000..2027cf1d2c0 --- /dev/null +++ b/pkg/acquisition/modules/appsec/appsec_runner_test.go @@ -0,0 +1,139 @@ +package appsecacquisition + +import ( + "testing" + + "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestAppsecRuleLoad(t *testing.T) { + log.SetLevel(log.TraceLevel) + tests := []appsecRuleTest{ + { + name: "simple rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 1) + }, + }, + { + name: "simple native rule load", + expected_load_ok: true, + inband_native_rules: []string{ + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 1) + }, + }, + { + name: "simple native rule load (2)", + expected_load_ok: true, + inband_native_rules: []string{ + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + `Secrule REQUEST_HEADERS:Content-Type "@rx ^multipart/form-data" "id:101,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=MULTIPART"`, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "simple native rule load + dedup", + expected_load_ok: true, + inband_native_rules: []string{ + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + `Secrule REQUEST_HEADERS:Content-Type "@rx ^multipart/form-data" "id:101,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=MULTIPART"`, + `Secrule REQUEST_HEADERS:Content-Type "@rx ^application/x-www-form-urlencoded" "id:100,phase:1,pass,nolog,noauditlog,ctl:requestBodyProcessor=URLENCODED"`, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "multi simple rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "multi simple rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + Name: "rule2", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 2) + }, + }, + { + name: "imbricated rule load", + expected_load_ok: true, + inband_rules: []appsec_rule.CustomRule{ + { + Name: "rule1", + + Or: []appsec_rule.CustomRule{ + { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "toto"}, + }, + { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "tutu"}, + }, + { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "tata"}, + }, { + //Name: "rule1", + Zones: []string{"ARGS"}, + Match: appsec_rule.Match{Type: "equals", Value: "titi"}, + }, + }, + }, + }, + afterload_asserts: func(runner AppsecRunner) { + require.Len(t, runner.AppsecInbandEngine.GetRuleGroup().GetRules(), 4) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + loadAppSecEngine(test, t) + }) + } +} diff --git a/pkg/acquisition/modules/appsec/appsec_test.go b/pkg/acquisition/modules/appsec/appsec_test.go index d2079b43726..1534f5cb7fa 100644 --- a/pkg/acquisition/modules/appsec/appsec_test.go +++ b/pkg/acquisition/modules/appsec/appsec_test.go @@ -18,6 +18,8 @@ type appsecRuleTest struct { expected_load_ok bool inband_rules []appsec_rule.CustomRule outofband_rules []appsec_rule.CustomRule + inband_native_rules []string + outofband_native_rules []string on_load []appsec.Hook pre_eval []appsec.Hook post_eval []appsec.Hook @@ -28,6 +30,7 @@ type appsecRuleTest struct { DefaultRemediation string DefaultPassAction string input_request appsec.ParsedRequest + afterload_asserts func(runner AppsecRunner) output_asserts func(events []types.Event, responses []appsec.AppsecTempResponse, appsecResponse appsec.BodyResponse, statusCode int) } @@ -53,6 +56,8 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { inbandRules = append(inbandRules, strRule) } + inbandRules = append(inbandRules, test.inband_native_rules...) + outofbandRules = append(outofbandRules, test.outofband_native_rules...) for ridx, rule := range test.outofband_rules { strRule, _, err := rule.Convert(appsec_rule.ModsecurityRuleType, rule.Name) if err != nil { @@ -94,6 +99,13 @@ func loadAppSecEngine(test appsecRuleTest, t *testing.T) { t.Fatalf("unable to initialize runner : %s", err) } + if test.afterload_asserts != nil { + //afterload asserts are just to evaluate the state of the runner after the rules have been loaded + //if it's present, don't try to process requests + test.afterload_asserts(runner) + return + } + input := test.input_request input.ResponseChannel = make(chan appsec.AppsecTempResponse) OutputEvents := make([]types.Event, 0) diff --git a/pkg/acquisition/modules/appsec/appsec_win_test.go b/pkg/acquisition/modules/appsec/appsec_win_test.go index e85d75df251..a6b8f3a0340 100644 --- a/pkg/acquisition/modules/appsec/appsec_win_test.go +++ b/pkg/acquisition/modules/appsec/appsec_win_test.go @@ -1,5 +1,4 @@ //go:build windows -// +build windows package appsecacquisition diff --git a/pkg/acquisition/modules/appsec/utils.go b/pkg/acquisition/modules/appsec/utils.go index 15de8046716..8995b305680 100644 --- a/pkg/acquisition/modules/appsec/utils.go +++ b/pkg/acquisition/modules/appsec/utils.go @@ -1,10 +1,10 @@ package appsecacquisition import ( + "errors" "fmt" "net" - "slices" - "strconv" + "net/http" "time" "github.com/oschwald/geoip2-golang" @@ -22,32 +22,49 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var appsecMetaKeys = []string{ - "id", - "name", - "method", - "uri", - "matched_zones", - "msg", -} +func AppsecEventGenerationGeoIPEnrich(src *models.Source) error { -func appendMeta(meta models.Meta, key string, value string) models.Meta { - if value == "" { - return meta + if src == nil || src.Scope == nil || *src.Scope != types.Ip { + return errors.New("source is nil or not an IP") } - meta = append(meta, &models.MetaItems0{ - Key: key, - Value: value, - }) - return meta + //GeoIP enrich + asndata, err := exprhelpers.GeoIPASNEnrich(src.IP) + + if err != nil { + return err + } else if asndata != nil { + record := asndata.(*geoip2.ASN) + src.AsName = record.AutonomousSystemOrganization + src.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber) + } + + cityData, err := exprhelpers.GeoIPEnrich(src.IP) + if err != nil { + return err + } else if cityData != nil { + record := cityData.(*geoip2.City) + src.Cn = record.Country.IsoCode + src.Latitude = float32(record.Location.Latitude) + src.Longitude = float32(record.Location.Longitude) + } + + rangeData, err := exprhelpers.GeoIPRangeEnrich(src.IP) + if err != nil { + return err + } else if rangeData != nil { + record := rangeData.(*net.IPNet) + src.Range = record.String() + } + return nil } -func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { - //if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI +func AppsecEventGeneration(inEvt types.Event, request *http.Request) (*types.Event, error) { + // if the request didnd't trigger inband rules, we don't want to generate an event to LAPI/CAPI if !inEvt.Appsec.HasInBandMatches { return nil, nil } + evt := types.Event{} evt.Type = types.APPSEC evt.Process = true @@ -58,34 +75,12 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { Scope: ptr.Of(types.Ip), } - asndata, err := exprhelpers.GeoIPASNEnrich(sourceIP) - - if err != nil { - log.Errorf("Unable to enrich ip '%s' for ASN: %s", sourceIP, err) - } else if asndata != nil { - record := asndata.(*geoip2.ASN) - source.AsName = record.AutonomousSystemOrganization - source.AsNumber = fmt.Sprintf("%d", record.AutonomousSystemNumber) - } - - cityData, err := exprhelpers.GeoIPEnrich(sourceIP) - if err != nil { - log.Errorf("Unable to enrich ip '%s' for geo data: %s", sourceIP, err) - } else if cityData != nil { - record := cityData.(*geoip2.City) - source.Cn = record.Country.IsoCode - source.Latitude = float32(record.Location.Latitude) - source.Longitude = float32(record.Location.Longitude) - } - - rangeData, err := exprhelpers.GeoIPRangeEnrich(sourceIP) - if err != nil { - log.Errorf("Unable to enrich ip '%s' for range: %s", sourceIP, err) - } else if rangeData != nil { - record := rangeData.(*net.IPNet) - source.Range = record.String() + // Enrich source with GeoIP data + if err := AppsecEventGenerationGeoIPEnrich(&source); err != nil { + log.Errorf("unable to enrich source with GeoIP data : %s", err) } + // Build overflow evt.Overflow.Sources = make(map[string]models.Source) evt.Overflow.Sources[sourceIP] = source @@ -93,80 +88,11 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { alert.Capacity = ptr.Of(int32(1)) alert.Events = make([]*models.Event, len(evt.Appsec.GetRuleIDs())) - now := ptr.Of(time.Now().UTC().Format(time.RFC3339)) - - tmpAppsecContext := make(map[string][]string) - - for _, matched_rule := range inEvt.Appsec.MatchedRules { - evtRule := models.Event{} - - evtRule.Timestamp = now - - evtRule.Meta = make(models.Meta, 0) - - for _, key := range appsecMetaKeys { - - if tmpAppsecContext[key] == nil { - tmpAppsecContext[key] = make([]string, 0) - } - - switch value := matched_rule[key].(type) { - case string: - evtRule.Meta = appendMeta(evtRule.Meta, key, value) - if value != "" && !slices.Contains(tmpAppsecContext[key], value) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], value) - } - case int: - val := strconv.Itoa(value) - evtRule.Meta = appendMeta(evtRule.Meta, key, val) - if val != "" && !slices.Contains(tmpAppsecContext[key], val) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], val) - } - case []string: - for _, v := range value { - evtRule.Meta = appendMeta(evtRule.Meta, key, v) - if v != "" && !slices.Contains(tmpAppsecContext[key], v) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], v) - } - } - case []int: - for _, v := range value { - val := strconv.Itoa(v) - evtRule.Meta = appendMeta(evtRule.Meta, key, val) - if val != "" && !slices.Contains(tmpAppsecContext[key], val) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], val) - } - - } - default: - val := fmt.Sprintf("%v", value) - evtRule.Meta = appendMeta(evtRule.Meta, key, val) - if val != "" && !slices.Contains(tmpAppsecContext[key], val) { - tmpAppsecContext[key] = append(tmpAppsecContext[key], val) - } - - } + metas, errors := alertcontext.AppsecEventToContext(inEvt.Appsec, request) + if len(errors) > 0 { + for _, err := range errors { + log.Errorf("failed to generate appsec context: %s", err) } - alert.Events = append(alert.Events, &evtRule) - } - - metas := make([]*models.MetaItems0, 0) - - for key, values := range tmpAppsecContext { - if len(values) == 0 { - continue - } - - valueStr, err := alertcontext.TruncateContext(values, alertcontext.MaxContextValueLen) - if err != nil { - log.Warningf(err.Error()) - } - - meta := models.MetaItems0{ - Key: key, - Value: valueStr, - } - metas = append(metas, &meta) } alert.Meta = metas @@ -185,15 +111,13 @@ func AppsecEventGeneration(inEvt types.Event) (*types.Event, error) { alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) evt.Overflow.APIAlerts = []models.Alert{alert} evt.Overflow.Alert = &alert + return &evt, nil } func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types.Event, error) { - evt := types.Event{} - //we might want to change this based on in-band vs out-of-band ? - evt.Type = types.LOG - evt.ExpectMode = types.LIVE - //def needs fixing + evt := types.MakeEvent(false, types.LOG, true) + // def needs fixing evt.Stage = "s00-raw" evt.Parsed = map[string]string{ "source_ip": r.ClientIP, @@ -203,19 +127,19 @@ func EventFromRequest(r *appsec.ParsedRequest, labels map[string]string) (types. "req_uuid": r.Tx.ID(), "source": "crowdsec-appsec", "remediation_cmpt_ip": r.RemoteAddrNormalized, - //TBD: - //http_status - //user_agent + // TBD: + // http_status + // user_agent } evt.Line = types.Line{ Time: time.Now(), - //should we add some info like listen addr/port/path ? + // should we add some info like listen addr/port/path ? Labels: labels, Process: true, Module: "appsec", Src: "appsec", - Raw: "dummy-appsec-data", //we discard empty Line.Raw items :) + Raw: "dummy-appsec-data", // we discard empty Line.Raw items :) } evt.Appsec = types.AppsecEvent{} @@ -247,29 +171,29 @@ func LogAppsecEvent(evt *types.Event, logger *log.Entry) { "target_uri": req, }).Debugf("%s triggered non-blocking rules on %s (%d rules) [%v]", evt.Parsed["source_ip"], req, len(evt.Appsec.MatchedRules), evt.Appsec.GetRuleIDs()) } - } func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedRequest) error { - if evt == nil { - //an error was already emitted, let's not spam the logs + // an error was already emitted, let's not spam the logs return nil } if !req.Tx.IsInterrupted() { - //if the phase didn't generate an interruption, we don't have anything to add to the event + // if the phase didn't generate an interruption, we don't have anything to add to the event return nil } - //if one interruption was generated, event is good for processing :) + // if one interruption was generated, event is good for processing :) evt.Process = true if evt.Meta == nil { evt.Meta = map[string]string{} } + if evt.Parsed == nil { evt.Parsed = map[string]string{} } + if req.IsInBand { evt.Meta["appsec_interrupted"] = "true" evt.Meta["appsec_action"] = req.Tx.Interruption().Action @@ -290,9 +214,11 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR if variable.Key() != "" { key += "." + variable.Key() } + if variable.Value() == "" { continue } + for _, collectionToKeep := range r.AppsecRuntime.CompiledVariablesTracking { match := collectionToKeep.MatchString(key) if match { @@ -303,6 +229,7 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR } } } + return true }) @@ -325,11 +252,12 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR ruleNameProm := fmt.Sprintf("%d", rule.Rule().ID()) if details, ok := appsec.AppsecRulesDetails[rule.Rule().ID()]; ok { - //Only set them for custom rules, not for rules written in seclang + // Only set them for custom rules, not for rules written in seclang name = details.Name version = details.Version hash = details.Hash ruleNameProm = details.Name + r.logger.Debugf("custom rule for event, setting name: %s, version: %s, hash: %s", name, version, hash) } else { name = fmt.Sprintf("native_rule:%d", rule.Rule().ID()) @@ -338,12 +266,15 @@ func (r *AppsecRunner) AccumulateTxToEvent(evt *types.Event, req *appsec.ParsedR AppsecRuleHits.With(prometheus.Labels{"rule_name": ruleNameProm, "type": kind, "source": req.RemoteAddrNormalized, "appsec_engine": req.AppsecEngine}).Inc() matchedZones := make([]string, 0) + for _, matchData := range rule.MatchedDatas() { zone := matchData.Variable().Name() + varName := matchData.Key() if varName != "" { zone += "." + varName } + matchedZones = append(matchedZones, zone) } diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch.go b/pkg/acquisition/modules/cloudwatch/cloudwatch.go index 1a78ae6fa7a..ba267c9050b 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch.go @@ -57,16 +57,16 @@ type CloudwatchSource struct { // CloudwatchSourceConfiguration allows user to define one or more streams to monitor within a cloudwatch log group type CloudwatchSourceConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` - GroupName string `yaml:"group_name"` //the group name to be monitored - StreamRegexp *string `yaml:"stream_regexp,omitempty"` //allow to filter specific streams + GroupName string `yaml:"group_name"` // the group name to be monitored + StreamRegexp *string `yaml:"stream_regexp,omitempty"` // allow to filter specific streams StreamName *string `yaml:"stream_name,omitempty"` StartTime, EndTime *time.Time `yaml:"-"` - DescribeLogStreamsLimit *int64 `yaml:"describelogstreams_limit,omitempty"` //batch size for DescribeLogStreamsPagesWithContext + DescribeLogStreamsLimit *int64 `yaml:"describelogstreams_limit,omitempty"` // batch size for DescribeLogStreamsPagesWithContext GetLogEventsPagesLimit *int64 `yaml:"getlogeventspages_limit,omitempty"` - PollNewStreamInterval *time.Duration `yaml:"poll_new_stream_interval,omitempty"` //frequency at which we poll for new streams within the log group - MaxStreamAge *time.Duration `yaml:"max_stream_age,omitempty"` //monitor only streams that have been updated within $duration - PollStreamInterval *time.Duration `yaml:"poll_stream_interval,omitempty"` //frequency at which we poll each stream - StreamReadTimeout *time.Duration `yaml:"stream_read_timeout,omitempty"` //stop monitoring streams that haven't been updated within $duration, might be reopened later tho + PollNewStreamInterval *time.Duration `yaml:"poll_new_stream_interval,omitempty"` // frequency at which we poll for new streams within the log group + MaxStreamAge *time.Duration `yaml:"max_stream_age,omitempty"` // monitor only streams that have been updated within $duration + PollStreamInterval *time.Duration `yaml:"poll_stream_interval,omitempty"` // frequency at which we poll each stream + StreamReadTimeout *time.Duration `yaml:"stream_read_timeout,omitempty"` // stop monitoring streams that haven't been updated within $duration, might be reopened later tho AwsApiCallTimeout *time.Duration `yaml:"aws_api_timeout,omitempty"` AwsProfile *string `yaml:"aws_profile,omitempty"` PrependCloudwatchTimestamp *bool `yaml:"prepend_cloudwatch_timestamp,omitempty"` @@ -86,7 +86,7 @@ type LogStreamTailConfig struct { logger *log.Entry ExpectMode int t tomb.Tomb - StartTime, EndTime time.Time //only used for CatMode + StartTime, EndTime time.Time // only used for CatMode } var ( @@ -111,7 +111,7 @@ func (cw *CloudwatchSource) UnmarshalConfig(yamlConfig []byte) error { return fmt.Errorf("cannot parse CloudwatchSource configuration: %w", err) } - if len(cw.Config.GroupName) == 0 { + if cw.Config.GroupName == "" { return errors.New("group_name is mandatory for CloudwatchSource") } @@ -159,6 +159,7 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, Metr if err != nil { return err } + cw.metricsLevel = MetricsLevel cw.logger = logger.WithField("group", cw.Config.GroupName) @@ -175,16 +176,18 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, Metr if *cw.Config.MaxStreamAge > *cw.Config.StreamReadTimeout { cw.logger.Warningf("max_stream_age > stream_read_timeout, stream might keep being opened/closed") } + cw.logger.Tracef("aws_config_dir set to %s", *cw.Config.AwsConfigDir) if *cw.Config.AwsConfigDir != "" { _, err := os.Stat(*cw.Config.AwsConfigDir) if err != nil { cw.logger.Errorf("can't read aws_config_dir '%s' got err %s", *cw.Config.AwsConfigDir, err) - return fmt.Errorf("can't read aws_config_dir %s got err %s ", *cw.Config.AwsConfigDir, err) + return fmt.Errorf("can't read aws_config_dir %s got err %w ", *cw.Config.AwsConfigDir, err) } + os.Setenv("AWS_SDK_LOAD_CONFIG", "1") - //as aws sdk relies on $HOME, let's allow the user to override it :) + // as aws sdk relies on $HOME, let's allow the user to override it :) os.Setenv("AWS_CONFIG_FILE", fmt.Sprintf("%s/config", *cw.Config.AwsConfigDir)) os.Setenv("AWS_SHARED_CREDENTIALS_FILE", fmt.Sprintf("%s/credentials", *cw.Config.AwsConfigDir)) } else { @@ -192,25 +195,30 @@ func (cw *CloudwatchSource) Configure(yamlConfig []byte, logger *log.Entry, Metr cw.logger.Errorf("aws_region is not specified, specify it or aws_config_dir") return errors.New("aws_region is not specified, specify it or aws_config_dir") } + os.Setenv("AWS_REGION", *cw.Config.AwsRegion) } if err := cw.newClient(); err != nil { return err } + cw.streamIndexes = make(map[string]string) targetStream := "*" + if cw.Config.StreamRegexp != nil { if _, err := regexp.Compile(*cw.Config.StreamRegexp); err != nil { return fmt.Errorf("while compiling regexp '%s': %w", *cw.Config.StreamRegexp, err) } + targetStream = *cw.Config.StreamRegexp } else if cw.Config.StreamName != nil { targetStream = *cw.Config.StreamName } cw.logger.Infof("Adding cloudwatch group '%s' (stream:%s) to datasources", cw.Config.GroupName, targetStream) + return nil } @@ -231,25 +239,30 @@ func (cw *CloudwatchSource) newClient() error { if sess == nil { return errors.New("failed to create aws session") } + if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { cw.logger.Debugf("[testing] overloading endpoint with %s", v) cw.cwClient = cloudwatchlogs.New(sess, aws.NewConfig().WithEndpoint(v)) } else { cw.cwClient = cloudwatchlogs.New(sess) } + if cw.cwClient == nil { return errors.New("failed to create cloudwatch client") } + return nil } -func (cw *CloudwatchSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (cw *CloudwatchSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { cw.t = t monitChan := make(chan LogStreamTailConfig) + t.Go(func() error { - return cw.LogStreamManager(monitChan, out) + return cw.LogStreamManager(ctx, monitChan, out) }) - return cw.WatchLogGroupForStreams(monitChan) + + return cw.WatchLogGroupForStreams(ctx, monitChan) } func (cw *CloudwatchSource) GetMetrics() []prometheus.Collector { @@ -276,9 +289,10 @@ func (cw *CloudwatchSource) Dump() interface{} { return cw } -func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig) error { +func (cw *CloudwatchSource) WatchLogGroupForStreams(ctx context.Context, out chan LogStreamTailConfig) error { cw.logger.Debugf("Starting to watch group (interval:%s)", cw.Config.PollNewStreamInterval) ticker := time.NewTicker(*cw.Config.PollNewStreamInterval) + var startFrom *string for { @@ -289,11 +303,11 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig case <-ticker.C: hasMoreStreams := true startFrom = nil + for hasMoreStreams { cw.logger.Tracef("doing the call to DescribeLogStreamsPagesWithContext") - ctx := context.Background() - //there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime + // there can be a lot of streams in a group, and we're only interested in those recently written to, so we sort by LastEventTime err := cw.cwClient.DescribeLogStreamsPagesWithContext( ctx, &cloudwatchlogs.DescribeLogStreamsInput{ @@ -305,13 +319,14 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig }, func(page *cloudwatchlogs.DescribeLogStreamsOutput, lastPage bool) bool { cw.logger.Tracef("in helper of DescribeLogStreamsPagesWithContext") + for _, event := range page.LogStreams { startFrom = page.NextToken - //we check if the stream has been written to recently enough to be monitored + // we check if the stream has been written to recently enough to be monitored if event.LastIngestionTime != nil { - //aws uses millisecond since the epoch + // aws uses millisecond since the epoch oldest := time.Now().UTC().Add(-*cw.Config.MaxStreamAge) - //TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. + // TBD : verify that this is correct : Unix 2nd arg expects Nanoseconds, and have a code that is more explicit. LastIngestionTime := time.Unix(0, *event.LastIngestionTime*int64(time.Millisecond)) if LastIngestionTime.Before(oldest) { cw.logger.Tracef("stop iteration, %s reached oldest age, stop (%s < %s)", *event.LogStreamName, LastIngestionTime, time.Now().UTC().Add(-*cw.Config.MaxStreamAge)) @@ -319,7 +334,7 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig return false } cw.logger.Tracef("stream %s is elligible for monitoring", *event.LogStreamName) - //the stream has been updated recently, check if we should monitor it + // the stream has been updated recently, check if we should monitor it var expectMode int if !cw.Config.UseTimeMachine { expectMode = types.LIVE @@ -356,8 +371,7 @@ func (cw *CloudwatchSource) WatchLogGroupForStreams(out chan LogStreamTailConfig } // LogStreamManager receives the potential streams to monitor, and starts a go routine when needed -func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outChan chan types.Event) error { - +func (cw *CloudwatchSource) LogStreamManager(ctx context.Context, in chan LogStreamTailConfig, outChan chan types.Event) error { cw.logger.Debugf("starting to monitor streams for %s", cw.Config.GroupName) pollDeadStreamInterval := time.NewTicker(def_PollDeadStreamInterval) @@ -384,7 +398,7 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha for idx, stream := range cw.monitoredStreams { if newStream.GroupName == stream.GroupName && newStream.StreamName == stream.StreamName { - //stream exists, but is dead, remove it from list + // stream exists, but is dead, remove it from list if !stream.t.Alive() { cw.logger.Debugf("stream %s already exists, but is dead", newStream.StreamName) cw.monitoredStreams = append(cw.monitoredStreams[:idx], cw.monitoredStreams[idx+1:]...) @@ -398,7 +412,7 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha } } - //let's start watching this stream + // let's start watching this stream if shouldCreate { if cw.metricsLevel != configuration.METRICS_NONE { openedStreams.With(prometheus.Labels{"group": newStream.GroupName}).Inc() @@ -407,7 +421,7 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha newStream.logger = cw.logger.WithField("stream", newStream.StreamName) cw.logger.Debugf("starting tail of stream %s", newStream.StreamName) newStream.t.Go(func() error { - return cw.TailLogStream(&newStream, outChan) + return cw.TailLogStream(ctx, &newStream, outChan) }) cw.monitoredStreams = append(cw.monitoredStreams, &newStream) } @@ -442,11 +456,11 @@ func (cw *CloudwatchSource) LogStreamManager(in chan LogStreamTailConfig, outCha } } -func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan types.Event) error { +func (cw *CloudwatchSource) TailLogStream(ctx context.Context, cfg *LogStreamTailConfig, outChan chan types.Event) error { var startFrom *string lastReadMessage := time.Now().UTC() ticker := time.NewTicker(cfg.PollStreamInterval) - //resume at existing index if we already had + // resume at existing index if we already had streamIndexMutex.Lock() v := cw.streamIndexes[cfg.GroupName+"+"+cfg.StreamName] streamIndexMutex.Unlock() @@ -464,7 +478,6 @@ func (cw *CloudwatchSource) TailLogStream(cfg *LogStreamTailConfig, outChan chan for hasMorePages { /*for the first call, we only consume the last item*/ cfg.logger.Tracef("calling GetLogEventsPagesWithContext") - ctx := context.Background() err := cw.cwClient.GetLogEventsPagesWithContext(ctx, &cloudwatchlogs.GetLogEventsInput{ Limit: aws.Int64(cfg.GetLogEventsPagesLimit), @@ -567,7 +580,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, if len(v) != 1 { return errors.New("expected zero or one argument for 'start_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, startDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.StartTime = &startDate @@ -575,7 +588,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, if len(v) != 1 { return errors.New("expected zero or one argument for 'end_date'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported strdate, endDate := parser.GenDateParse(v[0]) cw.logger.Debugf("parsed '%s' as '%s'", v[0], strdate) cw.Config.EndTime = &endDate @@ -583,7 +596,7 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, if len(v) != 1 { return errors.New("expected zero or one argument for 'backlog'") } - //let's reuse our parser helper so that a ton of date formats are supported + // let's reuse our parser helper so that a ton of date formats are supported duration, err := time.ParseDuration(v[0]) if err != nil { return fmt.Errorf("unable to parse '%s' as duration: %w", v[0], err) @@ -618,8 +631,8 @@ func (cw *CloudwatchSource) ConfigureByDSN(dsn string, labels map[string]string, return nil } -func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { - //StreamName string, Start time.Time, End time.Time +func (cw *CloudwatchSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + // StreamName string, Start time.Time, End time.Time config := LogStreamTailConfig{ GroupName: cw.Config.GroupName, StreamName: *cw.Config.StreamName, @@ -633,12 +646,12 @@ func (cw *CloudwatchSource) OneShotAcquisition(out chan types.Event, t *tomb.Tom Labels: cw.Config.Labels, ExpectMode: types.TIMEMACHINE, } - return cw.CatLogStream(&config, out) + return cw.CatLogStream(ctx, &config, out) } -func (cw *CloudwatchSource) CatLogStream(cfg *LogStreamTailConfig, outChan chan types.Event) error { +func (cw *CloudwatchSource) CatLogStream(ctx context.Context, cfg *LogStreamTailConfig, outChan chan types.Event) error { var startFrom *string - var head = true + head := true /*convert the times*/ startTime := cfg.StartTime.UTC().Unix() * 1000 endTime := cfg.EndTime.UTC().Unix() * 1000 @@ -652,7 +665,6 @@ func (cw *CloudwatchSource) CatLogStream(cfg *LogStreamTailConfig, outChan chan if startFrom != nil { cfg.logger.Tracef("next_token: %s", *startFrom) } - ctx := context.Background() err := cw.cwClient.GetLogEventsPagesWithContext(ctx, &cloudwatchlogs.GetLogEventsInput{ Limit: aws.Int64(10), @@ -698,7 +710,7 @@ func (cw *CloudwatchSource) CatLogStream(cfg *LogStreamTailConfig, outChan chan func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) (types.Event, error) { l := types.Line{} - evt := types.Event{} + evt := types.MakeEvent(cfg.ExpectMode == types.TIMEMACHINE, types.LOG, true) if log.Message == nil { return evt, errors.New("nil message") } @@ -714,9 +726,6 @@ func cwLogToEvent(log *cloudwatchlogs.OutputLogEvent, cfg *LogStreamTailConfig) l.Process = true l.Module = "cloudwatch" evt.Line = l - evt.Process = true - evt.Type = types.LOG - evt.ExpectMode = cfg.ExpectMode cfg.logger.Debugf("returned event labels : %+v", evt.Line.Labels) return evt, nil } diff --git a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go index bab7593f26f..3d638896537 100644 --- a/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go +++ b/pkg/acquisition/modules/cloudwatch/cloudwatch_test.go @@ -1,6 +1,7 @@ package cloudwatchacquisition import ( + "context" "errors" "fmt" "net" @@ -34,6 +35,7 @@ func deleteAllLogGroups(t *testing.T, cw *CloudwatchSource) { input := &cloudwatchlogs.DescribeLogGroupsInput{} result, err := cw.cwClient.DescribeLogGroups(input) require.NoError(t, err) + for _, group := range result.LogGroups { _, err := cw.cwClient.DeleteLogGroup(&cloudwatchlogs.DeleteLogGroupInput{ LogGroupName: group.LogGroupName, @@ -62,18 +64,22 @@ func TestMain(m *testing.M) { if runtime.GOOS == "windows" { os.Exit(0) } + if err := checkForLocalStackAvailability(); err != nil { log.Fatalf("local stack error : %s", err) } + def_PollNewStreamInterval = 1 * time.Second def_PollStreamInterval = 1 * time.Second def_StreamReadTimeout = 10 * time.Second def_MaxStreamAge = 5 * time.Second def_PollDeadStreamInterval = 5 * time.Second + os.Exit(m.Run()) } func TestWatchLogGroupForStreams(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -447,7 +453,7 @@ stream_name: test_stream`), dbgLogger.Infof("running StreamingAcquisition") actmb := tomb.Tomb{} actmb.Go(func() error { - err := cw.StreamingAcquisition(out, &actmb) + err := cw.StreamingAcquisition(ctx, out, &actmb) dbgLogger.Infof("acquis done") cstest.RequireErrorContains(t, err, tc.expectedStartErr) return nil @@ -503,7 +509,6 @@ stream_name: test_stream`), if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) @@ -513,6 +518,7 @@ stream_name: test_stream`), } func TestConfiguration(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -571,9 +577,9 @@ stream_name: test_stream`), switch cw.GetMode() { case "tail": - err = cw.StreamingAcquisition(out, &tmb) + err = cw.StreamingAcquisition(ctx, out, &tmb) case "cat": - err = cw.OneShotAcquisition(out, &tmb) + err = cw.OneShotAcquisition(ctx, out, &tmb) } cstest.RequireErrorContains(t, err, tc.expectedStartErr) @@ -631,6 +637,8 @@ func TestConfigureByDSN(t *testing.T) { } func TestOneShotAcquisition(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -762,7 +770,7 @@ func TestOneShotAcquisition(t *testing.T) { var rcvdEvts []types.Event dbgLogger.Infof("running StreamingAcquisition") - err = cw.OneShotAcquisition(out, &tmb) + err = cw.OneShotAcquisition(ctx, out, &tmb) dbgLogger.Infof("acquis done") cstest.RequireErrorContains(t, err, tc.expectedStartErr) close(out) @@ -798,7 +806,6 @@ func TestOneShotAcquisition(t *testing.T) { if len(res) != 0 { t.Fatalf("leftover unmatched results : %v", res) } - } if tc.teardown != nil { tc.teardown(t, &cw) diff --git a/pkg/acquisition/modules/docker/docker.go b/pkg/acquisition/modules/docker/docker.go index 9a6e13feee4..b27255ec13f 100644 --- a/pkg/acquisition/modules/docker/docker.go +++ b/pkg/acquisition/modules/docker/docker.go @@ -286,9 +286,9 @@ func (d *DockerSource) SupportedModes() []string { } // OneShotAcquisition reads a set of file and returns when done -func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (d *DockerSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { d.logger.Debug("In oneshot") - runningContainer, err := d.Client.ContainerList(context.Background(), dockerTypes.ContainerListOptions{}) + runningContainer, err := d.Client.ContainerList(ctx, dockerTypes.ContainerListOptions{}) if err != nil { return err } @@ -298,10 +298,10 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er d.logger.Debugf("container with id %s is already being read from", container.ID) continue } - if containerConfig := d.EvalContainer(container); containerConfig != nil { + if containerConfig := d.EvalContainer(ctx, container); containerConfig != nil { d.logger.Infof("reading logs from container %s", containerConfig.Name) d.logger.Debugf("logs options: %+v", *d.containerLogsOptions) - dockerReader, err := d.Client.ContainerLogs(context.Background(), containerConfig.ID, *d.containerLogsOptions) + dockerReader, err := d.Client.ContainerLogs(ctx, containerConfig.ID, *d.containerLogsOptions) if err != nil { d.logger.Errorf("unable to read logs from container: %+v", err) return err @@ -334,7 +334,10 @@ func (d *DockerSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) er if d.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"source": containerConfig.Name}).Inc() } - evt := types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + evt := types.MakeEvent(true, types.LOG, true) + evt.Line = l + evt.Process = true + evt.Type = types.LOG out <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) } @@ -372,58 +375,56 @@ func (d *DockerSource) CanRun() error { return nil } -func (d *DockerSource) getContainerTTY(containerId string) bool { - containerDetails, err := d.Client.ContainerInspect(context.Background(), containerId) +func (d *DockerSource) getContainerTTY(ctx context.Context, containerId string) bool { + containerDetails, err := d.Client.ContainerInspect(ctx, containerId) if err != nil { return false } return containerDetails.Config.Tty } -func (d *DockerSource) getContainerLabels(containerId string) map[string]interface{} { - containerDetails, err := d.Client.ContainerInspect(context.Background(), containerId) +func (d *DockerSource) getContainerLabels(ctx context.Context, containerId string) map[string]interface{} { + containerDetails, err := d.Client.ContainerInspect(ctx, containerId) if err != nil { return map[string]interface{}{} } return parseLabels(containerDetails.Config.Labels) } -func (d *DockerSource) EvalContainer(container dockerTypes.Container) *ContainerConfig { +func (d *DockerSource) EvalContainer(ctx context.Context, container dockerTypes.Container) *ContainerConfig { for _, containerID := range d.Config.ContainerID { if containerID == container.ID { - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)} } } for _, containerName := range d.Config.ContainerName { for _, name := range container.Names { - if strings.HasPrefix(name, "/") && len(name) > 0 { + if strings.HasPrefix(name, "/") && name != "" { name = name[1:] } if name == containerName { - return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} + return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)} } } - } for _, cont := range d.compiledContainerID { if matched := cont.MatchString(container.ID); matched { - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)} } } for _, cont := range d.compiledContainerName { for _, name := range container.Names { if matched := cont.MatchString(name); matched { - return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(container.ID)} + return &ContainerConfig{ID: container.ID, Name: name, Labels: d.Config.Labels, Tty: d.getContainerTTY(ctx, container.ID)} } } - } if d.Config.UseContainerLabels { - parsedLabels := d.getContainerLabels(container.ID) + parsedLabels := d.getContainerLabels(ctx, container.ID) if len(parsedLabels) == 0 { d.logger.Tracef("container has no 'crowdsec' labels set, ignoring container: %s", container.ID) return nil @@ -460,13 +461,13 @@ func (d *DockerSource) EvalContainer(container dockerTypes.Container) *Container } d.logger.Errorf("label %s is not a string", k) } - return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: labels, Tty: d.getContainerTTY(container.ID)} + return &ContainerConfig{ID: container.ID, Name: container.Names[0], Labels: labels, Tty: d.getContainerTTY(ctx, container.ID)} } return nil } -func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteChan chan *ContainerConfig) error { +func (d *DockerSource) WatchContainer(ctx context.Context, monitChan chan *ContainerConfig, deleteChan chan *ContainerConfig) error { ticker := time.NewTicker(d.CheckIntervalDuration) d.logger.Infof("Container watcher started, interval: %s", d.CheckIntervalDuration.String()) for { @@ -477,7 +478,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha case <-ticker.C: // to track for garbage collection runningContainersID := make(map[string]bool) - runningContainer, err := d.Client.ContainerList(context.Background(), dockerTypes.ContainerListOptions{}) + runningContainer, err := d.Client.ContainerList(ctx, dockerTypes.ContainerListOptions{}) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "cannot connect to the docker daemon at") { for idx, container := range d.runningContainerState { @@ -503,7 +504,7 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha if _, ok := d.runningContainerState[container.ID]; ok { continue } - if containerConfig := d.EvalContainer(container); containerConfig != nil { + if containerConfig := d.EvalContainer(ctx, container); containerConfig != nil { monitChan <- containerConfig } } @@ -520,16 +521,16 @@ func (d *DockerSource) WatchContainer(monitChan chan *ContainerConfig, deleteCha } } -func (d *DockerSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (d *DockerSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { d.t = t monitChan := make(chan *ContainerConfig) deleteChan := make(chan *ContainerConfig) d.logger.Infof("Starting docker acquisition") t.Go(func() error { - return d.DockerManager(monitChan, deleteChan, out) + return d.DockerManager(ctx, monitChan, deleteChan, out) }) - return d.WatchContainer(monitChan, deleteChan) + return d.WatchContainer(ctx, monitChan, deleteChan) } func (d *DockerSource) Dump() interface{} { @@ -543,9 +544,9 @@ func ReadTailScanner(scanner *bufio.Scanner, out chan string, t *tomb.Tomb) erro return scanner.Err() } -func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types.Event, deleteChan chan *ContainerConfig) error { +func (d *DockerSource) TailDocker(ctx context.Context, container *ContainerConfig, outChan chan types.Event, deleteChan chan *ContainerConfig) error { container.logger.Infof("start tail for container %s", container.Name) - dockerReader, err := d.Client.ContainerLogs(context.Background(), container.ID, *d.containerLogsOptions) + dockerReader, err := d.Client.ContainerLogs(ctx, container.ID, *d.containerLogsOptions) if err != nil { container.logger.Errorf("unable to read logs from container: %+v", err) return err @@ -581,21 +582,17 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types l.Src = container.Name l.Process = true l.Module = d.GetName() - var evt types.Event - if !d.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(d.Config.UseTimeMachine, types.LOG, true) + evt.Line = l linesRead.With(prometheus.Labels{"source": container.Name}).Inc() outChan <- evt d.logger.Debugf("Sent line to parsing: %+v", evt.Line.Raw) case <-readerTomb.Dying(): - //This case is to handle temporarily losing the connection to the docker socket - //The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) + // This case is to handle temporarily losing the connection to the docker socket + // The only known case currently is when using docker-socket-proxy (and maybe a docker daemon restart) d.logger.Debugf("readerTomb dying for container %s, removing it from runningContainerState", container.Name) deleteChan <- container - //Also reset the Since to avoid re-reading logs + // Also reset the Since to avoid re-reading logs d.Config.Since = time.Now().UTC().Format(time.RFC3339) d.containerLogsOptions.Since = d.Config.Since return nil @@ -603,7 +600,7 @@ func (d *DockerSource) TailDocker(container *ContainerConfig, outChan chan types } } -func (d *DockerSource) DockerManager(in chan *ContainerConfig, deleteChan chan *ContainerConfig, outChan chan types.Event) error { +func (d *DockerSource) DockerManager(ctx context.Context, in chan *ContainerConfig, deleteChan chan *ContainerConfig, outChan chan types.Event) error { d.logger.Info("DockerSource Manager started") for { select { @@ -612,7 +609,7 @@ func (d *DockerSource) DockerManager(in chan *ContainerConfig, deleteChan chan * newContainer.t = &tomb.Tomb{} newContainer.logger = d.logger.WithField("container_name", newContainer.Name) newContainer.t.Go(func() error { - return d.TailDocker(newContainer, outChan, deleteChan) + return d.TailDocker(ctx, newContainer, outChan, deleteChan) }) d.runningContainerState[newContainer.ID] = newContainer } diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index e332569fb3a..5d8208637e8 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -120,6 +120,7 @@ type mockDockerCli struct { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() log.SetOutput(os.Stdout) log.SetLevel(log.InfoLevel) log.Info("Test 'TestStreamingAcquisition'") @@ -185,7 +186,7 @@ container_name_regexp: readerTomb := &tomb.Tomb{} streamTomb := tomb.Tomb{} streamTomb.Go(func() error { - return dockerSource.StreamingAcquisition(out, &dockerTomb) + return dockerSource.StreamingAcquisition(ctx, out, &dockerTomb) }) readerTomb.Go(func() error { time.Sleep(1 * time.Second) @@ -245,7 +246,7 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o for _, line := range data { startLineByte := make([]byte, 8) - binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream + binary.LittleEndian.PutUint32(startLineByte, 1) // stdout stream binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) ret += fmt.Sprintf("%s%s", startLineByte, line) } @@ -266,6 +267,8 @@ func (cli *mockDockerCli) ContainerInspect(ctx context.Context, c string) (docke } func TestOneShot(t *testing.T) { + ctx := context.Background() + log.Infof("Test 'TestOneShot'") tests := []struct { @@ -320,7 +323,7 @@ func TestOneShot(t *testing.T) { dockerClient.Client = new(mockDockerCli) out := make(chan types.Event, 100) tomb := tomb.Tomb{} - err := dockerClient.OneShotAcquisition(out, &tomb) + err := dockerClient.OneShotAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) // else we do the check before actualLines is incremented ... diff --git a/pkg/acquisition/modules/docker/utils.go b/pkg/acquisition/modules/docker/utils.go index c724f581194..6a0d494097f 100644 --- a/pkg/acquisition/modules/docker/utils.go +++ b/pkg/acquisition/modules/docker/utils.go @@ -22,7 +22,7 @@ func parseKeyToMap(m map[string]interface{}, key string, value string) { return } - for i := range len(parts) { + for i := range parts { if parts[i] == "" { return } diff --git a/pkg/acquisition/modules/file/file.go b/pkg/acquisition/modules/file/file.go index c36672507db..9f439b0c82e 100644 --- a/pkg/acquisition/modules/file/file.go +++ b/pkg/acquisition/modules/file/file.go @@ -3,6 +3,7 @@ package fileacquisition import ( "bufio" "compress/gzip" + "context" "errors" "fmt" "io" @@ -73,7 +74,7 @@ func (f *FileSource) UnmarshalConfig(yamlConfig []byte) error { f.logger.Tracef("FileAcquisition configuration: %+v", f.config) } - if len(f.config.Filename) != 0 { + if f.config.Filename != "" { f.config.Filenames = append(f.config.Filenames, f.config.Filename) } @@ -202,11 +203,11 @@ func (f *FileSource) ConfigureByDSN(dsn string, labels map[string]string, logger args := strings.Split(dsn, "?") - if len(args[0]) == 0 { + if args[0] == "" { return errors.New("empty file:// DSN") } - if len(args) == 2 && len(args[1]) != 0 { + if len(args) == 2 && args[1] != "" { params, err := url.ParseQuery(args[1]) if err != nil { return fmt.Errorf("could not parse file args: %w", err) @@ -279,7 +280,7 @@ func (f *FileSource) SupportedModes() []string { } // OneShotAcquisition reads a set of file and returns when done -func (f *FileSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *FileSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("In oneshot") for _, file := range f.files { @@ -320,7 +321,7 @@ func (f *FileSource) CanRun() error { return nil } -func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (f *FileSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { f.logger.Debug("Starting live acquisition") t.Go(func() error { return f.monitorNewFiles(out, t) @@ -385,7 +386,6 @@ func (f *FileSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) er } filink, err := os.Lstat(file) - if err != nil { f.logger.Errorf("Could not lstat() new file %s, ignoring it : %s", file, err) continue @@ -427,118 +427,122 @@ func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error { return nil } - if event.Op&fsnotify.Create == fsnotify.Create { - fi, err := os.Stat(event.Name) - if err != nil { - logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err) - continue - } + if event.Op&fsnotify.Create != fsnotify.Create { + continue + } - if fi.IsDir() { - continue - } + fi, err := os.Stat(event.Name) + if err != nil { + logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err) + continue + } - logger.Debugf("Detected new file %s", event.Name) + if fi.IsDir() { + continue + } - matched := false + logger.Debugf("Detected new file %s", event.Name) - for _, pattern := range f.config.Filenames { - logger.Debugf("Matching %s with %s", pattern, event.Name) + matched := false - matched, err = filepath.Match(pattern, event.Name) - if err != nil { - logger.Errorf("Could not match pattern : %s", err) - continue - } + for _, pattern := range f.config.Filenames { + logger.Debugf("Matching %s with %s", pattern, event.Name) - if matched { - logger.Debugf("Matched %s with %s", pattern, event.Name) - break - } + matched, err = filepath.Match(pattern, event.Name) + if err != nil { + logger.Errorf("Could not match pattern : %s", err) + continue } - if !matched { - continue + if matched { + logger.Debugf("Matched %s with %s", pattern, event.Name) + break } + } - // before opening the file, check if we need to specifically avoid it. (XXX) - skip := false + if !matched { + continue + } - for _, pattern := range f.exclude_regexps { - if pattern.MatchString(event.Name) { - f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String()) + // before opening the file, check if we need to specifically avoid it. (XXX) + skip := false - skip = true + for _, pattern := range f.exclude_regexps { + if pattern.MatchString(event.Name) { + f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String()) - break - } - } + skip = true - if skip { - continue + break } + } - f.tailMapMutex.RLock() - if f.tails[event.Name] { - f.tailMapMutex.RUnlock() - // we already have a tail on it, do not start a new one - logger.Debugf("Already tailing file %s, not creating a new tail", event.Name) + if skip { + continue + } - break - } + f.tailMapMutex.RLock() + if f.tails[event.Name] { f.tailMapMutex.RUnlock() - // cf. https://github.com/crowdsecurity/crowdsec/issues/1168 - // do not rely on stat, reclose file immediately as it's opened by Tail - fd, err := os.Open(event.Name) - if err != nil { - f.logger.Errorf("unable to read %s : %s", event.Name, err) - continue - } - if err := fd.Close(); err != nil { - f.logger.Errorf("unable to close %s : %s", event.Name, err) - continue - } + // we already have a tail on it, do not start a new one + logger.Debugf("Already tailing file %s, not creating a new tail", event.Name) - pollFile := false - if f.config.PollWithoutInotify != nil { - pollFile = *f.config.PollWithoutInotify - } else { - networkFS, fsType, err := types.IsNetworkFS(event.Name) - if err != nil { - f.logger.Warningf("Could not get fs type for %s : %s", event.Name, err) - } - f.logger.Debugf("fs for %s is network: %t (%s)", event.Name, networkFS, fsType) - if networkFS { - pollFile = true - } - } + break + } + f.tailMapMutex.RUnlock() + // cf. https://github.com/crowdsecurity/crowdsec/issues/1168 + // do not rely on stat, reclose file immediately as it's opened by Tail + fd, err := os.Open(event.Name) + if err != nil { + f.logger.Errorf("unable to read %s : %s", event.Name, err) + continue + } - filink, err := os.Lstat(event.Name) + if err = fd.Close(); err != nil { + f.logger.Errorf("unable to close %s : %s", event.Name, err) + continue + } + pollFile := false + if f.config.PollWithoutInotify != nil { + pollFile = *f.config.PollWithoutInotify + } else { + networkFS, fsType, err := types.IsNetworkFS(event.Name) if err != nil { - logger.Errorf("Could not lstat() new file %s, ignoring it : %s", event.Name, err) - continue + f.logger.Warningf("Could not get fs type for %s : %s", event.Name, err) } - if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile { - logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", event.Name) - } + f.logger.Debugf("fs for %s is network: %t (%s)", event.Name, networkFS, fsType) - //Slightly different parameters for Location, as we want to read the first lines of the newly created file - tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}}) - if err != nil { - logger.Errorf("Could not start tailing file %s : %s", event.Name, err) - break + if networkFS { + pollFile = true } + } - f.tailMapMutex.Lock() - f.tails[event.Name] = true - f.tailMapMutex.Unlock() - t.Go(func() error { - defer trace.CatchPanic("crowdsec/acquis/tailfile") - return f.tailFile(out, t, tail) - }) + filink, err := os.Lstat(event.Name) + if err != nil { + logger.Errorf("Could not lstat() new file %s, ignoring it : %s", event.Name, err) + continue } + + if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile { + logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", event.Name) + } + + // Slightly different parameters for Location, as we want to read the first lines of the newly created file + tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}}) + if err != nil { + logger.Errorf("Could not start tailing file %s : %s", event.Name, err) + break + } + + f.tailMapMutex.Lock() + f.tails[event.Name] = true + f.tailMapMutex.Unlock() + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/tailfile") + return f.tailFile(out, t, tail) + }) case err, ok := <-f.watcher.Errors: if !ok { return nil @@ -572,13 +576,14 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai return nil case <-tail.Dying(): // our tailer is dying - err := tail.Err() errMsg := fmt.Sprintf("file reader of %s died", tail.Filename) + + err := tail.Err() if err != nil { errMsg = fmt.Sprintf(errMsg+" : %s", err) } - logger.Warningf(errMsg) + logger.Warning(errMsg) return nil case line := <-tail.Lines: @@ -616,11 +621,9 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai // we're tailing, it must be real time logs logger.Debugf("pushing %+v", l) - expectMode := types.LIVE - if f.config.UseTimeMachine { - expectMode = types.TIMEMACHINE - } - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: expectMode} + evt := types.MakeEvent(f.config.UseTimeMachine, types.LOG, true) + evt.Line = l + out <- evt } } } @@ -629,8 +632,8 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom var scanner *bufio.Scanner logger := f.logger.WithField("oneshot", filename) - fd, err := os.Open(filename) + fd, err := os.Open(filename) if err != nil { return fmt.Errorf("failed opening %s: %w", filename, err) } @@ -679,7 +682,7 @@ func (f *FileSource) readFile(filename string, out chan types.Event, t *tomb.Tom linesRead.With(prometheus.Labels{"source": filename}).Inc() // we're reading logs at once, it must be time-machine buckets - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE, Unmarshaled: make(map[string]interface{})} } } diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index 5d38552b3c5..a26e44cc9c7 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,6 +1,7 @@ package fileacquisition_test import ( + "context" "fmt" "os" "runtime" @@ -100,6 +101,8 @@ func TestConfigureDSN(t *testing.T) { } func TestOneShot(t *testing.T) { + ctx := context.Background() + permDeniedFile := "/etc/shadow" permDeniedError := "failed opening /etc/shadow: open /etc/shadow: permission denied" @@ -223,7 +226,7 @@ filename: test_files/test_delete.log`, if tc.afterConfigure != nil { tc.afterConfigure() } - err = f.OneShotAcquisition(out, &tomb) + err = f.OneShotAcquisition(ctx, out, &tomb) actualLines := len(out) cstest.RequireErrorContains(t, err, tc.expectedErr) @@ -243,6 +246,7 @@ filename: test_files/test_delete.log`, } func TestLiveAcquisition(t *testing.T) { + ctx := context.Background() permDeniedFile := "/etc/shadow" permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied" testPattern := "test_files/*.log" @@ -394,7 +398,7 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(out, &tomb) + err = f.StreamingAcquisition(ctx, out, &tomb) cstest.RequireErrorContains(t, err, tc.expectedErr) if tc.expectedLines != 0 { diff --git a/pkg/acquisition/modules/http/http.go b/pkg/acquisition/modules/http/http.go new file mode 100644 index 00000000000..98af134c84e --- /dev/null +++ b/pkg/acquisition/modules/http/http.go @@ -0,0 +1,416 @@ +package httpacquisition + +import ( + "compress/gzip" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "time" + + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + + "gopkg.in/tomb.v2" + "gopkg.in/yaml.v3" + + "github.com/crowdsecurity/go-cs-lib/trace" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +var ( + dataSourceName = "http" +) + +var linesRead = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_httpsource_hits_total", + Help: "Total lines that were read from http source", + }, + []string{"path", "src"}) + +type HttpConfiguration struct { + //IPFilter []string `yaml:"ip_filter"` + //ChunkSize *int64 `yaml:"chunk_size"` + ListenAddr string `yaml:"listen_addr"` + Path string `yaml:"path"` + AuthType string `yaml:"auth_type"` + BasicAuth *BasicAuthConfig `yaml:"basic_auth"` + Headers *map[string]string `yaml:"headers"` + TLS *TLSConfig `yaml:"tls"` + CustomStatusCode *int `yaml:"custom_status_code"` + CustomHeaders *map[string]string `yaml:"custom_headers"` + MaxBodySize *int64 `yaml:"max_body_size"` + Timeout *time.Duration `yaml:"timeout"` + configuration.DataSourceCommonCfg `yaml:",inline"` +} + +type BasicAuthConfig struct { + Username string `yaml:"username"` + Password string `yaml:"password"` +} + +type TLSConfig struct { + InsecureSkipVerify bool `yaml:"insecure_skip_verify"` + ServerCert string `yaml:"server_cert"` + ServerKey string `yaml:"server_key"` + CaCert string `yaml:"ca_cert"` +} + +type HTTPSource struct { + metricsLevel int + Config HttpConfiguration + logger *log.Entry + Server *http.Server +} + +func (h *HTTPSource) GetUuid() string { + return h.Config.UniqueId +} + +func (h *HTTPSource) UnmarshalConfig(yamlConfig []byte) error { + h.Config = HttpConfiguration{} + err := yaml.Unmarshal(yamlConfig, &h.Config) + if err != nil { + return fmt.Errorf("cannot parse %s datasource configuration: %w", dataSourceName, err) + } + + if h.Config.Mode == "" { + h.Config.Mode = configuration.TAIL_MODE + } + + return nil +} + +func (hc *HttpConfiguration) Validate() error { + if hc.ListenAddr == "" { + return errors.New("listen_addr is required") + } + + if hc.Path == "" { + hc.Path = "/" + } + if hc.Path[0] != '/' { + return errors.New("path must start with /") + } + + switch hc.AuthType { + case "basic_auth": + baseErr := "basic_auth is selected, but" + if hc.BasicAuth == nil { + return errors.New(baseErr + " basic_auth is not provided") + } + if hc.BasicAuth.Username == "" { + return errors.New(baseErr + " username is not provided") + } + if hc.BasicAuth.Password == "" { + return errors.New(baseErr + " password is not provided") + } + case "headers": + if hc.Headers == nil { + return errors.New("headers is selected, but headers is not provided") + } + case "mtls": + if hc.TLS == nil || hc.TLS.CaCert == "" { + return errors.New("mtls is selected, but ca_cert is not provided") + } + default: + return errors.New("invalid auth_type: must be one of basic_auth, headers, mtls") + } + + if hc.TLS != nil { + if hc.TLS.ServerCert == "" { + return errors.New("server_cert is required") + } + if hc.TLS.ServerKey == "" { + return errors.New("server_key is required") + } + } + + if hc.MaxBodySize != nil && *hc.MaxBodySize <= 0 { + return errors.New("max_body_size must be positive") + } + + /* + if hc.ChunkSize != nil && *hc.ChunkSize <= 0 { + return errors.New("chunk_size must be positive") + } + */ + + if hc.CustomStatusCode != nil { + statusText := http.StatusText(*hc.CustomStatusCode) + if statusText == "" { + return errors.New("invalid HTTP status code") + } + } + + return nil +} + +func (h *HTTPSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLevel int) error { + h.logger = logger + h.metricsLevel = MetricsLevel + err := h.UnmarshalConfig(yamlConfig) + if err != nil { + return err + } + + if err := h.Config.Validate(); err != nil { + return fmt.Errorf("invalid configuration: %w", err) + } + + return nil +} + +func (h *HTTPSource) ConfigureByDSN(string, map[string]string, *log.Entry, string) error { + return fmt.Errorf("%s datasource does not support command-line acquisition", dataSourceName) +} + +func (h *HTTPSource) GetMode() string { + return h.Config.Mode +} + +func (h *HTTPSource) GetName() string { + return dataSourceName +} + +func (h *HTTPSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + return fmt.Errorf("%s datasource does not support one-shot acquisition", dataSourceName) +} + +func (h *HTTPSource) CanRun() error { + return nil +} + +func (h *HTTPSource) GetMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead} +} + +func (h *HTTPSource) GetAggregMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead} +} + +func (h *HTTPSource) Dump() interface{} { + return h +} + +func (hc *HttpConfiguration) NewTLSConfig() (*tls.Config, error) { + tlsConfig := tls.Config{ + InsecureSkipVerify: hc.TLS.InsecureSkipVerify, + } + + if hc.TLS.ServerCert != "" && hc.TLS.ServerKey != "" { + cert, err := tls.LoadX509KeyPair(hc.TLS.ServerCert, hc.TLS.ServerKey) + if err != nil { + return nil, fmt.Errorf("failed to load server cert/key: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + if hc.AuthType == "mtls" && hc.TLS.CaCert != "" { + caCert, err := os.ReadFile(hc.TLS.CaCert) + if err != nil { + return nil, fmt.Errorf("failed to read ca cert: %w", err) + } + + caCertPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("failed to load system cert pool: %w", err) + } + + if caCertPool == nil { + caCertPool = x509.NewCertPool() + } + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + return &tlsConfig, nil +} + +func authorizeRequest(r *http.Request, hc *HttpConfiguration) error { + if hc.AuthType == "basic_auth" { + username, password, ok := r.BasicAuth() + if !ok { + return errors.New("missing basic auth") + } + if username != hc.BasicAuth.Username || password != hc.BasicAuth.Password { + return errors.New("invalid basic auth") + } + } + if hc.AuthType == "headers" { + for key, value := range *hc.Headers { + if r.Header.Get(key) != value { + return errors.New("invalid headers") + } + } + } + return nil +} + +func (h *HTTPSource) processRequest(w http.ResponseWriter, r *http.Request, hc *HttpConfiguration, out chan types.Event) error { + if hc.MaxBodySize != nil && r.ContentLength > *hc.MaxBodySize { + w.WriteHeader(http.StatusRequestEntityTooLarge) + return fmt.Errorf("body size exceeds max body size: %d > %d", r.ContentLength, *hc.MaxBodySize) + } + + srcHost, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return err + } + + defer r.Body.Close() + + reader := r.Body + + if r.Header.Get("Content-Encoding") == "gzip" { + reader, err = gzip.NewReader(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer reader.Close() + } + + decoder := json.NewDecoder(reader) + for { + var message json.RawMessage + + if err := decoder.Decode(&message); err != nil { + if err == io.EOF { + break + } + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("failed to decode: %w", err) + } + + line := types.Line{ + Raw: string(message), + Src: srcHost, + Time: time.Now().UTC(), + Labels: hc.Labels, + Process: true, + Module: h.GetName(), + } + + if h.metricsLevel == configuration.METRICS_AGGREGATE { + line.Src = hc.Path + } + + evt := types.MakeEvent(h.Config.UseTimeMachine, types.LOG, true) + evt.Line = line + + if h.metricsLevel == configuration.METRICS_AGGREGATE { + linesRead.With(prometheus.Labels{"path": hc.Path, "src": ""}).Inc() + } else if h.metricsLevel == configuration.METRICS_FULL { + linesRead.With(prometheus.Labels{"path": hc.Path, "src": srcHost}).Inc() + } + + h.logger.Tracef("line to send: %+v", line) + out <- evt + } + + return nil +} + +func (h *HTTPSource) RunServer(out chan types.Event, t *tomb.Tomb) error { + mux := http.NewServeMux() + mux.HandleFunc(h.Config.Path, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + h.logger.Errorf("method not allowed: %s", r.Method) + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + if err := authorizeRequest(r, &h.Config); err != nil { + h.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + err := h.processRequest(w, r, &h.Config, out) + if err != nil { + h.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) + return + } + + if h.Config.CustomHeaders != nil { + for key, value := range *h.Config.CustomHeaders { + w.Header().Set(key, value) + } + } + if h.Config.CustomStatusCode != nil { + w.WriteHeader(*h.Config.CustomStatusCode) + } else { + w.WriteHeader(http.StatusOK) + } + + w.Write([]byte("OK")) + }) + + h.Server = &http.Server{ + Addr: h.Config.ListenAddr, + Handler: mux, + } + + if h.Config.Timeout != nil { + h.Server.ReadTimeout = *h.Config.Timeout + } + + if h.Config.TLS != nil { + tlsConfig, err := h.Config.NewTLSConfig() + if err != nil { + return fmt.Errorf("failed to create tls config: %w", err) + } + h.logger.Tracef("tls config: %+v", tlsConfig) + h.Server.TLSConfig = tlsConfig + } + + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/http/server") + if h.Config.TLS != nil { + h.logger.Infof("start https server on %s", h.Config.ListenAddr) + err := h.Server.ListenAndServeTLS(h.Config.TLS.ServerCert, h.Config.TLS.ServerKey) + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("https server failed: %w", err) + } + } else { + h.logger.Infof("start http server on %s", h.Config.ListenAddr) + err := h.Server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("http server failed: %w", err) + } + } + return nil + }) + + //nolint //fp + for { + select { + case <-t.Dying(): + h.logger.Infof("%s datasource stopping", dataSourceName) + if err := h.Server.Close(); err != nil { + return fmt.Errorf("while closing %s server: %w", dataSourceName, err) + } + return nil + } + } +} + +func (h *HTTPSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + h.logger.Debugf("start http server on %s", h.Config.ListenAddr) + + t.Go(func() error { + defer trace.CatchPanic("crowdsec/acquis/http/live") + return h.RunServer(out, t) + }) + + return nil +} diff --git a/pkg/acquisition/modules/http/http_test.go b/pkg/acquisition/modules/http/http_test.go new file mode 100644 index 00000000000..f89ba7aa8ba --- /dev/null +++ b/pkg/acquisition/modules/http/http_test.go @@ -0,0 +1,785 @@ +package httpacquisition + +import ( + "compress/gzip" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testHTTPServerAddr = "http://127.0.0.1:8080" + testHTTPServerAddrTLS = "https://127.0.0.1:8080" +) + +func TestConfigure(t *testing.T) { + tests := []struct { + config string + expectedErr string + }{ + { + config: ` +foobar: bla`, + expectedErr: "invalid configuration: listen_addr is required", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: wrongpath`, + expectedErr: "invalid configuration: path must start with /", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth`, + expectedErr: "invalid configuration: basic_auth is selected, but basic_auth is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers`, + expectedErr: "invalid configuration: headers is selected, but headers is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: 132`, + expectedErr: "invalid configuration: basic_auth is selected, but password is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + password: 132`, + expectedErr: "invalid configuration: basic_auth is selected, but username is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers:`, + expectedErr: "invalid configuration: headers is selected, but headers is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: toto`, + expectedErr: "invalid configuration: invalid auth_type: must be one of basic_auth, headers, mtls", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +tls: + server_key: key`, + expectedErr: "invalid configuration: server_cert is required", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +tls: + server_cert: cert`, + expectedErr: "invalid configuration: server_key is required", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: mtls +tls: + server_cert: cert + server_key: key`, + expectedErr: "invalid configuration: mtls is selected, but ca_cert is not provided", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +max_body_size: 0`, + expectedErr: "invalid configuration: max_body_size must be positive", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +timeout: toto`, + expectedErr: "cannot parse http datasource configuration: yaml: unmarshal errors:\n line 8: cannot unmarshal !!str `toto` into time.Duration", + }, + { + config: ` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: value +custom_status_code: 999`, + expectedErr: "invalid configuration: invalid HTTP status code", + }, + } + + subLogger := log.WithFields(log.Fields{ + "type": "http", + }) + + for _, test := range tests { + h := HTTPSource{} + err := h.Configure([]byte(test.config), subLogger, 0) + cstest.AssertErrorContains(t, err, test.expectedErr) + } +} + +func TestGetUuid(t *testing.T) { + h := HTTPSource{} + h.Config.UniqueId = "test" + assert.Equal(t, "test", h.GetUuid()) +} + +func TestUnmarshalConfig(t *testing.T) { + h := HTTPSource{} + err := h.UnmarshalConfig([]byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: 15 + auth_type: headers`)) + cstest.AssertErrorMessage(t, err, "cannot parse http datasource configuration: yaml: line 4: found a tab character that violates indentation") +} + +func TestConfigureByDSN(t *testing.T) { + h := HTTPSource{} + err := h.ConfigureByDSN("http://localhost:8080/test", map[string]string{}, log.WithFields(log.Fields{ + "type": "http", + }), "test") + cstest.AssertErrorMessage( + t, + err, + "http datasource does not support command-line acquisition", + ) +} + +func TestGetMode(t *testing.T) { + h := HTTPSource{} + h.Config.Mode = "test" + assert.Equal(t, "test", h.GetMode()) +} + +func TestGetName(t *testing.T) { + h := HTTPSource{} + assert.Equal(t, "http", h.GetName()) +} + +func SetupAndRunHTTPSource(t *testing.T, h *HTTPSource, config []byte, metricLevel int) (chan types.Event, *tomb.Tomb) { + ctx := context.Background() + subLogger := log.WithFields(log.Fields{ + "type": "http", + }) + err := h.Configure(config, subLogger, metricLevel) + require.NoError(t, err) + tomb := tomb.Tomb{} + out := make(chan types.Event) + err = h.StreamingAcquisition(ctx, out, &tomb) + require.NoError(t, err) + + for _, metric := range h.GetMetrics() { + prometheus.Register(metric) + } + + return out, &tomb +} + +func TestStreamingAcquisitionWrongHTTPMethod(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: test + password: test`), 0) + + time.Sleep(1 * time.Second) + + res, err := http.Get(fmt.Sprintf("%s/test", testHTTPServerAddr)) + require.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, res.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() + +} + +func TestStreamingAcquisitionUnknownPath(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: test + password: test`), 0) + + time.Sleep(1 * time.Second) + + res, err := http.Get(fmt.Sprintf("%s/unknown", testHTTPServerAddr)) + require.NoError(t, err) + assert.Equal(t, http.StatusNotFound, res.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionBasicAuth(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: basic_auth +basic_auth: + username: test + password: test`), 0) + + time.Sleep(1 * time.Second) + + client := &http.Client{} + + resp, err := http.Post(fmt.Sprintf("%s/test", testHTTPServerAddr), "application/json", strings.NewReader("test")) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test")) + require.NoError(t, err) + req.SetBasicAuth("test", "WrongPassword") + + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionBadHeaders(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 0) + + time.Sleep(1 * time.Second) + + client := &http.Client{} + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("test")) + require.NoError(t, err) + + req.Header.Add("Key", "wrong") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionMaxBodySize(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +max_body_size: 5`), 0) + + time.Sleep(1 * time.Second) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader("testtest")) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionSuccess(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + rawEvt := `{"test": "test"}` + + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 1) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +custom_status_code: 201 +custom_headers: + success: true`), 2) + + time.Sleep(1 * time.Second) + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusCreated, resp.StatusCode) + assert.Equal(t, "true", resp.Header.Get("Success")) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 1) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +type slowReader struct { + delay time.Duration + body []byte + index int +} + +func (sr *slowReader) Read(p []byte) (int, error) { + if sr.index >= len(sr.body) { + return 0, io.EOF + } + time.Sleep(sr.delay) // Simulate a delay in reading + n := copy(p, sr.body[sr.index:]) + sr.index += n + return n, nil +} + +func assertEvents(out chan types.Event, expected []string, errChan chan error) { + readLines := []types.Event{} + + for i := 0; i < len(expected); i++ { + select { + case event := <-out: + readLines = append(readLines, event) + case <-time.After(2 * time.Second): + errChan <- errors.New("timeout waiting for event") + return + } + } + + if len(readLines) != len(expected) { + errChan <- fmt.Errorf("expected %d lines, got %d", len(expected), len(readLines)) + return + } + + for i, evt := range readLines { + if evt.Line.Raw != expected[i] { + errChan <- fmt.Errorf(`expected %s, got '%+v'`, expected, evt.Line.Raw) + return + } + if evt.Line.Src != "127.0.0.1" { + errChan <- fmt.Errorf("expected '127.0.0.1', got '%s'", evt.Line.Src) + return + } + if evt.Line.Module != "http" { + errChan <- fmt.Errorf("expected 'http', got '%s'", evt.Line.Module) + return + } + } + errChan <- nil +} + +func TestStreamingAcquisitionTimeout(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +timeout: 1s`), 0) + + time.Sleep(1 * time.Second) + + slow := &slowReader{ + delay: 2 * time.Second, + body: []byte(`{"test": "delayed_payload"}`), + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), slow) + require.NoError(t, err) + + req.Header.Add("Key", "test") + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionTLSHTTPRequest(t *testing.T) { + h := &HTTPSource{} + _, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +auth_type: mtls +path: /test +tls: + server_cert: testdata/server.crt + server_key: testdata/server.key + ca_cert: testdata/ca.crt`), 0) + + time.Sleep(1 * time.Second) + + resp, err := http.Post(fmt.Sprintf("%s/test", testHTTPServerAddr), "application/json", strings.NewReader("test")) + require.NoError(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test +tls: + server_cert: testdata/server.crt + server_key: testdata/server.key +`), 0) + + time.Sleep(1 * time.Second) + + caCert, err := os.ReadFile("testdata/server.crt") + require.NoError(t, err) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + } + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt)) + require.NoError(t, err) + + req.Header.Add("Key", "test") + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 0) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionMTLS(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: mtls +tls: + server_cert: testdata/server.crt + server_key: testdata/server.key + ca_cert: testdata/ca.crt`), 0) + + time.Sleep(1 * time.Second) + + // init client cert + cert, err := tls.LoadX509KeyPair("testdata/client.crt", "testdata/client.key") + require.NoError(t, err) + + caCert, err := os.ReadFile("testdata/ca.crt") + require.NoError(t, err) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt}, errChan) + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddrTLS), strings.NewReader(rawEvt)) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 0) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionGzipData(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + + rawEvt := `{"test": "test"}` + errChan := make(chan error) + go assertEvents(out, []string{rawEvt, rawEvt}, errChan) + + var b strings.Builder + gz := gzip.NewWriter(&b) + + _, err := gz.Write([]byte(rawEvt)) + require.NoError(t, err) + + _, err = gz.Write([]byte(rawEvt)) + require.NoError(t, err) + + err = gz.Close() + require.NoError(t, err) + + // send gzipped compressed data + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(b.String())) + require.NoError(t, err) + + req.Header.Add("Key", "test") + req.Header.Add("Content-Encoding", "gzip") + req.Header.Add("Content-Type", "application/json") + + resp, err := client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 2) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func TestStreamingAcquisitionNDJson(t *testing.T) { + h := &HTTPSource{} + out, tomb := SetupAndRunHTTPSource(t, h, []byte(` +source: http +listen_addr: 127.0.0.1:8080 +path: /test +auth_type: headers +headers: + key: test`), 2) + + time.Sleep(1 * time.Second) + rawEvt := `{"test": "test"}` + + errChan := make(chan error) + go assertEvents(out, []string{rawEvt, rawEvt}, errChan) + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/test", testHTTPServerAddr), strings.NewReader(fmt.Sprintf("%s\n%s\n", rawEvt, rawEvt))) + + require.NoError(t, err) + + req.Header.Add("Key", "test") + req.Header.Add("Content-Type", "application/x-ndjson") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + err = <-errChan + require.NoError(t, err) + + assertMetrics(t, h.GetMetrics(), 2) + + h.Server.Close() + tomb.Kill(nil) + tomb.Wait() +} + +func assertMetrics(t *testing.T, metrics []prometheus.Collector, expected int) { + promMetrics, err := prometheus.DefaultGatherer.Gather() + require.NoError(t, err) + + isExist := false + for _, metricFamily := range promMetrics { + if metricFamily.GetName() == "cs_httpsource_hits_total" { + isExist = true + assert.Len(t, metricFamily.GetMetric(), 1) + for _, metric := range metricFamily.GetMetric() { + assert.InDelta(t, float64(expected), metric.GetCounter().GetValue(), 0.000001) + labels := metric.GetLabel() + assert.Len(t, labels, 2) + assert.Equal(t, "path", labels[0].GetName()) + assert.Equal(t, "/test", labels[0].GetValue()) + assert.Equal(t, "src", labels[1].GetName()) + assert.Equal(t, "127.0.0.1", labels[1].GetValue()) + } + } + } + if !isExist && expected > 0 { + t.Fatalf("expected metric cs_httpsource_hits_total not found") + } + + for _, metric := range metrics { + metric.(*prometheus.CounterVec).Reset() + } +} diff --git a/pkg/acquisition/modules/http/testdata/ca.crt b/pkg/acquisition/modules/http/testdata/ca.crt new file mode 100644 index 00000000000..ac81b9db8a6 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/ca.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIIDvzCCAqegAwIBAgIUHQfsFpWkCy7gAmDa3A6O+y5CvAswDQYJKoZIhvcNAQEL +BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV +BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx +EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDAxMDBaFw0yOTEwMjIxMDAx +MDBaMG8xCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD +VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxETAPBgNVBAsTCENyb3dkc2Vj +MRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQCZSR2/A24bpVHSiEeSlelfdA32uhk9wHkauwy2qxos/G/UmKG/dgWrHzRh +LawlFVHtVn4u7Hjqz2y2EsH3bX42jC5NMVARgXIOBr1dE6F5/bPqA6SoVgkDm9wh +ZBigyAMxYsR4+3ahuf0pQflBShKrLZ1UYoe6tQXob7l3x5vThEhNkBawBkLfWpj7 +7Imm1tGyEZdxCMkT400KRtSmJRrnpiOCUosnacwgp7MCbKWOIOng07Eh16cVUiuI +BthWU/LycIuac2xaD9PFpeK/MpwASRRPXZgPUhiZuaa7vttD0phCdDaS46Oln5/7 +tFRZH0alBZmkpVZJCWAP4ujIA3vLAgMBAAGjUzBRMA4GA1UdDwEB/wQEAwIBBjAP +BgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBTwpg+WN1nZJs4gj5hfoa+fMSZjGTAP +BgNVHREECDAGhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQAZuOWT8zHcwbWvC6Jm +/ccgB/U7SbeIYFJrCZd9mTyqsgnkFNH8yJ5F4dXXtPXr+SO/uWWa3G5hams3qVFf +zWzzPDQdyhUhfh5fjUHR2RsSGBmCxcapYHpVvAP5aY1/ujYrXMvAJV0hfDO2tGHb +rveuJxhe8ymQ1Yb2u9NcmI1HG9IVt3Airz4gAIUJWbFvRigky0bukfddOkfiUiaF +DMPJQO6HAj8d8ctSHHVZWzhAInZ1pDg6HIHYF44m1tT27pSQoi0ZFocskDi/fC2f +EIF0nu5fRLUS6BZEfpnDi9U0lbJ/kUrgT5IFHMFqXdRpDqcnXpJZhYtp5l6GoqjM +gT33 +-----END CERTIFICATE----- diff --git a/pkg/acquisition/modules/http/testdata/client.crt b/pkg/acquisition/modules/http/testdata/client.crt new file mode 100644 index 00000000000..55efdddad09 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/client.crt @@ -0,0 +1,24 @@ +-----BEGIN CERTIFICATE----- +MIID7jCCAtagAwIBAgIUJMTPh3oPJLPgsnb9T85ieb4EuOQwDQYJKoZIhvcNAQEL +BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV +BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx +EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDQ2MDBaFw0yNTEwMjMxMDQ2 +MDBaMHIxCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD +VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxFzAVBgNVBAsTDlRlc3Rpbmcg +Y2xpZW50MQ8wDQYDVQQDEwZjbGllbnQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQDAUOdpRieRrrH6krUjgcjLgJg6TzoWAb/iv6rfcioX1L9bj9fZSkwu +GqKzXX/PceIXElzQgiGJZErbJtnTzhGS80QgtAB8BwWQIT2zgoGcYJf7pPFvmcMM +qMGFwK0dMC+LHPk+ePtFz8dskI2XJ8jgBdtuZcnDblMuVGtjYT6n0rszvRdo118+ +mlGCLPzOfsO1JdOqLWAR88yZfqCFt1TrwmzpRT1crJQeM6i7muw4aO0L7uSek9QM +6APHz0QexSq7/zHOtRjA4jnJbDzZJHRlwOdlsNU9cmTz6uWIQXlg+2ovD55YurNy ++jYfmfDYpimhoeGf54zaETp1fTuTJYpxAgMBAAGjfzB9MA4GA1UdDwEB/wQEAwIF +oDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAd +BgNVHQ4EFgQUmH0/7RuKnoW7sEK4Cr8eVNGbb8swHwYDVR0jBBgwFoAU8KYPljdZ +2SbOII+YX6GvnzEmYxkwDQYJKoZIhvcNAQELBQADggEBAHVn9Zuoyxu9iTFoyJ50 +e/XKcmt2uK2M1x+ap2Av7Wb/Omikx/R2YPq7994BfiUCAezY2YtreZzkE6Io1wNM +qApijEJnlqEmOXiYJqlF89QrCcsAsz6lfaqitYBZSL3o4KT+7/uUDVxgNEjEksRz +9qy6DFBLvyhxbOM2zDEV+MVfemBWSvNiojHqXzDBkZnBHHclJLuIKsXDZDGhKbNd +hsoGU00RLevvcUpUJ3a68ekgwiYFJifm0uyfmao9lmiB3i+8ZW3Q4rbwHtD+U7U2 +3n+U5PkhiUAveuMfrvUMzsTolZiop9ZLtcALDUFaqyr4tjfVOf5+CGjiluio7oE1 +UYg= +-----END CERTIFICATE----- diff --git a/pkg/acquisition/modules/http/testdata/client.key b/pkg/acquisition/modules/http/testdata/client.key new file mode 100644 index 00000000000..f8ef2efbd58 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/client.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAwFDnaUYnka6x+pK1I4HIy4CYOk86FgG/4r+q33IqF9S/W4/X +2UpMLhqis11/z3HiFxJc0IIhiWRK2ybZ084RkvNEILQAfAcFkCE9s4KBnGCX+6Tx +b5nDDKjBhcCtHTAvixz5Pnj7Rc/HbJCNlyfI4AXbbmXJw25TLlRrY2E+p9K7M70X +aNdfPppRgiz8zn7DtSXTqi1gEfPMmX6ghbdU68Js6UU9XKyUHjOou5rsOGjtC+7k +npPUDOgDx89EHsUqu/8xzrUYwOI5yWw82SR0ZcDnZbDVPXJk8+rliEF5YPtqLw+e +WLqzcvo2H5nw2KYpoaHhn+eM2hE6dX07kyWKcQIDAQABAoIBAQChriKuza0MfBri +9x3UCRN/is/wDZVe1P+2KL8F9ZvPxytNVeP4qM7c38WzF8MQ6sRR8z0WiqCZOjj4 +f3QX7iG2MlAvUkUqAFk778ZIuUov5sE/bU8RLOrfJKz1vqOLa2w8/xHH5LwS1/jn +m6t9zZPCSwpMiMSUSZci1xQlS6b6POZMjeqLPqv9cP8PJNv9UNrHFcQnQi1iwKJH +MJ7CQI3R8FSeGad3P7tB9YDaBm7hHmd/TevuFkymcKNT44XBSgddPDfgKui6sHTY +QQWgWI9VGVO350ZBLRLkrk8wboY4vc15qbBzYFG66WiR/tNdLt3rDYxpcXaDvcQy +e47mYNVxAoGBAMFsUmPDssqzmOkmZxHDM+VmgHYPXjDqQdE299FtuClobUW4iU4g +By7o84aCIBQz2sp9f1KM+10lr+Bqw3s7QBbR5M67PA8Zm45DL9t70NR/NZHGzFRD +BR/NMbwzCqNtY2UGDhYQLGhW8heAwsYwir8ZqmOfKTd9aY1pu/S8m9AlAoGBAP6I +483EIN8R5y+beGcGynYeIrH5Gc+W2FxWIW9jh/G7vRbhMlW4z0GxV3uEAYmOlBH2 +AqUkV6+uzU0P4B/m3vCYqLycBVDwifJazDj9nskVL5kGMxia62iwDMXs5nqNS4WJ +ZM5Gl2xIiwmgWnYnujM3eKF2wbm439wj4na80SldAoGANdIqatA9o+GtntKsw2iJ +vD91Z2SHVR0aC1k8Q+4/3GXOYiQjMLYAybDQcpEq0/RJ4SZik1nfZ9/gvJV4p4Wp +I7Br9opq/9ikTEWtv2kIhtiO02151ciAWIUEXdXmE+uQSMASk1kUwkPPQXL2v6cq +NFqz6tyS33nqMQtG3abNxHECgYA4AEA2nmcpDRRTSh50dG8JC9pQU+EU5jhWIHEc +w8Y+LjMNHKDpcU7QQkdgGowICsGTLhAo61ULhycORGboPfBg+QVu8djNlQ6Urttt +0ocj8LBXN6D4UeVnVAyLY3LWFc4+5Bq0s51PKqrEhG5Cvrzd1d+JjspSpVVDZvXF +cAeI1QKBgC/cMN3+2Sc+2biu46DnkdYpdF/N0VGMOgzz+unSVD4RA2mEJ9UdwGga +feshtrtcroHtEmc+WDYgTTnAq1MbsVFQYIwZ5fL/GJ1R8ccaWiPuX2HrKALKG4Y3 +CMFpDUWhRgtaBsmuOpUq3FeS5cyPNMHk6axL1KyFoJk9AgfhqhTp +-----END RSA PRIVATE KEY----- diff --git a/pkg/acquisition/modules/http/testdata/server.crt b/pkg/acquisition/modules/http/testdata/server.crt new file mode 100644 index 00000000000..7a02c606c9d --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/server.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID5jCCAs6gAwIBAgIUU3F6URi0oTe9ontkf7JqXOo89QYwDQYJKoZIhvcNAQEL +BQAwbzELMAkGA1UEBhMCRlIxFjAUBgNVBAgTDUlsZS1kZS1GcmFuY2UxDjAMBgNV +BAcTBVBhcmlzMREwDwYDVQQKEwhDcm93ZHNlYzERMA8GA1UECxMIQ3Jvd2RzZWMx +EjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0yNDEwMjMxMDAzMDBaFw0yNTEwMjMxMDAz +MDBaMG8xCzAJBgNVBAYTAkZSMRYwFAYDVQQIEw1JbGUtZGUtRnJhbmNlMQ4wDAYD +VQQHEwVQYXJpczERMA8GA1UEChMIQ3Jvd2RzZWMxETAPBgNVBAsTCENyb3dkc2Vj +MRIwEAYDVQQDEwlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK +AoIBAQC/lnUubjBGe5x0LgIE5GeG52LRzj99iLWuvey4qbSwFZ07ECgv+JttVwDm +AjEeakj2ZR46WHvHAR9eBNkRCORyWX0iKVIzm09PXYi80KtwGLaA8YMEio9/08Cc ++LS0TuP0yiOcw+btrhmvvauDzcQhA6u55q8anCZiF2BlHfX9Sh6QKewA3NhOkzbU +VTxqrOqfcRsGNub7dheqfP5bfrPkF6Y6l/0Fhyx0NMsu1zaQ0hCls2hkTf0Y3XGt +IQNWoN22seexR3qRmPf0j3jBa0qOmGgd6kAd+YpsjDblgCNUIJZiVj51fVb0sGRx +ShkfKGU6t0eznTWPCqswujO/sn+pAgMBAAGjejB4MA4GA1UdDwEB/wQEAwIFoDAd +BgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNV +HQ4EFgQUOiIF+7Wzx1J8Ki3DiBfx+E6zlSUwGgYDVR0RBBMwEYIJbG9jYWxob3N0 +hwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQA0dzlhBr/0wXPyj/iWxMOXxZ1FNJ9f +lxBMhLAgX0WrT2ys+284J7Hcn0lJeqelluYpmeKn9vmCAEj3MmUmHzZyf//lhuUJ +0DlYWIHUsGaJHJ7A+1hQqrcXHhkcRy5WGIM9VoddKbBbg2b6qzTSvxn8EnuD7H4h +28wLyGLCzsSXoVcAB8u+svYt29TPuy6xmMAokyIShV8FsE77fjVTgtCuxmx1PKv3 +zd6+uEae7bbZ+GJH1zKF0vokejQvmByt+YuIXlNbMseaMUeDdpy+6qlRvbbN1dyp +rkQXfWvidMfSue5nH/akAn83v/CdKxG6tfW83d9Rud3naabUkywALDng +-----END CERTIFICATE----- diff --git a/pkg/acquisition/modules/http/testdata/server.key b/pkg/acquisition/modules/http/testdata/server.key new file mode 100644 index 00000000000..4d0ee53b4c2 --- /dev/null +++ b/pkg/acquisition/modules/http/testdata/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAv5Z1Lm4wRnucdC4CBORnhudi0c4/fYi1rr3suKm0sBWdOxAo +L/ibbVcA5gIxHmpI9mUeOlh7xwEfXgTZEQjkcll9IilSM5tPT12IvNCrcBi2gPGD +BIqPf9PAnPi0tE7j9MojnMPm7a4Zr72rg83EIQOrueavGpwmYhdgZR31/UoekCns +ANzYTpM21FU8aqzqn3EbBjbm+3YXqnz+W36z5BemOpf9BYcsdDTLLtc2kNIQpbNo +ZE39GN1xrSEDVqDdtrHnsUd6kZj39I94wWtKjphoHepAHfmKbIw25YAjVCCWYlY+ +dX1W9LBkcUoZHyhlOrdHs501jwqrMLozv7J/qQIDAQABAoIBAF1Vd/rJlV0Q5RQ4 +QaWOe9zdpmedeZK3YgMh5UvE6RCLRxC5+0n7bASlSPvEf5dYofjfJA26g3pcUqKj +6/d/hIMsk2hsBu67L7TzVSTe51XxxB8nCPPSaLwWNZSDGM1qTWU4gIbjbQHHOh5C +YWcRfAW1WxhyiEWHYq+QwdYg9XCRrSg1UzvVvW1Yt2wDGcSZP5whbXipfw3BITDs +XU7ODYNkU1sjIzQZwzVGxOf9qKdhZFZ26Vhoz8OJNMLyJxY7EspuwR7HbDGt11Pb +CxOt/BV44LwdVYeqh57oIKtckQW33W/6EeaWr7GfMzyH5WSrsOJoK5IJVrZaPTcS +QiMYLA0CgYEA9vMVsGshBl3TeRGaU3XLHqooXD4kszbdnjfPrwGlfCO/iybhDqo5 +WFypM/bYcIWzbTez/ihufHEHPSCUbFEcN4B+oczGcuxTcZjFyvJYvq2ycxPUiDIi +JnVUcVxgh1Yn39+CsQ/b6meP7MumTD2P3I87CeQGlWTO5Ys9mdw0BjcCgYEAxpv1 +64l5UoFJGr4yElNKDIKnhEFbJZsLGKiiuVXcS1QVHW5Az5ar9fPxuepyHpz416l3 +ppncuhJiUIP+jbu5e0s0LsN46mLS3wkHLgYJj06CNT3uOSLSg1iFl7DusdbyiaA7 +wEJ/aotS1NZ4XaeryAWHwYJ6Kag3nz6NV3ZYuR8CgYEAxAFCuMj+6F+2RsTa+d1n +v8oMyNImLPyiQD9KHzyuTW7OTDMqtIoVg/Xf8re9KOpl9I0e1t7eevT3auQeCi8C +t2bMm7290V+UB3jbnO5n08hn+ADIUuV/x4ie4m8QyrpuYbm0sLbGtTFHwgoNzzuZ +oNUqZfpP42mk8fpnhWSLAlcCgYEAgpY7XRI4HkJ5ocbav2faMV2a7X/XgWNvKViA +HeJRhYoUlBRRMuz7xi0OjFKVlIFbsNlxna5fDk1WLWCMd/6tl168Qd8u2tX9lr6l +5OH9WSeiv4Un5JN73PbQaAvi9jXBpTIg92oBwzk2TlFyNQoxDcRtHZQ/5LIBWIhV +gOOEtLsCgYEA1wbGc4XlH+/nXVsvx7gmfK8pZG8XA4/ToeIEURwPYrxtQZLB4iZs +aqWGgIwiB4F4UkuKZIjMrgInU9y0fG6EL96Qty7Yjh7dGy1vJTZl6C+QU6o4sEwl +r5Id5BNLEaqISWQ0LvzfwdfABYlvFfBdaGbzUzLEitD79eyhxuNEOBw= +-----END RSA PRIVATE KEY----- diff --git a/pkg/acquisition/modules/journalctl/journalctl.go b/pkg/acquisition/modules/journalctl/journalctl.go index 762dfe9ba12..27f20b9f446 100644 --- a/pkg/acquisition/modules/journalctl/journalctl.go +++ b/pkg/acquisition/modules/journalctl/journalctl.go @@ -65,8 +65,8 @@ func readLine(scanner *bufio.Scanner, out chan string, errChan chan error) error return nil } -func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) error { - ctx, cancel := context.WithCancel(context.Background()) +func (j *JournalCtlSource) runJournalCtl(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + ctx, cancel := context.WithCancel(ctx) cmd := exec.CommandContext(ctx, journalctlCmd, j.args...) stdout, err := cmd.StdoutPipe() @@ -113,7 +113,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err return readLine(stdoutscanner, stdoutChan, errChan) }) t.Go(func() error { - //looks like journalctl closes stderr quite early, so ignore its status (but not its output) + // looks like journalctl closes stderr quite early, so ignore its status (but not its output) return readLine(stderrScanner, stderrChan, nil) }) @@ -122,7 +122,7 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err case <-t.Dying(): logger.Infof("journalctl datasource %s stopping", j.src) cancel() - cmd.Wait() //avoid zombie process + cmd.Wait() // avoid zombie process return nil case stdoutLine := <-stdoutChan: l := types.Line{} @@ -136,12 +136,9 @@ func (j *JournalCtlSource) runJournalCtl(out chan types.Event, t *tomb.Tomb) err if j.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"source": j.src}).Inc() } - var evt types.Event - if !j.config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + + evt := types.MakeEvent(j.config.UseTimeMachine, types.LOG, true) + evt.Line = l out <- evt case stderrLine := <-stderrChan: logger.Warnf("Got stderr message : %s", stderrLine) @@ -217,13 +214,13 @@ func (j *JournalCtlSource) ConfigureByDSN(dsn string, labels map[string]string, j.config.Labels = labels j.config.UniqueId = uuid - //format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 + // format for the DSN is : journalctl://filters=FILTER1&filters=FILTER2 if !strings.HasPrefix(dsn, "journalctl://") { return fmt.Errorf("invalid DSN %s for journalctl source, must start with journalctl://", dsn) } qs := strings.TrimPrefix(dsn, "journalctl://") - if len(qs) == 0 { + if qs == "" { return errors.New("empty journalctl:// DSN") } @@ -262,26 +259,27 @@ func (j *JournalCtlSource) GetName() string { return "journalctl" } -func (j *JournalCtlSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (j *JournalCtlSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { defer trace.CatchPanic("crowdsec/acquis/journalctl/oneshot") - err := j.runJournalCtl(out, t) + err := j.runJournalCtl(ctx, out, t) j.logger.Debug("Oneshot journalctl acquisition is done") return err - } -func (j *JournalCtlSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (j *JournalCtlSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/journalctl/streaming") - return j.runJournalCtl(out, t) + return j.runJournalCtl(ctx, out, t) }) return nil } + func (j *JournalCtlSource) CanRun() error { - //TODO: add a more precise check on version or something ? + // TODO: add a more precise check on version or something ? _, err := exec.LookPath(journalctlCmd) return err } + func (j *JournalCtlSource) Dump() interface{} { return j } diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index 53e2d0802ad..687067c1881 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -1,6 +1,7 @@ package journalctlacquisition import ( + "context" "os" "os/exec" "path/filepath" @@ -106,6 +107,8 @@ func TestConfigureDSN(t *testing.T) { } func TestOneShot(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -164,7 +167,7 @@ journalctl_filter: t.Fatalf("Unexpected error : %s", err) } - err = j.OneShotAcquisition(out, &tomb) + err = j.OneShotAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if err != nil { @@ -187,6 +190,7 @@ journalctl_filter: } func TestStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -250,7 +254,7 @@ journalctl_filter: }() } - err = j.StreamingAcquisition(out, &tomb) + err = j.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if err != nil { diff --git a/pkg/acquisition/modules/kafka/kafka.go b/pkg/acquisition/modules/kafka/kafka.go index ca0a7556fca..77fc44e310d 100644 --- a/pkg/acquisition/modules/kafka/kafka.go +++ b/pkg/acquisition/modules/kafka/kafka.go @@ -23,9 +23,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var ( - dataSourceName = "kafka" -) +var dataSourceName = "kafka" var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -82,7 +80,7 @@ func (k *KafkaSource) UnmarshalConfig(yamlConfig []byte) error { k.Config.Mode = configuration.TAIL_MODE } - k.logger.Debugf("successfully unmarshaled kafka configuration : %+v", k.Config) + k.logger.Debugf("successfully parsed kafka configuration : %+v", k.Config) return err } @@ -129,7 +127,7 @@ func (k *KafkaSource) GetName() string { return dataSourceName } -func (k *KafkaSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KafkaSource) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return fmt.Errorf("%s datasource does not support one-shot acquisition", dataSourceName) } @@ -149,12 +147,12 @@ func (k *KafkaSource) Dump() interface{} { return k } -func (k *KafkaSource) ReadMessage(out chan types.Event) error { +func (k *KafkaSource) ReadMessage(ctx context.Context, out chan types.Event) error { // Start processing from latest Offset - k.Reader.SetOffsetAt(context.Background(), time.Now()) + k.Reader.SetOffsetAt(ctx, time.Now()) for { k.logger.Tracef("reading message from topic '%s'", k.Config.Topic) - m, err := k.Reader.ReadMessage(context.Background()) + m, err := k.Reader.ReadMessage(ctx) if err != nil { if errors.Is(err, io.EOF) { return nil @@ -175,21 +173,16 @@ func (k *KafkaSource) ReadMessage(out chan types.Event) error { if k.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"topic": k.Config.Topic}).Inc() } - var evt types.Event - - if !k.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true) + evt.Line = l out <- evt } } -func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { +func (k *KafkaSource) RunReader(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { k.logger.Debugf("starting %s datasource reader goroutine with configuration %+v", dataSourceName, k.Config) t.Go(func() error { - return k.ReadMessage(out) + return k.ReadMessage(ctx, out) }) //nolint //fp for { @@ -204,12 +197,12 @@ func (k *KafkaSource) RunReader(out chan types.Event, t *tomb.Tomb) error { } } -func (k *KafkaSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KafkaSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { k.logger.Infof("start reader on brokers '%+v' with topic '%s'", k.Config.Brokers, k.Config.Topic) t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kafka/live") - return k.RunReader(out, t) + return k.RunReader(ctx, out, t) }) return nil diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 7b467142cc9..d796166a6ca 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -80,9 +80,9 @@ group_id: crowdsec`, } } -func writeToKafka(w *kafka.Writer, logs []string) { +func writeToKafka(ctx context.Context, w *kafka.Writer, logs []string) { for idx, log := range logs { - err := w.WriteMessages(context.Background(), kafka.Message{ + err := w.WriteMessages(ctx, kafka.Message{ Key: []byte(strconv.Itoa(idx)), // create an arbitrary message payload for the value Value: []byte(log), @@ -128,6 +128,7 @@ func createTopic(topic string, broker string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -176,12 +177,12 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w, ts.logs) + go writeToKafka(ctx, w, ts.logs) READLOOP: for { select { @@ -199,6 +200,7 @@ topic: crowdsecplaintext`), subLogger, configuration.METRICS_NONE) } func TestStreamingAcquisitionWithSSL(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -252,12 +254,12 @@ tls: tomb := tomb.Tomb{} out := make(chan types.Event) - err = k.StreamingAcquisition(out, &tomb) + err = k.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) actualLines := 0 - go writeToKafka(w2, ts.logs) + go writeToKafka(ctx, w2, ts.logs) READLOOP: for { select { diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go index 0e6c1980fa9..3744e43f38d 100644 --- a/pkg/acquisition/modules/kinesis/kinesis.go +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "errors" "fmt" @@ -29,7 +30,7 @@ type KinesisConfiguration struct { configuration.DataSourceCommonCfg `yaml:",inline"` StreamName string `yaml:"stream_name"` StreamARN string `yaml:"stream_arn"` - UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` //Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords + UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` // Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords AwsProfile *string `yaml:"aws_profile"` AwsRegion string `yaml:"aws_region"` AwsEndpoint string `yaml:"aws_endpoint"` @@ -114,8 +115,8 @@ func (k *KinesisSource) newClient() error { func (k *KinesisSource) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} - } + func (k *KinesisSource) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, linesReadShards} } @@ -181,14 +182,13 @@ func (k *KinesisSource) GetName() string { return "kinesis" } -func (k *KinesisSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return errors.New("kinesis datasource does not support one-shot acquisition") } func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubscriptionLogEvent, error) { b := bytes.NewBuffer(record) r, err := gzip.NewReader(b) - if err != nil { k.logger.Error(err) return nil, err @@ -299,8 +299,8 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan var data []CloudwatchSubscriptionLogEvent var err error if k.Config.FromSubscription { - //The AWS docs says that the data is base64 encoded - //but apparently GetRecords decodes it for us ? + // The AWS docs says that the data is base64 encoded + // but apparently GetRecords decodes it for us ? data, err = k.decodeFromSubscription(record.Data) if err != nil { logger.Errorf("Cannot decode data: %s", err) @@ -322,12 +322,8 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan } else { l.Src = k.Config.StreamName } - var evt types.Event - if !k.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(k.Config.UseTimeMachine, types.LOG, true) + evt.Line = l out <- evt } } @@ -335,9 +331,9 @@ func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEventStreamReader, out chan types.Event, shardId string, streamName string) error { logger := k.logger.WithField("shard_id", shardId) - //ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately - //and we won't be able to start a new one if this is the first one started by the tomb - //TODO: look into parent shards to see if a shard is closed before starting to read it ? + // ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately + // and we won't be able to start a new one if this is the first one started by the tomb + // TODO: look into parent shards to see if a shard is closed before starting to read it ? time.Sleep(time.Second) for { select { @@ -420,7 +416,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { case <-t.Dying(): k.logger.Infof("Kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves err = k.DeregisterConsumer() if err != nil { return fmt.Errorf("cannot deregister consumer: %w", err) @@ -431,7 +427,7 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { if k.shardReaderTomb.Err() != nil { return k.shardReaderTomb.Err() } - //All goroutines have exited without error, so a resharding event, start again + // All goroutines have exited without error, so a resharding event, start again k.logger.Debugf("All reader goroutines have exited, resharding event or periodic resubscribe") continue } @@ -441,15 +437,17 @@ func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) error { logger := k.logger.WithField("shard", shardId) logger.Debugf("Starting to read shard") - sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ShardId: aws.String(shardId), + sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ + ShardId: aws.String(shardId), StreamName: &k.Config.StreamName, - ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest)}) + ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest), + }) if err != nil { logger.Errorf("Cannot get shard iterator: %s", err) return fmt.Errorf("cannot get shard iterator: %w", err) } it := sharIt.ShardIterator - //AWS recommends to wait for a second between calls to GetRecords for a given shard + // AWS recommends to wait for a second between calls to GetRecords for a given shard ticker := time.NewTicker(time.Second) for { select { @@ -460,7 +458,7 @@ func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) erro switch err.(type) { case *kinesis.ProvisionedThroughputExceededException: logger.Warn("Provisioned throughput exceeded") - //TODO: implement exponential backoff + // TODO: implement exponential backoff continue case *kinesis.ExpiredIteratorException: logger.Warn("Expired iterator") @@ -506,7 +504,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error case <-t.Dying(): k.logger.Info("kinesis source is dying") k.shardReaderTomb.Kill(nil) - _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + _ = k.shardReaderTomb.Wait() // we don't care about the error as we kill the tomb ourselves return nil case <-k.shardReaderTomb.Dying(): reason := k.shardReaderTomb.Err() @@ -520,7 +518,7 @@ func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error } } -func (k *KinesisSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (k *KinesisSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/kinesis/streaming") if k.Config.UseEnhancedFanOut { diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go index 46e404aa49b..027cbde9240 100644 --- a/pkg/acquisition/modules/kinesis/kinesis_test.go +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -3,6 +3,7 @@ package kinesisacquisition import ( "bytes" "compress/gzip" + "context" "encoding/json" "fmt" "net" @@ -60,8 +61,8 @@ func GenSubObject(i int) []byte { gz := gzip.NewWriter(&b) gz.Write(body) gz.Close() - //AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point - //localstack does not do it, so let's just write a raw gzipped stream + // AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point + // localstack does not do it, so let's just write a raw gzipped stream return b.Bytes() } @@ -99,10 +100,10 @@ func TestMain(m *testing.M) { os.Setenv("AWS_ACCESS_KEY_ID", "foobar") os.Setenv("AWS_SECRET_ACCESS_KEY", "foobar") - //delete_streams() - //create_streams() + // delete_streams() + // create_streams() code := m.Run() - //delete_streams() + // delete_streams() os.Exit(code) } @@ -149,6 +150,7 @@ stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, } func TestReadFromStream(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -176,11 +178,11 @@ stream_name: stream-1-shard`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) for i := range test.count { @@ -193,6 +195,7 @@ stream_name: stream-1-shard`, } func TestReadFromMultipleShards(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -220,11 +223,11 @@ stream_name: stream-2-shards`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, false) c := 0 @@ -239,6 +242,7 @@ stream_name: stream-2-shards`, } func TestFromSubscription(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -267,11 +271,11 @@ from_subscription: true`, } tomb := &tomb.Tomb{} out := make(chan types.Event) - err = f.StreamingAcquisition(out, tomb) + err = f.StreamingAcquisition(ctx, out, tomb) if err != nil { t.Fatalf("Error starting source: %s", err) } - //Allow the datasource to start listening to the stream + // Allow the datasource to start listening to the stream time.Sleep(4 * time.Second) WriteToStream(f.Config.StreamName, test.count, test.shards, true) for i := range test.count { diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go index e48a074b764..1fa6c894a32 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit.go @@ -131,11 +131,11 @@ func (ka *KubernetesAuditSource) GetName() string { return "k8s-audit" } -func (ka *KubernetesAuditSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (ka *KubernetesAuditSource) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return errors.New("k8s-audit datasource does not support one-shot acquisition") } -func (ka *KubernetesAuditSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (ka *KubernetesAuditSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { ka.outChan = out t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/k8s-audit/live") @@ -149,7 +149,7 @@ func (ka *KubernetesAuditSource) StreamingAcquisition(out chan types.Event, t *t }) <-t.Dying() ka.logger.Infof("Stopping k8s-audit server on %s:%d%s", ka.config.ListenAddr, ka.config.ListenPort, ka.config.WebhookPath) - ka.server.Shutdown(context.TODO()) + ka.server.Shutdown(ctx) return nil }) return nil @@ -164,7 +164,6 @@ func (ka *KubernetesAuditSource) Dump() interface{} { } func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.Request) { - if ka.metricsLevel != configuration.METRICS_NONE { requestCount.WithLabelValues(ka.addr).Inc() } @@ -196,7 +195,7 @@ func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.R } bytesEvent, err := json.Marshal(auditEvent) if err != nil { - ka.logger.Errorf("Error marshaling audit event: %s", err) + ka.logger.Errorf("Error serializing audit event: %s", err) continue } ka.logger.Tracef("Got audit event: %s", string(bytesEvent)) @@ -208,11 +207,8 @@ func (ka *KubernetesAuditSource) webhookHandler(w http.ResponseWriter, r *http.R Process: true, Module: ka.GetName(), } - ka.outChan <- types.Event{ - Line: l, - Process: true, - Type: types.LOG, - ExpectMode: types.LIVE, - } + evt := types.MakeEvent(ka.config.UseTimeMachine, types.LOG, true) + evt.Line = l + ka.outChan <- evt } } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index 020bd4c91a0..a086a756e4a 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,6 +1,7 @@ package kubernetesauditacquisition import ( + "context" "net/http/httptest" "strings" "testing" @@ -52,6 +53,7 @@ listen_addr: 0.0.0.0`, } func TestInvalidConfig(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -83,7 +85,7 @@ webhook_path: /k8s-audit`, err = f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) require.NoError(t, err) - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) @@ -98,6 +100,7 @@ webhook_path: /k8s-audit`, } func TestHandler(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -257,14 +260,14 @@ webhook_path: /k8s-audit`, req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - f.StreamingAcquisition(out, tb) + f.StreamingAcquisition(ctx, out, tb) f.webhookHandler(w, req) res := w.Result() assert.Equal(t, test.expectedStatusCode, res.StatusCode) - //time.Sleep(1 * time.Second) + // time.Sleep(1 * time.Second) require.NoError(t, err) tb.Kill(nil) diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 420da6e391c..fce199c5708 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -16,7 +16,7 @@ import ( log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) type LokiClient struct { @@ -119,7 +119,7 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu case <-lc.t.Dying(): return lc.t.Err() case <-ticker.C: - resp, err := lc.Get(uri) + resp, err := lc.Get(ctx, uri) if err != nil { if ok := lc.shouldRetry(); !ok { return fmt.Errorf("error querying range: %w", err) @@ -215,7 +215,7 @@ func (lc *LokiClient) Ready(ctx context.Context) error { return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") - resp, err := lc.Get(url) + resp, err := lc.Get(ctx, url) if err != nil { lc.Logger.Warnf("Error checking if Loki is ready: %s", err) continue @@ -300,8 +300,8 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ } // Create a wrapper for http.Get to be able to set headers and auth -func (lc *LokiClient) Get(url string) (*http.Response, error) { - request, err := http.NewRequest(http.MethodGet, url, nil) +func (lc *LokiClient) Get(ctx context.Context, url string) (*http.Response, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } @@ -319,6 +319,6 @@ func NewLokiClient(config Config) *LokiClient { if config.Username != "" || config.Password != "" { headers["Authorization"] = "Basic " + base64.StdEncoding.EncodeToString([]byte(config.Username+":"+config.Password)) } - headers["User-Agent"] = cwversion.UserAgent() + headers["User-Agent"] = useragent.Default() return &LokiClient{Logger: log.WithField("component", "lokiclient"), config: config, requestHeaders: headers} } diff --git a/pkg/acquisition/modules/loki/loki.go b/pkg/acquisition/modules/loki/loki.go index 15c454723ee..c57e6a67c94 100644 --- a/pkg/acquisition/modules/loki/loki.go +++ b/pkg/acquisition/modules/loki/loki.go @@ -53,6 +53,7 @@ type LokiConfiguration struct { WaitForReady time.Duration `yaml:"wait_for_ready"` // Retry interval, default is 10 seconds Auth LokiAuthConfiguration `yaml:"auth"` MaxFailureDuration time.Duration `yaml:"max_failure_duration"` // Max duration of failure before stopping the source + NoReadyCheck bool `yaml:"no_ready_check"` // Bypass /ready check before starting configuration.DataSourceCommonCfg `yaml:",inline"` } @@ -229,6 +230,14 @@ func (l *LokiSource) ConfigureByDSN(dsn string, labels map[string]string, logger l.logger.Logger.SetLevel(level) } + if noReadyCheck := params.Get("no_ready_check"); noReadyCheck != "" { + noReadyCheck, err := strconv.ParseBool(noReadyCheck) + if err != nil { + return fmt.Errorf("invalid no_ready_check in dsn: %w", err) + } + l.Config.NoReadyCheck = noReadyCheck + } + l.Config.URL = fmt.Sprintf("%s://%s", scheme, u.Host) if u.User != nil { l.Config.Auth.Username = u.User.Username() @@ -261,29 +270,31 @@ func (l *LokiSource) GetName() string { } // OneShotAcquisition reads a set of file and returns when done -func (l *LokiSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (l *LokiSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.logger.Debug("Loki one shot acquisition") l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(context.Background(), l.Config.WaitForReady) - defer cancel() - err := l.Client.Ready(readyCtx) - if err != nil { - return fmt.Errorf("loki is not ready: %w", err) + + if !l.Config.NoReadyCheck { + readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) + defer readyCancel() + err := l.Client.Ready(readyCtx) + if err != nil { + return fmt.Errorf("loki is not ready: %w", err) + } } - ctx, cancel := context.WithCancel(context.Background()) - c := l.Client.QueryRange(ctx, false) + lokiCtx, cancel := context.WithCancel(ctx) + defer cancel() + c := l.Client.QueryRange(lokiCtx, false) for { select { case <-t.Dying(): l.logger.Debug("Loki one shot acquisition stopped") - cancel() return nil case resp, ok := <-c: if !ok { l.logger.Info("Loki acquisition done, chan closed") - cancel() return nil } for _, stream := range resp.Data.Result { @@ -307,41 +318,33 @@ func (l *LokiSource) readOneEntry(entry lokiclient.Entry, labels map[string]stri if l.metricsLevel != configuration.METRICS_NONE { linesRead.With(prometheus.Labels{"source": l.Config.URL}).Inc() } - expectMode := types.LIVE - if l.Config.UseTimeMachine { - expectMode = types.TIMEMACHINE - } - out <- types.Event{ - Line: ll, - Process: true, - Type: types.LOG, - ExpectMode: expectMode, - } + evt := types.MakeEvent(l.Config.UseTimeMachine, types.LOG, true) + evt.Line = ll + out <- evt } -func (l *LokiSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (l *LokiSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { l.Client.SetTomb(t) - readyCtx, cancel := context.WithTimeout(context.Background(), l.Config.WaitForReady) - defer cancel() - err := l.Client.Ready(readyCtx) - if err != nil { - return fmt.Errorf("loki is not ready: %w", err) + + if !l.Config.NoReadyCheck { + readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) + defer readyCancel() + err := l.Client.Ready(readyCtx) + if err != nil { + return fmt.Errorf("loki is not ready: %w", err) + } } ll := l.logger.WithField("websocket_url", l.lokiWebsocket) t.Go(func() error { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() respChan := l.Client.QueryRange(ctx, true) - if err != nil { - ll.Errorf("could not start loki tail: %s", err) - return fmt.Errorf("while starting loki tail: %w", err) - } for { select { case resp, ok := <-respChan: if !ok { ll.Warnf("loki channel closed") - return err + return errors.New("loki channel closed") } for _, stream := range resp.Data.Result { for _, entry := range stream.Entries { diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index 5f41cd4c62e..643aefad715 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -34,6 +34,7 @@ func TestConfiguration(t *testing.T) { password string waitForReady time.Duration delayFor time.Duration + noReadyCheck bool testName string }{ { @@ -95,7 +96,19 @@ query: > delayFor: 1 * time.Second, }, { - + config: ` +mode: tail +source: loki +url: http://localhost:3100/ +no_ready_check: true +query: > + {server="demo"} +`, + expectedErr: "", + testName: "Correct config with no_ready_check", + noReadyCheck: true, + }, + { config: ` mode: tail source: loki @@ -111,7 +124,6 @@ query: > testName: "Correct config with password", }, { - config: ` mode: tail source: loki @@ -150,6 +162,8 @@ query: > t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) } } + + assert.Equal(t, test.noReadyCheck, lokiSource.Config.NoReadyCheck) }) } } @@ -166,6 +180,7 @@ func TestConfigureDSN(t *testing.T) { scheme string waitForReady time.Duration delayFor time.Duration + noReadyCheck bool }{ { name: "Wrong scheme", @@ -204,10 +219,11 @@ func TestConfigureDSN(t *testing.T) { }, { name: "Correct DSN", - dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s`, + dsn: `loki://localhost:3100/?query={server="demo"}&wait_for_ready=5s&delay_for=1s&no_ready_check=true`, expectedErr: "", waitForReady: 5 * time.Second, delayFor: 1 * time.Second, + noReadyCheck: true, }, { name: "SSL DSN", @@ -258,10 +274,13 @@ func TestConfigureDSN(t *testing.T) { t.Fatalf("Wrong DelayFor %v != %v", lokiSource.Config.DelayFor, test.delayFor) } } + + assert.Equal(t, test.noReadyCheck, lokiSource.Config.NoReadyCheck) + } } -func feedLoki(logger *log.Entry, n int, title string) error { +func feedLoki(ctx context.Context, logger *log.Entry, n int, title string) error { streams := LogStreams{ Streams: []LogStream{ { @@ -286,7 +305,7 @@ func feedLoki(logger *log.Entry, n int, title string) error { return err } - req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://127.0.0.1:3100/loki/api/v1/push", bytes.NewBuffer(buff)) if err != nil { return err } @@ -314,6 +333,8 @@ func feedLoki(logger *log.Entry, n int, title string) error { } func TestOneShotAcquisition(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -344,12 +365,11 @@ since: 1h subLogger := logger.WithField("type", "loki") lokiSource := loki.LokiSource{} err := lokiSource.Configure([]byte(ts.config), subLogger, configuration.METRICS_NONE) - if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = feedLoki(subLogger, 20, title) + err = feedLoki(ctx, subLogger, 20, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -367,7 +387,7 @@ since: 1h lokiTomb := tomb.Tomb{} - err = lokiSource.OneShotAcquisition(out, &lokiTomb) + err = lokiSource.OneShotAcquisition(ctx, out, &lokiTomb) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -421,6 +441,8 @@ query: > }, } + ctx := context.Background() + for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { logger := log.New() @@ -438,7 +460,7 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = lokiSource.StreamingAcquisition(out, &lokiTomb) + err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) cstest.AssertErrorContains(t, err, ts.streamErr) if ts.streamErr != "" { @@ -448,7 +470,7 @@ query: > time.Sleep(time.Second * 2) // We need to give time to start reading from the WS readTomb := tomb.Tomb{} - readCtx, cancel := context.WithTimeout(context.Background(), time.Second*10) + readCtx, cancel := context.WithTimeout(ctx, time.Second*10) count := 0 readTomb.Go(func() error { @@ -472,7 +494,7 @@ query: > } }) - err = feedLoki(subLogger, ts.expectedLines, title) + err = feedLoki(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -491,6 +513,7 @@ query: > } func TestStopStreaming(t *testing.T) { + ctx := context.Background() if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } @@ -518,14 +541,14 @@ query: > lokiTomb := &tomb.Tomb{} - err = lokiSource.StreamingAcquisition(out, lokiTomb) + err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) if err != nil { t.Fatalf("Unexpected error : %s", err) } time.Sleep(time.Second * 2) - err = feedLoki(subLogger, 1, title) + err = feedLoki(ctx, subLogger, 1, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go index 9ef4d2ba757..cdc84a8a3ca 100644 --- a/pkg/acquisition/modules/s3/s3.go +++ b/pkg/acquisition/modules/s3/s3.go @@ -38,7 +38,7 @@ type S3Configuration struct { AwsEndpoint string `yaml:"aws_endpoint"` BucketName string `yaml:"bucket_name"` Prefix string `yaml:"prefix"` - Key string `yaml:"-"` //Only for DSN acquisition + Key string `yaml:"-"` // Only for DSN acquisition PollingMethod string `yaml:"polling_method"` PollingInterval int `yaml:"polling_interval"` SQSName string `yaml:"sqs_name"` @@ -93,10 +93,12 @@ type S3Event struct { } `json:"detail"` } -const PollMethodList = "list" -const PollMethodSQS = "sqs" -const SQSFormatEventBridge = "eventbridge" -const SQSFormatS3Notification = "s3notification" +const ( + PollMethodList = "list" + PollMethodSQS = "sqs" + SQSFormatEventBridge = "eventbridge" + SQSFormatS3Notification = "s3notification" +) var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -336,7 +338,7 @@ func (s *S3Source) sqsPoll() error { out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: aws.Int64(10), - WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? + WaitTimeSeconds: aws.Int64(20), // Probably no need to make it configurable ? }) if err != nil { logger.Errorf("Error while polling SQS: %s", err) @@ -351,7 +353,7 @@ func (s *S3Source) sqsPoll() error { bucket, key, err := s.extractBucketAndPrefix(message.Body) if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) - //Always delete the message to avoid infinite loop + // Always delete the message to avoid infinite loop _, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -377,7 +379,7 @@ func (s *S3Source) sqsPoll() error { } func (s *S3Source) readFile(bucket string, key string) error { - //TODO: Handle SSE-C + // TODO: Handle SSE-C var scanner *bufio.Scanner logger := s.logger.WithFields(log.Fields{ @@ -390,14 +392,13 @@ func (s *S3Source) readFile(bucket string, key string) error { Bucket: aws.String(bucket), Key: aws.String(key), }) - if err != nil { return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err) } defer output.Body.Close() if strings.HasSuffix(key, ".gz") { - //This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) + // This *might* be a gzipped file, but sometimes the SDK will decompress the data for us (it's not clear when it happens, only had the issue with cloudtrail logs) header := make([]byte, 2) _, err := output.Body.Read(header) if err != nil { @@ -442,12 +443,8 @@ func (s *S3Source) readFile(bucket string, key string) error { } else if s.MetricsLevel == configuration.METRICS_AGGREGATE { l.Src = bucket } - var evt types.Event - if !s.Config.UseTimeMachine { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(s.Config.UseTimeMachine, types.LOG, true) + evt.Line = l s.out <- evt } } @@ -467,6 +464,7 @@ func (s *S3Source) GetUuid() string { func (s *S3Source) GetMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} } + func (s *S3Source) GetAggregMetrics() []prometheus.Collector { return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} } @@ -567,11 +565,11 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * }) dsn = strings.TrimPrefix(dsn, "s3://") args := strings.Split(dsn, "?") - if len(args[0]) == 0 { + if args[0] == "" { return errors.New("empty s3:// DSN") } - if len(args) == 2 && len(args[1]) != 0 { + if len(args) == 2 && args[1] != "" { params, err := url.ParseQuery(args[1]) if err != nil { return fmt.Errorf("could not parse s3 args: %w", err) @@ -610,7 +608,7 @@ func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger * pathParts := strings.Split(args[0], "/") s.logger.Debugf("pathParts: %v", pathParts) - //FIXME: handle s3://bucket/ + // FIXME: handle s3://bucket/ if len(pathParts) == 1 { s.Config.BucketName = pathParts[0] s.Config.Prefix = "" @@ -641,10 +639,10 @@ func (s *S3Source) GetName() string { return "s3" } -func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *S3Source) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key) s.out = out - s.ctx, s.cancel = context.WithCancel(context.Background()) + s.ctx, s.cancel = context.WithCancel(ctx) s.Config.UseTimeMachine = true s.t = t if s.Config.Key != "" { @@ -653,7 +651,7 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return err } } else { - //No key, get everything in the bucket based on the prefix + // No key, get everything in the bucket based on the prefix objects, err := s.getBucketContent() if err != nil { return err @@ -669,11 +667,11 @@ func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error return nil } -func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *S3Source) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { s.t = t s.out = out - s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(context.Background()) + s.readerChan = make(chan S3Object, 100) // FIXME: does this needs to be buffered? + s.ctx, s.cancel = context.WithCancel(ctx) s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) t.Go(func() error { s.readManager() diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 93e166dfec5..367048aa33a 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -208,6 +208,7 @@ func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sq } func TestDSNAcquis(t *testing.T) { + ctx := context.Background() tests := []struct { name string dsn string @@ -260,7 +261,7 @@ func TestDSNAcquis(t *testing.T) { f.s3Client = mockS3Client{} tmb := tomb.Tomb{} - err = f.OneShotAcquisition(out, &tmb) + err = f.OneShotAcquisition(ctx, out, &tmb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -272,6 +273,7 @@ func TestDSNAcquis(t *testing.T) { } func TestListPolling(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -331,7 +333,7 @@ prefix: foo/ } }() - err = f.StreamingAcquisition(out, &tb) + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } @@ -348,6 +350,7 @@ prefix: foo/ } func TestSQSPoll(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -411,7 +414,7 @@ sqs_name: test } }() - err = f.StreamingAcquisition(out, &tb) + err = f.StreamingAcquisition(ctx, out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } diff --git a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go index 8fb5089a61f..3af6614bce6 100644 --- a/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go +++ b/pkg/acquisition/modules/syslog/internal/parser/rfc3164/parse_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" "time" + + "github.com/stretchr/testify/assert" + + "github.com/crowdsecurity/go-cs-lib/cstest" ) func TestPri(t *testing.T) { @@ -26,28 +30,20 @@ func TestPri(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parsePRI() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.PRI != test.expected { - t.Errorf("expected %d, got %d", test.expected, r.PRI) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.PRI) }) } } func TestTimestamp(t *testing.T) { - tests := []struct { input string expected string @@ -68,25 +64,19 @@ func TestTimestamp(t *testing.T) { if test.currentYear { opts = append(opts, WithCurrentYear()) } + r := NewRFC3164Parser(opts...) r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseTimestamp() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Timestamp.Format(time.RFC3339) != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Timestamp.Format(time.RFC3339)) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Timestamp.Format(time.RFC3339)) }) } } @@ -121,25 +111,19 @@ func TestHostname(t *testing.T) { if test.strictHostname { opts = append(opts, WithStrictHostname()) } + r := NewRFC3164Parser(opts...) r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseHostname() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Hostname != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Hostname) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Hostname) }) } } @@ -164,27 +148,16 @@ func TestTag(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseTag() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else { - if r.Tag != test.expected { - t.Errorf("expected %s, got %s", test.expected, r.Tag) - } - if r.PID != test.expectedPID { - t.Errorf("expected %s, got %s", test.expected, r.Message) - } - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Tag) + assert.Equal(t, test.expectedPID, r.PID) }) } } @@ -207,22 +180,15 @@ func TestMessage(t *testing.T) { r := &RFC3164{} r.buf = []byte(test.input) r.len = len(r.buf) + err := r.parseMessage() - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error %s, got %s", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: %s", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error %s, got no error", test.expectedErr) - } else if r.Message != test.expected { - t.Errorf("expected message %s, got %s", test.expected, r.Tag) - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected, r.Message) }) } } @@ -236,6 +202,7 @@ func TestParse(t *testing.T) { Message string PRI int } + tests := []struct { input string expected expected @@ -326,39 +293,20 @@ func TestParse(t *testing.T) { for _, test := range tests { t.Run(test.input, func(t *testing.T) { r := NewRFC3164Parser(test.opts...) + err := r.Parse([]byte(test.input)) - if err != nil { - if test.expectedErr != "" { - if err.Error() != test.expectedErr { - t.Errorf("expected error '%s', got '%s'", test.expectedErr, err) - } - } else { - t.Errorf("unexpected error: '%s'", err) - } - } else { - if test.expectedErr != "" { - t.Errorf("expected error '%s', got no error", test.expectedErr) - } else { - if r.Timestamp != test.expected.Timestamp { - t.Errorf("expected timestamp '%s', got '%s'", test.expected.Timestamp, r.Timestamp) - } - if r.Hostname != test.expected.Hostname { - t.Errorf("expected hostname '%s', got '%s'", test.expected.Hostname, r.Hostname) - } - if r.Tag != test.expected.Tag { - t.Errorf("expected tag '%s', got '%s'", test.expected.Tag, r.Tag) - } - if r.PID != test.expected.PID { - t.Errorf("expected pid '%s', got '%s'", test.expected.PID, r.PID) - } - if r.Message != test.expected.Message { - t.Errorf("expected message '%s', got '%s'", test.expected.Message, r.Message) - } - if r.PRI != test.expected.PRI { - t.Errorf("expected pri '%d', got '%d'", test.expected.PRI, r.PRI) - } - } + cstest.RequireErrorContains(t, err, test.expectedErr) + + if test.expectedErr != "" { + return } + + assert.Equal(t, test.expected.Timestamp, r.Timestamp) + assert.Equal(t, test.expected.Hostname, r.Hostname) + assert.Equal(t, test.expected.Tag, r.Tag) + assert.Equal(t, test.expected.PID, r.PID) + assert.Equal(t, test.expected.Message, r.Message) + assert.Equal(t, test.expected.PRI, r.PRI) }) } } diff --git a/pkg/acquisition/modules/syslog/syslog.go b/pkg/acquisition/modules/syslog/syslog.go index 06c32e62f77..fb6a04600c1 100644 --- a/pkg/acquisition/modules/syslog/syslog.go +++ b/pkg/acquisition/modules/syslog/syslog.go @@ -1,6 +1,7 @@ package syslogacquisition import ( + "context" "errors" "fmt" "net" @@ -83,7 +84,7 @@ func (s *SyslogSource) ConfigureByDSN(dsn string, labels map[string]string, logg return errors.New("syslog datasource does not support one shot acquisition") } -func (s *SyslogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *SyslogSource) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return errors.New("syslog datasource does not support one shot acquisition") } @@ -105,7 +106,7 @@ func (s *SyslogSource) UnmarshalConfig(yamlConfig []byte) error { } if s.config.Addr == "" { - s.config.Addr = "127.0.0.1" //do we want a usable or secure default ? + s.config.Addr = "127.0.0.1" // do we want a usable or secure default ? } if s.config.Port == 0 { s.config.Port = 514 @@ -135,7 +136,7 @@ func (s *SyslogSource) Configure(yamlConfig []byte, logger *log.Entry, MetricsLe return nil } -func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (s *SyslogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { c := make(chan syslogserver.SyslogMessage) s.server = &syslogserver.SyslogServer{Logger: s.logger.WithField("syslog", "internal"), MaxMessageLen: s.config.MaxMessageLen} s.server.SetChannel(c) @@ -152,7 +153,8 @@ func (s *SyslogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) } func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, - appname string, pid string, msg string) string { + appname string, pid string, msg string, +) string { ret := "" if !ts.IsZero() { ret += ts.Format("Jan 2 15:04:05") @@ -178,7 +180,6 @@ func (s *SyslogSource) buildLogFromSyslog(ts time.Time, hostname string, ret += msg } return ret - } func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c chan syslogserver.SyslogMessage) error { @@ -234,11 +235,9 @@ func (s *SyslogSource) handleSyslogMsg(out chan types.Event, t *tomb.Tomb, c cha l.Time = ts l.Src = syslogLine.Client l.Process = true - if !s.config.UseTimeMachine { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} - } else { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} - } + evt := types.MakeEvent(s.config.UseTimeMachine, types.LOG, true) + evt.Line = l + out <- evt } } } diff --git a/pkg/acquisition/modules/syslog/syslog_test.go b/pkg/acquisition/modules/syslog/syslog_test.go index 1750f375138..57fa3e8747b 100644 --- a/pkg/acquisition/modules/syslog/syslog_test.go +++ b/pkg/acquisition/modules/syslog/syslog_test.go @@ -1,6 +1,7 @@ package syslogacquisition import ( + "context" "fmt" "net" "runtime" @@ -80,6 +81,7 @@ func writeToSyslog(logs []string) { } func TestStreamingAcquisition(t *testing.T) { + ctx := context.Background() tests := []struct { name string config string @@ -100,8 +102,10 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 2, - logs: []string{`<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, - `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`}, + logs: []string{ + `<13>1 2021-05-18T11:58:40.828081+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla`, + `<13>1 2021-05-18T12:12:37.560695+02:00 mantis sshd 49340 - [timeQuality isSynced="0" tzKnown="1"] blabla2[foobar]`, + }, }, { name: "RFC3164", @@ -109,10 +113,12 @@ listen_addr: 127.0.0.1`, listen_port: 4242 listen_addr: 127.0.0.1`, expectedLines: 3, - logs: []string{`<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, + logs: []string{ + `<13>May 18 12:37:56 mantis sshd[49340]: blabla2[foobar]`, `<13>May 18 12:37:56 mantis sshd[49340]: blabla2`, `<13>May 18 12:37:56 mantis sshd: blabla2`, - `<13>May 18 12:37:56 mantis sshd`}, + `<13>May 18 12:37:56 mantis sshd`, + }, }, } if runtime.GOOS != "windows" { @@ -139,7 +145,7 @@ listen_addr: 127.0.0.1`, } tomb := tomb.Tomb{} out := make(chan types.Event) - err = s.StreamingAcquisition(out, &tomb) + err = s.StreamingAcquisition(ctx, out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) if ts.expectedErr != "" { return diff --git a/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx b/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx new file mode 100644 index 00000000000..2c4f8b0f680 Binary files /dev/null and b/pkg/acquisition/modules/wineventlog/test_files/Setup.evtx differ diff --git a/pkg/acquisition/modules/wineventlog/wineventlog.go b/pkg/acquisition/modules/wineventlog/wineventlog.go index 44035d0a708..3023a371576 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "errors" "github.com/prometheus/client_golang/prometheus" @@ -39,7 +40,7 @@ func (w *WinEventLogSource) SupportedModes() []string { return []string{configuration.TAIL_MODE, configuration.CAT_MODE} } -func (w *WinEventLogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) OneShotAcquisition(_ context.Context, _ chan types.Event, _ *tomb.Tomb) error { return nil } @@ -59,7 +60,7 @@ func (w *WinEventLogSource) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { return nil } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go index c6b10b7c38c..8283bcc21a2 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows.go @@ -1,10 +1,13 @@ package wineventlogacquisition import ( + "context" "encoding/xml" "errors" "fmt" + "net/url" "runtime" + "strconv" "strings" "syscall" "time" @@ -29,7 +32,7 @@ type WinEventLogConfiguration struct { EventLevel string `yaml:"event_level"` EventIDs []int `yaml:"event_ids"` XPathQuery string `yaml:"xpath_query"` - EventFile string `yaml:"event_file"` + EventFile string PrettyName string `yaml:"pretty_name"` } @@ -47,10 +50,13 @@ type QueryList struct { } type Select struct { - Path string `xml:"Path,attr"` + Path string `xml:"Path,attr,omitempty"` Query string `xml:",chardata"` } +// 0 identifies the local machine in windows APIs +const localMachine = 0 + var linesRead = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "cs_winevtlogsource_hits_total", @@ -77,7 +83,7 @@ func logLevelToInt(logLevel string) ([]string, error) { // This is lifted from winops/winlog, but we only want to render the basic XML string, we don't need the extra fluff func (w *WinEventLogSource) getXMLEvents(config *winlog.SubscribeConfig, publisherCache map[string]windows.Handle, resultSet windows.Handle, maxEvents int) ([]string, error) { - var events = make([]windows.Handle, maxEvents) + events := make([]windows.Handle, maxEvents) var returned uint32 // Get handles to events from the result set. @@ -88,7 +94,7 @@ func (w *WinEventLogSource) getXMLEvents(config *winlog.SubscribeConfig, publish 2000, // Timeout in milliseconds to wait. 0, // Reserved. Must be zero. &returned) // The number of handles in the array that are set by the API. - if err == windows.ERROR_NO_MORE_ITEMS { + if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) { return nil, err } else if err != nil { return nil, fmt.Errorf("wevtapi.EvtNext failed: %v", err) @@ -149,7 +155,7 @@ func (w *WinEventLogSource) buildXpathQuery() (string, error) { queryList := QueryList{Select: Select{Path: w.config.EventChannel, Query: query}} xpathQuery, err := xml.Marshal(queryList) if err != nil { - w.logger.Errorf("Marshal failed: %v", err) + w.logger.Errorf("Serialize failed: %v", err) return "", err } w.logger.Debugf("xpathQuery: %s", xpathQuery) @@ -182,7 +188,7 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error } if status == syscall.WAIT_OBJECT_0 { renderedEvents, err := w.getXMLEvents(w.evtConfig, publisherCache, subscription, 500) - if err == windows.ERROR_NO_MORE_ITEMS { + if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) { windows.ResetEvent(w.evtConfig.SignalEvent) } else if err != nil { w.logger.Errorf("getXMLEvents failed: %v", err) @@ -200,9 +206,9 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error l.Src = w.name l.Process = true if !w.config.UseTimeMachine { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE, Unmarshaled: make(map[string]interface{})} } else { - out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + out <- types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE, Unmarshaled: make(map[string]interface{})} } } } @@ -211,20 +217,28 @@ func (w *WinEventLogSource) getEvents(out chan types.Event, t *tomb.Tomb) error } } -func (w *WinEventLogSource) generateConfig(query string) (*winlog.SubscribeConfig, error) { +func (w *WinEventLogSource) generateConfig(query string, live bool) (*winlog.SubscribeConfig, error) { var config winlog.SubscribeConfig var err error - // Create a subscription signaler. - config.SignalEvent, err = windows.CreateEvent( - nil, // Default security descriptor. - 1, // Manual reset. - 1, // Initial state is signaled. - nil) // Optional name. - if err != nil { - return &config, fmt.Errorf("windows.CreateEvent failed: %v", err) + if live { + // Create a subscription signaler. + config.SignalEvent, err = windows.CreateEvent( + nil, // Default security descriptor. + 1, // Manual reset. + 1, // Initial state is signaled. + nil) // Optional name. + if err != nil { + return &config, fmt.Errorf("windows.CreateEvent failed: %v", err) + } + config.Flags = wevtapi.EvtSubscribeToFutureEvents + } else { + config.ChannelPath, err = syscall.UTF16PtrFromString(w.config.EventFile) + if err != nil { + return &config, fmt.Errorf("syscall.UTF16PtrFromString failed: %v", err) + } + config.Flags = wevtapi.EvtQueryFilePath | wevtapi.EvtQueryForwardDirection } - config.Flags = wevtapi.EvtSubscribeToFutureEvents config.Query, err = syscall.UTF16PtrFromString(query) if err != nil { return &config, fmt.Errorf("syscall.UTF16PtrFromString failed: %v", err) @@ -282,7 +296,7 @@ func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry, Metr return err } - w.evtConfig, err = w.generateConfig(w.query) + w.evtConfig, err = w.generateConfig(w.query, true) if err != nil { return err } @@ -291,6 +305,78 @@ func (w *WinEventLogSource) Configure(yamlConfig []byte, logger *log.Entry, Metr } func (w *WinEventLogSource) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry, uuid string) error { + if !strings.HasPrefix(dsn, "wineventlog://") { + return fmt.Errorf("invalid DSN %s for wineventlog source, must start with wineventlog://", dsn) + } + + w.logger = logger + w.config = WinEventLogConfiguration{} + + dsn = strings.TrimPrefix(dsn, "wineventlog://") + + args := strings.Split(dsn, "?") + + if args[0] == "" { + return errors.New("empty wineventlog:// DSN") + } + + if len(args) > 2 { + return errors.New("too many arguments in DSN") + } + + w.config.EventFile = args[0] + + if len(args) == 2 && args[1] != "" { + params, err := url.ParseQuery(args[1]) + if err != nil { + return fmt.Errorf("failed to parse DSN parameters: %w", err) + } + + for key, value := range params { + switch key { + case "log_level": + if len(value) != 1 { + return errors.New("log_level must be a single value") + } + lvl, err := log.ParseLevel(value[0]) + if err != nil { + return fmt.Errorf("failed to parse log_level: %s", err) + } + w.logger.Logger.SetLevel(lvl) + case "event_id": + for _, id := range value { + evtid, err := strconv.Atoi(id) + if err != nil { + return fmt.Errorf("failed to parse event_id: %s", err) + } + w.config.EventIDs = append(w.config.EventIDs, evtid) + } + case "event_level": + if len(value) != 1 { + return errors.New("event_level must be a single value") + } + w.config.EventLevel = value[0] + } + } + } + + var err error + + // FIXME: handle custom xpath query + w.query, err = w.buildXpathQuery() + + if err != nil { + return fmt.Errorf("buildXpathQuery failed: %w", err) + } + + w.logger.Debugf("query: %s\n", w.query) + + w.evtConfig, err = w.generateConfig(w.query, false) + + if err != nil { + return fmt.Errorf("generateConfig failed: %w", err) + } + return nil } @@ -299,10 +385,58 @@ func (w *WinEventLogSource) GetMode() string { } func (w *WinEventLogSource) SupportedModes() []string { - return []string{configuration.TAIL_MODE} + return []string{configuration.TAIL_MODE, configuration.CAT_MODE} } -func (w *WinEventLogSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) OneShotAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { + handle, err := wevtapi.EvtQuery(localMachine, w.evtConfig.ChannelPath, w.evtConfig.Query, w.evtConfig.Flags) + if err != nil { + return fmt.Errorf("EvtQuery failed: %v", err) + } + + defer winlog.Close(handle) + + publisherCache := make(map[string]windows.Handle) + defer func() { + for _, h := range publisherCache { + winlog.Close(h) + } + }() + +OUTER_LOOP: + for { + select { + case <-t.Dying(): + w.logger.Infof("wineventlog is dying") + return nil + default: + evts, err := w.getXMLEvents(w.evtConfig, publisherCache, handle, 500) + if errors.Is(err, windows.ERROR_NO_MORE_ITEMS) { + log.Info("No more items") + break OUTER_LOOP + } else if err != nil { + return fmt.Errorf("getXMLEvents failed: %v", err) + } + w.logger.Debugf("Got %d events", len(evts)) + for _, evt := range evts { + w.logger.Tracef("Event: %s", evt) + if w.metricsLevel != configuration.METRICS_NONE { + linesRead.With(prometheus.Labels{"source": w.name}).Inc() + } + l := types.Line{} + l.Raw = evt + l.Module = w.GetName() + l.Labels = w.config.Labels + l.Time = time.Now() + l.Src = w.name + l.Process = true + csevt := types.MakeEvent(w.config.UseTimeMachine, types.LOG, true) + csevt.Line = l + out <- csevt + } + } + } + return nil } @@ -325,7 +459,7 @@ func (w *WinEventLogSource) CanRun() error { return nil } -func (w *WinEventLogSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { +func (w *WinEventLogSource) StreamingAcquisition(ctx context.Context, out chan types.Event, t *tomb.Tomb) error { t.Go(func() error { defer trace.CatchPanic("crowdsec/acquis/wineventlog/streaming") return w.getEvents(out, t) diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go similarity index 71% rename from pkg/acquisition/modules/wineventlog/wineventlog_test.go rename to pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go index 2ea0e365be5..2f6fe15450f 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go @@ -3,7 +3,7 @@ package wineventlogacquisition import ( - "runtime" + "context" "testing" "time" @@ -18,9 +18,8 @@ import ( ) func TestBadConfiguration(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + tests := []struct { config string expectedErr string @@ -63,9 +62,8 @@ xpath_query: test`, } func TestQueryBuilder(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + tests := []struct { config string expectedQuery string @@ -129,9 +127,8 @@ event_level: bla`, } func TestLiveAcquisition(t *testing.T) { - if runtime.GOOS != "windows" { - t.Skip("Skipping test on non-windows OS") - } + exprhelpers.Init(nil) + ctx := context.Background() tests := []struct { config string @@ -180,7 +177,6 @@ event_ids: subLogger := log.WithField("type", "windowseventlog") evthandler, err := eventlog.Open("Application") - if err != nil { t.Fatalf("failed to open event log: %s", err) } @@ -190,7 +186,7 @@ event_ids: c := make(chan types.Event) f := WinEventLogSource{} f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE) - f.StreamingAcquisition(c, to) + f.StreamingAcquisition(ctx, c, to) time.Sleep(time.Second) lines := test.expectedLines go func() { @@ -225,3 +221,83 @@ event_ids: to.Wait() } } + +func TestOneShotAcquisition(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + dsn string + expectedCount int + expectedErr string + expectedConfigureErr string + }{ + { + name: "non-existing file", + dsn: `wineventlog://foo.evtx`, + expectedCount: 0, + expectedErr: "The system cannot find the file specified.", + }, + { + name: "empty DSN", + dsn: `wineventlog://`, + expectedCount: 0, + expectedConfigureErr: "empty wineventlog:// DSN", + }, + { + name: "existing file", + dsn: `wineventlog://test_files/Setup.evtx`, + expectedCount: 24, + expectedErr: "", + }, + { + name: "filter on event_id", + dsn: `wineventlog://test_files/Setup.evtx?event_id=2`, + expectedCount: 1, + }, + { + name: "filter on event_id", + dsn: `wineventlog://test_files/Setup.evtx?event_id=2&event_id=3`, + expectedCount: 24, + }, + } + + exprhelpers.Init(nil) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + lineCount := 0 + to := &tomb.Tomb{} + c := make(chan types.Event) + f := WinEventLogSource{} + err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "wineventlog"}, log.WithField("type", "windowseventlog"), "") + + if test.expectedConfigureErr != "" { + assert.Contains(t, err.Error(), test.expectedConfigureErr) + return + } + + require.NoError(t, err) + + go func() { + for { + select { + case <-c: + lineCount++ + case <-to.Dying(): + return + } + } + }() + + err = f.OneShotAcquisition(ctx, c, to) + if test.expectedErr != "" { + assert.Contains(t, err.Error(), test.expectedErr) + } else { + require.NoError(t, err) + + time.Sleep(2 * time.Second) + assert.Equal(t, test.expectedCount, lineCount) + } + }) + } +} diff --git a/pkg/acquisition/s3.go b/pkg/acquisition/s3.go new file mode 100644 index 00000000000..73343b0408d --- /dev/null +++ b/pkg/acquisition/s3.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_s3 + +package acquisition + +import ( + s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("s3", func() DataSource { return &s3acquisition.S3Source{} }) +} diff --git a/pkg/acquisition/syslog.go b/pkg/acquisition/syslog.go new file mode 100644 index 00000000000..f62cc23b916 --- /dev/null +++ b/pkg/acquisition/syslog.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_syslog + +package acquisition + +import ( + syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("syslog", func() DataSource { return &syslogacquisition.SyslogSource{} }) +} diff --git a/pkg/acquisition/wineventlog.go b/pkg/acquisition/wineventlog.go new file mode 100644 index 00000000000..0c4889a3f5c --- /dev/null +++ b/pkg/acquisition/wineventlog.go @@ -0,0 +1,12 @@ +//go:build !no_datasource_wineventlog + +package acquisition + +import ( + wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" +) + +//nolint:gochecknoinits +func init() { + registerDataSource("wineventlog", func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }) +} diff --git a/pkg/alertcontext/alertcontext.go b/pkg/alertcontext/alertcontext.go index c502def32cd..1b7d1e20018 100644 --- a/pkg/alertcontext/alertcontext.go +++ b/pkg/alertcontext/alertcontext.go @@ -3,6 +3,7 @@ package alertcontext import ( "encoding/json" "fmt" + "net/http" "slices" "strconv" @@ -30,9 +31,12 @@ type Context struct { func ValidateContextExpr(key string, expressions []string) error { for _, expression := range expressions { - _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + _, err := expr.Compile(expression, exprhelpers.GetExprOptions(map[string]interface{}{ + "evt": &types.Event{}, + "match": &types.MatchedRule{}, + "req": &http.Request{}})...) if err != nil { - return fmt.Errorf("compilation of '%s' failed: %v", expression, err) + return fmt.Errorf("compilation of '%s' failed: %w", expression, err) } } @@ -72,9 +76,12 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { } for _, value := range values { - valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) + valueCompiled, err := expr.Compile(value, exprhelpers.GetExprOptions(map[string]interface{}{ + "evt": &types.Event{}, + "match": &types.MatchedRule{}, + "req": &http.Request{}})...) if err != nil { - return fmt.Errorf("compilation of '%s' context value failed: %v", value, err) + return fmt.Errorf("compilation of '%s' context value failed: %w", value, err) } alertContext.ContextToSendCompiled[key] = append(alertContext.ContextToSendCompiled[key], valueCompiled) @@ -85,6 +92,32 @@ func NewAlertContext(contextToSend map[string][]string, valueLength int) error { return nil } +// Truncate the context map to fit in the context value length +func TruncateContextMap(contextMap map[string][]string, contextValueLen int) ([]*models.MetaItems0, []error) { + metas := make([]*models.MetaItems0, 0) + errors := make([]error, 0) + + for key, values := range contextMap { + if len(values) == 0 { + continue + } + + valueStr, err := TruncateContext(values, alertContext.ContextValueLen) + if err != nil { + errors = append(errors, fmt.Errorf("error truncating content for %s: %w", key, err)) + continue + } + + meta := models.MetaItems0{ + Key: key, + Value: valueStr, + } + metas = append(metas, &meta) + } + return metas, errors +} + +// Truncate an individual []string to fit in the context value length func TruncateContext(values []string, contextValueLen int) (string, error) { valueByte, err := json.Marshal(values) if err != nil { @@ -116,61 +149,102 @@ func TruncateContext(values []string, contextValueLen int) (string, error) { return ret, nil } -func EventToContext(events []types.Event) (models.Meta, []error) { +func EvalAlertContextRules(evt types.Event, match *types.MatchedRule, request *http.Request, tmpContext map[string][]string) []error { + var errors []error - metas := make([]*models.MetaItems0, 0) - tmpContext := make(map[string][]string) + //if we're evaluating context for appsec event, match and request will be present. + //otherwise, only evt will be. + if match == nil { + match = types.NewMatchedRule() + } + if request == nil { + request = &http.Request{} + } - for _, evt := range events { - for key, values := range alertContext.ContextToSendCompiled { - if _, ok := tmpContext[key]; !ok { - tmpContext[key] = make([]string, 0) - } + for key, values := range alertContext.ContextToSendCompiled { - for _, value := range values { - var val string + if _, ok := tmpContext[key]; !ok { + tmpContext[key] = make([]string, 0) + } - output, err := expr.Run(value, map[string]interface{}{"evt": evt}) - if err != nil { - errors = append(errors, fmt.Errorf("failed to get value for %s : %v", key, err)) - continue - } + for _, value := range values { + var val string - switch out := output.(type) { - case string: - val = out - case int: - val = strconv.Itoa(out) - default: - errors = append(errors, fmt.Errorf("unexpected return type for %s : %T", key, output)) - continue + output, err := expr.Run(value, map[string]interface{}{"match": match, "evt": evt, "req": request}) + if err != nil { + errors = append(errors, fmt.Errorf("failed to get value for %s: %w", key, err)) + continue + } + switch out := output.(type) { + case string: + val = out + if val != "" && !slices.Contains(tmpContext[key], val) { + tmpContext[key] = append(tmpContext[key], val) } - + case []string: + for _, v := range out { + if v != "" && !slices.Contains(tmpContext[key], v) { + tmpContext[key] = append(tmpContext[key], v) + } + } + case int: + val = strconv.Itoa(out) + if val != "" && !slices.Contains(tmpContext[key], val) { + tmpContext[key] = append(tmpContext[key], val) + } + case []int: + for _, v := range out { + val = strconv.Itoa(v) + if val != "" && !slices.Contains(tmpContext[key], val) { + tmpContext[key] = append(tmpContext[key], val) + } + } + default: + val := fmt.Sprintf("%v", output) if val != "" && !slices.Contains(tmpContext[key], val) { tmpContext[key] = append(tmpContext[key], val) } } } } + return errors +} - for key, values := range tmpContext { - if len(values) == 0 { - continue - } +// Iterate over the individual appsec matched rules to create the needed alert context. +func AppsecEventToContext(event types.AppsecEvent, request *http.Request) (models.Meta, []error) { + var errors []error - valueStr, err := TruncateContext(values, alertContext.ContextValueLen) - if err != nil { - log.Warningf(err.Error()) - } + tmpContext := make(map[string][]string) - meta := models.MetaItems0{ - Key: key, - Value: valueStr, - } - metas = append(metas, &meta) + evt := types.MakeEvent(false, types.LOG, false) + for _, matched_rule := range event.MatchedRules { + tmpErrors := EvalAlertContextRules(evt, &matched_rule, request, tmpContext) + errors = append(errors, tmpErrors...) } + metas, truncErrors := TruncateContextMap(tmpContext, alertContext.ContextValueLen) + errors = append(errors, truncErrors...) + + ret := models.Meta(metas) + + return ret, errors +} + +// Iterate over the individual events to create the needed alert context. +func EventToContext(events []types.Event) (models.Meta, []error) { + var errors []error + + tmpContext := make(map[string][]string) + + for _, evt := range events { + tmpErrors := EvalAlertContextRules(evt, nil, nil, tmpContext) + errors = append(errors, tmpErrors...) + } + + metas, truncErrors := TruncateContextMap(tmpContext, alertContext.ContextValueLen) + errors = append(errors, truncErrors...) + ret := models.Meta(metas) return ret, errors diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go index c111d1bbcfb..dc752ba8b09 100644 --- a/pkg/alertcontext/alertcontext_test.go +++ b/pkg/alertcontext/alertcontext_test.go @@ -2,6 +2,7 @@ package alertcontext import ( "fmt" + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestNewAlertContext(t *testing.T) { @@ -200,3 +202,163 @@ func TestEventToContext(t *testing.T) { assert.ElementsMatch(t, test.expectedResult, metas) } } + +func TestValidateContextExpr(t *testing.T) { + tests := []struct { + name string + key string + exprs []string + expectedErr *string + }{ + { + name: "basic config", + key: "source_ip", + exprs: []string{ + "evt.Parsed.source_ip", + }, + expectedErr: nil, + }, + { + name: "basic config with non existent field", + key: "source_ip", + exprs: []string{ + "evt.invalid.source_ip", + }, + expectedErr: ptr.Of("compilation of 'evt.invalid.source_ip' failed: type types.Event has no field invalid"), + }, + } + for _, test := range tests { + fmt.Printf("Running test '%s'\n", test.name) + err := ValidateContextExpr(test.key, test.exprs) + if test.expectedErr == nil { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, *test.expectedErr) + } + } +} + +func TestAppsecEventToContext(t *testing.T) { + + tests := []struct { + name string + contextToSend map[string][]string + match types.AppsecEvent + req *http.Request + expectedResult models.Meta + expectedErrLen int + }{ + { + name: "basic test on match", + contextToSend: map[string][]string{ + "id": {"match.id"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{}, + expectedResult: []*models.MetaItems0{ + { + Key: "id", + Value: "[\"test\"]", + }, + }, + expectedErrLen: 0, + }, + { + name: "basic test on req", + contextToSend: map[string][]string{ + "ua": {"req.UserAgent()"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{ + Header: map[string][]string{ + "User-Agent": {"test"}, + }, + }, + expectedResult: []*models.MetaItems0{ + { + Key: "ua", + Value: "[\"test\"]", + }, + }, + expectedErrLen: 0, + }, + { + name: "test on req -> []string", + contextToSend: map[string][]string{ + "foobarxx": {"req.Header.Values('Foobar')"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{ + Header: map[string][]string{ + "User-Agent": {"test"}, + "Foobar": {"test1", "test2"}, + }, + }, + expectedResult: []*models.MetaItems0{ + { + Key: "foobarxx", + Value: "[\"test1\",\"test2\"]", + }, + }, + expectedErrLen: 0, + }, + { + name: "test on type int", + contextToSend: map[string][]string{ + "foobarxx": {"len(req.Header.Values('Foobar'))"}, + }, + match: types.AppsecEvent{ + MatchedRules: types.MatchedRules{ + { + "id": "test", + }, + }, + }, + req: &http.Request{ + Header: map[string][]string{ + "User-Agent": {"test"}, + "Foobar": {"test1", "test2"}, + }, + }, + expectedResult: []*models.MetaItems0{ + { + Key: "foobarxx", + Value: "[\"2\"]", + }, + }, + expectedErrLen: 0, + }, + } + + for _, test := range tests { + //reset cache + alertContext = Context{} + //compile + if err := NewAlertContext(test.contextToSend, 100); err != nil { + t.Fatalf("failed to compile %s: %s", test.name, err) + } + //run + + metas, errors := AppsecEventToContext(test.match, test.req) + assert.Len(t, errors, test.expectedErrLen) + assert.ElementsMatch(t, test.expectedResult, metas) + } +} diff --git a/pkg/alertcontext/config.go b/pkg/alertcontext/config.go index 21d16db3972..6ef877619e4 100644 --- a/pkg/alertcontext/config.go +++ b/pkg/alertcontext/config.go @@ -98,20 +98,14 @@ func addContextFromFile(toSend map[string][]string, filePath string) error { return nil } - // LoadConsoleContext loads the context from the hub (if provided) and the file console_context_path. func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error { c.Crowdsec.ContextToSend = make(map[string][]string, 0) if hub != nil { - items, err := hub.GetInstalledItemsByType(cwhub.CONTEXTS) - if err != nil { - return err - } - - for _, item := range items { + for _, item := range hub.GetInstalledByType(cwhub.CONTEXTS, true) { // context in item files goes under the key 'context' - if err = addContextFromItem(c.Crowdsec.ContextToSend, item); err != nil { + if err := addContextFromItem(c.Crowdsec.ContextToSend, item); err != nil { return err } } @@ -139,7 +133,7 @@ func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error { feedback, err := json.Marshal(c.Crowdsec.ContextToSend) if err != nil { - return fmt.Errorf("marshaling console context: %s", err) + return fmt.Errorf("serializing console context: %s", err) } log.Debugf("console context to send: %s", feedback) diff --git a/pkg/apiclient/alerts_service_test.go b/pkg/apiclient/alerts_service_test.go index 12ef2d295f4..0d1ff41685f 100644 --- a/pkg/apiclient/alerts_service_test.go +++ b/pkg/apiclient/alerts_service_test.go @@ -14,7 +14,6 @@ import ( "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -35,7 +34,6 @@ func TestAlertsListAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -180,16 +178,16 @@ func TestAlertsListAsMachine(t *testing.T) { }, } - //log.Debugf("data : -> %s", spew.Sdump(alerts)) - //log.Debugf("resp : -> %s", spew.Sdump(resp)) - //log.Debugf("expected : -> %s", spew.Sdump(expected)) - //first one returns data + // log.Debugf("data : -> %s", spew.Sdump(alerts)) + // log.Debugf("resp : -> %s", spew.Sdump(resp)) + // log.Debugf("expected : -> %s", spew.Sdump(expected)) + // first one returns data alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, expected, *alerts) - //this one doesn't + // this one doesn't filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err = client.Alerts.List(context.Background(), filter) @@ -214,7 +212,6 @@ func TestAlertsGetAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -360,7 +357,7 @@ func TestAlertsGetAsMachine(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, *expected, *alerts) - //fail + // fail _, _, err = client.Alerts.GetByID(context.Background(), 2) cstest.RequireErrorMessage(t, err, "API error: object not found") } @@ -388,7 +385,6 @@ func TestAlertsCreateAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -430,7 +426,6 @@ func TestAlertsDeleteAsMachine(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) diff --git a/pkg/apiclient/auth_jwt.go b/pkg/apiclient/auth_jwt.go index b202e382842..193486ff065 100644 --- a/pkg/apiclient/auth_jwt.go +++ b/pkg/apiclient/auth_jwt.go @@ -2,6 +2,7 @@ package apiclient import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -30,15 +31,17 @@ type JWTTransport struct { // Transport is the underlying HTTP transport to use when making requests. // It will default to http.DefaultTransport if nil. Transport http.RoundTripper - UpdateScenario func() ([]string, error) + UpdateScenario func(context.Context) ([]string, error) refreshTokenMutex sync.Mutex } func (t *JWTTransport) refreshJwtToken() error { var err error + ctx := context.TODO() + if t.UpdateScenario != nil { - t.Scenarios, err = t.UpdateScenario() + t.Scenarios, err = t.UpdateScenario(ctx) if err != nil { return fmt.Errorf("can't update scenario list: %w", err) } diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index 3e887149a98..d22c9394014 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -36,11 +35,13 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) newStr := buf.String() var payload BasicMockPayload + err := json.Unmarshal([]byte(newStr), &payload) if err != nil || payload.MachineID == "" || payload.Password == "" { log.Printf("Bad payload") @@ -48,8 +49,8 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { } var responseBody string - responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] + responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] if !hasFoundErrorMock { responseCode = http.StatusOK responseBody = `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}` @@ -76,7 +77,7 @@ func TestWatcherRegister(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers") log.Printf("URL is %s", urlx) @@ -87,12 +88,13 @@ func TestWatcherRegister(t *testing.T) { clientconfig := Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", } - client, err := RegisterClient(&clientconfig, &http.Client{}) + ctx := context.Background() + + client, err := RegisterClient(ctx, &clientconfig, &http.Client{}) require.NoError(t, err) log.Printf("->%T", client) @@ -102,7 +104,7 @@ func TestWatcherRegister(t *testing.T) { for _, errorCodeToTest := range errorCodesToTest { clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) - client, err = RegisterClient(&clientconfig, &http.Client{}) + client, err = RegisterClient(ctx, &clientconfig, &http.Client{}) require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest) require.Error(t, err, "error expected for the response code %d", errorCodeToTest) } @@ -113,7 +115,7 @@ func TestWatcherAuth(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers/login") log.Printf("URL is %s", urlx) @@ -121,11 +123,10 @@ func TestWatcherAuth(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok auth + // ok auth clientConfig := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, @@ -161,7 +162,7 @@ func TestWatcherAuth(t *testing.T) { bodyBytes, err := io.ReadAll(resp.Response.Body) require.NoError(t, err) - log.Printf(string(bodyBytes)) + log.Print(string(bodyBytes)) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) } @@ -174,7 +175,7 @@ func TestWatcherUnregister(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + // body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") @@ -184,6 +185,7 @@ func TestWatcherUnregister(t *testing.T) { mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) @@ -206,7 +208,6 @@ func TestWatcherUnregister(t *testing.T) { mycfg := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, @@ -229,6 +230,7 @@ func TestWatcherEnroll(t *testing.T) { mux.HandleFunc("/watchers/enroll", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") + buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) newStr := buf.String() @@ -260,7 +262,6 @@ func TestWatcherEnroll(t *testing.T) { mycfg := &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index 2cb68f597f3..47d97a28344 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -12,6 +12,7 @@ import ( "github.com/golang-jwt/jwt/v4" + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -66,11 +67,16 @@ type service struct { } func NewClient(config *Config) (*ApiClient, error) { + userAgent := config.UserAgent + if userAgent == "" { + userAgent = useragent.Default() + } + t := &JWTTransport{ MachineID: &config.MachineID, Password: &config.Password, Scenarios: config.Scenarios, - UserAgent: config.UserAgent, + UserAgent: userAgent, VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, RetryConfig: NewRetryConfig( @@ -105,7 +111,7 @@ func NewClient(config *Config) (*ApiClient, error) { t.Transport.(*http.Transport).TLSClientConfig = &tlsconfig } - c := &ApiClient{client: t.Client(), BaseURL: baseURL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} + c := &ApiClient{client: t.Client(), BaseURL: baseURL, UserAgent: userAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) @@ -143,6 +149,10 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt } } + if userAgent == "" { + userAgent = useragent.Default() + } + c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: prefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) @@ -157,7 +167,7 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt return c, nil } -func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { +func RegisterClient(ctx context.Context, config *Config, client *http.Client) (*ApiClient, error) { transport, baseURL := createTransport(config.URL) if client == nil { @@ -178,15 +188,20 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { client.Transport = transport } - c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} + userAgent := config.UserAgent + if userAgent == "" { + userAgent = useragent.Default() + } + + c := &ApiClient{client: client, BaseURL: baseURL, UserAgent: userAgent, URLPrefix: config.VersionPrefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) c.Alerts = (*AlertsService)(&c.common) c.Auth = (*AuthService)(&c.common) - resp, err := c.Auth.RegisterWatcher(context.Background(), models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken}) - /*if we have http status, return it*/ + resp, err := c.Auth.RegisterWatcher(ctx, models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password, RegistrationToken: config.RegistrationToken}) if err != nil { + /*if we have http status, return it*/ if resp != nil && resp.Response != nil { return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err) } diff --git a/pkg/apiclient/client_http.go b/pkg/apiclient/client_http.go index 0240618f535..eeca929ea6e 100644 --- a/pkg/apiclient/client_http.go +++ b/pkg/apiclient/client_http.go @@ -61,9 +61,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* req.Header.Add("User-Agent", c.UserAgent) } - if log.GetLevel() >= log.DebugLevel { - log.Debugf("[URL] %s %s", req.Method, req.URL) - } + log.Debugf("[URL] %s %s", req.Method, req.URL) resp, err := c.client.Do(req) if resp != nil && resp.Body != nil { diff --git a/pkg/apiclient/client_http_test.go b/pkg/apiclient/client_http_test.go index 4bdfe1d0da5..45cd8410a8e 100644 --- a/pkg/apiclient/client_http_test.go +++ b/pkg/apiclient/client_http_test.go @@ -10,22 +10,19 @@ import ( "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/cwversion" ) func TestNewRequestInvalid(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //missing slash in uri + // missing slash in uri apiURL, err := url.Parse(urlx) require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -57,7 +54,6 @@ func TestNewRequestTimeout(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index bd83e512afc..d1f58f33ad2 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -17,8 +17,6 @@ import ( "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/cstest" - - "github.com/crowdsecurity/crowdsec/pkg/cwversion" ) /*this is a ripoff of google/go-github approach : @@ -97,7 +95,6 @@ func TestNewClientOk(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -134,7 +131,6 @@ func TestNewClientOk_UnixSocket(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -172,7 +168,6 @@ func TestNewClientKo(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -247,10 +242,11 @@ func TestNewClientRegisterKO(t *testing.T) { apiURL, err := url.Parse("http://127.0.0.1:4242/") require.NoError(t, err) - _, err = RegisterClient(&Config{ + ctx := context.Background() + + _, err = RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) @@ -278,10 +274,11 @@ func TestNewClientRegisterOK(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - client, err := RegisterClient(&Config{ + ctx := context.Background() + + client, err := RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) @@ -311,10 +308,11 @@ func TestNewClientRegisterOK_UnixSocket(t *testing.T) { t.Fatalf("parsing api url: %s", apiURL) } - client, err := RegisterClient(&Config{ + ctx := context.Background() + + client, err := RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) @@ -341,10 +339,11 @@ func TestNewClientBadAnswer(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - _, err = RegisterClient(&Config{ + ctx := context.Background() + + _, err = RegisterClient(ctx, &Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) diff --git a/pkg/apiclient/config.go b/pkg/apiclient/config.go index b08452e74e0..29a8acf185e 100644 --- a/pkg/apiclient/config.go +++ b/pkg/apiclient/config.go @@ -1,6 +1,7 @@ package apiclient import ( + "context" "net/url" "github.com/go-openapi/strfmt" @@ -15,5 +16,5 @@ type Config struct { VersionPrefix string UserAgent string RegistrationToken string - UpdateScenario func() ([]string, error) + UpdateScenario func(context.Context) ([]string, error) } diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index 388a870f999..fea2f39072d 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -31,6 +31,8 @@ type DecisionsListOpts struct { type DecisionsStreamOpts struct { Startup bool `url:"startup,omitempty"` + CommunityPull bool `url:"community_pull"` + AdditionalPull bool `url:"additional_pull"` Scopes string `url:"scopes,omitempty"` ScenariosContaining string `url:"scenarios_containing,omitempty"` ScenariosNotContaining string `url:"scenarios_not_containing,omitempty"` @@ -43,6 +45,17 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) { return "", err } + //Those 2 are a bit different + //They default to true, and we only want to include them if they are false + + if params.Get("community_pull") == "true" { + params.Del("community_pull") + } + + if params.Get("additional_pull") == "true" { + params.Del("additional_pull") + } + return fmt.Sprintf("%s?%s", url, params.Encode()), nil } @@ -144,7 +157,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) for idx, decision := range decisionsGroup.Decisions { - decision := decision // fix exportloopref linter message + decision := decision //nolint:copyloopvar // fix exportloopref linter message partialDecisions[idx] = &models.Decision{ Scenario: &scenarioDeleted, Scope: decisionsGroup.Scope, diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index 6942cfc9d85..942d14689ff 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "testing" log "github.com/sirupsen/logrus" @@ -13,7 +14,6 @@ import ( "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/modelscapi" ) @@ -26,6 +26,7 @@ func TestDecisionsList(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") + if r.URL.RawQuery == "ip=1.2.3.4" { assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) @@ -34,14 +35,14 @@ func TestDecisionsList(t *testing.T) { } else { w.WriteHeader(http.StatusOK) w.Write([]byte(`null`)) - //no results + // no results } }) apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -68,7 +69,7 @@ func TestDecisionsList(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, *expected, *decisions) - //Empty return + // Empty return decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")} decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter) require.NoError(t, err) @@ -85,8 +86,9 @@ func TestDecisionsStream(t *testing.T) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { - if r.URL.RawQuery == "startup=true" { + if strings.Contains(r.URL.RawQuery, "startup=true") { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":null,"new":[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]}`)) } else { @@ -99,6 +101,7 @@ func TestDecisionsStream(t *testing.T) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodDelete) + if r.Method == http.MethodDelete { w.WriteHeader(http.StatusOK) } @@ -107,7 +110,7 @@ func TestDecisionsStream(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -134,14 +137,14 @@ func TestDecisionsStream(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Equal(t, *expected, *decisions) - //and second call, we get empty lists + // and second call, we get empty lists decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false}) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, decisions.New) assert.Empty(t, decisions.Deleted) - //delete stream + // delete stream resp, err = newcli.Decisions.StopStream(context.Background()) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.Response.StatusCode) @@ -156,8 +159,9 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { - if r.URL.RawQuery == "startup=true" { + if strings.Contains(r.URL.RawQuery, "startup=true") { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":[{"scope":"ip","decisions":["1.2.3.5"]}],"new":[{"scope":"ip", "scenario": "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'", "decisions":[{"duration":"3h59m55.756182786s","value":"1.2.3.4"}]}]}`)) } else { @@ -170,7 +174,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -220,6 +224,7 @@ func TestDecisionsStreamV3(t *testing.T) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodGet) + if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"deleted":[{"scope":"ip","decisions":["1.2.3.5"]}], @@ -231,7 +236,7 @@ func TestDecisionsStreamV3(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -305,7 +310,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { apiURL, err := url.Parse(urlx + "/") require.NoError(t, err) - //ok answer + // ok answer auth := &APIKeyTransport{ APIKey: "ixu", } @@ -391,7 +396,7 @@ func TestDeleteDecisions(t *testing.T) { assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"nbDeleted":"1"}`)) - //w.Write([]byte(`{"message":"0 deleted alerts"}`)) + // w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) log.Printf("URL is %s", urlx) @@ -402,7 +407,6 @@ func TestDeleteDecisions(t *testing.T) { client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", - UserAgent: cwversion.UserAgent(), URL: apiURL, VersionPrefix: "v1", }) @@ -426,6 +430,8 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { Scopes string ScenariosContaining string ScenariosNotContaining string + CommunityPull bool + AdditionalPull bool } tests := []struct { @@ -437,11 +443,17 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { { name: "no filter", expected: baseURLString + "?", + fields: fields{ + CommunityPull: true, + AdditionalPull: true, + }, }, { name: "startup=true", fields: fields{ - Startup: true, + Startup: true, + CommunityPull: true, + AdditionalPull: true, }, expected: baseURLString + "?startup=true", }, @@ -452,9 +464,19 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { Scopes: "ip,range", ScenariosContaining: "ssh", ScenariosNotContaining: "bf", + CommunityPull: true, + AdditionalPull: true, }, expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", }, + { + name: "pull options", + fields: fields{ + CommunityPull: false, + AdditionalPull: false, + }, + expected: baseURLString + "?additional_pull=false&community_pull=false", + }, } for _, tt := range tests { @@ -464,10 +486,13 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { Scopes: tt.fields.Scopes, ScenariosContaining: tt.fields.ScenariosContaining, ScenariosNotContaining: tt.fields.ScenariosNotContaining, + CommunityPull: tt.fields.CommunityPull, + AdditionalPull: tt.fields.AdditionalPull, } got, err := o.addQueryParamsToURL(baseURLString) cstest.RequireErrorContains(t, err, tt.expectedErr) + if tt.expectedErr != "" { return } @@ -502,7 +527,6 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { // client, err := NewClient(&Config{ // MachineID: "test_login", // Password: "test_password", -// UserAgent: cwversion.UserAgent(), // URL: apiURL, // VersionPrefix: "v1", // }) diff --git a/pkg/apiclient/resperr.go b/pkg/apiclient/resperr.go index 00689147332..1b0786f9882 100644 --- a/pkg/apiclient/resperr.go +++ b/pkg/apiclient/resperr.go @@ -19,7 +19,7 @@ func (e *ErrorResponse) Error() string { message := ptr.OrEmpty(e.Message) errors := "" - if len(e.Errors) > 0 { + if e.Errors != "" { errors = fmt.Sprintf(" (%s)", e.Errors) } @@ -51,7 +51,7 @@ func CheckResponse(r *http.Response) error { // try to unmarshal and if there are no 'message' or 'errors' fields, display the body as is, // the API is following a different convention err := json.Unmarshal(data, ret) - if err != nil || (ret.Message == nil && len(ret.Errors) == 0) { + if err != nil || (ret.Message == nil && ret.Errors == "") { ret.Message = ptr.Of(fmt.Sprintf("http code %d, response: %s", r.StatusCode, string(data))) return ret } diff --git a/pkg/apiclient/useragent/useragent.go b/pkg/apiclient/useragent/useragent.go new file mode 100644 index 00000000000..5a62ce1ac06 --- /dev/null +++ b/pkg/apiclient/useragent/useragent.go @@ -0,0 +1,9 @@ +package useragent + +import ( + "github.com/crowdsecurity/go-cs-lib/version" +) + +func Default() string { + return "crowdsec/" + version.String() + "-" + version.System +} diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 891eb3a8f4a..d86234e4813 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "net/http" @@ -25,11 +26,11 @@ type LAPI struct { DBConfig *csconfig.DatabaseCfg } -func SetupLAPITest(t *testing.T) LAPI { +func SetupLAPITest(t *testing.T, ctx context.Context) LAPI { t.Helper() - router, loginResp, config := InitMachineTest(t) + router, loginResp, config := InitMachineTest(t, ctx) - APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) + APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) return LAPI{ router: router, @@ -39,14 +40,14 @@ func SetupLAPITest(t *testing.T) LAPI { } } -func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.ResponseRecorder { +func (l *LAPI) InsertAlertFromFile(t *testing.T, ctx context.Context, path string) *httptest.ResponseRecorder { alertReader := GetAlertReaderFromFile(t, path) - return l.RecordResponse(t, http.MethodPost, "/v1/alerts", alertReader, "password") + return l.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertReader, "password") } -func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { +func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { w := httptest.NewRecorder() - req, err := http.NewRequest(verb, url, body) + req, err := http.NewRequestWithContext(ctx, verb, url, body) require.NoError(t, err) switch authType { @@ -58,24 +59,27 @@ func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strin t.Fatal("auth type not supported") } + // Port is required for gin to properly parse the client IP + req.RemoteAddr = "127.0.0.1:1234" + l.router.ServeHTTP(w, req) return w } -func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { - router, config := NewAPITest(t) - loginResp := LoginToTestAPI(t, router, config) +func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { + router, config := NewAPITest(t, ctx) + loginResp := LoginToTestAPI(t, ctx, router, config) return router, loginResp, config } -func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { - body := CreateTestMachine(t, router, "") - ValidateMachine(t, "test", config.API.Server.DbConfig) +func LoginToTestAPI(t *testing.T, ctx context.Context, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { + body := CreateTestMachine(t, ctx, router, "") + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -92,50 +96,55 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon } func TestSimulatedAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk+simul.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk+simul.json") alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json") - //exclude decision in simulation mode + // exclude decision in simulation mode - w := lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) - //include decision in simulation mode + // include decision in simulation mode - w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", alertContent, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) } func TestCreateAlert(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Alert with invalid format - w := lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") - w = lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", alertContent, "password") + w = lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String()) + assert.Equal(t, + `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, + w.Body.String()) // Create Valid Alert - w = lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + w = lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assert.Equal(t, 201, w.Code) assert.Equal(t, `["1"]`, w.Body.String()) } func TestCreateAlertChannels(t *testing.T) { - apiServer, config := NewAPIServer(t) + ctx := context.Background() + apiServer, config := NewAPIServer(t, ctx) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() - loginResp := LoginToTestAPI(t, apiServer.router, config) + loginResp := LoginToTestAPI(t, ctx, apiServer.router, config) lapi := LAPI{router: apiServer.router, loginResp: loginResp} var ( @@ -151,221 +160,225 @@ func TestCreateAlertChannels(t *testing.T) { wg.Done() }() - lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_ssh-bf.json") wg.Wait() assert.Len(t, pd.Alert.Decisions, 1) apiServer.Close() } func TestAlertListFilters(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_ssh-bf.json") alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json") - //bad filter + // bad filter - w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) - //get without filters + // get without filters - w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) - //check alert and decision + // check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test decision_type filter (ok) + // test decision_type filter (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test decision_type filter (bad value) + // test decision_type filter (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test scope (ok) + // test scope (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test scope (bad value) + // test scope (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test scenario (ok) + // test scenario (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test scenario (bad value) + // test scenario (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test ip (ok) + // test ip (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test ip (bad value) + // test ip (bad value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test ip (invalid value) + // test ip (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) - //test range (ok) + // test range (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test range + // test range - w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test range (invalid value) + // test range (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) - //test since (ok) + // test since (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1h", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1h", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test since (ok but yields no results) + // test since (ok but yields no results) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test since (invalid value) + // test since (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) - //test until (ok) + // test until (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test until (ok but no return) + // test until (ok but no return) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1m", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1m", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test until (invalid value) + // test until (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) - //test simulated (ok) + // test simulated (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test simulated (ok) + // test simulated (ok) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?simulated=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test has active decision + // test has active decision - w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) - //test has active decision + // test has active decision - w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) - //test has active decision (invalid value) + // test has active decision (invalid value) - w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) } func TestAlertBulkInsert(t *testing.T) { - lapi := SetupLAPITest(t) - //insert a bulk of 20 alerts to trigger bulk insert - lapi.InsertAlertFromFile(t, "./tests/alert_bulk.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + // insert a bulk of 20 alerts to trigger bulk insert + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_bulk.json") alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json") - w := lapi.RecordResponse(t, "GET", "/v1/alerts", alertContent, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", alertContent, "password") assert.Equal(t, 200, w.Code) } func TestListAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // List Alert with invalid filter - w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", emptyBody, "password") + w := lapi.RecordResponse(t, ctx, "GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert - w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } func TestCreateAlertErrors(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") - //test invalid bearer + // test invalid bearer w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", "ratata")) lapi.router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - //test invalid bearer + // test invalid bearer w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/alerts", alertContent) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/alerts", alertContent) req.Header.Add("User-Agent", UserAgent) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) lapi.router.ServeHTTP(w, req) @@ -373,12 +386,13 @@ func TestCreateAlertErrors(t *testing.T) { } func TestDeleteAlert(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -387,7 +401,7 @@ func TestDeleteAlert(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -396,12 +410,13 @@ func TestDeleteAlert(t *testing.T) { } func TestDeleteAlertByID(t *testing.T) { - lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.2:4242" lapi.router.ServeHTTP(w, req) @@ -410,7 +425,7 @@ func TestDeleteAlertByID(t *testing.T) { // Delete Alert w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts/1", strings.NewReader("")) AddAuthHeaders(req, lapi.loginResp) req.RemoteAddr = "127.0.0.1:4242" lapi.router.ServeHTTP(w, req) @@ -419,12 +434,13 @@ func TestDeleteAlertByID(t *testing.T) { } func TestDeleteAlertTrustedIPS(t *testing.T) { + ctx := context.Background() cfg := LoadTestConfig(t) // IPv6 mocking doesn't seem to work. // cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"} cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" - server, err := NewServer(cfg.API.Server) + server, err := NewServer(ctx, cfg.API.Server) require.NoError(t, err) err = server.InitController() @@ -433,7 +449,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { router, err := server.Router() require.NoError(t, err) - loginResp := LoginToTestAPI(t, router, cfg) + loginResp := LoginToTestAPI(t, ctx, router, cfg) lapi := LAPI{ router: router, loginResp: loginResp, @@ -441,7 +457,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeleteFailedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -453,7 +469,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assertAlertDeletedFromIP := func(ip string) { w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodDelete, "/v1/alerts", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodDelete, "/v1/alerts", strings.NewReader("")) AddAuthHeaders(req, loginResp) req.RemoteAddr = ip + ":1234" @@ -462,17 +478,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeleteFailedFromIP("4.3.2.1") assertAlertDeletedFromIP("1.2.3.4") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.0") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.1") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.255") - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") assertAlertDeletedFromIP("127.0.0.1") } diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index 883ff21298d..45c02c806e7 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -10,36 +11,83 @@ import ( ) func TestAPIKey(t *testing.T) { - router, config := NewAPITest(t) + ctx := context.Background() + router, config := NewAPITest(t, ctx) - APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) + APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig) // Login with empty token w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with invalid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", "a1b2c3d4e5f6") + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 403, w.Code) - assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.JSONEq(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with valid token w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "127.0.0.1:1234" router.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "null", w.Body.String()) + + // Login with valid token from another IP + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "4.3.2.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Make the requests multiple times to make sure we only create one + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "4.3.2.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Use the original bouncer again + w = httptest.NewRecorder() + req, _ = http.NewRequestWithContext(ctx, http.MethodGet, "/v1/decisions", strings.NewReader("")) + req.Header.Add("User-Agent", UserAgent) + req.Header.Add("X-Api-Key", APIKey) + req.RemoteAddr = "127.0.0.1:1234" + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "null", w.Body.String()) + + // Check if our second bouncer was properly created + bouncers := GetBouncers(t, config.API.Server.DbConfig) + + assert.Len(t, bouncers, 2) + assert.Equal(t, "test@4.3.2.1", bouncers[1].Name) + assert.Equal(t, bouncers[0].APIKey, bouncers[1].APIKey) + assert.Equal(t, bouncers[0].AuthType, bouncers[1].AuthType) + assert.False(t, bouncers[0].AutoCreated) + assert.True(t, bouncers[1].AutoCreated) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 5b850cbff0d..51a85b1ea23 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -23,7 +23,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" @@ -70,6 +69,10 @@ type apic struct { consoleConfig *csconfig.ConsoleConfig isPulling chan bool whitelists *csconfig.CapiWhitelist + + pullBlocklists bool + pullCommunity bool + shareSignals bool } // randomDuration returns a duration value between d-delta and d+delta @@ -83,10 +86,10 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration { return ret } -func (a *apic) FetchScenariosListFromDB() ([]string, error) { +func (a *apic) FetchScenariosListFromDB(ctx context.Context) ([]string, error) { scenarios := make([]string, 0) - machines, err := a.dbClient.ListMachines() + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, fmt.Errorf("while listing machines: %w", err) } @@ -175,7 +178,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool) return signal } -func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { +func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { var err error ret := &apic{ @@ -199,6 +202,9 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con usageMetricsIntervalFirst: randomDuration(usageMetricsInterval, usageMetricsIntervalDelta), isPulling: make(chan bool, 1), whitelists: apicWhitelist, + pullBlocklists: *config.PullConfig.Blocklists, + pullCommunity: *config.PullConfig.Community, + shareSignals: *config.Sharing, } password := strfmt.Password(config.Credentials.Password) @@ -213,7 +219,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err) } - ret.scenarioList, err = ret.FetchScenariosListFromDB() + ret.scenarioList, err = ret.FetchScenariosListFromDB(ctx) if err != nil { return nil, fmt.Errorf("while fetching scenarios from db: %w", err) } @@ -221,7 +227,6 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con ret.apiClient, err = apiclient.NewClient(&apiclient.Config{ MachineID: config.Credentials.Login, Password: password, - UserAgent: cwversion.UserAgent(), URL: apiURL, PapiURL: papiURL, VersionPrefix: "v3", @@ -234,12 +239,12 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con // The watcher will be authenticated by the RoundTripper the first time it will call CAPI // Explicit authentication will provoke a useless supplementary call to CAPI - scenarios, err := ret.FetchScenariosListFromDB() + scenarios, err := ret.FetchScenariosListFromDB(ctx) if err != nil { return ret, fmt.Errorf("get scenario in db: %w", err) } - authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(ctx, models.WatcherAuthRequest{ MachineID: &config.Credentials.Login, Password: &password, Scenarios: scenarios, @@ -258,7 +263,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con } // keep track of all alerts in cache and push it to CAPI every PushInterval. -func (a *apic) Push() error { +func (a *apic) Push(ctx context.Context) error { defer trace.CatchPanic("lapi/pushToAPIC") var cache models.AddSignalsRequest @@ -278,7 +283,7 @@ func (a *apic) Push() error { return nil } - go a.Send(&cache) + go a.Send(ctx, &cache) return nil case <-ticker.C: @@ -291,13 +296,13 @@ func (a *apic) Push() error { a.mu.Unlock() log.Infof("Signal push: %d signals to push", len(cacheCopy)) - go a.Send(&cacheCopy) + go a.Send(ctx, &cacheCopy) } case alerts := <-a.AlertsAddChan: var signals []*models.AddSignalsRequestItem for _, alert := range alerts { - if ok := shouldShareAlert(alert, a.consoleConfig); ok { + if ok := shouldShareAlert(alert, a.consoleConfig, a.shareSignals); ok { signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext)) } } @@ -326,7 +331,13 @@ func getScenarioTrustOfAlert(alert *models.Alert) string { return scenarioTrust } -func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig) bool { +func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig, shareSignals bool) bool { + + if !shareSignals { + log.Debugf("sharing signals is disabled") + return false + } + if *alert.Simulated { log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID) return false @@ -353,7 +364,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig return true } -func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { +func (a *apic) Send(ctx context.Context, cacheOrig *models.AddSignalsRequest) { /*we do have a problem with this : The apic.Push background routine reads from alertToPush chan. This chan is filled by Controller.CreateAlert @@ -377,7 +388,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { for { if pageEnd >= len(cache) { send = cache[pageStart:] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() @@ -391,7 +402,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } send = cache[pageStart:pageEnd] - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() @@ -406,13 +417,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } } -func (a *apic) CAPIPullIsOld() (bool, error) { +func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) { /*only pull community blocklist if it's older than 1h30 */ alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -425,37 +436,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) { return true, nil } -func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) { - nbDeleted := 0 - - for _, decision := range deletedDecisions { - filter := map[string][]string{ - "value": {*decision.Value}, - "origin": {*decision.Origin}, - } - if strings.ToLower(*decision.Scope) != "ip" { - filter["type"] = []string{*decision.Type} - filter["scopes"] = []string{*decision.Scope} - } - - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) - if err != nil { - return 0, fmt.Errorf("expiring decisions error: %w", err) - } - - dbCliDel, err := strconv.Atoi(dbCliRet) - if err != nil { - return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) - } - - updateCounterForDecision(deleteCounters, decision.Origin, decision.Scenario, dbCliDel) - nbDeleted += dbCliDel - } - - return nbDeleted, nil -} - -func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { +func (a *apic) HandleDeletedDecisionsV3(ctx context.Context, deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { var nbDeleted int for _, decisions := range deletedDecisions { @@ -470,7 +451,7 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi filter["scopes"] = []string{*scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } @@ -616,7 +597,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop(forcePull bool) error { +func (a *apic) PullTop(ctx context.Context, forcePull bool) error { var err error // A mutex with TryLock would be a bit simpler @@ -631,7 +612,7 @@ func (a *apic) PullTop(forcePull bool) error { } if !forcePull { - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil { return err } else if !lastPullIsOld { return nil @@ -640,7 +621,7 @@ func (a *apic) PullTop(forcePull bool) error { log.Debug("Acquiring lock for pullCAPI") - err = a.dbClient.AcquirePullCAPILock() + err = a.dbClient.AcquirePullCAPILock(ctx) if a.dbClient.IsLocked(err) { log.Info("PullCAPI is already running, skipping") return nil @@ -650,14 +631,16 @@ func (a *apic) PullTop(forcePull bool) error { defer func() { log.Debug("Releasing lock for pullCAPI") - if err := a.dbClient.ReleasePullCAPILock(); err != nil { + if err := a.dbClient.ReleasePullCAPILock(ctx); err != nil { log.Errorf("while releasing lock: %v", err) } }() log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) + log.Debugf("Community pull: %t | Blocklist pull: %t", a.pullCommunity, a.pullBlocklists) + + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup, CommunityPull: a.pullCommunity, AdditionalPull: a.pullBlocklists}) if err != nil { return fmt.Errorf("get stream: %w", err) } @@ -675,34 +658,37 @@ func (a *apic) PullTop(forcePull bool) error { addCounters, deleteCounters := makeAddAndDeleteCounters() // process deleted decisions - nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters) + nbDeleted, err := a.HandleDeletedDecisionsV3(ctx, data.Deleted, deleteCounters) if err != nil { return err } log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) - if len(data.New) == 0 { - log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") - return nil - } + if len(data.New) > 0 { + // create one alert for community blocklist using the first decision + decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) + // apply APIC specific whitelists + decisions = a.ApplyApicWhitelists(decisions) - // create one alert for community blocklist using the first decision - decisions := a.apiClient.Decisions.GetDecisionsFromGroups(data.New) - // apply APIC specific whitelists - decisions = a.ApplyApicWhitelists(decisions) + alert := createAlertForDecision(decisions[0]) + alertsFromCapi := []*models.Alert{alert} + alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - alert := createAlertForDecision(decisions[0]) - alertsFromCapi := []*models.Alert{alert} - alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - - err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters) - if err != nil { - return fmt.Errorf("while saving alerts: %w", err) + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters) + if err != nil { + return fmt.Errorf("while saving alerts: %w", err) + } + } else { + if a.pullCommunity { + log.Info("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") + } else { + log.Debug("capi/community-blocklist : community blocklist pull is disabled") + } } // update blocklists - if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil { + if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } @@ -710,9 +696,9 @@ func (a *apic) PullTop(forcePull bool) error { } // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert -func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { +func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { addCounters, _ := makeAddAndDeleteCounters() - if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ + if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{blocklist}, }, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) @@ -765,7 +751,7 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis return decisions[:outIdx] } -func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { +func (a *apic) SaveAlerts(ctx context.Context, alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error { for _, alert := range alertsFromCapi { setAlertScenario(alert, addCounters, deleteCounters) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) @@ -774,7 +760,7 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") } - alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) + alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(ctx, alert) if err != nil { return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) } @@ -785,13 +771,13 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string return nil } -func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) { +func (a *apic) ShouldForcePullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) (bool, error) { // we should force pull if the blocklist decisions are about to expire or there's no decision in the db alertQuery := a.dbClient.Ent.Alert.Query() alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) - alertInstance, err := alertQuery.First(context.Background()) + alertInstance, err := alertQuery.First(ctx) if err != nil { if ent.IsNotFound(err) { log.Debugf("no alert found for %s, force refresh", *blocklist.Name) @@ -804,7 +790,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo decisionQuery := a.dbClient.Ent.Decision.Query() decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) - firstDecision, err := decisionQuery.First(context.Background()) + firstDecision, err := decisionQuery.First(ctx) if err != nil { if ent.IsNotFound(err) { log.Debugf("no decision found for %s, force refresh", *blocklist.Name) @@ -822,7 +808,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil @@ -834,7 +820,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap } if !forcePull { - _forcePull, err := a.ShouldForcePullBlocklist(blocklist) + _forcePull, err := a.ShouldForcePullBlocklist(ctx, blocklist) if err != nil { return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) } @@ -850,13 +836,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap ) if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } } - decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp) if err != nil { return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) } @@ -871,7 +857,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } @@ -886,7 +872,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap alertsFromCapi := []*models.Alert{alert} alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters) - err = a.SaveAlerts(alertsFromCapi, addCounters, nil) + err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, nil) if err != nil { return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err) } @@ -894,7 +880,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } @@ -910,7 +896,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil { + if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { return err } } @@ -933,13 +919,13 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int } } -func (a *apic) Pull() error { +func (a *apic) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false for { - scenario, err := a.FetchScenariosListFromDB() + scenario, err := a.FetchScenariosListFromDB(ctx) if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) } @@ -957,7 +943,7 @@ func (a *apic) Pull() error { time.Sleep(1 * time.Second) } - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -969,7 +955,7 @@ func (a *apic) Pull() error { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) continue } diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index 176984f1ad6..fe0dfd55821 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -23,22 +23,22 @@ type dbPayload struct { Metrics []*models.DetailedMetrics `json:"metrics"` } -func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { +func (a *apic) GetUsageMetrics(ctx context.Context) (*models.AllMetrics, []int, error) { allMetrics := &models.AllMetrics{} metricsIds := make([]int, 0) - lps, err := a.dbClient.ListMachines() + lps, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, nil, err } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, nil, err } for _, bouncer := range bouncers { - dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(bouncer.Name) + dbMetrics, err := a.dbClient.GetBouncerUsageMetricsByName(ctx, bouncer.Name) if err != nil { log.Errorf("unable to get bouncer usage metrics: %s", err) continue @@ -70,7 +70,7 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { err := json.Unmarshal([]byte(dbMetric.Payload), dbPayload) if err != nil { - log.Errorf("unable to unmarshal bouncer metric (%s)", err) + log.Errorf("unable to parse bouncer metric (%s)", err) continue } @@ -81,7 +81,7 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { } for _, lp := range lps { - dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(lp.MachineId) + dbMetrics, err := a.dbClient.GetLPUsageMetricsByMachineID(ctx, lp.MachineId) if err != nil { log.Errorf("unable to get LP usage metrics: %s", err) continue @@ -132,7 +132,7 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { err := json.Unmarshal([]byte(dbMetric.Payload), dbPayload) if err != nil { - log.Errorf("unable to unmarshal log processor metric (%s)", err) + log.Errorf("unable to parse log processor metric (%s)", err) continue } @@ -181,12 +181,12 @@ func (a *apic) GetUsageMetrics() (*models.AllMetrics, []int, error) { return allMetrics, metricsIds, nil } -func (a *apic) MarkUsageMetricsAsSent(ids []int) error { - return a.dbClient.MarkUsageMetricsAsSent(ids) +func (a *apic) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { + return a.dbClient.MarkUsageMetricsAsSent(ctx, ids) } -func (a *apic) GetMetrics() (*models.Metrics, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) GetMetrics(ctx context.Context) (*models.Metrics, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { } } - bouncers, err := a.dbClient.ListBouncers() + bouncers, err := a.dbClient.ListBouncers(ctx) if err != nil { return nil, err } @@ -230,8 +230,8 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { }, nil } -func (a *apic) fetchMachineIDs() ([]string, error) { - machines, err := a.dbClient.ListMachines() +func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) { + machines, err := a.dbClient.ListMachines(ctx) if err != nil { return nil, err } @@ -251,7 +251,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) { // Metrics are sent at start, then at the randomized metricsIntervalFirst, // then at regular metricsInterval. If a change is detected in the list // of machines, the next metrics are sent immediately. -func (a *apic) SendMetrics(stop chan (bool)) { +func (a *apic) SendMetrics(ctx context.Context, stop chan (bool)) { defer trace.CatchPanic("lapi/metricsToAPIC") // verify the list of machines every interval @@ -275,7 +275,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { machineIDs := []string{} reloadMachineIDs := func() { - ids, err := a.fetchMachineIDs() + ids, err := a.fetchMachineIDs(ctx) if err != nil { log.Debugf("unable to get machines (%s), will retry", err) @@ -311,7 +311,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { case <-metTicker.C: metTicker.Stop() - metrics, err := a.GetMetrics() + metrics, err := a.GetMetrics(ctx) if err != nil { log.Errorf("unable to get metrics (%s)", err) } @@ -319,7 +319,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { if metrics != nil { log.Info("capi metrics: sending") - _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) + _, _, err = a.apiClient.Metrics.Add(ctx, metrics) if err != nil { log.Errorf("capi metrics: failed: %s", err) } @@ -337,7 +337,7 @@ func (a *apic) SendMetrics(stop chan (bool)) { } } -func (a *apic) SendUsageMetrics() { +func (a *apic) SendUsageMetrics(ctx context.Context) { defer trace.CatchPanic("lapi/usageMetricsToAPIC") firstRun := true @@ -358,16 +358,21 @@ func (a *apic) SendUsageMetrics() { ticker.Reset(a.usageMetricsInterval) } - metrics, metricsId, err := a.GetUsageMetrics() + metrics, metricsId, err := a.GetUsageMetrics(ctx) if err != nil { log.Errorf("unable to get usage metrics: %s", err) continue } - _, resp, err := a.apiClient.UsageMetrics.Add(context.Background(), metrics) + _, resp, err := a.apiClient.UsageMetrics.Add(ctx, metrics) if err != nil { log.Errorf("unable to send usage metrics: %s", err) + if resp == nil || resp.Response == nil { + // Most likely a transient network error, it will be retried later + continue + } + if resp.Response.StatusCode >= http.StatusBadRequest && resp.Response.StatusCode != http.StatusUnprocessableEntity { // In case of 422, mark the metrics as sent anyway, the API did not like what we sent, // and it's unlikely we'll be able to fix it @@ -375,7 +380,7 @@ func (a *apic) SendUsageMetrics() { } } - err = a.MarkUsageMetricsAsSent(metricsId) + err = a.MarkUsageMetricsAsSent(ctx, metricsId) if err != nil { log.Errorf("unable to mark usage metrics as sent: %s", err) continue diff --git a/pkg/apiserver/apic_metrics_test.go b/pkg/apiserver/apic_metrics_test.go index d1e48ac90a3..d81af03f710 100644 --- a/pkg/apiserver/apic_metrics_test.go +++ b/pkg/apiserver/apic_metrics_test.go @@ -11,10 +11,11 @@ import ( "github.com/stretchr/testify/require" "github.com/crowdsecurity/crowdsec/pkg/apiclient" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" ) func TestAPICSendMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string duration time.Duration @@ -35,7 +36,7 @@ func TestAPICSendMetrics(t *testing.T) { metricsInterval: time.Millisecond * 20, expectedCalls: 5, setUp: func(api *apic) { - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) api.dbClient.Ent.Machine.Create(). SetMachineId("1234"). SetPassword(testPassword.String()). @@ -43,16 +44,16 @@ func TestAPICSendMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) api.dbClient.Ent.Bouncer.Create(). SetIPAddress("1.2.3.6"). SetName("someBouncer"). SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) }, }, } @@ -70,12 +71,12 @@ func TestAPICSendMetrics(t *testing.T) { apiClient, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) - api := getAPIC(t) + api := getAPIC(t, ctx) api.pushInterval = time.Millisecond api.pushIntervalFirst = time.Millisecond api.apiClient = apiClient @@ -87,7 +88,7 @@ func TestAPICSendMetrics(t *testing.T) { httpmock.ZeroCallCounters() - go api.SendMetrics(stop) + go api.SendMetrics(ctx, stop) time.Sleep(tc.duration) stop <- true diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 546a236251f..a8fbb40c4fa 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -26,7 +26,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" @@ -35,11 +34,9 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -func getDBClient(t *testing.T) *database.Client { +func getDBClient(t *testing.T, ctx context.Context) *database.Client { t.Helper() - ctx := context.Background() - dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{ @@ -52,9 +49,9 @@ func getDBClient(t *testing.T) *database.Client { return dbClient } -func getAPIC(t *testing.T) *apic { +func getAPIC(t *testing.T, ctx context.Context) *apic { t.Helper() - dbClient := getDBClient(t) + dbClient := getDBClient(t, ctx) return &apic{ AlertsAddChan: make(chan []*models.Alert), @@ -72,7 +69,10 @@ func getAPIC(t *testing.T) *apic { ShareCustomScenarios: ptr.Of(false), ShareContext: ptr.Of(false), }, - isPulling: make(chan bool, 1), + isPulling: make(chan bool, 1), + shareSignals: true, + pullBlocklists: true, + pullCommunity: true, } } @@ -85,8 +85,8 @@ func absDiff(a int, b int) int { return c } -func assertTotalDecisionCount(t *testing.T, dbClient *database.Client, count int) { - d := dbClient.Ent.Decision.Query().AllX(context.Background()) +func assertTotalDecisionCount(t *testing.T, ctx context.Context, dbClient *database.Client, count int) { + d := dbClient.Ent.Decision.Query().AllX(ctx) assert.Len(t, d, count) } @@ -112,9 +112,10 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { } func TestAPICCAPIPullIsOld(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) - isOld, err := api.CAPIPullIsOld() + isOld, err := api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.True(t, isOld) @@ -125,7 +126,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SetScope("Country"). SetValue("Blah"). SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) + SaveX(ctx) api.dbClient.Ent.Alert.Create(). SetCreatedAt(time.Now()). @@ -133,15 +134,17 @@ func TestAPICCAPIPullIsOld(t *testing.T) { AddDecisions( decision, ). - SaveX(context.Background()) + SaveX(ctx) - isOld, err = api.CAPIPullIsOld() + isOld, err = api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.False(t, isOld) } func TestAPICFetchScenariosListFromDB(t *testing.T) { + ctx := context.Background() + tests := []struct { name string machineIDsWithScenarios map[string]string @@ -166,21 +169,21 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - api := getAPIC(t) + api := getAPIC(t, ctx) for machineID, scenarios := range tc.machineIDsWithScenarios { api.dbClient.Ent.Machine.Create(). SetMachineId(machineID). SetPassword(testPassword.String()). SetIpAddress("1.2.3.4"). SetScenarios(scenarios). - ExecX(context.Background()) + ExecX(ctx) } - scenarios, err := api.FetchScenariosListFromDB() + scenarios, err := api.FetchScenariosListFromDB(ctx) require.NoError(t, err) for machineID := range tc.machineIDsWithScenarios { - api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background()) + api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(ctx) } assert.ElementsMatch(t, tc.expectedScenarios, scenarios) @@ -189,6 +192,8 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { } func TestNewAPIC(t *testing.T) { + ctx := context.Background() + var testConfig *csconfig.OnlineApiClientCfg setConfig := func() { @@ -198,6 +203,11 @@ func TestNewAPIC(t *testing.T) { Login: "foo", Password: "bar", }, + Sharing: ptr.Of(true), + PullConfig: csconfig.CapiPullConfig{ + Community: ptr.Of(true), + Blocklists: ptr.Of(true), + }, } } @@ -216,7 +226,7 @@ func TestNewAPIC(t *testing.T) { name: "simple", action: func() {}, args: args{ - dbClient: getDBClient(t), + dbClient: getDBClient(t, ctx), consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig, }, }, @@ -224,7 +234,7 @@ func TestNewAPIC(t *testing.T) { name: "error in parsing URL", action: func() { testConfig.Credentials.URL = "foobar http://" }, args: args{ - dbClient: getDBClient(t), + dbClient: getDBClient(t, ctx), consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig, }, expectedErr: "first path segment in URL cannot contain colon", @@ -247,53 +257,18 @@ func TestNewAPIC(t *testing.T) { ), )) tc.action() - _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) + _, err := NewAPIC(ctx, testConfig, tc.args.dbClient, tc.args.consoleConfig, nil) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } -func TestAPICHandleDeletedDecisions(t *testing.T) { - api := getAPIC(t) - _, deleteCounters := makeAddAndDeleteCounters() - - decision1 := api.dbClient.Ent.Decision.Create(). - SetUntil(time.Now().Add(time.Hour)). - SetScenario("crowdsec/test"). - SetType("ban"). - SetScope("IP"). - SetValue("1.2.3.4"). - SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) - - api.dbClient.Ent.Decision.Create(). - SetUntil(time.Now().Add(time.Hour)). - SetScenario("crowdsec/test"). - SetType("ban"). - SetScope("IP"). - SetValue("1.2.3.4"). - SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) - - assertTotalDecisionCount(t, api.dbClient, 2) - - nbDeleted, err := api.HandleDeletedDecisions([]*models.Decision{{ - Value: ptr.Of("1.2.3.4"), - Origin: ptr.Of(types.CAPIOrigin), - Type: &decision1.Type, - Scenario: ptr.Of("crowdsec/test"), - Scope: ptr.Of("IP"), - }}, deleteCounters) - - require.NoError(t, err) - assert.Equal(t, 2, nbDeleted) - assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"]) -} - func TestAPICGetMetrics(t *testing.T) { + ctx := context.Background() + cleanUp := func(api *apic) { - api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) - api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(ctx) + api.dbClient.Ent.Machine.Delete().ExecX(ctx) } tests := []struct { name string @@ -352,7 +327,7 @@ func TestAPICGetMetrics(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - apiClient := getAPIC(t) + apiClient := getAPIC(t, ctx) cleanUp(apiClient) for i, machineID := range tc.machineIDs { @@ -363,7 +338,7 @@ func TestAPICGetMetrics(t *testing.T) { SetScenarios("crowdsecurity/test"). SetLastPush(time.Time{}). SetUpdatedAt(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) } for i, bouncerName := range tc.bouncers { @@ -373,10 +348,10 @@ func TestAPICGetMetrics(t *testing.T) { SetAPIKey("foobar"). SetRevoked(false). SetLastPull(time.Time{}). - ExecX(context.Background()) + ExecX(ctx) } - foundMetrics, err := apiClient.GetMetrics() + foundMetrics, err := apiClient.GetMetrics(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) @@ -547,7 +522,8 @@ func TestFillAlertsWithDecisions(t *testing.T) { } func TestAPICWhitelists(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) // one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4")) @@ -570,7 +546,7 @@ func TestAPICWhitelists(t *testing.T) { SetScenario("crowdsecurity/ssh-bf"). SetUntil(time.Now().Add(time.Hour)). ExecX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 1) + assertTotalDecisionCount(t, ctx, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() @@ -676,16 +652,16 @@ func TestAPICWhitelists(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing + assertTotalDecisionCount(t, ctx, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) @@ -733,7 +709,8 @@ func TestAPICWhitelists(t *testing.T) { } func TestAPICPullTop(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). SetType("ban"). @@ -741,8 +718,8 @@ func TestAPICPullTop(t *testing.T) { SetScope("Ip"). SetScenario("crowdsecurity/ssh-bf"). SetUntil(time.Now().Add(time.Hour)). - ExecX(context.Background()) - assertTotalDecisionCount(t, api.dbClient, 1) + ExecX(ctx) + assertTotalDecisionCount(t, ctx, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1) httpmock.Activate() @@ -817,23 +794,22 @@ func TestAPICPullTop(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - assertTotalDecisionCount(t, api.dbClient, 5) + assertTotalDecisionCount(t, ctx, api.dbClient, 5) assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list. alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) validDecisions := api.dbClient.Ent.Decision.Query().Where( decision.UntilGT(time.Now())). - AllX(context.Background(), - ) + AllX(context.Background()) decisionScenarioFreq := make(map[string]int) alertScenario := make(map[string]int) @@ -858,8 +834,9 @@ func TestAPICPullTop(t *testing.T) { } func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + ctx := context.Background() // no decision in db, no last modified parameter. - api := getAPIC(t) + api := getAPIC(t, ctx) httpmock.Activate() defer httpmock.DeactivateAndReset() @@ -905,17 +882,17 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -925,15 +902,16 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) httpmock.Activate() defer httpmock.DeactivateAndReset() @@ -997,18 +975,19 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) } func TestAPICPullBlocklistCall(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) httpmock.Activate() defer httpmock.DeactivateAndReset() @@ -1024,13 +1003,13 @@ func TestAPICPullBlocklistCall(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) api.apiClient = apic - err = api.PullBlocklist(&modelscapi.BlocklistLink{ + err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{ URL: ptr.Of("http://api.crowdsec.net/blocklist1"), Name: ptr.Of("blocklist1"), Scope: ptr.Of("Ip"), @@ -1041,6 +1020,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { } func TestAPICPush(t *testing.T) { + ctx := context.Background() tests := []struct { name string alerts []*models.Alert @@ -1093,9 +1073,8 @@ func TestAPICPush(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { - api := getAPIC(t) + api := getAPIC(t, ctx) api.pushInterval = time.Millisecond api.pushIntervalFirst = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") @@ -1107,7 +1086,7 @@ func TestAPICPush(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) @@ -1116,14 +1095,16 @@ func TestAPICPush(t *testing.T) { httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{})) + // capture the alerts to avoid datarace + alerts := tc.alerts go func() { - api.AlertsAddChan <- tc.alerts + api.AlertsAddChan <- alerts time.Sleep(time.Second) api.Shutdown() }() - err = api.Push() + err = api.Push(ctx) require.NoError(t, err) assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount()) }) @@ -1131,7 +1112,8 @@ func TestAPICPush(t *testing.T) { } func TestAPICPull(t *testing.T) { - api := getAPIC(t) + ctx := context.Background() + api := getAPIC(t, ctx) tests := []struct { name string setUp func() @@ -1159,7 +1141,7 @@ func TestAPICPull(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - api = getAPIC(t) + api = getAPIC(t, ctx) api.pullInterval = time.Millisecond api.pullIntervalFirst = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") @@ -1171,7 +1153,7 @@ func TestAPICPull(t *testing.T) { apic, err := apiclient.NewDefaultClient( url, "/api", - cwversion.UserAgent(), + "", nil, ) require.NoError(t, err) @@ -1201,7 +1183,7 @@ func TestAPICPull(t *testing.T) { go func() { logrus.SetOutput(&buf) - if err := api.Pull(); err != nil { + if err := api.Pull(ctx); err != nil { panic(err) } }() @@ -1210,7 +1192,7 @@ func TestAPICPull(t *testing.T) { time.Sleep(time.Millisecond * 500) logrus.SetOutput(os.Stderr) assert.Contains(t, buf.String(), tc.logContains) - assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount) + assertTotalDecisionCount(t, ctx, api.dbClient, tc.expectedDecisionCount) }) } } @@ -1219,6 +1201,7 @@ func TestShouldShareAlert(t *testing.T) { tests := []struct { name string consoleConfig *csconfig.ConsoleConfig + shareSignals bool alert *models.Alert expectedRet bool expectedTrust string @@ -1229,6 +1212,7 @@ func TestShouldShareAlert(t *testing.T) { ShareCustomScenarios: ptr.Of(true), }, alert: &models.Alert{Simulated: ptr.Of(false)}, + shareSignals: true, expectedRet: true, expectedTrust: "custom", }, @@ -1238,6 +1222,7 @@ func TestShouldShareAlert(t *testing.T) { ShareCustomScenarios: ptr.Of(false), }, alert: &models.Alert{Simulated: ptr.Of(false)}, + shareSignals: true, expectedRet: false, expectedTrust: "custom", }, @@ -1246,6 +1231,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(true), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), Decisions: []*models.Decision{{Origin: ptr.Of(types.CscliOrigin)}}, @@ -1258,6 +1244,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(false), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), Decisions: []*models.Decision{{Origin: ptr.Of(types.CscliOrigin)}}, @@ -1270,6 +1257,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareTaintedScenarios: ptr.Of(true), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), ScenarioHash: ptr.Of("whateverHash"), @@ -1282,6 +1270,7 @@ func TestShouldShareAlert(t *testing.T) { consoleConfig: &csconfig.ConsoleConfig{ ShareTaintedScenarios: ptr.Of(false), }, + shareSignals: true, alert: &models.Alert{ Simulated: ptr.Of(false), ScenarioHash: ptr.Of("whateverHash"), @@ -1289,11 +1278,24 @@ func TestShouldShareAlert(t *testing.T) { expectedRet: false, expectedTrust: "tainted", }, + { + name: "manual alert should not be shared if global sharing is disabled", + consoleConfig: &csconfig.ConsoleConfig{ + ShareManualDecisions: ptr.Of(true), + }, + shareSignals: false, + alert: &models.Alert{ + Simulated: ptr.Of(false), + ScenarioHash: ptr.Of("whateverHash"), + }, + expectedRet: false, + expectedTrust: "manual", + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - ret := shouldShareAlert(tc.alert, tc.consoleConfig) + ret := shouldShareAlert(tc.alert, tc.consoleConfig, tc.shareSignals) assert.Equal(t, tc.expectedRet, ret) }) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 31b31bcb82d..05f9150b037 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -46,20 +46,11 @@ type APIServer struct { consoleConfig *csconfig.ConsoleConfig } -func recoverFromPanic(c *gin.Context) { - err := recover() - if err == nil { - return - } - - // Check for a broken connection, as it is not really a - // condition that warrants a panic stack trace. - brokenPipe := false - +func isBrokenConnection(err any) bool { if ne, ok := err.(*net.OpError); ok { if se, ok := ne.Err.(*os.SyscallError); ok { if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { - brokenPipe = true + return true } } } @@ -79,11 +70,22 @@ func recoverFromPanic(c *gin.Context) { errors.Is(strErr, errClosedBody) || errors.Is(strErr, errHandlerComplete) || errors.Is(strErr, errStreamClosed) { - brokenPipe = true + return true } } - if brokenPipe { + return false +} + +func recoverFromPanic(c *gin.Context) { + err := recover() + if err == nil { + return + } + + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + if isBrokenConnection(err) { log.Warningf("client %s disconnected: %s", c.ClientIP(), err) c.Abort() } else { @@ -159,18 +161,16 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro // NewServer creates a LAPI server. // It sets up a gin router, a database client, and a controller. -func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { +func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) { var flushScheduler *gocron.Scheduler - ctx := context.TODO() - dbClient, err := database.NewClient(ctx, config.DbConfig) if err != nil { return nil, fmt.Errorf("unable to init database client: %w", err) } if config.DbConfig.Flush != nil { - flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) + flushScheduler, err = dbClient.StartFlushScheduler(ctx, config.DbConfig.Flush) if err != nil { return nil, err } @@ -229,7 +229,6 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { controller := &controllers.Controller{ DBClient: dbClient, - Ectx: ctx, Router: router, Profiles: config.Profiles, Log: clog, @@ -249,7 +248,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { log.Printf("Loading CAPI manager") - apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) + apiClient, err = NewAPIC(ctx, config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { return nil, err } @@ -258,7 +257,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { controller.AlertsAddChan = apiClient.AlertsAddChan - if config.ConsoleConfig.IsPAPIEnabled() { + if config.ConsoleConfig.IsPAPIEnabled() && config.OnlineClient.Credentials.PapiURL != "" { if apiClient.apiClient.IsEnrolled() { log.Info("Machine is enrolled in the console, Loading PAPI Client") @@ -301,6 +300,72 @@ func (s *APIServer) Router() (*gin.Engine, error) { return s.router, nil } +func (s *APIServer) apicPush(ctx context.Context) error { + if err := s.apic.Push(ctx); err != nil { + log.Errorf("capi push: %s", err) + return err + } + + return nil +} + +func (s *APIServer) apicPull(ctx context.Context) error { + if err := s.apic.Pull(ctx); err != nil { + log.Errorf("capi pull: %s", err) + return err + } + + return nil +} + +func (s *APIServer) papiPull(ctx context.Context) error { + if err := s.papi.Pull(ctx); err != nil { + log.Errorf("papi pull: %s", err) + return err + } + + return nil +} + +func (s *APIServer) papiSync() error { + if err := s.papi.SyncDecisions(); err != nil { + log.Errorf("capi decisions sync: %s", err) + return err + } + + return nil +} + +func (s *APIServer) initAPIC(ctx context.Context) { + s.apic.pushTomb.Go(func() error { return s.apicPush(ctx) }) + s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) + + // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios + if s.apic.apiClient.IsEnrolled() { + if s.consoleConfig.IsPAPIEnabled() && s.papi != nil { + if s.papi.URL != "" { + log.Info("Starting PAPI decision receiver") + s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) }) + s.papi.syncTomb.Go(s.papiSync) + } else { + log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") + } + } else { + log.Warningf("Machine is not allowed to synchronize decisions, you can enable it with `cscli console enable console_management`") + } + } + + s.apic.metricsTomb.Go(func() error { + s.apic.SendMetrics(ctx, make(chan bool)) + return nil + }) + + s.apic.metricsTomb.Go(func() error { + s.apic.SendUsageMetrics(ctx) + return nil + }) +} + func (s *APIServer) Run(apiReady chan bool) error { defer trace.CatchPanic("lapi/runServer") @@ -315,64 +380,10 @@ func (s *APIServer) Run(apiReady chan bool) error { TLSConfig: tlsCfg, } - if s.apic != nil { - s.apic.pushTomb.Go(func() error { - if err := s.apic.Push(); err != nil { - log.Errorf("capi push: %s", err) - return err - } - - return nil - }) - - s.apic.pullTomb.Go(func() error { - if err := s.apic.Pull(); err != nil { - log.Errorf("capi pull: %s", err) - return err - } - - return nil - }) - - // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios - if s.apic.apiClient.IsEnrolled() { - if s.consoleConfig.IsPAPIEnabled() { - if s.papi.URL != "" { - log.Info("Starting PAPI decision receiver") - s.papi.pullTomb.Go(func() error { - if err := s.papi.Pull(); err != nil { - log.Errorf("papi pull: %s", err) - return err - } - - return nil - }) - - s.papi.syncTomb.Go(func() error { - if err := s.papi.SyncDecisions(); err != nil { - log.Errorf("capi decisions sync: %s", err) - return err - } - - return nil - }) - } else { - log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") - } - } else { - log.Warningf("Machine is not allowed to synchronize decisions, you can enable it with `cscli console enable console_management`") - } - } - - s.apic.metricsTomb.Go(func() error { - s.apic.SendMetrics(make(chan bool)) - return nil - }) + ctx := context.TODO() - s.apic.metricsTomb.Go(func() error { - s.apic.SendUsageMetrics() - return nil - }) + if s.apic != nil { + s.initAPIC(ctx) } s.httpServerTomb.Go(func() error { diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index f48791ebcb8..cf4c91dedda 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -3,7 +3,6 @@ package apiserver import ( "context" "encoding/json" - "fmt" "net/http" "net/http/httptest" "os" @@ -25,6 +24,7 @@ import ( middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -41,7 +41,7 @@ var ( MachineID: &testMachineID, Password: &testPassword, } - UserAgent = fmt.Sprintf("crowdsec-test/%s", version.Version) + UserAgent = "crowdsec-test/" + version.Version emptyBody = strings.NewReader("") ) @@ -63,6 +63,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config { } apiServerConfig := csconfig.LocalApiServerCfg{ ListenURI: "http://127.0.0.1:8080", + LogLevel: ptr.Of(log.DebugLevel), DbConfig: &dbconfig, ProfilesPath: "./tests/profiles.yaml", ConsoleConfig: &csconfig.ConsoleConfig{ @@ -135,12 +136,12 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { +func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) log.Printf("Creating new API server") @@ -149,8 +150,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { - apiServer, config := NewAPIServer(t) +func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t, ctx) err := apiServer.InitController() require.NoError(t, err) @@ -161,12 +162,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { +func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) require.NoError(t, err) err = apiServer.InitController() @@ -181,13 +182,11 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { return router, config } -func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { - ctx := context.Background() - +func ValidateMachine(t *testing.T, ctx context.Context, machineID string, config *csconfig.DatabaseCfg) { dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - err = dbClient.ValidateMachine(machineID) + err = dbClient.ValidateMachine(ctx, machineID) require.NoError(t, err) } @@ -197,7 +196,7 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) - machines, err := dbClient.ListMachines() + machines, err := dbClient.ListMachines(ctx) require.NoError(t, err) for _, machine := range machines { @@ -209,6 +208,18 @@ func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) return "" } +func GetBouncers(t *testing.T, config *csconfig.DatabaseCfg) []*ent.Bouncer { + ctx := context.Background() + + dbClient, err := database.NewClient(ctx, config) + require.NoError(t, err) + + bouncers, err := dbClient.ListBouncers(ctx) + require.NoError(t, err) + + return bouncers +} + func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader { alertContentBytes, err := os.ReadFile(path) require.NoError(t, err) @@ -270,7 +281,7 @@ func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map return response, resp.Code } -func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string { +func CreateTestMachine(t *testing.T, ctx context.Context, router *gin.Engine, token string) string { regReq := MachineTest regReq.RegistrationToken = token b, err := json.Marshal(regReq) @@ -279,56 +290,57 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) return body } -func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { - ctx := context.Background() - +func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) string { dbClient, err := database.NewClient(ctx, config) require.NoError(t, err) apiKey, err := middlewares.GenerateAPIKey(keyLength) require.NoError(t, err) - _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) + _, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType, false) require.NoError(t, err) return apiKey } func TestWithWrongDBConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) config.API.Server.DbConfig.Type = "test" - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'") assert.Nil(t, apiServer) } func TestWithWrongFlushConfig(t *testing.T) { + ctx := context.Background() config := LoadTestConfig(t) maxItems := -1 config.API.Server.DbConfig.Flush.MaxItems = &maxItems - apiServer, err := NewServer(config.API.Server) + apiServer, err := NewServer(ctx, config.API.Server) cstest.RequireErrorContains(t, err, "max_items can't be zero or negative") assert.Nil(t, apiServer) } func TestUnknownPath(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil) req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) } /* @@ -347,6 +359,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0 */ func TestLoggingDebugToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -368,7 +382,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") expectedLines := []string{"/test42"} cfg.LogLevel = ptr.Of(log.DebugLevel) @@ -376,15 +390,15 @@ func TestLoggingDebugToFileConfig(t *testing.T) { err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) // wait for the request to happen time.Sleep(500 * time.Millisecond) @@ -398,6 +412,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) { } func TestLoggingErrorToFileConfig(t *testing.T) { + ctx := context.Background() + /*declare settings*/ maxAge := "1h" flushConfig := csconfig.FlushDBCfg{ @@ -419,19 +435,19 @@ func TestLoggingErrorToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) + expectedFile := filepath.Join(tempDir, "crowdsec_api.log") cfg.LogLevel = ptr.Of(log.ErrorLevel) // Configure logging err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) require.NoError(t, err) - api, err := NewServer(&cfg) + api, err := NewServer(ctx, &cfg) require.NoError(t, err) require.NotNil(t, api) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodGet, "/test42", nil) + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index 29f02723b70..719bb231006 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -1,7 +1,6 @@ package controllers import ( - "context" "net" "net/http" "strings" @@ -18,7 +17,6 @@ import ( ) type Controller struct { - Ectx context.Context DBClient *database.Client Router *gin.Engine Profiles []*csconfig.ProfileCfg @@ -83,7 +81,6 @@ func (c *Controller) NewV1() error { v1Config := v1.ControllerV1Config{ DbClient: c.DBClient, - Ctx: c.Ectx, ProfilesCfg: c.Profiles, DecisionDeleteChan: c.DecisionDeleteChan, AlertsAddChan: c.AlertsAddChan, diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 82dc51d6879..d1f93228512 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -6,7 +6,6 @@ import ( "net" "net/http" "strconv" - "strings" "time" "github.com/gin-gonic/gin" @@ -64,7 +63,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { var Metas models.Meta if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { - log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) + log.Errorf("unable to parse events meta '%s' : %s", eventItem.Serialized, err) } outputAlert.Events = append(outputAlert.Events, &models.Event{ @@ -124,25 +123,11 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin } } -func normalizeScope(scope string) string { - switch strings.ToLower(scope) { - case "ip": - return types.Ip - case "range": - return types.Range - case "as": - return types.AS - case "country": - return types.Country - default: - return scope - } -} - // CreateAlert writes the alerts received in the body to the database func (c *Controller) CreateAlert(gctx *gin.Context) { var input models.AddAlertsRequest + ctx := gctx.Request.Context() machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { @@ -160,12 +145,12 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { for _, alert := range input { // normalize scope for alert.Source and decisions if alert.Source.Scope != nil { - *alert.Source.Scope = normalizeScope(*alert.Source.Scope) + *alert.Source.Scope = types.NormalizeScope(*alert.Source.Scope) } for _, decision := range alert.Decisions { if decision.Scope != nil { - *decision.Scope = normalizeScope(*decision.Scope) + *decision.Scope = types.NormalizeScope(*decision.Scope) } } @@ -255,7 +240,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { c.DBClient.CanFlush = false } - alerts, err := c.DBClient.CreateAlert(machineID, input) + alerts, err := c.DBClient.CreateAlert(ctx, machineID, input) c.DBClient.CanFlush = true if err != nil { @@ -277,7 +262,9 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { // FindAlerts: returns alerts from the database based on the specified filter func (c *Controller) FindAlerts(gctx *gin.Context) { - result, err := c.DBClient.QueryAlertWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + result, err := c.DBClient.QueryAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return @@ -295,15 +282,16 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { // FindAlertByID returns the alert associated with the ID func (c *Controller) FindAlertByID(gctx *gin.Context) { + ctx := gctx.Request.Context() alertIDStr := gctx.Param("alert_id") - alertID, err := strconv.Atoi(alertIDStr) + alertID, err := strconv.Atoi(alertIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } - result, err := c.DBClient.GetAlertByID(alertID) + result, err := c.DBClient.GetAlertByID(ctx, alertID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -323,6 +311,8 @@ func (c *Controller) FindAlertByID(gctx *gin.Context) { func (c *Controller) DeleteAlertByID(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) @@ -337,7 +327,7 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { return } - err = c.DBClient.DeleteAlertByID(decisionID) + err = c.DBClient.DeleteAlertByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) return @@ -350,13 +340,15 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { // DeleteAlerts deletes alerts from the database based on the specified filter func (c *Controller) DeleteAlerts(gctx *gin.Context) { + ctx := gctx.Request.Context() + incomingIP := gctx.ClientIP() if incomingIP != "127.0.0.1" && incomingIP != "::1" && !networksContainIP(c.TrustedIPs, incomingIP) && !isUnixSocket(gctx) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index 6de4abe3b3b..f8b6aa76ea5 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -1,7 +1,6 @@ package v1 import ( - "context" "fmt" "net" @@ -14,7 +13,6 @@ import ( ) type Controller struct { - Ectx context.Context DBClient *database.Client APIKeyHeader string Middlewares *middlewares.Middlewares @@ -31,7 +29,6 @@ type Controller struct { type ControllerV1Config struct { DbClient *database.Client - Ctx context.Context ProfilesCfg []*csconfig.ProfileCfg AlertsAddChan chan []*models.Alert @@ -52,7 +49,6 @@ func New(cfg *ControllerV1Config) (*Controller, error) { } v1 := &Controller{ - Ectx: cfg.Ctx, DBClient: cfg.DbClient, APIKeyHeader: middlewares.APIKeyHeader, Profiles: profiles, diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 3d8e0232224..ffefffc226b 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -1,8 +1,8 @@ package v1 import ( + "context" "encoding/json" - "fmt" "net/http" "strconv" "time" @@ -43,6 +43,8 @@ func (c *Controller) GetDecision(gctx *gin.Context) { data []*ent.Decision ) + ctx := gctx.Request.Context() + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) @@ -50,7 +52,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { return } - data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) + data, err = c.DBClient.QueryDecisionWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -73,7 +75,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute { - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil { log.Errorf("failed to update bouncer last pull: %v", err) } } @@ -91,7 +93,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { return } - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) @@ -113,7 +117,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -134,33 +140,38 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { gctx.JSON(http.StatusOK, deleteDecisionResp) } -func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(map[string][]string) ([]*ent.Decision, error)) error { +func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFunc func(context.Context, map[string][]string) ([]*ent.Decision, error)) error { // respBuffer := bytes.NewBuffer([]byte{}) - limit := 30000 //FIXME : make it configurable + limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - limitStr := fmt.Sprintf("%d", limit) + ctx := gctx.Request.Context() + + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) + lastIdStr := strconv.Itoa(lastId) filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(filters) + data, err := dbFunc(ctx, filters) if err != nil { return err } + if len(data) > 0 { lastId = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) if needComma { - //respBuffer.Write([]byte(",")) - gctx.Writer.Write([]byte(",")) + // respBuffer.Write([]byte(",")) + gctx.Writer.WriteString(",") } else { needComma = true } @@ -172,10 +183,12 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return err } - //respBuffer.Reset() + // respBuffer.Reset() } } + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + if len(data) < limit { gctx.Writer.Flush() @@ -186,33 +199,38 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun return nil } -func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(*time.Time, map[string][]string) ([]*ent.Decision, error)) error { - //respBuffer := bytes.NewBuffer([]byte{}) - limit := 30000 //FIXME : make it configurable +func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPull *time.Time, dbFunc func(context.Context, *time.Time, map[string][]string) ([]*ent.Decision, error)) error { + // respBuffer := bytes.NewBuffer([]byte{}) + limit := 30000 // FIXME : make it configurable needComma := false lastId := 0 - limitStr := fmt.Sprintf("%d", limit) + ctx := gctx.Request.Context() + + limitStr := strconv.Itoa(limit) filters["limit"] = []string{limitStr} + for { if lastId > 0 { - lastIdStr := fmt.Sprintf("%d", lastId) + lastIdStr := strconv.Itoa(lastId) filters["id_gt"] = []string{lastIdStr} } - data, err := dbFunc(lastPull, filters) + data, err := dbFunc(ctx, lastPull, filters) if err != nil { return err } + if len(data) > 0 { lastId = data[len(data)-1].ID + results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) if needComma { - //respBuffer.Write([]byte(",")) - gctx.Writer.Write([]byte(",")) + // respBuffer.Write([]byte(",")) + gctx.Writer.WriteString(",") } else { needComma = true } @@ -224,10 +242,12 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul return err } - //respBuffer.Reset() + // respBuffer.Reset() } } + log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) + if len(data) < limit { gctx.Writer.Flush() @@ -244,7 +264,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B gctx.Writer.Header().Set("Content-Type", "application/json") gctx.Writer.Header().Set("Transfer-Encoding", "chunked") gctx.Writer.WriteHeader(http.StatusOK) - gctx.Writer.Write([]byte(`{"new": [`)) //No need to check for errors, the doc says it always returns nil + gctx.Writer.WriteString(`{"new": [`) // No need to check for errors, the doc says it always returns nil // if the blocker just started, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { @@ -252,48 +272,47 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters) if err != nil { log.Errorf("failed sending new decisions for startup: %v", err) - gctx.Writer.Write([]byte(`], "deleted": []}`)) + gctx.Writer.WriteString(`], "deleted": []}`) gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`], "deleted": [`)) - //Expired decisions + gctx.Writer.WriteString(`], "deleted": [`) + // Expired decisions err = writeStartupDecisions(gctx, filters, c.DBClient.QueryExpiredDecisionsWithFilters) if err != nil { log.Errorf("failed sending expired decisions for startup: %v", err) - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString(`]}`) gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString(`]}`) gctx.Writer.Flush() } else { err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryNewDecisionsSinceWithFilters) if err != nil { log.Errorf("failed sending new decisions for delta: %v", err) - gctx.Writer.Write([]byte(`], "deleted": []}`)) + gctx.Writer.WriteString(`], "deleted": []}`) gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`], "deleted": [`)) + gctx.Writer.WriteString(`], "deleted": [`) err = writeDeltaDecisions(gctx, filters, bouncerInfo.LastPull, c.DBClient.QueryExpiredDecisionsSinceWithFilters) - if err != nil { log.Errorf("failed sending expired decisions for delta: %v", err) - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString("]}") gctx.Writer.Flush() return err } - gctx.Writer.Write([]byte(`]}`)) + gctx.Writer.WriteString("]}") gctx.Writer.Flush() } @@ -301,8 +320,12 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B } func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { - var data []*ent.Decision - var err error + var ( + data []*ent.Decision + err error + ) + + ctx := gctx.Request.Context() ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} @@ -310,18 +333,18 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if val, ok := gctx.Request.URL.Query()["startup"]; ok { if val[0] == "true" { - data, err = c.DBClient.QueryAllDecisionsWithFilters(filters) + data, err = c.DBClient.QueryAllDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters) + data, err = c.DBClient.QueryExpiredDecisionsWithFilters(ctx, filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -338,14 +361,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en } // getting new decisions - data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters) + data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(ctx, bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return err } - //data = KeepLongestDecision(data) + // data = KeepLongestDecision(data) ret["new"] = FormatDecisions(data) since := time.Time{} @@ -354,7 +377,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en } // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(&since, filters) // do we want to give exactly lastPull time ? + data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(ctx, &since, filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -371,6 +394,8 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en func (c *Controller) StreamDecision(gctx *gin.Context) { var err error + ctx := gctx.Request.Context() + streamStartTime := time.Now().UTC() bouncerInfo, err := getBouncerFromContext(gctx) @@ -381,8 +406,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } if gctx.Request.Method == http.MethodHead { - //For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db - //We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) + // For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db + // We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true) gctx.String(http.StatusOK, "") return @@ -400,8 +425,8 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } if err == nil { - //Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions - if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { + // Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions + if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) } } diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index e1231eaa9ec..799b736ccfe 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -9,7 +9,9 @@ import ( func (c *Controller) HeartBeat(gctx *gin.Context) { machineID, _ := getMachineIDFromContext(gctx) - if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { + ctx := gctx.Request.Context() + + if err := c.DBClient.UpdateMachineLastHeartBeat(ctx, machineID); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 0030f7d3b39..ff59e389cb1 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -46,6 +46,8 @@ func (c *Controller) shouldAutoRegister(token string, gctx *gin.Context) (bool, } func (c *Controller) CreateMachine(gctx *gin.Context) { + ctx := gctx.Request.Context() + var input models.WatcherRegistrationRequest if err := gctx.ShouldBindJSON(&input); err != nil { @@ -66,7 +68,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) { return } - if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { + if _, err := c.DBClient.CreateMachine(ctx, input.MachineID, input.Password, gctx.ClientIP(), autoRegister, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index ddb38512a11..4f6ee0986eb 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -68,7 +68,8 @@ func PrometheusBouncersHasEmptyDecision(c *gin.Context) { bouncer, _ := getBouncerFromContext(c) if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": bouncer.Name}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } @@ -76,7 +77,8 @@ func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { bouncer, _ := getBouncerFromContext(c) if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": bouncer.Name}).Inc() + "bouncer": bouncer.Name, + }).Inc() } } @@ -87,7 +89,8 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { LapiMachineHits.With(prometheus.Labels{ "machine": machineID, "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() } c.Next() @@ -101,7 +104,8 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { LapiBouncerHits.With(prometheus.Labels{ "bouncer": bouncer.Name, "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() } c.Next() @@ -114,7 +118,8 @@ func PrometheusMiddleware() gin.HandlerFunc { LapiRouteHits.With(prometheus.Labels{ "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() + "method": c.Request.Method, + }).Inc() c.Next() elapsed := time.Since(startTime) diff --git a/pkg/apiserver/controllers/v1/usagemetrics.go b/pkg/apiserver/controllers/v1/usagemetrics.go index 74f27bb6cf4..5b2c3e3b1a9 100644 --- a/pkg/apiserver/controllers/v1/usagemetrics.go +++ b/pkg/apiserver/controllers/v1/usagemetrics.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "encoding/json" "errors" "net/http" @@ -18,17 +19,15 @@ import ( ) // updateBaseMetrics updates the base metrics for a machine or bouncer -func (c *Controller) updateBaseMetrics(machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { +func (c *Controller) updateBaseMetrics(ctx context.Context, machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { switch { case machineID != "": - c.DBClient.MachineUpdateBaseMetrics(machineID, baseMetrics, hubItems, datasources) + return c.DBClient.MachineUpdateBaseMetrics(ctx, machineID, baseMetrics, hubItems, datasources) case bouncer != nil: - c.DBClient.BouncerUpdateBaseMetrics(bouncer.Name, bouncer.Type, baseMetrics) + return c.DBClient.BouncerUpdateBaseMetrics(ctx, bouncer.Name, bouncer.Type, baseMetrics) default: return errors.New("no machineID or bouncerName set") } - - return nil } // UsageMetrics receives metrics from log processors and remediation components @@ -172,7 +171,9 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { } } - err := c.updateBaseMetrics(machineID, bouncer, baseMetrics, hubItems, datasources) + ctx := gctx.Request.Context() + + err := c.updateBaseMetrics(ctx, machineID, bouncer, baseMetrics, hubItems, datasources) if err != nil { logger.Errorf("Failed to update base metrics: %s", err) c.HandleDBErrors(gctx, err) @@ -182,7 +183,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { jsonPayload, err := json.Marshal(payload) if err != nil { - logger.Errorf("Failed to marshal usage metrics: %s", err) + logger.Errorf("Failed to serialize usage metrics: %s", err) c.HandleDBErrors(gctx, err) return @@ -190,7 +191,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) { receivedAt := time.Now().UTC() - if _, err := c.DBClient.CreateMetric(generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { + if _, err := c.DBClient.CreateMetric(ctx, generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil { logger.Error(err) c.HandleDBErrors(gctx, err) diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index e4c9dda47ce..a0af6956443 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -12,82 +13,86 @@ const ( ) func TestDeleteDecisionRange(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) } func TestDeleteDecisionFilter(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } func TestDeleteDecisionFilterByScenario(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // delete by wrong scenario - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by scenario good - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) } func TestGetDecisionFilters(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_minibulk.json") // Get Decision - w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -101,7 +106,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : type filter - w = lapi.RecordResponse(t, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -118,7 +123,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : scope/value - w = lapi.RecordResponse(t, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -132,7 +137,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : ip filter - w = lapi.RecordResponse(t, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -145,7 +150,7 @@ func TestGetDecisionFilters(t *testing.T) { // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // Get decision : by range - w = lapi.RecordResponse(t, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -155,13 +160,14 @@ func TestGetDecisionFilters(t *testing.T) { } func TestGetDecision(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Get Decision - w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) @@ -180,51 +186,52 @@ func TestGetDecision(t *testing.T) { assert.Equal(t, int64(3), decisions[2].ID) // Get Decision with invalid filter. It should ignore this filter - w = lapi.RecordResponse(t, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) assert.Len(t, decisions, 3) } func TestDeleteDecisionByID(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") - //Have one alerts - w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // Have one alert + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with Invalid ID - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) assert.Equal(t, 400, w.Code) errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "decision_id must be valid integer", errResp["message"]) // Delete alert with ID that not exist - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) errResp, _ = readDecisionsErrorResp(t, w) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) - //Have one alerts - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // Have one alert + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with valid ID - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "1", resp.NbDeleted) - //Have one alert (because we delete an alert that has dup targets) - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // Have one alert (because we delete an alert that has dup targets) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -232,33 +239,35 @@ func TestDeleteDecisionByID(t *testing.T) { } func TestDeleteDecision(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) // Create Valid Alert - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Delete alert with Invalid filter - w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"]) // Delete all alert - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "3", resp.NbDeleted) } func TestStreamStartDecisionDedup(t *testing.T) { - //Ensure that at stream startup we only get the longest decision - lapi := SetupLAPITest(t) + ctx := context.Background() + // Ensure that at stream startup we only get the longest decision + lapi := SetupLAPITest(t, ctx) // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 - lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, ctx, "./tests/alert_sample.json") // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + w := lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -268,11 +277,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // Get Stream, we only get one decision (the longest one, id=2) - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -282,11 +291,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // And get the remaining decision (1) - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) @@ -296,11 +305,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete the last decision, we receive the delete order - w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, ctx, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - //and now we only get a deleted decision - w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + // and now we only get a deleted decision + w = lapi.RecordResponse(t, ctx, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions["deleted"], 1) diff --git a/pkg/apiserver/heartbeat_test.go b/pkg/apiserver/heartbeat_test.go index fbf01c7fb8e..db051566f75 100644 --- a/pkg/apiserver/heartbeat_test.go +++ b/pkg/apiserver/heartbeat_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "testing" @@ -8,11 +9,12 @@ import ( ) func TestHeartBeat(t *testing.T) { - lapi := SetupLAPITest(t) + ctx := context.Background() + lapi := SetupLAPITest(t, ctx) - w := lapi.RecordResponse(t, http.MethodGet, "/v1/heartbeat", emptyBody, "password") + w := lapi.RecordResponse(t, ctx, http.MethodGet, "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse(t, "POST", "/v1/heartbeat", emptyBody, "password") + w = lapi.RecordResponse(t, ctx, "POST", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 405, w.Code) } diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index aa6e84e416b..f6f51763975 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "net/http" "net/http/httptest" "strings" @@ -10,13 +11,14 @@ import ( ) func TestLogin(t *testing.T) { - router, config := NewAPITest(t) + ctx := context.Background() + router, config := NewAPITest(t, ctx) - body := CreateTestMachine(t, router, "") + body := CreateTestMachine(t, ctx, router, "") // Login with machine not validated yet w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -25,7 +27,7 @@ func TestLogin(t *testing.T) { // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -34,7 +36,7 @@ func TestLogin(t *testing.T) { // Login with invalid body w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -43,19 +45,19 @@ func TestLogin(t *testing.T) { // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) - //Validate machine - ValidateMachine(t, "test", config.API.Server.DbConfig) + // Validate machine + ValidateMachine(t, ctx, "test", config.API.Server.DbConfig) // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -64,7 +66,7 @@ func TestLogin(t *testing.T) { // Login with valid machine w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -74,7 +76,7 @@ func TestLogin(t *testing.T) { // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 041a6bee528..969f75707d6 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -14,11 +15,12 @@ import ( ) func TestCreateMachine(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) // Create machine with invalid format w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("test")) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader("test")) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -27,7 +29,7 @@ func TestCreateMachine(t *testing.T) { // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -41,7 +43,7 @@ func TestCreateMachine(t *testing.T) { body := string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -50,8 +52,10 @@ func TestCreateMachine(t *testing.T) { } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) + ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) router.TrustedPlatform = "X-Real-IP" + // Create machine b, err := json.Marshal(MachineTest) require.NoError(t, err) @@ -59,7 +63,7 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-Ip", "1.1.1.1") router.ServeHTTP(w, req) @@ -73,7 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config := NewAPITest(t) + ctx := context.Background() + router, config := NewAPITest(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -82,7 +87,7 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.Header.Add("X-Real-IP", "1.1.1.1") router.ServeHTTP(w, req) @@ -92,13 +97,14 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) - //For some reason, the IP is empty when running tests - //if no forwarded-for headers are present + // For some reason, the IP is empty when running tests + // if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config := NewAPITestForwardedFor(t) + ctx := context.Background() + router, config := NewAPITestForwardedFor(t, ctx) // Create machine b, err := json.Marshal(MachineTest) @@ -107,7 +113,7 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -116,23 +122,24 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) - //For some reason, the IP is empty when running tests - //if no forwarded-for headers are present + // For some reason, the IP is empty when running tests + // if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) - body := CreateTestMachine(t, router, "") + body := CreateTestMachine(t, ctx, router, "") w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) @@ -141,9 +148,10 @@ func TestCreateMachineAlreadyExist(t *testing.T) { } func TestAutoRegistration(t *testing.T) { - router, _ := NewAPITest(t) + ctx := context.Background() + router, _ := NewAPITest(t, ctx) - //Invalid registration token / valid source IP + // Invalid registration token / valid source IP regReq := MachineTest regReq.RegistrationToken = invalidRegistrationToken b, err := json.Marshal(regReq) @@ -152,14 +160,14 @@ func TestAutoRegistration(t *testing.T) { body := string(b) w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) assert.Equal(t, http.StatusUnauthorized, w.Code) - //Invalid registration token / invalid source IP + // Invalid registration token / invalid source IP regReq = MachineTest regReq.RegistrationToken = invalidRegistrationToken b, err = json.Marshal(regReq) @@ -168,14 +176,14 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "42.42.42.42:4242" router.ServeHTTP(w, req) assert.Equal(t, http.StatusUnauthorized, w.Code) - //valid registration token / invalid source IP + // valid registration token / invalid source IP regReq = MachineTest regReq.RegistrationToken = validRegistrationToken b, err = json.Marshal(regReq) @@ -184,14 +192,14 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "42.42.42.42:4242" router.ServeHTTP(w, req) assert.Equal(t, http.StatusUnauthorized, w.Code) - //Valid registration token / valid source IP + // Valid registration token / valid source IP regReq = MachineTest regReq.RegistrationToken = validRegistrationToken b, err = json.Marshal(regReq) @@ -200,14 +208,14 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) assert.Equal(t, http.StatusAccepted, w.Code) - //No token / valid source IP + // No token / valid source IP regReq = MachineTest regReq.MachineID = ptr.Of("test2") b, err = json.Marshal(regReq) @@ -216,7 +224,7 @@ func TestAutoRegistration(t *testing.T) { body = string(b) w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) + req, _ = http.NewRequestWithContext(ctx, http.MethodPost, "/v1/watchers", strings.NewReader(body)) req.Header.Add("User-Agent", UserAgent) req.RemoteAddr = "127.0.0.1:4242" router.ServeHTTP(w, req) diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index e822666db0f..3c154be4fab 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -64,6 +64,8 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + ctx := c.Request.Context() + extractedCN, err := a.TlsAuth.ValidateCert(c) if err != nil { logger.Warn(err) @@ -73,7 +75,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger = logger.WithField("cn", extractedCN) bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) + bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName) // This is likely not the proper way, but isNotFound does not seem to work if err != nil && strings.Contains(err.Error(), "bouncer not found") { @@ -87,7 +89,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType, true) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) return nil @@ -112,16 +114,69 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { return nil } + clientIP := c.ClientIP() + + ctx := c.Request.Context() + hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(hashStr) + // Appsec case, we only care if the key is valid + // No content is returned, no last_pull update or anything + if c.Request.Method == http.MethodHead { + bouncer, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) + if err != nil { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + return bouncer[0] + } + + // most common case, check if this specific bouncer exists + bouncer, err := a.DbClient.SelectBouncerWithIP(ctx, hashStr, clientIP) + if err != nil && !ent.IsNotFound(err) { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + + // We found the bouncer with key and IP, we can use it + if bouncer != nil { + if bouncer.AuthType != types.ApiKeyAuthType { + logger.Errorf("bouncer isn't allowed to auth by API key") + return nil + } + return bouncer + } + + // We didn't find the bouncer with key and IP, let's try to find it with the key only + bouncers, err := a.DbClient.SelectBouncers(ctx, hashStr, types.ApiKeyAuthType) if err != nil { logger.Errorf("while fetching bouncer info: %s", err) return nil } - if bouncer.AuthType != types.ApiKeyAuthType { - logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) + if len(bouncers) == 0 { + logger.Debugf("no bouncer found with this key") + return nil + } + + logger.Debugf("found %d bouncers with this key", len(bouncers)) + + // We only have one bouncer with this key and no IP + // This is the first request made by this bouncer, keep this one + if len(bouncers) == 1 && bouncers[0].IPAddress == "" { + return bouncers[0] + } + + // Bouncers are ordered by ID, first one *should* be the manually created one + // Can probably get a bit weird if the user deletes the manually created one + bouncerName := fmt.Sprintf("%s@%s", bouncers[0].Name, clientIP) + + logger.Infof("Creating bouncer %s", bouncerName) + + bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, clientIP, hashStr, types.ApiKeyAuthType, true) + + if err != nil { + logger.Errorf("while creating bouncer db entry: %s", err) return nil } @@ -132,6 +187,8 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { var bouncer *ent.Bouncer + ctx := c.Request.Context() + clientIP := c.ClientIP() logger := log.WithField("ip", clientIP) @@ -150,27 +207,20 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return } - logger = logger.WithField("name", bouncer.Name) - - if bouncer.IPAddress == "" { - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { - logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - - return - } + // Appsec request, return immediately if we found something + if c.Request.Method == http.MethodHead { + c.Set(BouncerContextKey, bouncer) + return } - // Don't update IP on HEAD request, as it's used by the appsec to check the validity of the API key provided - if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead { - log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress) + logger = logger.WithField("name", bouncer.Name) - if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil { + // 1st time we see this bouncer, we update its IP + if bouncer.IPAddress == "" { + if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() - return } } @@ -182,7 +232,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { + if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() diff --git a/pkg/apiserver/middlewares/v1/cache.go b/pkg/apiserver/middlewares/v1/cache.go index a058ec40393..b0037bc4fa4 100644 --- a/pkg/apiserver/middlewares/v1/cache.go +++ b/pkg/apiserver/middlewares/v1/cache.go @@ -9,7 +9,7 @@ import ( ) type cacheEntry struct { - err error // if nil, the certificate is not revocated + err error // if nil, the certificate is not revocated timestamp time.Time } diff --git a/pkg/apiserver/middlewares/v1/crl.go b/pkg/apiserver/middlewares/v1/crl.go index f85a410998e..64d7d3f0d96 100644 --- a/pkg/apiserver/middlewares/v1/crl.go +++ b/pkg/apiserver/middlewares/v1/crl.go @@ -12,13 +12,13 @@ import ( ) type CRLChecker struct { - path string // path to the CRL file - fileInfo os.FileInfo // last stat of the CRL file - crls []*x509.RevocationList // parsed CRLs + path string // path to the CRL file + fileInfo os.FileInfo // last stat of the CRL file + crls []*x509.RevocationList // parsed CRLs logger *log.Entry mu sync.RWMutex - lastLoad time.Time // time when the CRL file was last read successfully - onLoad func() // called when the CRL file changes (and is read successfully) + lastLoad time.Time // time when the CRL file was last read successfully + onLoad func() // called when the CRL file changes (and is read successfully) } func NewCRLChecker(crlPath string, onLoad func(), logger *log.Entry) (*CRLChecker, error) { diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 64406deff3e..9171e9fce06 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -55,6 +55,7 @@ type authInput struct { } func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { + ctx := c.Request.Context() ret := authInput{} if j.TlsAuth == nil { @@ -76,7 +77,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if ent.IsNotFound(err) { // Machine was not found, let's create it logger.Infof("machine %s not found, create it", ret.machineID) @@ -91,7 +92,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { password := strfmt.Password(pwd) - ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType) + ret.clientMachine, err = j.DbClient.CreateMachine(ctx, &ret.machineID, &password, "", true, true, types.TlsAuthType) if err != nil { return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err) } @@ -127,6 +128,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { err error ) + ctx := c.Request.Context() + ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { @@ -143,7 +146,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) return nil, err @@ -175,6 +178,8 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { auth *authInput ) + ctx := c.Request.Context() + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { auth, err = j.authTLS(c) if err != nil { @@ -198,7 +203,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { } } - err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineScenarios(ctx, scenarios, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -208,7 +213,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { clientIP := c.ClientIP() if auth.clientMachine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err) return nil, jwt.ErrFailedAuthentication @@ -218,7 +223,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" { log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress) - err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID) + err = j.DbClient.UpdateMachineIP(ctx, clientIP, auth.clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication @@ -231,7 +236,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil { + if err := j.DbClient.UpdateMachineVersion(ctx, useragent[1], auth.clientMachine.ID); err != nil { log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", clientIP) diff --git a/pkg/apiserver/middlewares/v1/ocsp.go b/pkg/apiserver/middlewares/v1/ocsp.go index 24557bfda7b..0b6406ad0e7 100644 --- a/pkg/apiserver/middlewares/v1/ocsp.go +++ b/pkg/apiserver/middlewares/v1/ocsp.go @@ -70,7 +70,7 @@ func (oc *OCSPChecker) query(server string, cert *x509.Certificate, issuer *x509 // It returns a boolean indicating if the certificate is revoked and a boolean indicating // if the OCSP check was successful and could be cached. func (oc *OCSPChecker) isRevokedBy(cert *x509.Certificate, issuer *x509.Certificate) (bool, bool) { - if cert.OCSPServer == nil || len(cert.OCSPServer) == 0 { + if len(cert.OCSPServer) == 0 { oc.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") return false, true } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 0d0fd0ecd42..83ba13843b9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -156,11 +156,11 @@ func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { return nil } -func (p *Papi) GetPermissions() (PapiPermCheckSuccess, error) { +func (p *Papi) GetPermissions(ctx context.Context) (PapiPermCheckSuccess, error) { httpClient := p.apiClient.GetClient() papiCheckUrl := fmt.Sprintf("%s%s%s", p.URL, types.PAPIVersion, types.PAPIPermissionsUrl) - req, err := http.NewRequest(http.MethodGet, papiCheckUrl, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, papiCheckUrl, nil) if err != nil { return PapiPermCheckSuccess{}, fmt.Errorf("failed to create request: %w", err) } @@ -205,8 +205,8 @@ func reverse(s []longpollclient.Event) []longpollclient.Event { return a } -func (p *Papi) PullOnce(since time.Time, sync bool) error { - events, err := p.Client.PullOnce(since) +func (p *Papi) PullOnce(ctx context.Context, since time.Time, sync bool) error { + events, err := p.Client.PullOnce(ctx, since) if err != nil { return err } @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { } // PullPAPI is the long polling client for real-time decisions from PAPI -func (p *Papi) Pull() error { +func (p *Papi) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/PullPAPI") p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } @@ -245,30 +245,30 @@ func (p *Papi) Pull() error { if lastTimestampStr == nil { binTime, err := lastTimestamp.MarshalText() if err != nil { - return fmt.Errorf("failed to marshal last timestamp: %w", err) + return fmt.Errorf("failed to serialize last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) } } else { if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil { - return fmt.Errorf("failed to unmarshal last timestamp: %w", err) + return fmt.Errorf("failed to parse last timestamp: %w", err) } } p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp) - for event := range p.Client.Start(lastTimestamp) { + for event := range p.Client.Start(ctx, lastTimestamp) { logger := p.Logger.WithField("request-id", event.RequestId) // update last timestamp in database newTime := time.Now().UTC() binTime, err := newTime.MarshalText() if err != nil { - return fmt.Errorf("failed to marshal last timestamp: %w", err) + return fmt.Errorf("failed to serialize last timestamp: %w", err) } err = p.handleEvent(event, false) @@ -277,7 +277,7 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index a1137161698..78f5dc9b0fe 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" @@ -42,6 +43,8 @@ type listUnsubscribe struct { } func DecisionCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "delete": data, err := json.Marshal(message.Data) @@ -64,7 +67,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err) } @@ -94,6 +97,8 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { } func AlertCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) @@ -152,7 +157,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } // use a different method: alert and/or decision might already be partially present in the database - _, err = p.DBClient.CreateOrUpdateAlert("", alert) + _, err = p.DBClient.CreateOrUpdateAlert(ctx, "", alert) if err != nil { log.Errorf("Failed to create alerts in DB: %s", err) } else { @@ -167,6 +172,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } func ManagementCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + if sync { p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil @@ -194,7 +201,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { filter["origin"] = []string{types.ListOrigin} filter["scenario"] = []string{unsubscribeMsg.Name} - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err) } @@ -215,17 +222,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } + ctx := context.TODO() + if forcePullMsg.Blocklist == nil { p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) + err = p.apic.PullTop(ctx, true) if err != nil { return fmt.Errorf("failed to force pull operation: %w", err) } } else { p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) - err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{ + err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ Name: &forcePullMsg.Blocklist.Name, URL: &forcePullMsg.Blocklist.Url, Remediation: &forcePullMsg.Blocklist.Remediation, diff --git a/pkg/apiserver/usage_metrics_test.go b/pkg/apiserver/usage_metrics_test.go index 41dd0ccdc2c..32aeb7d9a5a 100644 --- a/pkg/apiserver/usage_metrics_test.go +++ b/pkg/apiserver/usage_metrics_test.go @@ -13,6 +13,8 @@ import ( ) func TestLPMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string body string @@ -28,7 +30,7 @@ func TestLPMetrics(t *testing.T) { name: "empty metrics for LP", body: `{ }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing log processor data", authType: PASSWORD, }, @@ -48,7 +50,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -72,7 +74,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -96,7 +98,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing remediation component data", authType: APIKEY, }, @@ -115,7 +117,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedResponse: "", expectedMetricsCount: 1, expectedFeatureFlags: "a,b,c", @@ -136,7 +138,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 422, + expectedStatusCode: http.StatusUnprocessableEntity, expectedResponse: "log_processors.0.datasources in body is required", authType: PASSWORD, }, @@ -155,7 +157,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedOSName: "foo", expectedOSVersion: "42", @@ -177,7 +179,7 @@ func TestLPMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 422, + expectedStatusCode: http.StatusUnprocessableEntity, expectedResponse: "log_processors.0.os.name in body is required", authType: PASSWORD, }, @@ -185,20 +187,20 @@ func TestLPMetrics(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - lapi := SetupLAPITest(t) + lapi := SetupLAPITest(t, ctx) - dbClient, err := database.NewClient(context.Background(), lapi.DBConfig) + dbClient, err := database.NewClient(ctx, lapi.DBConfig) if err != nil { t.Fatalf("unable to create database client: %s", err) } - w := lapi.RecordResponse(t, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) assert.Equal(t, tt.expectedStatusCode, w.Code) assert.Contains(t, w.Body.String(), tt.expectedResponse) - machine, _ := dbClient.QueryMachineByID("test") - metrics, _ := dbClient.GetLPUsageMetricsByMachineID("test") + machine, _ := dbClient.QueryMachineByID(ctx, "test") + metrics, _ := dbClient.GetLPUsageMetricsByMachineID(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) assert.Equal(t, tt.expectedOSName, machine.Osname) @@ -214,6 +216,8 @@ func TestLPMetrics(t *testing.T) { } func TestRCMetrics(t *testing.T) { + ctx := context.Background() + tests := []struct { name string body string @@ -229,7 +233,7 @@ func TestRCMetrics(t *testing.T) { name: "empty metrics for RC", body: `{ }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing remediation component data", authType: APIKEY, }, @@ -247,7 +251,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -269,7 +273,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedResponse: "", expectedOSName: "foo", @@ -291,7 +295,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 400, + expectedStatusCode: http.StatusBadRequest, expectedResponse: "Missing log processor data", authType: PASSWORD, }, @@ -308,7 +312,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedResponse: "", expectedMetricsCount: 1, expectedFeatureFlags: "a,b,c", @@ -327,7 +331,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 201, + expectedStatusCode: http.StatusCreated, expectedMetricsCount: 1, expectedOSName: "foo", expectedOSVersion: "42", @@ -347,7 +351,7 @@ func TestRCMetrics(t *testing.T) { } ] }`, - expectedStatusCode: 422, + expectedStatusCode: http.StatusUnprocessableEntity, expectedResponse: "remediation_components.0.os.name in body is required", authType: APIKEY, }, @@ -355,20 +359,20 @@ func TestRCMetrics(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - lapi := SetupLAPITest(t) + lapi := SetupLAPITest(t, ctx) - dbClient, err := database.NewClient(context.Background(), lapi.DBConfig) + dbClient, err := database.NewClient(ctx, lapi.DBConfig) if err != nil { t.Fatalf("unable to create database client: %s", err) } - w := lapi.RecordResponse(t, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) + w := lapi.RecordResponse(t, ctx, http.MethodPost, "/v1/usage-metrics", strings.NewReader(tt.body), tt.authType) assert.Equal(t, tt.expectedStatusCode, w.Code) assert.Contains(t, w.Body.String(), tt.expectedResponse) - bouncer, _ := dbClient.SelectBouncerByName("test") - metrics, _ := dbClient.GetBouncerUsageMetricsByName("test") + bouncer, _ := dbClient.SelectBouncerByName(ctx, "test") + metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test") assert.Len(t, metrics, tt.expectedMetricsCount) assert.Equal(t, tt.expectedOSName, bouncer.Osname) diff --git a/pkg/appsec/appsec.go b/pkg/appsec/appsec.go index 96f977b4738..553db205b5d 100644 --- a/pkg/appsec/appsec.go +++ b/pkg/appsec/appsec.go @@ -1,7 +1,6 @@ package appsec import ( - "errors" "fmt" "net/http" "os" @@ -40,7 +39,6 @@ const ( ) func (h *Hook) Build(hookStage int) error { - ctx := map[string]interface{}{} switch hookStage { case hookOnLoad: @@ -54,7 +52,7 @@ func (h *Hook) Build(hookStage int) error { } opts := exprhelpers.GetExprOptions(ctx) if h.Filter != "" { - program, err := expr.Compile(h.Filter, opts...) //FIXME: opts + program, err := expr.Compile(h.Filter, opts...) // FIXME: opts if err != nil { return fmt.Errorf("unable to compile filter %s : %w", h.Filter, err) } @@ -73,11 +71,11 @@ func (h *Hook) Build(hookStage int) error { type AppsecTempResponse struct { InBandInterrupt bool OutOfBandInterrupt bool - Action string //allow, deny, captcha, log - UserHTTPResponseCode int //The response code to send to the user - BouncerHTTPResponseCode int //The response code to send to the remediation component - SendEvent bool //do we send an internal event on rule match - SendAlert bool //do we send an alert on rule match + Action string // allow, deny, captcha, log + UserHTTPResponseCode int // The response code to send to the user + BouncerHTTPResponseCode int // The response code to send to the remediation component + SendEvent bool // do we send an internal event on rule match + SendAlert bool // do we send an alert on rule match } type AppsecSubEngineOpts struct { @@ -93,7 +91,7 @@ type AppsecRuntimeConfig struct { InBandRules []AppsecCollection DefaultRemediation string - RemediationByTag map[string]string //Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + RemediationByTag map[string]string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME RemediationById map[int]string CompiledOnLoad []Hook CompiledPreEval []Hook @@ -101,22 +99,22 @@ type AppsecRuntimeConfig struct { CompiledOnMatch []Hook CompiledVariablesTracking []*regexp.Regexp Config *AppsecConfig - //CorazaLogger debuglog.Logger + // CorazaLogger debuglog.Logger - //those are ephemeral, created/destroyed with every req - OutOfBandTx ExtendedTransaction //is it a good idea ? - InBandTx ExtendedTransaction //is it a good idea ? + // those are ephemeral, created/destroyed with every req + OutOfBandTx ExtendedTransaction // is it a good idea ? + InBandTx ExtendedTransaction // is it a good idea ? Response AppsecTempResponse - //should we store matched rules here ? + // should we store matched rules here ? Logger *log.Entry - //Set by on_load to ignore some rules on loading + // Set by on_load to ignore some rules on loading DisabledInBandRuleIds []int - DisabledInBandRulesTags []string //Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + DisabledInBandRulesTags []string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME DisabledOutOfBandRuleIds []int - DisabledOutOfBandRulesTags []string //Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME + DisabledOutOfBandRulesTags []string // Also used for ByName, as the name (for modsec rules) is a tag crowdsec-NAME } type AppsecConfig struct { @@ -125,10 +123,10 @@ type AppsecConfig struct { InBandRules []string `yaml:"inband_rules"` DefaultRemediation string `yaml:"default_remediation"` DefaultPassAction string `yaml:"default_pass_action"` - BouncerBlockedHTTPCode int `yaml:"blocked_http_code"` //returned to the bouncer - BouncerPassedHTTPCode int `yaml:"passed_http_code"` //returned to the bouncer - UserBlockedHTTPCode int `yaml:"user_blocked_http_code"` //returned to the user - UserPassedHTTPCode int `yaml:"user_passed_http_code"` //returned to the user + BouncerBlockedHTTPCode int `yaml:"blocked_http_code"` // returned to the bouncer + BouncerPassedHTTPCode int `yaml:"passed_http_code"` // returned to the bouncer + UserBlockedHTTPCode int `yaml:"user_blocked_http_code"` // returned to the user + UserPassedHTTPCode int `yaml:"user_passed_http_code"` // returned to the user OnLoad []Hook `yaml:"on_load"` PreEval []Hook `yaml:"pre_eval"` @@ -151,45 +149,95 @@ func (w *AppsecRuntimeConfig) ClearResponse() { w.Response.SendAlert = true } -func (wc *AppsecConfig) LoadByPath(file string) error { +func (wc *AppsecConfig) SetUpLogger() { + if wc.LogLevel == nil { + lvl := wc.Logger.Logger.GetLevel() + wc.LogLevel = &lvl + } + /* wc.Name is actually the datasource name.*/ + wc.Logger = wc.Logger.Dup().WithField("name", wc.Name) + wc.Logger.Logger.SetLevel(*wc.LogLevel) + +} + +func (wc *AppsecConfig) LoadByPath(file string) error { wc.Logger.Debugf("loading config %s", file) yamlFile, err := os.ReadFile(file) if err != nil { return fmt.Errorf("unable to read file %s : %s", file, err) } - err = yaml.UnmarshalStrict(yamlFile, wc) + + //as LoadByPath can be called several time, we append rules/hooks, but override other options + var tmp AppsecConfig + + err = yaml.UnmarshalStrict(yamlFile, &tmp) if err != nil { return fmt.Errorf("unable to parse yaml file %s : %s", file, err) } - if wc.Name == "" { - return errors.New("name cannot be empty") + if wc.Name == "" && tmp.Name != "" { + wc.Name = tmp.Name } - if wc.LogLevel == nil { - lvl := wc.Logger.Logger.GetLevel() - wc.LogLevel = &lvl + + //We can append rules/hooks + if tmp.OutOfBandRules != nil { + wc.OutOfBandRules = append(wc.OutOfBandRules, tmp.OutOfBandRules...) } - wc.Logger = wc.Logger.Dup().WithField("name", wc.Name) - wc.Logger.Logger.SetLevel(*wc.LogLevel) + if tmp.InBandRules != nil { + wc.InBandRules = append(wc.InBandRules, tmp.InBandRules...) + } + if tmp.OnLoad != nil { + wc.OnLoad = append(wc.OnLoad, tmp.OnLoad...) + } + if tmp.PreEval != nil { + wc.PreEval = append(wc.PreEval, tmp.PreEval...) + } + if tmp.PostEval != nil { + wc.PostEval = append(wc.PostEval, tmp.PostEval...) + } + if tmp.OnMatch != nil { + wc.OnMatch = append(wc.OnMatch, tmp.OnMatch...) + } + if tmp.VariablesTracking != nil { + wc.VariablesTracking = append(wc.VariablesTracking, tmp.VariablesTracking...) + } + + //override other options + wc.LogLevel = tmp.LogLevel + + wc.DefaultRemediation = tmp.DefaultRemediation + wc.DefaultPassAction = tmp.DefaultPassAction + wc.BouncerBlockedHTTPCode = tmp.BouncerBlockedHTTPCode + wc.BouncerPassedHTTPCode = tmp.BouncerPassedHTTPCode + wc.UserBlockedHTTPCode = tmp.UserBlockedHTTPCode + wc.UserPassedHTTPCode = tmp.UserPassedHTTPCode + + if tmp.InbandOptions.DisableBodyInspection { + wc.InbandOptions.DisableBodyInspection = true + } + if tmp.InbandOptions.RequestBodyInMemoryLimit != nil { + wc.InbandOptions.RequestBodyInMemoryLimit = tmp.InbandOptions.RequestBodyInMemoryLimit + } + if tmp.OutOfBandOptions.DisableBodyInspection { + wc.OutOfBandOptions.DisableBodyInspection = true + } + if tmp.OutOfBandOptions.RequestBodyInMemoryLimit != nil { + wc.OutOfBandOptions.RequestBodyInMemoryLimit = tmp.OutOfBandOptions.RequestBodyInMemoryLimit + } + return nil } func (wc *AppsecConfig) Load(configName string) error { - appsecConfigs := hub.GetItemMap(cwhub.APPSEC_CONFIGS) + item := hub.GetItem(cwhub.APPSEC_CONFIGS, configName) - for _, hubAppsecConfigItem := range appsecConfigs { - if !hubAppsecConfigItem.State.Installed { - continue - } - if hubAppsecConfigItem.Name != configName { - continue - } - wc.Logger.Infof("loading %s", hubAppsecConfigItem.State.LocalPath) - err := wc.LoadByPath(hubAppsecConfigItem.State.LocalPath) + if item != nil && item.State.Installed { + wc.Logger.Infof("loading %s", item.State.LocalPath) + err := wc.LoadByPath(item.State.LocalPath) if err != nil { - return fmt.Errorf("unable to load appsec-config %s : %s", hubAppsecConfigItem.State.LocalPath, err) + return fmt.Errorf("unable to load appsec-config %s : %s", item.State.LocalPath, err) } return nil } @@ -224,10 +272,10 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { wc.DefaultRemediation = BanRemediation } - //set the defaults + // set the defaults switch wc.DefaultRemediation { case BanRemediation, CaptchaRemediation, AllowRemediation: - //those are the officially supported remediation(s) + // those are the officially supported remediation(s) default: wc.Logger.Warningf("default '%s' remediation of %s is none of [%s,%s,%s] ensure bouncer compatbility!", wc.DefaultRemediation, wc.Name, BanRemediation, CaptchaRemediation, AllowRemediation) } @@ -237,7 +285,7 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { ret.DefaultRemediation = wc.DefaultRemediation wc.Logger.Tracef("Loading config %+v", wc) - //load rules + // load rules for _, rule := range wc.OutOfBandRules { wc.Logger.Infof("loading outofband rule %s", rule) collections, err := LoadCollection(rule, wc.Logger.WithField("component", "appsec_collection_loader")) @@ -259,7 +307,7 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { wc.Logger.Infof("Loaded %d inband rules", len(ret.InBandRules)) - //load hooks + // load hooks for _, hook := range wc.OnLoad { if hook.OnSuccess != "" && hook.OnSuccess != "continue" && hook.OnSuccess != "break" { return nil, fmt.Errorf("invalid 'on_success' for on_load hook : %s", hook.OnSuccess) @@ -304,7 +352,7 @@ func (wc *AppsecConfig) Build() (*AppsecRuntimeConfig, error) { ret.CompiledOnMatch = append(ret.CompiledOnMatch, hook) } - //variable tracking + // variable tracking for _, variable := range wc.VariablesTracking { compiledVariableRule, err := regexp.Compile(variable) if err != nil { @@ -460,7 +508,6 @@ func (w *AppsecRuntimeConfig) ProcessPostEvalRules(request *ParsedRequest) error // here means there is no filter or the filter matched for _, applyExpr := range rule.ApplyExpr { o, err := exprhelpers.Run(applyExpr, GetPostEvalEnv(w, request), w.Logger, w.Logger.Level >= log.DebugLevel) - if err != nil { w.Logger.Errorf("unable to apply appsec post_eval expr: %s", err) continue @@ -604,7 +651,7 @@ func (w *AppsecRuntimeConfig) SetActionByName(name string, action string) error } func (w *AppsecRuntimeConfig) SetAction(action string) error { - //log.Infof("setting to %s", action) + // log.Infof("setting to %s", action) w.Logger.Debugf("setting action to %s", action) w.Response.Action = action return nil @@ -628,7 +675,7 @@ func (w *AppsecRuntimeConfig) GenerateResponse(response AppsecTempResponse, logg if response.Action == AllowRemediation { resp.HTTPStatus = w.Config.UserPassedHTTPCode bouncerStatusCode = w.Config.BouncerPassedHTTPCode - } else { //ban, captcha and anything else + } else { // ban, captcha and anything else resp.HTTPStatus = response.UserHTTPResponseCode if resp.HTTPStatus == 0 { resp.HTTPStatus = w.Config.UserBlockedHTTPCode diff --git a/pkg/appsec/appsec_rules_collection.go b/pkg/appsec/appsec_rules_collection.go index 09c1670de70..d283f95cb19 100644 --- a/pkg/appsec/appsec_rules_collection.go +++ b/pkg/appsec/appsec_rules_collection.go @@ -29,11 +29,11 @@ type AppsecCollectionConfig struct { SecLangRules []string `yaml:"seclang_rules"` Rules []appsec_rule.CustomRule `yaml:"rules"` - Labels map[string]interface{} `yaml:"labels"` //Labels is K:V list aiming at providing context the overflow + Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow - Data interface{} `yaml:"data"` //Ignore it - hash string `yaml:"-"` - version string `yaml:"-"` + Data interface{} `yaml:"data"` // Ignore it + hash string + version string } type RulesDetails struct { @@ -108,7 +108,7 @@ func LoadCollection(pattern string, logger *log.Entry) ([]AppsecCollection, erro logger.Debugf("Adding rule %s", strRule) appsecCol.Rules = append(appsecCol.Rules, strRule) - //We only take the first id, as it's the one of the "main" rule + // We only take the first id, as it's the one of the "main" rule if _, ok := AppsecRulesDetails[int(rulesId[0])]; !ok { AppsecRulesDetails[int(rulesId[0])] = RulesDetails{ LogLevel: log.InfoLevel, diff --git a/pkg/appsec/loader.go b/pkg/appsec/loader.go index 56ec23e3671..c724010cec2 100644 --- a/pkg/appsec/loader.go +++ b/pkg/appsec/loader.go @@ -9,19 +9,15 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) -var appsecRules = make(map[string]AppsecCollectionConfig) //FIXME: would probably be better to have a struct for this +var appsecRules = make(map[string]AppsecCollectionConfig) // FIXME: would probably be better to have a struct for this -var hub *cwhub.Hub //FIXME: this is a temporary hack to make the hub available in the package +var hub *cwhub.Hub // FIXME: this is a temporary hack to make the hub available in the package func LoadAppsecRules(hubInstance *cwhub.Hub) error { hub = hubInstance appsecRules = make(map[string]AppsecCollectionConfig) - for _, hubAppsecRuleItem := range hub.GetItemMap(cwhub.APPSEC_RULES) { - if !hubAppsecRuleItem.State.Installed { - continue - } - + for _, hubAppsecRuleItem := range hub.GetInstalledByType(cwhub.APPSEC_RULES, false) { content, err := os.ReadFile(hubAppsecRuleItem.State.LocalPath) if err != nil { log.Warnf("unable to read file %s : %s", hubAppsecRuleItem.State.LocalPath, err) @@ -32,7 +28,7 @@ func LoadAppsecRules(hubInstance *cwhub.Hub) error { err = yaml.UnmarshalStrict(content, &rule) if err != nil { - log.Warnf("unable to unmarshal file %s : %s", hubAppsecRuleItem.State.LocalPath, err) + log.Warnf("unable to parse file %s : %s", hubAppsecRuleItem.State.LocalPath, err) continue } diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index 4a28b590e80..5f2f8f9248b 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -38,10 +38,17 @@ type ApiCredentialsCfg struct { CertPath string `yaml:"cert_path,omitempty"` } -/*global api config (for lapi->oapi)*/ +type CapiPullConfig struct { + Community *bool `yaml:"community,omitempty"` + Blocklists *bool `yaml:"blocklists,omitempty"` +} + +/*global api config (for lapi->capi)*/ type OnlineApiClientCfg struct { CredentialsFilePath string `yaml:"credentials_path,omitempty"` // credz will be edited by software, store in diff file Credentials *ApiCredentialsCfg `yaml:"-"` + PullConfig CapiPullConfig `yaml:"pull,omitempty"` + Sharing *bool `yaml:"sharing,omitempty"` } /*local api config (for crowdsec/cscli->lapi)*/ @@ -99,7 +106,7 @@ func (o *OnlineApiClientCfg) Load() error { err = dec.Decode(o.Credentials) if err != nil { if !errors.Is(err, io.EOF) { - return fmt.Errorf("failed unmarshaling api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + return fmt.Errorf("failed to parse api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) } } @@ -134,7 +141,7 @@ func (l *LocalApiClientCfg) Load() error { err = dec.Decode(&l.Credentials) if err != nil { if !errors.Is(err, io.EOF) { - return fmt.Errorf("failed unmarshaling api client credential configuration file '%s': %w", l.CredentialsFilePath, err) + return fmt.Errorf("failed to parse api client credential configuration file '%s': %w", l.CredentialsFilePath, err) } } @@ -344,6 +351,21 @@ func (c *Config) LoadAPIServer(inCli bool) error { log.Printf("push and pull to Central API disabled") } + //Set default values for CAPI push/pull + if c.API.Server.OnlineClient != nil { + if c.API.Server.OnlineClient.PullConfig.Community == nil { + c.API.Server.OnlineClient.PullConfig.Community = ptr.Of(true) + } + + if c.API.Server.OnlineClient.PullConfig.Blocklists == nil { + c.API.Server.OnlineClient.PullConfig.Blocklists = ptr.Of(true) + } + + if c.API.Server.OnlineClient.Sharing == nil { + c.API.Server.OnlineClient.Sharing = ptr.Of(true) + } + } + if err := c.LoadDBConfig(inCli); err != nil { return err } diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index 96945202aa8..17802ba31dd 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -101,7 +101,7 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { CredentialsFilePath: "./testdata/bad_lapi-secrets.yaml", }, expected: &ApiCredentialsCfg{}, - expectedErr: "failed unmarshaling api server credentials", + expectedErr: "failed to parse api server credentials", }, { name: "missing field configuration", @@ -212,6 +212,11 @@ func TestLoadAPIServer(t *testing.T) { Login: "test", Password: "testpassword", }, + Sharing: ptr.Of(true), + PullConfig: CapiPullConfig{ + Community: ptr.Of(true), + Blocklists: ptr.Of(true), + }, }, Profiles: tmpLAPI.Profiles, ProfilesPath: "./testdata/profiles.yaml", diff --git a/pkg/csconfig/config_paths.go b/pkg/csconfig/config_paths.go index 7675b90d7dd..a8d39a664f3 100644 --- a/pkg/csconfig/config_paths.go +++ b/pkg/csconfig/config_paths.go @@ -10,7 +10,7 @@ type ConfigurationPaths struct { ConfigDir string `yaml:"config_dir"` DataDir string `yaml:"data_dir,omitempty"` SimulationFilePath string `yaml:"simulation_path,omitempty"` - HubIndexFile string `yaml:"index_path,omitempty"` //path of the .index.json + HubIndexFile string `yaml:"index_path,omitempty"` // path of the .index.json HubDir string `yaml:"hub_dir,omitempty"` PluginDir string `yaml:"plugin_dir,omitempty"` NotificationDir string `yaml:"notification_dir,omitempty"` @@ -28,18 +28,18 @@ func (c *Config) loadConfigurationPaths() error { } if c.ConfigPaths.HubDir == "" { - c.ConfigPaths.HubDir = filepath.Clean(c.ConfigPaths.ConfigDir + "/hub") + c.ConfigPaths.HubDir = filepath.Join(c.ConfigPaths.ConfigDir, "hub") } if c.ConfigPaths.HubIndexFile == "" { - c.ConfigPaths.HubIndexFile = filepath.Clean(c.ConfigPaths.HubDir + "/.index.json") + c.ConfigPaths.HubIndexFile = filepath.Join(c.ConfigPaths.HubDir, ".index.json") } if c.ConfigPaths.PatternDir == "" { - c.ConfigPaths.PatternDir = filepath.Join(c.ConfigPaths.ConfigDir, "patterns/") + c.ConfigPaths.PatternDir = filepath.Join(c.ConfigPaths.ConfigDir, "patterns") } - var configPathsCleanup = []*string{ + configPathsCleanup := []*string{ &c.ConfigPaths.HubDir, &c.ConfigPaths.HubIndexFile, &c.ConfigPaths.ConfigDir, diff --git a/pkg/csconfig/config_test.go b/pkg/csconfig/config_test.go index 11f1f0cf68d..b69954de178 100644 --- a/pkg/csconfig/config_test.go +++ b/pkg/csconfig/config_test.go @@ -42,5 +42,5 @@ func TestNewCrowdSecConfig(t *testing.T) { func TestDefaultConfig(t *testing.T) { x := NewDefaultConfig() _, err := yaml.Marshal(x) - require.NoError(t, err, "failed marshaling config: %s", err) + require.NoError(t, err, "failed to serialize config: %s", err) } diff --git a/pkg/csconfig/console.go b/pkg/csconfig/console.go index 4c14f5f7d49..21ecbf3d736 100644 --- a/pkg/csconfig/console.go +++ b/pkg/csconfig/console.go @@ -95,7 +95,7 @@ func (c *LocalApiServerCfg) LoadConsoleConfig() error { err = yaml.Unmarshal(yamlFile, c.ConsoleConfig) if err != nil { - return fmt.Errorf("unmarshaling console config file '%s': %w", c.ConsoleConfigPath, err) + return fmt.Errorf("parsing console config file '%s': %w", c.ConsoleConfigPath, err) } if c.ConsoleConfig.ShareCustomScenarios == nil { diff --git a/pkg/csconfig/crowdsec_service.go b/pkg/csconfig/crowdsec_service.go index 7820595b46f..cf796805dee 100644 --- a/pkg/csconfig/crowdsec_service.go +++ b/pkg/csconfig/crowdsec_service.go @@ -143,14 +143,14 @@ func (c *CrowdsecServiceCfg) DumpContextConfigFile() error { // XXX: MakeDirs out, err := yaml.Marshal(c.ContextToSend) if err != nil { - return fmt.Errorf("while marshaling ConsoleConfig (for %s): %w", c.ConsoleContextPath, err) + return fmt.Errorf("while serializing ConsoleConfig (for %s): %w", c.ConsoleContextPath, err) } - if err = os.MkdirAll(filepath.Dir(c.ConsoleContextPath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(c.ConsoleContextPath), 0o700); err != nil { return fmt.Errorf("while creating directories for %s: %w", c.ConsoleContextPath, err) } - if err := os.WriteFile(c.ConsoleContextPath, out, 0600); err != nil { + if err := os.WriteFile(c.ConsoleContextPath, out, 0o600); err != nil { return fmt.Errorf("while dumping console config to %s: %w", c.ConsoleContextPath, err) } diff --git a/pkg/csconfig/simulation.go b/pkg/csconfig/simulation.go index 947b47e3c1e..c9041df464a 100644 --- a/pkg/csconfig/simulation.go +++ b/pkg/csconfig/simulation.go @@ -37,7 +37,7 @@ func (c *Config) LoadSimulation() error { simCfg := SimulationConfig{} if c.ConfigPaths.SimulationFilePath == "" { - c.ConfigPaths.SimulationFilePath = filepath.Clean(c.ConfigPaths.ConfigDir + "/simulation.yaml") + c.ConfigPaths.SimulationFilePath = filepath.Join(c.ConfigPaths.ConfigDir, "simulation.yaml") } patcher := yamlpatch.NewPatcher(c.ConfigPaths.SimulationFilePath, ".local") @@ -52,7 +52,7 @@ func (c *Config) LoadSimulation() error { if err := dec.Decode(&simCfg); err != nil { if !errors.Is(err, io.EOF) { - return fmt.Errorf("while unmarshaling simulation file '%s': %w", c.ConfigPaths.SimulationFilePath, err) + return fmt.Errorf("while parsing simulation file '%s': %w", c.ConfigPaths.SimulationFilePath, err) } } diff --git a/pkg/csconfig/simulation_test.go b/pkg/csconfig/simulation_test.go index a678d7edd49..a1e5f0a5b02 100644 --- a/pkg/csconfig/simulation_test.go +++ b/pkg/csconfig/simulation_test.go @@ -60,7 +60,7 @@ func TestSimulationLoading(t *testing.T) { }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: "while unmarshaling simulation file './testdata/config.yaml': yaml: unmarshal errors", + expectedErr: "while parsing simulation file './testdata/config.yaml': yaml: unmarshal errors", }, { name: "basic bad file content", @@ -71,7 +71,7 @@ func TestSimulationLoading(t *testing.T) { }, Crowdsec: &CrowdsecServiceCfg{}, }, - expectedErr: "while unmarshaling simulation file './testdata/config.yaml': yaml: unmarshal errors", + expectedErr: "while parsing simulation file './testdata/config.yaml': yaml: unmarshal errors", }, } diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index f6629b2609e..e996fa9b68c 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -45,7 +45,7 @@ type PluginBroker struct { pluginConfigByName map[string]PluginConfig pluginMap map[string]plugin.Plugin notificationConfigsByPluginType map[string][][]byte // "slack" -> []{config1, config2} - notificationPluginByName map[string]Notifier + notificationPluginByName map[string]protobufs.NotifierServer watcher PluginWatcher pluginKillMethods []func() pluginProcConfig *csconfig.PluginCfg @@ -72,10 +72,10 @@ type ProfileAlert struct { Alert *models.Alert } -func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { +func (pb *PluginBroker) Init(ctx context.Context, pluginCfg *csconfig.PluginCfg, profileConfigs []*csconfig.ProfileCfg, configPaths *csconfig.ConfigurationPaths) error { pb.PluginChannel = make(chan ProfileAlert) pb.notificationConfigsByPluginType = make(map[string][][]byte) - pb.notificationPluginByName = make(map[string]Notifier) + pb.notificationPluginByName = make(map[string]protobufs.NotifierServer) pb.pluginMap = make(map[string]plugin.Plugin) pb.pluginConfigByName = make(map[string]PluginConfig) pb.alertsByPluginName = make(map[string][]*models.Alert) @@ -85,7 +85,7 @@ func (pb *PluginBroker) Init(pluginCfg *csconfig.PluginCfg, profileConfigs []*cs if err := pb.loadConfig(configPaths.NotificationDir); err != nil { return fmt.Errorf("while loading plugin config: %w", err) } - if err := pb.loadPlugins(configPaths.PluginDir); err != nil { + if err := pb.loadPlugins(ctx, configPaths.PluginDir); err != nil { return fmt.Errorf("while loading plugin: %w", err) } pb.watcher = PluginWatcher{} @@ -230,7 +230,7 @@ func (pb *PluginBroker) verifyPluginBinaryWithProfile() error { return nil } -func (pb *PluginBroker) loadPlugins(path string) error { +func (pb *PluginBroker) loadPlugins(ctx context.Context, path string) error { binaryPaths, err := listFilesAtPath(path) if err != nil { return err @@ -265,7 +265,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return err } data = []byte(csstring.StrictExpand(string(data), os.LookupEnv)) - _, err = pluginClient.Configure(context.Background(), &protobufs.Config{Config: data}) + _, err = pluginClient.Configure(ctx, &protobufs.Config{Config: data}) if err != nil { return fmt.Errorf("while configuring %s: %w", pc.Name, err) } @@ -276,7 +276,7 @@ func (pb *PluginBroker) loadPlugins(path string) error { return pb.verifyPluginBinaryWithProfile() } -func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (Notifier, error) { +func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) (protobufs.NotifierServer, error) { handshake, err := getHandshake() if err != nil { @@ -313,7 +313,7 @@ func (pb *PluginBroker) loadNotificationPlugin(name string, binaryPath string) ( return nil, err } pb.pluginKillMethods = append(pb.pluginKillMethods, c.Kill) - return raw.(Notifier), nil + return raw.(protobufs.NotifierServer), nil } func (pb *PluginBroker) pushNotificationsToPlugin(pluginName string, alerts []*models.Alert) error { diff --git a/pkg/csplugin/broker_suite_test.go b/pkg/csplugin/broker_suite_test.go index 778bb2dfe2e..1210c67058a 100644 --- a/pkg/csplugin/broker_suite_test.go +++ b/pkg/csplugin/broker_suite_test.go @@ -1,6 +1,7 @@ package csplugin import ( + "context" "io" "os" "os/exec" @@ -96,6 +97,7 @@ func (s *PluginSuite) TearDownTest() { func (s *PluginSuite) SetupSubTest() { var err error + t := s.T() s.runDir, err = os.MkdirTemp("", "cs_plugin_test") @@ -127,6 +129,7 @@ func (s *PluginSuite) SetupSubTest() { func (s *PluginSuite) TearDownSubTest() { t := s.T() + if s.pluginBroker != nil { s.pluginBroker.Kill() s.pluginBroker = nil @@ -140,19 +143,24 @@ func (s *PluginSuite) TearDownSubTest() { os.Remove("./out") } -func (s *PluginSuite) InitBroker(procCfg *csconfig.PluginCfg) (*PluginBroker, error) { +func (s *PluginSuite) InitBroker(ctx context.Context, procCfg *csconfig.PluginCfg) (*PluginBroker, error) { pb := PluginBroker{} + if procCfg == nil { procCfg = &csconfig.PluginCfg{} } + profiles := csconfig.NewDefaultConfig().API.Server.Profiles profiles = append(profiles, &csconfig.ProfileCfg{ Notifications: []string{"dummy_default"}, }) - err := pb.Init(procCfg, profiles, &csconfig.ConfigurationPaths{ + + err := pb.Init(ctx, procCfg, profiles, &csconfig.ConfigurationPaths{ PluginDir: s.pluginDir, NotificationDir: s.notifDir, }) + s.pluginBroker = &pb + return s.pluginBroker, err } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index f2179acb2c1..ae5a615b489 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -38,7 +39,7 @@ func (s *PluginSuite) readconfig() PluginConfig { require.NoError(t, err, "unable to read config file %s", s.pluginConfig) err = yaml.Unmarshal(orig, &config) - require.NoError(t, err, "unable to unmarshal config file") + require.NoError(t, err, "unable to parse config file") return config } @@ -46,13 +47,14 @@ func (s *PluginSuite) readconfig() PluginConfig { func (s *PluginSuite) writeconfig(config PluginConfig) { t := s.T() data, err := yaml.Marshal(&config) - require.NoError(t, err, "unable to marshal config file") + require.NoError(t, err, "unable to serialize config file") err = os.WriteFile(s.pluginConfig, data, 0o644) require.NoError(t, err, "unable to write config file %s", s.pluginConfig) } func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -135,20 +137,22 @@ func (s *PluginSuite) TestBrokerInit() { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerNoThreshold() { + ctx := context.Background() + var alerts []models.Alert DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -187,6 +191,8 @@ func (s *PluginSuite) TestBrokerNoThreshold() { } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { + ctx := context.Background() + // test grouping by "time" DefaultEmptyTicker = 50 * time.Millisecond @@ -198,7 +204,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -224,6 +230,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { } func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -234,7 +241,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { cfg.GroupWait = 4 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -264,6 +271,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { } func (s *PluginSuite) TestBrokerRunGroupThreshold() { + ctx := context.Background() // test grouping by "size" DefaultEmptyTicker = 50 * time.Millisecond @@ -274,7 +282,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { cfg.GroupThreshold = 4 s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -318,6 +326,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { } func (s *PluginSuite) TestBrokerRunTimeThreshold() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() @@ -327,7 +336,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { cfg.GroupWait = 1 * time.Second s.writeconfig(cfg) - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} @@ -353,11 +362,12 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { } func (s *PluginSuite) TestBrokerRunSimple() { + ctx := context.Background() DefaultEmptyTicker = 50 * time.Millisecond t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 97a3ad33deb..570f23e5015 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -26,6 +27,7 @@ not if it will actually reject plugins with invalid permissions */ func (s *PluginSuite) TestBrokerInit() { + ctx := context.Background() tests := []struct { name string action func(*testing.T) @@ -54,22 +56,22 @@ func (s *PluginSuite) TestBrokerInit() { } for _, tc := range tests { - tc := tc s.Run(tc.name, func() { t := s.T() if tc.action != nil { tc.action(t) } - _, err := s.InitBroker(&tc.procCfg) + _, err := s.InitBroker(ctx, &tc.procCfg) cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } func (s *PluginSuite) TestBrokerRun() { + ctx := context.Background() t := s.T() - pb, err := s.InitBroker(nil) + pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) tomb := tomb.Tomb{} diff --git a/pkg/csplugin/listfiles_test.go b/pkg/csplugin/listfiles_test.go index a4188804149..c476d7a4e4a 100644 --- a/pkg/csplugin/listfiles_test.go +++ b/pkg/csplugin/listfiles_test.go @@ -21,7 +21,7 @@ func TestListFilesAtPath(t *testing.T) { require.NoError(t, err) _, err = os.Create(filepath.Join(dir, "slack")) require.NoError(t, err) - err = os.Mkdir(filepath.Join(dir, "somedir"), 0755) + err = os.Mkdir(filepath.Join(dir, "somedir"), 0o755) require.NoError(t, err) _, err = os.Create(filepath.Join(dir, "somedir", "inner")) require.NoError(t, err) diff --git a/pkg/csplugin/notifier.go b/pkg/csplugin/notifier.go index 2b5d57fbcff..615322ac0c3 100644 --- a/pkg/csplugin/notifier.go +++ b/pkg/csplugin/notifier.go @@ -10,17 +10,15 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/protobufs" ) -type Notifier interface { - Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) - Configure(ctx context.Context, cfg *protobufs.Config) (*protobufs.Empty, error) -} - type NotifierPlugin struct { plugin.Plugin - Impl Notifier + Impl protobufs.NotifierServer } -type GRPCClient struct{ client protobufs.NotifierClient } +type GRPCClient struct{ + protobufs.UnimplementedNotifierServer + client protobufs.NotifierClient +} func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notification) (*protobufs.Empty, error) { done := make(chan error) @@ -40,14 +38,12 @@ func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notific } func (m *GRPCClient) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) { - _, err := m.client.Configure( - context.Background(), config, - ) + _, err := m.client.Configure(ctx, config) return &protobufs.Empty{}, err } type GRPCServer struct { - Impl Notifier + Impl protobufs.NotifierServer } func (p *NotifierPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { diff --git a/pkg/csplugin/utils.go b/pkg/csplugin/utils.go index 2e7f0c80528..571d78add56 100644 --- a/pkg/csplugin/utils.go +++ b/pkg/csplugin/utils.go @@ -123,10 +123,10 @@ func pluginIsValid(path string) error { mode := details.Mode() perm := uint32(mode) - if (perm & 00002) != 0 { + if (perm & 0o0002) != 0 { return fmt.Errorf("plugin at %s is world writable, world writable plugins are invalid", path) } - if (perm & 00020) != 0 { + if (perm & 0o0020) != 0 { return fmt.Errorf("plugin at %s is group writable, group writable plugins are invalid", path) } if (mode & os.ModeSetgid) != 0 { diff --git a/pkg/csplugin/utils_windows.go b/pkg/csplugin/utils_windows.go index 8d4956ceeeb..91002079398 100644 --- a/pkg/csplugin/utils_windows.go +++ b/pkg/csplugin/utils_windows.go @@ -116,7 +116,7 @@ func CheckPerms(path string) error { */ aceCount := rs.Field(3).Uint() - for i := uint64(0); i < aceCount; i++ { + for i := range aceCount { ace := &AccessAllowedAce{} ret, _, _ := procGetAce.Call(uintptr(unsafe.Pointer(dacl)), uintptr(i), uintptr(unsafe.Pointer(&ace))) if ret == 0 { diff --git a/pkg/csplugin/utils_windows_test.go b/pkg/csplugin/utils_windows_test.go index 6a76e1215e5..1eb4dfb9033 100644 --- a/pkg/csplugin/utils_windows_test.go +++ b/pkg/csplugin/utils_windows_test.go @@ -37,7 +37,6 @@ func TestGetPluginNameAndTypeFromPath(t *testing.T) { }, } for _, tc := range tests { - tc := tc t.Run(tc.name, func(t *testing.T) { got, got1, err := getPluginTypeAndSubtypeFromPath(tc.path) cstest.RequireErrorContains(t, err, tc.expectedErr) diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index b76c3c4eadd..84e63ec6493 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -15,11 +15,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" ) -var ctx = context.Background() - func resetTestTomb(testTomb *tomb.Tomb, pw *PluginWatcher) { testTomb.Kill(nil) <-pw.PluginEvents + if err := testTomb.Wait(); err != nil { log.Fatal(err) } @@ -46,13 +45,17 @@ func listenChannelWithTimeout(ctx context.Context, channel chan string) error { case <-ctx.Done(): return ctx.Err() } + return nil } func TestPluginWatcherInterval(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) testTomb := tomb.Tomb{} @@ -66,6 +69,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") resetTestTomb(&testTomb, &pw) @@ -74,6 +78,7 @@ func TestPluginWatcherInterval(t *testing.T) { ct, cancel = context.WithTimeout(ctx, time.Millisecond*5) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) @@ -81,9 +86,12 @@ func TestPluginWatcherInterval(t *testing.T) { } func TestPluginAlertCountWatcher(t *testing.T) { + ctx := context.Background() + if runtime.GOOS == "windows" { t.Skip("Skipping test on windows because timing is not reliable") } + pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) configs := map[string]PluginConfig{ @@ -92,28 +100,34 @@ func TestPluginAlertCountWatcher(t *testing.T) { }, } testTomb := tomb.Tomb{} + pw.Init(configs, alertsByPluginName) pw.Start(&testTomb) // Channel won't contain any events since threshold is not crossed. ct, cancel := context.WithTimeout(ctx, time.Second) defer cancel() + err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel won't contain any events since threshold is not crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 4, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") // Channel will contain an event since threshold is crossed. resetWatcherAlertCounter(&pw) insertNAlertsToPlugin(&pw, 5, "testPlugin") + ct, cancel = context.WithTimeout(ctx, time.Second) defer cancel() + err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) resetTestTomb(&testTomb, &pw) diff --git a/pkg/cticlient/client.go b/pkg/cticlient/client.go index b817121e222..90112d80abf 100644 --- a/pkg/cticlient/client.go +++ b/pkg/cticlient/client.go @@ -8,8 +8,9 @@ import ( "net/http" "strings" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" log "github.com/sirupsen/logrus" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) const ( @@ -46,7 +47,7 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map } req.Header.Set("X-Api-Key", c.apiKey) - req.Header.Set("User-Agent", cwversion.UserAgent()) + req.Header.Set("User-Agent", useragent.Default()) resp, err := c.httpClient.Do(req) if err != nil { diff --git a/pkg/cwhub/cwhub.go b/pkg/cwhub/cwhub.go index 0a9cc443ce0..683f1853b43 100644 --- a/pkg/cwhub/cwhub.go +++ b/pkg/cwhub/cwhub.go @@ -4,11 +4,10 @@ import ( "fmt" "net/http" "path/filepath" - "sort" "strings" "time" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) // hubTransport wraps a Transport to set a custom User-Agent. @@ -17,7 +16,7 @@ type hubTransport struct { } func (t *hubTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", cwversion.UserAgent()) + req.Header.Set("User-Agent", useragent.Default()) return t.RoundTripper.RoundTrip(req) } @@ -45,10 +44,3 @@ func safePath(dir, filePath string) (string, error) { return absFilePath, nil } - -// SortItemSlice sorts a slice of items by name, case insensitive. -func SortItemSlice(items []*Item) { - sort.Slice(items, func(i, j int) bool { - return strings.ToLower(items[i].Name) < strings.ToLower(items[j].Name) - }) -} diff --git a/pkg/cwhub/cwhub_test.go b/pkg/cwhub/cwhub_test.go index a4641483622..17e7a0dc723 100644 --- a/pkg/cwhub/cwhub_test.go +++ b/pkg/cwhub/cwhub_test.go @@ -146,7 +146,7 @@ func setResponseByPath() { "/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_parser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), "/crowdsecurity/master/parsers/s01-parse/crowdsecurity/foobar_subparser.yaml": fileToStringX("./testdata/foobar_parser.yaml"), "/crowdsecurity/master/collections/crowdsecurity/test_collection.yaml": fileToStringX("./testdata/collection_v1.yaml"), - "/crowdsecurity/master/.index.json": fileToStringX("./testdata/index1.json"), + "/crowdsecurity/master/.index.json": fileToStringX("./testdata/index1.json"), "/crowdsecurity/master/scenarios/crowdsecurity/foobar_scenario.yaml": `filter: true name: crowdsecurity/foobar_scenario`, "/crowdsecurity/master/scenarios/crowdsecurity/barfoo_scenario.yaml": `filter: true diff --git a/pkg/cwhub/doc.go b/pkg/cwhub/doc.go index 89d8de0fa8b..f86b95c6454 100644 --- a/pkg/cwhub/doc.go +++ b/pkg/cwhub/doc.go @@ -74,7 +74,7 @@ // Now you can use the hub object to access the existing items: // // // list all the parsers -// for _, parser := range hub.GetItemMap(cwhub.PARSERS) { +// for _, parser := range hub.GetItemsByType(cwhub.PARSERS, false) { // fmt.Printf("parser: %s\n", parser.Name) // } // diff --git a/pkg/cwhub/errors.go b/pkg/cwhub/errors.go index f1e779b5476..b0be444fcba 100644 --- a/pkg/cwhub/errors.go +++ b/pkg/cwhub/errors.go @@ -5,10 +5,8 @@ import ( "fmt" ) -var ( - // ErrNilRemoteHub is returned when trying to download with a local-only configuration. - ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers") -) +// ErrNilRemoteHub is returned when trying to download with a local-only configuration. +var ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers") // IndexNotFoundError is returned when the remote hub index is not found. type IndexNotFoundError struct { diff --git a/pkg/cwhub/hub.go b/pkg/cwhub/hub.go index 1293d6fa235..f74a794a512 100644 --- a/pkg/cwhub/hub.go +++ b/pkg/cwhub/hub.go @@ -8,11 +8,12 @@ import ( "io" "os" "path" - "slices" "strings" "github.com/sirupsen/logrus" + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) @@ -78,7 +79,7 @@ func (h *Hub) parseIndex() error { } if err := json.Unmarshal(bidx, &h.items); err != nil { - return fmt.Errorf("failed to unmarshal index: %w", err) + return fmt.Errorf("failed to parse index: %w", err) } h.logger.Debugf("%d item types in hub index", len(ItemTypes)) @@ -117,13 +118,14 @@ func (h *Hub) ItemStats() []string { tainted := 0 for _, itemType := range ItemTypes { - if len(h.GetItemMap(itemType)) == 0 { + items := h.GetItemsByType(itemType, false) + if len(items) == 0 { continue } - loaded += fmt.Sprintf("%d %s, ", len(h.GetItemMap(itemType)), itemType) + loaded += fmt.Sprintf("%d %s, ", len(items), itemType) - for _, item := range h.GetItemMap(itemType) { + for _, item := range items { if item.State.IsLocal() { local++ } @@ -153,7 +155,7 @@ func (h *Hub) ItemStats() []string { // Update downloads the latest version of the index and writes it to disk if it changed. It cannot be called after Load() // unless the hub is completely empty. func (h *Hub) Update(ctx context.Context) error { - if h.pathIndex != nil && len(h.pathIndex) > 0 { + if len(h.pathIndex) > 0 { // if this happens, it's a bug. return errors.New("cannot update hub after items have been loaded") } @@ -218,73 +220,62 @@ func (h *Hub) GetItemFQ(itemFQName string) (*Item, error) { return i, nil } -// GetNamesByType returns a slice of (full) item names for a given type -// (eg. for collections: crowdsecurity/apache2 crowdsecurity/nginx). -func (h *Hub) GetNamesByType(itemType string) []string { - m := h.GetItemMap(itemType) - if m == nil { - return nil - } +// GetItemsByType returns a slice of all the items of a given type, installed or not, optionally sorted by case-insensitive name. +// A non-existent type will silently return an empty slice. +func (h *Hub) GetItemsByType(itemType string, sorted bool) []*Item { + items := h.items[itemType] - names := make([]string, 0, len(m)) - for k := range m { - names = append(names, k) - } + ret := make([]*Item, len(items)) - return names -} + if sorted { + for idx, name := range maptools.SortedKeysNoCase(items) { + ret[idx] = items[name] + } -// GetItemsByType returns a slice of all the items of a given type, installed or not. -func (h *Hub) GetItemsByType(itemType string) ([]*Item, error) { - if !slices.Contains(ItemTypes, itemType) { - return nil, fmt.Errorf("invalid item type %s", itemType) + return ret } - items := h.items[itemType] - - ret := make([]*Item, len(items)) - idx := 0 - for _, item := range items { ret[idx] = item - idx++ + idx += 1 } - return ret, nil + return ret } -// GetInstalledItemsByType returns a slice of the installed items of a given type. -func (h *Hub) GetInstalledItemsByType(itemType string) ([]*Item, error) { - if !slices.Contains(ItemTypes, itemType) { - return nil, fmt.Errorf("invalid item type %s", itemType) - } - - items := h.items[itemType] +// GetInstalledByType returns a slice of all the installed items of a given type, optionally sorted by case-insensitive name. +// A non-existent type will silently return an empty slice. +func (h *Hub) GetInstalledByType(itemType string, sorted bool) []*Item { + ret := make([]*Item, 0) - retItems := make([]*Item, 0) - - for _, item := range items { + for _, item := range h.GetItemsByType(itemType, sorted) { if item.State.Installed { - retItems = append(retItems, item) + ret = append(ret, item) } } - return retItems, nil + return ret } -// GetInstalledNamesByType returns the names of the installed items of a given type. -func (h *Hub) GetInstalledNamesByType(itemType string) ([]string, error) { - items, err := h.GetInstalledItemsByType(itemType) - if err != nil { - return nil, err - } +// GetInstalledListForAPI returns a slice of names of all the installed scenarios and appsec-rules. +// The returned list is sorted by type (scenarios first) and case-insensitive name. +func (h *Hub) GetInstalledListForAPI() []string { + scenarios := h.GetInstalledByType(SCENARIOS, true) + appsecRules := h.GetInstalledByType(APPSEC_RULES, true) + + ret := make([]string, len(scenarios)+len(appsecRules)) - retStr := make([]string, len(items)) + idx := 0 + for _, item := range scenarios { + ret[idx] = item.Name + idx += 1 + } - for idx, it := range items { - retStr[idx] = it.Name + for _, item := range appsecRules { + ret[idx] = item.Name + idx += 1 } - return retStr, nil + return ret } diff --git a/pkg/cwhub/relativepath.go b/pkg/cwhub/relativepath.go new file mode 100644 index 00000000000..bcd4c576840 --- /dev/null +++ b/pkg/cwhub/relativepath.go @@ -0,0 +1,28 @@ +package cwhub + +import ( + "path/filepath" + "strings" +) + +// relativePathComponents returns the list of path components after baseDir. +// If path is not inside baseDir, it returns an empty slice. +func relativePathComponents(path string, baseDir string) []string { + absPath, err := filepath.Abs(path) + if err != nil { + return []string{} + } + + absBaseDir, err := filepath.Abs(baseDir) + if err != nil { + return []string{} + } + + // is path inside baseDir? + relPath, err := filepath.Rel(absBaseDir, absPath) + if err != nil || strings.HasPrefix(relPath, "..") || relPath == "." { + return []string{} + } + + return strings.Split(relPath, string(filepath.Separator)) +} diff --git a/pkg/cwhub/relativepath_test.go b/pkg/cwhub/relativepath_test.go new file mode 100644 index 00000000000..11eba566064 --- /dev/null +++ b/pkg/cwhub/relativepath_test.go @@ -0,0 +1,72 @@ +package cwhub + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRelativePathComponents(t *testing.T) { + tests := []struct { + name string + path string + baseDir string + expected []string + }{ + { + name: "Path within baseDir", + path: "/home/user/project/src/file.go", + baseDir: "/home/user/project", + expected: []string{"src", "file.go"}, + }, + { + name: "Path is baseDir", + path: "/home/user/project", + baseDir: "/home/user/project", + expected: []string{}, + }, + { + name: "Path outside baseDir", + path: "/home/user/otherproject/src/file.go", + baseDir: "/home/user/project", + expected: []string{}, + }, + { + name: "Path is subdirectory of baseDir", + path: "/home/user/project/src/", + baseDir: "/home/user/project", + expected: []string{"src"}, + }, + { + name: "Relative paths", + path: "project/src/file.go", + baseDir: "project", + expected: []string{"src", "file.go"}, + }, + { + name: "BaseDir with trailing slash", + path: "/home/user/project/src/file.go", + baseDir: "/home/user/project/", + expected: []string{"src", "file.go"}, + }, + { + name: "Empty baseDir", + path: "/home/user/project/src/file.go", + baseDir: "", + expected: []string{}, + }, + { + name: "Empty path", + path: "", + baseDir: "/home/user/project", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := relativePathComponents(tt.path, tt.baseDir) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/cwhub/sync.go b/pkg/cwhub/sync.go index 38bb376ae3b..c82822e64ef 100644 --- a/pkg/cwhub/sync.go +++ b/pkg/cwhub/sync.go @@ -20,22 +20,49 @@ func isYAMLFileName(path string) bool { return strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml") } -// linkTarget returns the target of a symlink, or empty string if it's dangling. -func linkTarget(path string, logger *logrus.Logger) (string, error) { - hubpath, err := os.Readlink(path) - if err != nil { - return "", fmt.Errorf("unable to read symlink: %s", path) +// resolveSymlink returns the ultimate target path of a symlink +// returns error if the symlink is dangling or too many symlinks are followed +func resolveSymlink(path string) (string, error) { + const maxSymlinks = 10 // Prevent infinite loops + for range maxSymlinks { + fi, err := os.Lstat(path) + if err != nil { + return "", err // dangling link + } + + if fi.Mode()&os.ModeSymlink == 0 { + // found the target + return path, nil + } + + path, err = os.Readlink(path) + if err != nil { + return "", err + } + + // relative to the link's directory? + if !filepath.IsAbs(path) { + path = filepath.Join(filepath.Dir(path), path) + } } - logger.Tracef("symlink %s -> %s", path, hubpath) + return "", errors.New("too many levels of symbolic links") +} - _, err = os.Lstat(hubpath) - if os.IsNotExist(err) { - logger.Warningf("link target does not exist: %s -> %s", path, hubpath) - return "", nil +// isPathInside checks if a path is inside the given directory +// it can return false negatives if the filesystem is case insensitive +func isPathInside(path, dir string) (bool, error) { + absFilePath, err := filepath.Abs(path) + if err != nil { + return false, err + } + + absDir, err := filepath.Abs(dir) + if err != nil { + return false, err } - return hubpath, nil + return strings.HasPrefix(absFilePath, absDir), nil } // information used to create a new Item, from a file path. @@ -53,58 +80,76 @@ func (h *Hub) getItemFileInfo(path string, logger *logrus.Logger) (*itemFileInfo hubDir := h.local.HubDir installDir := h.local.InstallDir - subs := strings.Split(path, string(os.PathSeparator)) + subsHub := relativePathComponents(path, hubDir) + subsInstall := relativePathComponents(path, installDir) - logger.Tracef("path:%s, hubdir:%s, installdir:%s", path, hubDir, installDir) - logger.Tracef("subs:%v", subs) - // we're in hub (~/.hub/hub/) - if strings.HasPrefix(path, hubDir) { + switch { + case len(subsHub) > 0: logger.Tracef("in hub dir") - // .../hub/parsers/s00-raw/crowdsec/skip-pretag.yaml - // .../hub/scenarios/crowdsec/ssh_bf.yaml - // .../hub/profiles/crowdsec/linux.yaml - if len(subs) < 4 { - return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs)) + // .../hub/parsers/s00-raw/crowdsecurity/skip-pretag.yaml + // .../hub/scenarios/crowdsecurity/ssh_bf.yaml + // .../hub/profiles/crowdsecurity/linux.yaml + if len(subsHub) < 3 { + return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subsHub)) + } + + ftype := subsHub[0] + if !slices.Contains(ItemTypes, ftype) { + // this doesn't really happen anymore, because we only scan the {hubtype} directories + return nil, fmt.Errorf("unknown configuration type '%s'", ftype) + } + + stage := "" + fauthor := subsHub[1] + fname := subsHub[2] + + if ftype == PARSERS || ftype == POSTOVERFLOWS { + stage = subsHub[1] + fauthor = subsHub[2] + fname = subsHub[3] } ret = &itemFileInfo{ inhub: true, - fname: subs[len(subs)-1], - fauthor: subs[len(subs)-2], - stage: subs[len(subs)-3], - ftype: subs[len(subs)-4], + ftype: ftype, + stage: stage, + fauthor: fauthor, + fname: fname, } - } else if strings.HasPrefix(path, installDir) { // we're in install /etc/crowdsec//... + + case len(subsInstall) > 0: logger.Tracef("in install dir") - if len(subs) < 3 { - return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs)) - } // .../config/parser/stage/file.yaml // .../config/postoverflow/stage/file.yaml // .../config/scenarios/scenar.yaml // .../config/collections/linux.yaml //file is empty - ret = &itemFileInfo{ - inhub: false, - fname: subs[len(subs)-1], - stage: subs[len(subs)-2], - ftype: subs[len(subs)-3], - fauthor: "", + + if len(subsInstall) < 2 { + return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subsInstall)) } - } else { - return nil, fmt.Errorf("file '%s' is not from hub '%s' nor from the configuration directory '%s'", path, hubDir, installDir) - } - logger.Tracef("stage:%s ftype:%s", ret.stage, ret.ftype) + // this can be in any number of subdirs, we join them to compose the item name + + ftype := subsInstall[0] + stage := "" + fname := strings.Join(subsInstall[1:], "/") - if ret.ftype != PARSERS && ret.ftype != POSTOVERFLOWS { - if !slices.Contains(ItemTypes, ret.stage) { - return nil, errors.New("unknown configuration type") + if ftype == PARSERS || ftype == POSTOVERFLOWS { + stage = subsInstall[1] + fname = strings.Join(subsInstall[2:], "/") } - ret.ftype = ret.stage - ret.stage = "" + ret = &itemFileInfo{ + inhub: false, + ftype: ftype, + stage: stage, + fauthor: "", + fname: fname, + } + default: + return nil, fmt.Errorf("file '%s' is not from hub '%s' nor from the configuration directory '%s'", path, hubDir, installDir) } logger.Tracef("CORRECTED [%s] by [%s] in stage [%s] of type [%s]", ret.fname, ret.fauthor, ret.stage, ret.ftype) @@ -165,7 +210,7 @@ func newLocalItem(h *Hub, path string, info *itemFileInfo) (*Item, error) { err = yaml.Unmarshal(itemContent, &itemName) if err != nil { - return nil, fmt.Errorf("failed to unmarshal %s: %w", path, err) + return nil, fmt.Errorf("failed to parse %s: %w", path, err) } if itemName.Name != "" { @@ -176,8 +221,6 @@ func newLocalItem(h *Hub, path string, info *itemFileInfo) (*Item, error) { } func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { - hubpath := "" - if err != nil { h.logger.Debugf("while syncing hub dir: %s", err) // there is a path error, we ignore the file @@ -190,8 +233,26 @@ func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { return err } + // permission errors, files removed while reading, etc. + if f == nil { + return nil + } + + if f.IsDir() { + // if a directory starts with a dot, we don't traverse it + // - single dot prefix is hidden by unix convention + // - double dot prefix is used by k8s to mount config maps + if strings.HasPrefix(f.Name(), ".") { + h.logger.Tracef("skipping hidden directory %s", path) + return filepath.SkipDir + } + + // keep traversing + return nil + } + // we only care about YAML files - if f == nil || f.IsDir() || !isYAMLFileName(f.Name()) { + if !isYAMLFileName(f.Name()) { return nil } @@ -201,35 +262,38 @@ func (h *Hub) itemVisit(path string, f os.DirEntry, err error) error { return nil } - // non symlinks are local user files or hub files - if f.Type()&os.ModeSymlink == 0 { - h.logger.Tracef("%s is not a symlink", path) - - if !info.inhub { - h.logger.Tracef("%s is a local file, skip", path) + // follow the link to see if it falls in the hub directory + // if it's not a link, target == path + target, err := resolveSymlink(path) + if err != nil { + // target does not exist, the user might have removed the file + // or switched to a hub branch without it; or symlink loop + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } - item, err := newLocalItem(h, path, info) - if err != nil { - return err - } + targetInHub, err := isPathInside(target, h.local.HubDir) + if err != nil { + h.logger.Warningf("Ignoring file %s: %s", path, err) + return nil + } - h.addItem(item) + // local (custom) item if the file or link target is not inside the hub dir + if !targetInHub { + h.logger.Tracef("%s is a local file, skip", path) - return nil - } - } else { - hubpath, err = linkTarget(path, h.logger) + item, err := newLocalItem(h, path, info) if err != nil { return err } - if hubpath == "" { - // target does not exist, the user might have removed the file - // or switched to a hub branch without it - return nil - } + h.addItem(item) + + return nil } + hubpath := target + // try to find which configuration item it is h.logger.Tracef("check [%s] of %s", info.fname, info.ftype) diff --git a/pkg/cwversion/component/component.go b/pkg/cwversion/component/component.go new file mode 100644 index 00000000000..7ed596525e0 --- /dev/null +++ b/pkg/cwversion/component/component.go @@ -0,0 +1,35 @@ +package component + +// Package component provides functionality for managing the registration of +// optional, compile-time components in the system. This is meant as a space +// saving measure, separate from feature flags (package pkg/fflag) which are +// only enabled/disabled at runtime. + +// Built is a map of all the known components, and whether they are built-in or not. +// This is populated as soon as possible by the respective init() functions +var Built = map[string]bool{ + "datasource_appsec": false, + "datasource_cloudwatch": false, + "datasource_docker": false, + "datasource_file": false, + "datasource_journalctl": false, + "datasource_k8s-audit": false, + "datasource_kafka": false, + "datasource_kinesis": false, + "datasource_loki": false, + "datasource_s3": false, + "datasource_syslog": false, + "datasource_wineventlog": false, + "datasource_http": false, + "cscli_setup": false, +} + +func Register(name string) { + if _, ok := Built[name]; !ok { + // having a list of the disabled components is essential + // to debug users' issues + panic("cannot register unknown compile-time component: " + name) + } + + Built[name] = true +} diff --git a/pkg/cwversion/constraint/constraint.go b/pkg/cwversion/constraint/constraint.go new file mode 100644 index 00000000000..67593f9ebbc --- /dev/null +++ b/pkg/cwversion/constraint/constraint.go @@ -0,0 +1,32 @@ +package constraint + +import ( + "fmt" + + goversion "github.com/hashicorp/go-version" +) + +const ( + Parser = ">= 1.0, <= 3.0" + Scenario = ">= 1.0, <= 3.0" + API = "v1" + Acquis = ">= 1.0, < 2.0" +) + +func Satisfies(strvers string, constraint string) (bool, error) { + vers, err := goversion.NewVersion(strvers) + if err != nil { + return false, fmt.Errorf("failed to parse '%s': %w", strvers, err) + } + + constraints, err := goversion.NewConstraint(constraint) + if err != nil { + return false, fmt.Errorf("failed to parse constraint '%s'", constraint) + } + + if !constraints.Check(vers) { + return false, nil + } + + return true, nil +} diff --git a/pkg/cwversion/version.go b/pkg/cwversion/version.go index 28d5c2a621c..2cb7de13e18 100644 --- a/pkg/cwversion/version.go +++ b/pkg/cwversion/version.go @@ -4,9 +4,12 @@ import ( "fmt" "strings" - goversion "github.com/hashicorp/go-version" - + "github.com/crowdsecurity/go-cs-lib/maptools" "github.com/crowdsecurity/go-cs-lib/version" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" ) var ( @@ -14,31 +17,44 @@ var ( Libre2 = "WebAssembly" ) -const ( - Constraint_parser = ">= 1.0, <= 3.0" - Constraint_scenario = ">= 1.0, <= 3.0" - Constraint_api = "v1" - Constraint_acquis = ">= 1.0, < 2.0" -) - func FullString() string { + dsBuilt := map[string]struct{}{} + dsExcluded := map[string]struct{}{} + + for ds, built := range component.Built { + if built { + dsBuilt[ds] = struct{}{} + continue + } + + dsExcluded[ds] = struct{}{} + } + ret := fmt.Sprintf("version: %s\n", version.String()) ret += fmt.Sprintf("Codename: %s\n", Codename) ret += fmt.Sprintf("BuildDate: %s\n", version.BuildDate) ret += fmt.Sprintf("GoVersion: %s\n", version.GoVersion) ret += fmt.Sprintf("Platform: %s\n", version.System) ret += fmt.Sprintf("libre2: %s\n", Libre2) - ret += fmt.Sprintf("User-Agent: %s\n", UserAgent()) - ret += fmt.Sprintf("Constraint_parser: %s\n", Constraint_parser) - ret += fmt.Sprintf("Constraint_scenario: %s\n", Constraint_scenario) - ret += fmt.Sprintf("Constraint_api: %s\n", Constraint_api) - ret += fmt.Sprintf("Constraint_acquis: %s\n", Constraint_acquis) + ret += fmt.Sprintf("User-Agent: %s\n", useragent.Default()) + ret += fmt.Sprintf("Constraint_parser: %s\n", constraint.Parser) + ret += fmt.Sprintf("Constraint_scenario: %s\n", constraint.Scenario) + ret += fmt.Sprintf("Constraint_api: %s\n", constraint.API) + ret += fmt.Sprintf("Constraint_acquis: %s\n", constraint.Acquis) - return ret -} + built := "(none)" + + if len(dsBuilt) > 0 { + built = strings.Join(maptools.SortedKeys(dsBuilt), ", ") + } -func UserAgent() string { - return "crowdsec/" + version.String() + "-" + version.System + ret += fmt.Sprintf("Built-in optional components: %s\n", built) + + if len(dsExcluded) > 0 { + ret += fmt.Sprintf("Excluded components: %s\n", strings.Join(maptools.SortedKeys(dsExcluded), ", ")) + } + + return ret } // VersionStrip remove the tag from the version string, used to match with a hub branch @@ -48,21 +64,3 @@ func VersionStrip() string { return ret[0] } - -func Satisfies(strvers string, constraint string) (bool, error) { - vers, err := goversion.NewVersion(strvers) - if err != nil { - return false, fmt.Errorf("failed to parse '%s': %w", strvers, err) - } - - constraints, err := goversion.NewConstraint(constraint) - if err != nil { - return false, fmt.Errorf("failed to parse constraint '%s'", constraint) - } - - if !constraints.Check(vers) { - return false, nil - } - - return true, nil -} diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 0f6d87fb1b6..ede9c89fe9a 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -35,12 +35,12 @@ const ( // CreateOrUpdateAlert is specific to PAPI : It checks if alert already exists, otherwise inserts it // if alert already exists, it checks it associated decisions already exists // if some associated decisions are missing (ie. previous insert ended up in error) it inserts them -func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) (string, error) { +func (c *Client) CreateOrUpdateAlert(ctx context.Context, machineID string, alertItem *models.Alert) (string, error) { if alertItem.UUID == "" { return "", errors.New("alert UUID is empty") } - alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(c.CTX) + alerts, err := c.Ent.Alert.Query().Where(alert.UUID(alertItem.UUID)).WithDecisions().All(ctx) if err != nil && !ent.IsNotFound(err) { return "", fmt.Errorf("unable to query alerts for uuid %s: %w", alertItem.UUID, err) @@ -48,7 +48,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // alert wasn't found, insert it (expected hotpath) if ent.IsNotFound(err) || len(alerts) == 0 { - alertIDs, err := c.CreateAlert(machineID, []*models.Alert{alertItem}) + alertIDs, err := c.CreateAlert(ctx, machineID, []*models.Alert{alertItem}) if err != nil { return "", fmt.Errorf("unable to create alert: %w", err) } @@ -165,7 +165,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(c.CTX) + decisionsCreateRet, err := c.Ent.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { return "", fmt.Errorf("creating alert decisions: %w", err) } @@ -178,7 +178,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) decisionChunks := slicetools.Chunks(decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(c.CTX) + err = c.Ent.Alert.Update().Where(alert.UUID(alertItem.UUID)).AddDecisions(decisionChunk...).Exec(ctx) if err != nil { return "", fmt.Errorf("updating alert %s: %w", alertItem.UUID, err) } @@ -191,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) // it takes care of creating the new alert with the associated decisions, and it will as well deleted the "older" overlapping decisions: // 1st pull, you get decisions [1,2,3]. it inserts [1,2,3] // 2nd pull, you get decisions [1,2,3,4]. it inserts [1,2,3,4] and will try to delete [1,2,3,4] with a different alert ID and same origin -func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, int, error) { +func (c *Client) UpdateCommunityBlocklist(ctx context.Context, alertItem *models.Alert) (int, int, int, error) { if alertItem == nil { return 0, 0, 0, errors.New("nil alert") } @@ -244,7 +244,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in SetScenarioHash(*alertItem.ScenarioHash). SetRemediation(true) // it's from CAPI, we always have decisions - alertRef, err := alertB.Save(c.CTX) + alertRef, err := alertB.Save(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating alert : %s", err) } @@ -253,7 +253,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, 0, 0, nil } - txClient, err := c.Ent.Tx(c.CTX) + txClient, err := c.Ent.Tx(ctx) if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } @@ -347,7 +347,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in decision.OriginEQ(DecOrigin), decision.Not(decision.HasOwnerWith(alert.IDEQ(alertRef.ID))), decision.ValueIn(deleteChunk...), - )).Exec(c.CTX) + )).Exec(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -363,7 +363,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in builderChunks := slicetools.Chunks(decisionBuilders, c.decisionBulkSize) for _, builderChunk := range builderChunks { - insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(c.CTX) + insertedDecisions, err := txClient.Decision.CreateBulk(builderChunk...).Save(ctx) if err != nil { rollbackErr := txClient.Rollback() if rollbackErr != nil { @@ -391,7 +391,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in return alertRef.ID, inserted, deleted, nil } -func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { +func (c *Client) createDecisionChunk(ctx context.Context, simulated bool, stopAtTime time.Time, decisions []*models.Decision) ([]*ent.Decision, error) { decisionCreate := []*ent.DecisionCreate{} for _, decisionItem := range decisions { @@ -436,7 +436,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return nil, nil } - ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(c.CTX) + ret, err := c.Ent.Decision.CreateBulk(decisionCreate...).Save(ctx) if err != nil { return nil, err } @@ -444,7 +444,7 @@ func (c *Client) createDecisionChunk(simulated bool, stopAtTime time.Time, decis return ret, nil } -func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { +func (c *Client) createAlertChunk(ctx context.Context, machineID string, owner *ent.Machine, alerts []*models.Alert) ([]string, error) { alertBuilders := []*ent.AlertCreate{} alertDecisions := [][]*ent.Decision{} @@ -456,14 +456,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ startAtTime, err := time.Parse(time.RFC3339, *alertItem.StartAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) + c.Log.Errorf("creating alert: Failed to parse startAtTime '%s', defaulting to now: %s", *alertItem.StartAt, err) startAtTime = time.Now().UTC() } stopAtTime, err := time.Parse(time.RFC3339, *alertItem.StopAt) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) + c.Log.Errorf("creating alert: Failed to parse stopAtTime '%s', defaulting to now: %s", *alertItem.StopAt, err) stopAtTime = time.Now().UTC() } @@ -483,7 +483,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for i, eventItem := range alertItem.Events { ts, err := time.Parse(time.RFC3339, *eventItem.Timestamp) if err != nil { - c.Log.Errorf("CreateAlertBulk: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) + c.Log.Errorf("creating alert: Failed to parse event timestamp '%s', defaulting to now: %s", *eventItem.Timestamp, err) ts = time.Now().UTC() } @@ -540,7 +540,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ c.Log.Warningf("dropped 'serialized' field (machine %s / scenario %s)", machineID, *alertItem.Scenario) } - events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(c.CTX) + events, err = c.Ent.Event.CreateBulk(eventBulk...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert events: %s", err) } @@ -554,12 +554,14 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ value := metaItem.Value if len(metaItem.Value) > 4095 { - c.Log.Warningf("truncated meta %s : value too long", metaItem.Key) + c.Log.Warningf("truncated meta %s: value too long", metaItem.Key) + value = value[:4095] } if len(metaItem.Key) > 255 { - c.Log.Warningf("truncated meta %s : key too long", metaItem.Key) + c.Log.Warningf("truncated meta %s: key too long", metaItem.Key) + key = key[:255] } @@ -568,7 +570,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ SetValue(value) } - metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(c.CTX) + metas, err = c.Ent.Meta.CreateBulk(metaBulk...).Save(ctx) if err != nil { c.Log.Warningf("error creating alert meta: %s", err) } @@ -578,7 +580,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ decisionChunks := slicetools.Chunks(alertItem.Decisions, c.decisionBulkSize) for _, decisionChunk := range decisionChunks { - decisionRet, err := c.createDecisionChunk(*alertItem.Simulated, stopAtTime, decisionChunk) + decisionRet, err := c.createDecisionChunk(ctx, *alertItem.Simulated, stopAtTime, decisionChunk) if err != nil { return nil, fmt.Errorf("creating alert decisions: %w", err) } @@ -636,7 +638,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ return nil, nil } - alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(c.CTX) + alertsCreateBulk, err := c.Ent.Alert.CreateBulk(alertBuilders...).Save(ctx) if err != nil { return nil, errors.Wrapf(BulkError, "bulk creating alert : %s", err) } @@ -653,7 +655,7 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ for retry < maxLockRetries { // so much for the happy path... but sqlite3 errors work differently - _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(c.CTX) + _, err := c.Ent.Alert.Update().Where(alert.IDEQ(a.ID)).AddDecisions(d2...).Save(ctx) if err == nil { break } @@ -678,23 +680,24 @@ func (c *Client) createAlertChunk(machineID string, owner *ent.Machine, alerts [ } } } + return ret, nil } -func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]string, error) { +func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList []*models.Alert) ([]string, error) { var ( owner *ent.Machine err error ) if machineID != "" { - owner, err = c.QueryMachineByID(machineID) + owner, err = c.QueryMachineByID(ctx, machineID) if err != nil { if !errors.Is(err, UserNotExists) { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } - c.Log.Debugf("CreateAlertBulk: Machine Id %s doesn't exist", machineID) + c.Log.Debugf("creating alert: machine %s doesn't exist", machineID) owner = nil } @@ -706,7 +709,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str alertIDs := []string{} for _, alertChunk := range alertChunks { - ids, err := c.createAlertChunk(machineID, owner, alertChunk) + ids, err := c.createAlertChunk(ctx, machineID, owner, alertChunk) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -715,7 +718,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str } if owner != nil { - err = owner.Update().SetLastPush(time.Now().UTC()).Exec(c.CTX) + err = owner.Update().SetLastPush(time.Now().UTC()).Exec(ctx) if err != nil { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } @@ -724,6 +727,160 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str return alertIDs, nil } +func handleSimulatedFilter(filter map[string][]string, predicates *[]predicate.Alert) { + /* the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ + if v, ok := filter["simulated"]; ok && v[0] == "false" { + *predicates = append(*predicates, alert.SimulatedEQ(false)) + } +} + +func handleOriginFilter(filter map[string][]string, predicates *[]predicate.Alert) { + if _, ok := filter["origin"]; ok { + filter["include_capi"] = []string{"true"} + } +} + +func handleScopeFilter(scope string, predicates *[]predicate.Alert) { + if strings.ToLower(scope) == "ip" { + scope = types.Ip + } else if strings.ToLower(scope) == "range" { + scope = types.Range + } + + *predicates = append(*predicates, alert.SourceScopeEQ(scope)) +} + +func handleTimeFilters(param, value string, predicates *[]predicate.Alert) error { + duration, err := ParseDuration(value) + if err != nil { + return fmt.Errorf("while parsing duration: %w", err) + } + + timePoint := time.Now().UTC().Add(-duration) + if timePoint.IsZero() { + return fmt.Errorf("empty time now() - %s", timePoint.String()) + } + + switch param { + case "since": + *predicates = append(*predicates, alert.StartedAtGTE(timePoint)) + case "created_before": + *predicates = append(*predicates, alert.CreatedAtLTE(timePoint)) + case "until": + *predicates = append(*predicates, alert.StartedAtLTE(timePoint)) + } + + return nil +} + +func handleIPv4Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), + alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + )) + } +} + +func handleIPv6Predicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) { + if contains { // decision contains {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip < query.start_ip + alert.HasDecisionsWith(decision.StartIPLT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix <= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip > query.end_ip + alert.HasDecisionsWith(decision.EndIPGT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix >= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), + ), + ), + )) + } else { // decision is contained within {start_ip,end_ip} + *predicates = append(*predicates, alert.And( + // matching addr size + alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), + alert.Or( + // decision.start_ip > query.start_ip + alert.HasDecisionsWith(decision.StartIPGT(start_ip)), + alert.And( + // decision.start_ip == query.start_ip + alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), + // decision.start_suffix >= query.start_suffix + alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), + ), + ), + alert.Or( + // decision.end_ip < query.end_ip + alert.HasDecisionsWith(decision.EndIPLT(end_ip)), + alert.And( + // decision.end_ip == query.end_ip + alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), + // decision.end_suffix <= query.end_suffix + alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), + ), + ), + )) + } +} + +func handleIPPredicates(ip_sz int, contains bool, start_ip, start_sfx, end_ip, end_sfx int64, predicates *[]predicate.Alert) error { + if ip_sz == 4 { + handleIPv4Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz == 16 { + handleIPv6Predicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, predicates) + } else if ip_sz != 0 { + return errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + } + + return nil +} + +func handleIncludeCapiFilter(value string, predicates *[]predicate.Alert) error { + if value == "false" { + *predicates = append(*predicates, alert.And( + // do not show alerts with active decisions having origin CAPI or lists + alert.And( + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), + alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), + ), + alert.Not( + alert.And( + // do not show neither alerts with no decisions if the Source Scope is lists: or CAPI + alert.Not(alert.HasDecisions()), + alert.Or( + alert.SourceScopeHasPrefix(types.ListOrigin+":"), + alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), + ), + ), + ), + )) + } else if value != "true" { + log.Errorf("invalid bool '%s' for include_capi", value) + } + + return nil +} + func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, error) { predicates := make([]predicate.Alert, 0) @@ -739,16 +896,8 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ - /*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ - if v, ok := filter["simulated"]; ok { - if v[0] == "false" { - predicates = append(predicates, alert.SimulatedEQ(false)) - } - } - - if _, ok := filter["origin"]; ok { - filter["include_capi"] = []string{"true"} - } + handleSimulatedFilter(filter, &predicates) + handleOriginFilter(filter, &predicates) for param, value := range filter { switch param { @@ -758,14 +907,7 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } case "scope": - scope := value[0] - if strings.ToLower(scope) == "ip" { - scope = types.Ip - } else if strings.ToLower(scope) == "range" { - scope = types.Range - } - - predicates = append(predicates, alert.SourceScopeEQ(scope)) + handleScopeFilter(value[0], &predicates) case "value": predicates = append(predicates, alert.SourceValueEQ(value[0])) case "scenario": @@ -775,68 +917,17 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e if err != nil { return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) } - case "since": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) + case "since", "created_before", "until": + if err := handleTimeFilters(param, value[0], &predicates); err != nil { + return nil, err } - - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) - } - - predicates = append(predicates, alert.StartedAtGTE(since)) - case "created_before": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - since := time.Now().UTC().Add(-duration) - if since.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", since.String()) - } - - predicates = append(predicates, alert.CreatedAtLTE(since)) - case "until": - duration, err := ParseDuration(value[0]) - if err != nil { - return nil, fmt.Errorf("while parsing duration: %w", err) - } - - until := time.Now().UTC().Add(-duration) - if until.IsZero() { - return nil, fmt.Errorf("empty time now() - %s", until.String()) - } - - predicates = append(predicates, alert.StartedAtLTE(until)) case "decision_type": predicates = append(predicates, alert.HasDecisionsWith(decision.TypeEQ(value[0]))) case "origin": predicates = append(predicates, alert.HasDecisionsWith(decision.OriginEQ(value[0]))) case "include_capi": // allows to exclude one or more specific origins - if value[0] == "false" { - predicates = append(predicates, alert.And( - // do not show alerts with active decisions having origin CAPI or lists - alert.And( - alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.CAPIOrigin))), - alert.Not(alert.HasDecisionsWith(decision.OriginEQ(types.ListOrigin))), - ), - alert.Not( - alert.And( - // do not show neither alerts with no decisions if the Source Scope is lists: or CAPI - alert.Not(alert.HasDecisions()), - alert.Or( - alert.SourceScopeHasPrefix(types.ListOrigin+":"), - alert.SourceScopeEQ(types.CommunityBlocklistPullSourceScope), - ), - ), - ), - ), - ) - } else if value[0] != "true" { - log.Errorf("Invalid bool '%s' for include_capi", value[0]) + if err = handleIncludeCapiFilter(value[0], &predicates); err != nil { + return nil, err } case "has_active_decision": if hasActiveDecision, err = strconv.ParseBool(value[0]); err != nil { @@ -861,72 +952,8 @@ func AlertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e } } - if ip_sz == 4 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPLTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPGTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - alert.HasDecisionsWith(decision.StartIPGTE(start_ip)), - alert.HasDecisionsWith(decision.EndIPLTE(end_ip)), - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - )) - } - } else if ip_sz == 16 { - if contains { /*decision contains {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - // matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - // decision.start_ip < query.start_ip - alert.HasDecisionsWith(decision.StartIPLT(start_ip)), - alert.And( - // decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - // decision.start_suffix <= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixLTE(start_sfx)), - )), - alert.Or( - // decision.end_ip > query.end_ip - alert.HasDecisionsWith(decision.EndIPGT(end_ip)), - alert.And( - // decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - // decision.end_suffix >= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixGTE(end_sfx)), - ), - ), - )) - } else { /*decision is contained within {start_ip,end_ip}*/ - predicates = append(predicates, alert.And( - // matching addr size - alert.HasDecisionsWith(decision.IPSizeEQ(int64(ip_sz))), - alert.Or( - // decision.start_ip > query.start_ip - alert.HasDecisionsWith(decision.StartIPGT(start_ip)), - alert.And( - // decision.start_ip == query.start_ip - alert.HasDecisionsWith(decision.StartIPEQ(start_ip)), - // decision.start_suffix >= query.start_suffix - alert.HasDecisionsWith(decision.StartSuffixGTE(start_sfx)), - )), - alert.Or( - // decision.end_ip < query.end_ip - alert.HasDecisionsWith(decision.EndIPLT(end_ip)), - alert.And( - // decision.end_ip == query.end_ip - alert.HasDecisionsWith(decision.EndIPEQ(end_ip)), - // decision.end_suffix <= query.end_suffix - alert.HasDecisionsWith(decision.EndSuffixLTE(end_sfx)), - ), - ), - )) - } - } else if ip_sz != 0 { - return nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) + if err := handleIPPredicates(ip_sz, contains, start_ip, start_sfx, end_ip, end_sfx, &predicates); err != nil { + return nil, err } return predicates, nil @@ -941,14 +968,12 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str return alerts.Where(preds...), nil } -func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) { +func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) { var res []struct { Scenario string Count int } - ctx := context.TODO() - query := c.Ent.Alert.Query() query, err := BuildAlertRequestFromFilter(query, filters) @@ -970,11 +995,11 @@ func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string return counts, nil } -func (c *Client) TotalAlerts() (int, error) { - return c.Ent.Alert.Query().Count(c.CTX) +func (c *Client) TotalAlerts(ctx context.Context) (int, error) { + return c.Ent.Alert.Query().Count(ctx) } -func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, error) { +func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Alert, error) { sort := "DESC" // we sort by desc by default if val, ok := filter["sort"]; ok { @@ -1021,7 +1046,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, WithOwner() if limit == 0 { - limit, err = alerts.Count(c.CTX) + limit, err = alerts.Count(ctx) if err != nil { return nil, fmt.Errorf("unable to count nb alerts: %w", err) } @@ -1033,7 +1058,7 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } - result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX) + result, err := alerts.Limit(paginationSize).Offset(offset).All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "pagination size: %d, offset: %d: %s", paginationSize, offset, err) } @@ -1062,35 +1087,35 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, return ret, nil } -func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { +func (c *Client) DeleteAlertGraphBatch(ctx context.Context, alertItems []*ent.Alert) (int, error) { idList := make([]int, 0) for _, alert := range alertItems { idList = append(idList, alert.ID) } _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch events") } _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch meta") } _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDIn(idList...))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return 0, errors.Wrapf(DeleteFail, "alert graph delete batch decisions") } deleted, err := c.Ent.Alert.Delete(). - Where(alert.IDIn(idList...)).Exec(c.CTX) + Where(alert.IDIn(idList...)).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraphBatch : %s", err) return deleted, errors.Wrapf(DeleteFail, "alert graph delete batch") @@ -1101,10 +1126,10 @@ func (c *Client) DeleteAlertGraphBatch(alertItems []*ent.Alert) (int, error) { return deleted, nil } -func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { +func (c *Client) DeleteAlertGraph(ctx context.Context, alertItem *ent.Alert) error { // delete the associated events _, err := c.Ent.Event.Delete(). - Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(event.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "event with alert ID '%d'", alertItem.ID) @@ -1112,7 +1137,7 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated meta _, err = c.Ent.Meta.Delete(). - Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(meta.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "meta with alert ID '%d'", alertItem.ID) @@ -1120,14 +1145,14 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { // delete the associated decisions _, err = c.Ent.Decision.Delete(). - Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(c.CTX) + Where(decision.HasOwnerWith(alert.IDEQ(alertItem.ID))).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "decision with alert ID '%d'", alertItem.ID) } // delete the alert - err = c.Ent.Alert.DeleteOne(alertItem).Exec(c.CTX) + err = c.Ent.Alert.DeleteOne(alertItem).Exec(ctx) if err != nil { c.Log.Warningf("DeleteAlertGraph : %s", err) return errors.Wrapf(DeleteFail, "alert with ID '%d'", alertItem.ID) @@ -1136,26 +1161,26 @@ func (c *Client) DeleteAlertGraph(alertItem *ent.Alert) error { return nil } -func (c *Client) DeleteAlertByID(id int) error { - alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(c.CTX) +func (c *Client) DeleteAlertByID(ctx context.Context, id int) error { + alertItem, err := c.Ent.Alert.Query().Where(alert.IDEQ(id)).Only(ctx) if err != nil { return err } - return c.DeleteAlertGraph(alertItem) + return c.DeleteAlertGraph(ctx, alertItem) } -func (c *Client) DeleteAlertWithFilter(filter map[string][]string) (int, error) { +func (c *Client) DeleteAlertWithFilter(ctx context.Context, filter map[string][]string) (int, error) { preds, err := AlertPredicatesFromFilter(filter) if err != nil { return 0, err } - return c.Ent.Alert.Delete().Where(preds...).Exec(c.CTX) + return c.Ent.Alert.Delete().Where(preds...).Exec(ctx) } -func (c *Client) GetAlertByID(alertID int) (*ent.Alert, error) { - alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(c.CTX) +func (c *Client) GetAlertByID(ctx context.Context, alertID int) (*ent.Alert, error) { + alert, err := c.Ent.Alert.Query().Where(alert.IDEQ(alertID)).WithDecisions().WithEvents().WithMetas().WithOwner().First(ctx) if err != nil { /*record not found, 404*/ if ent.IsNotFound(err) { diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index f79e9580afe..f9e62bc6522 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strings" "time" @@ -20,7 +21,7 @@ func (e *BouncerNotFoundError) Error() string { return fmt.Sprintf("'%s' does not exist", e.BouncerName) } -func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error { +func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error { os := baseMetrics.Os features := strings.Join(baseMetrics.FeatureFlags, ",") @@ -32,7 +33,7 @@ func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string SetOsversion(*os.Version). SetFeatureflags(features). SetType(bouncerType). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update base bouncer metrics in database: %w", err) } @@ -40,8 +41,10 @@ func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string return nil } -func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) +func (c *Client) SelectBouncers(ctx context.Context, apiKeyHash string, authType string) ([]*ent.Bouncer, error) { + //Order by ID so manually created bouncer will be first in the list to use as the base name + //when automatically creating a new entry if API keys are shared + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.AuthTypeEQ(authType)).Order(ent.Asc(bouncer.FieldID)).All(ctx) if err != nil { return nil, err } @@ -49,8 +52,8 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) +func (c *Client) SelectBouncerWithIP(ctx context.Context, apiKeyHash string, clientIP string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash), bouncer.IPAddressEQ(clientIP)).First(ctx) if err != nil { return nil, err } @@ -58,8 +61,17 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { return result, nil } -func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { - result, err := c.Ent.Bouncer.Query().All(c.CTX) +func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx) + if err != nil { + return nil, err + } + + return result, nil +} + +func (c *Client) ListBouncers(ctx context.Context) ([]*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } @@ -67,14 +79,16 @@ func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { +func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string, autoCreated bool) (*ent.Bouncer, error) { bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). SetAuthType(authType). - Save(c.CTX) + SetIPAddress(ipAddr). + SetAutoCreated(autoCreated). + Save(ctx) if err != nil { if ent.IsConstraintError(err) { return nil, fmt.Errorf("bouncer %s already exists", name) @@ -86,11 +100,11 @@ func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authTy return bouncer, nil } -func (c *Client) DeleteBouncer(name string) error { +func (c *Client) DeleteBouncer(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Bouncer. Delete(). Where(bouncer.NameEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -102,13 +116,13 @@ func (c *Client) DeleteBouncer(name string) error { return nil } -func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { +func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) { ids := make([]int, len(bouncers)) for i, b := range bouncers { ids[i] = b.ID } - nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err) } @@ -116,10 +130,10 @@ func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) { return nbDeleted, nil } -func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { +func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error { _, err := c.Ent.Bouncer.UpdateOneID(id). SetLastPull(lastPull). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine last pull in database: %w", err) } @@ -127,8 +141,8 @@ func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error { return nil } -func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(c.CTX) +func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer ip address in database: %w", err) } @@ -136,8 +150,8 @@ func (c *Client) UpdateBouncerIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id int) error { - _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(c.CTX) +func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error { + _, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx) if err != nil { return fmt.Errorf("unable to update bouncer type and version in database: %w", err) } @@ -145,7 +159,7 @@ func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id in return nil } -func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) { +func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) { return c.Ent.Bouncer.Query().Where( // poor man's coalesce bouncer.Or( @@ -155,5 +169,5 @@ func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) bouncer.CreatedAtLT(t), ), ), - ).All(c.CTX) + ).All(ctx) } diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..89ccb1e1b28 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,17 +1,20 @@ package database import ( + "context" + "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } + if err != nil { return nil, errors.Wrapf(QueryFail, "select config item: %s", err) } @@ -19,16 +22,16 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { - - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) - if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) + if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } } else if err != nil { return errors.Wrapf(QueryFail, "update config item: %s", err) } + return nil } diff --git a/pkg/database/database.go b/pkg/database/database.go index e513459199f..bb41dd3b645 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -21,7 +21,6 @@ import ( type Client struct { Ent *ent.Client - CTX context.Context Log *log.Logger CanFlush bool Type string @@ -106,7 +105,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro return &Client{ Ent: client, - CTX: ctx, Log: clog, CanFlush: true, Type: config.Type, diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index fc582247e59..7522a272799 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strconv" "strings" @@ -30,7 +31,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ @@ -120,7 +121,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] return query, nil } -func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -137,7 +138,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") @@ -146,7 +147,7 @@ func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*e return data, nil } -func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) @@ -164,7 +165,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") } - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") @@ -173,7 +174,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( return data, nil } -func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) { +func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*DecisionsByScenario, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -186,7 +187,7 @@ func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) var r []*DecisionsByScenario - err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(c.CTX, &r) + err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(ctx, &r) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -195,7 +196,7 @@ func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) return r, nil } -func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { var data []*ent.Decision var err error @@ -217,7 +218,7 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec decision.FieldValue, decision.FieldScope, decision.FieldOrigin, - ).Scan(c.CTX, &data) + ).Scan(ctx, &data) if err != nil { c.Log.Warningf("QueryDecisionWithFilter : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "query decision failed") @@ -254,7 +255,7 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) { ) } -func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) @@ -276,7 +277,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -285,7 +286,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since *time.Time, filters return data, nil } -func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -307,7 +308,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map query = query.Order(ent.Asc(decision.FieldID)) - data, err := query.All(c.CTX) + data, err := query.All(ctx) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) @@ -316,24 +317,11 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(since *time.Time, filters map return data, nil } -func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) { - toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) - if err != nil { - c.Log.Warningf("DeleteDecisionById : %s", err) - return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) - } - - count, err := c.DeleteDecisions(toDelete) - c.Log.Debugf("deleted %d decisions", count) - - return toDelete, err -} - -func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer) */ @@ -432,13 +420,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - toDelete, err := decisions.All(c.CTX) + toDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } - count, err := c.DeleteDecisions(toDelete) + count, err := c.DeleteDecisions(ctx, toDelete) if err != nil { c.Log.Warningf("While deleting decisions : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") @@ -448,11 +436,11 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, } // ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items -func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int - var contains = true + contains := true /*if contains is true, return bans that *contains* the given value (value is the inner) else, return bans that are *contained* by the given value (value is the outer)*/ decisions := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now().UTC())) @@ -557,13 +545,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - DecisionsToDelete, err := decisions.All(c.CTX) + DecisionsToDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("ExpireDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter") } - count, err := c.ExpireDecisions(DecisionsToDelete) + count, err := c.ExpireDecisions(ctx, DecisionsToDelete) if err != nil { return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err) } @@ -582,13 +570,13 @@ func decisionIDs(decisions []*ent.Decision) []int { // ExpireDecisions sets the expiration of a list of decisions to now() // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) <= decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Update().Where( decision.IDIn(ids...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) + ).SetUntil(time.Now().UTC()).Save(ctx) if err != nil { return 0, fmt.Errorf("expire decisions with provided filter: %w", err) } @@ -601,7 +589,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { total := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.ExpireDecisions(chunk) + rows, err := c.ExpireDecisions(ctx, chunk) if err != nil { return total, err } @@ -614,13 +602,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { // DeleteDecisions removes a list of decisions from the database // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) < decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Delete().Where( decision.IDIn(ids...), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err) } @@ -633,7 +621,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { tot := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.DeleteDecisions(chunk) + rows, err := c.DeleteDecisions(ctx, chunk) if err != nil { return tot, err } @@ -645,8 +633,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { } // ExpireDecision set the expiration of a decision to now() -func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) { - toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) { + toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { @@ -658,12 +646,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error return 0, nil, ItemNotFound } - count, err := c.ExpireDecisions(toUpdate) + count, err := c.ExpireDecisions(ctx, toUpdate) return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -681,7 +669,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -689,7 +677,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return count, nil } -func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -709,7 +697,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, fmt.Errorf("fail to count decisions: %w", err) } @@ -717,7 +705,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) return count, nil } -func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) { +func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -739,7 +727,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D decisions = decisions.Order(ent.Desc(decision.FieldUntil)) - decision, err := decisions.First(c.CTX) + decision, err := decisions.First(ctx) if err != nil && !ent.IsNotFound(err) { return 0, fmt.Errorf("fail to get decision: %w", err) } @@ -751,7 +739,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D return decision.Until.Sub(time.Now().UTC()), nil } -func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { +func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) @@ -767,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err := decisions.Count(c.CTX) + count, err := decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 3b4d619e384..197f61cde19 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -43,6 +43,8 @@ type Bouncer struct { Osversion string `json:"osversion,omitempty"` // Featureflags holds the value of the "featureflags" field. Featureflags string `json:"featureflags,omitempty"` + // AutoCreated holds the value of the "auto_created" field. + AutoCreated bool `json:"auto_created"` selectValues sql.SelectValues } @@ -51,7 +53,7 @@ func (*Bouncer) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case bouncer.FieldRevoked: + case bouncer.FieldRevoked, bouncer.FieldAutoCreated: values[i] = new(sql.NullBool) case bouncer.FieldID: values[i] = new(sql.NullInt64) @@ -159,6 +161,12 @@ func (b *Bouncer) assignValues(columns []string, values []any) error { } else if value.Valid { b.Featureflags = value.String } + case bouncer.FieldAutoCreated: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field auto_created", values[i]) + } else if value.Valid { + b.AutoCreated = value.Bool + } default: b.selectValues.Set(columns[i], values[i]) } @@ -234,6 +242,9 @@ func (b *Bouncer) String() string { builder.WriteString(", ") builder.WriteString("featureflags=") builder.WriteString(b.Featureflags) + builder.WriteString(", ") + builder.WriteString("auto_created=") + builder.WriteString(fmt.Sprintf("%v", b.AutoCreated)) builder.WriteByte(')') return builder.String() } diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index a6f62aeadd5..f25b5a5815a 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -39,6 +39,8 @@ const ( FieldOsversion = "osversion" // FieldFeatureflags holds the string denoting the featureflags field in the database. FieldFeatureflags = "featureflags" + // FieldAutoCreated holds the string denoting the auto_created field in the database. + FieldAutoCreated = "auto_created" // Table holds the table name of the bouncer in the database. Table = "bouncers" ) @@ -59,6 +61,7 @@ var Columns = []string{ FieldOsname, FieldOsversion, FieldFeatureflags, + FieldAutoCreated, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -82,6 +85,8 @@ var ( DefaultIPAddress string // DefaultAuthType holds the default value on creation for the "auth_type" field. DefaultAuthType string + // DefaultAutoCreated holds the default value on creation for the "auto_created" field. + DefaultAutoCreated bool ) // OrderOption defines the ordering options for the Bouncer queries. @@ -156,3 +161,8 @@ func ByOsversion(opts ...sql.OrderTermOption) OrderOption { func ByFeatureflags(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFeatureflags, opts...).ToFunc() } + +// ByAutoCreated orders the results by the auto_created field. +func ByAutoCreated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAutoCreated, opts...).ToFunc() +} diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index e02199bc0a9..79b8999354f 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -119,6 +119,11 @@ func Featureflags(v string) predicate.Bouncer { return predicate.Bouncer(sql.FieldEQ(FieldFeatureflags, v)) } +// AutoCreated applies equality check predicate on the "auto_created" field. It's identical to AutoCreatedEQ. +func AutoCreated(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { return predicate.Bouncer(sql.FieldEQ(FieldCreatedAt, v)) @@ -904,6 +909,16 @@ func FeatureflagsContainsFold(v string) predicate.Bouncer { return predicate.Bouncer(sql.FieldContainsFold(FieldFeatureflags, v)) } +// AutoCreatedEQ applies the EQ predicate on the "auto_created" field. +func AutoCreatedEQ(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldEQ(FieldAutoCreated, v)) +} + +// AutoCreatedNEQ applies the NEQ predicate on the "auto_created" field. +func AutoCreatedNEQ(v bool) predicate.Bouncer { + return predicate.Bouncer(sql.FieldNEQ(FieldAutoCreated, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { return predicate.Bouncer(sql.AndPredicates(predicates...)) diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 29b23f87cf1..9ff4c0e0820 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -178,6 +178,20 @@ func (bc *BouncerCreate) SetNillableFeatureflags(s *string) *BouncerCreate { return bc } +// SetAutoCreated sets the "auto_created" field. +func (bc *BouncerCreate) SetAutoCreated(b bool) *BouncerCreate { + bc.mutation.SetAutoCreated(b) + return bc +} + +// SetNillableAutoCreated sets the "auto_created" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableAutoCreated(b *bool) *BouncerCreate { + if b != nil { + bc.SetAutoCreated(*b) + } + return bc +} + // Mutation returns the BouncerMutation object of the builder. func (bc *BouncerCreate) Mutation() *BouncerMutation { return bc.mutation @@ -229,6 +243,10 @@ func (bc *BouncerCreate) defaults() { v := bouncer.DefaultAuthType bc.mutation.SetAuthType(v) } + if _, ok := bc.mutation.AutoCreated(); !ok { + v := bouncer.DefaultAutoCreated + bc.mutation.SetAutoCreated(v) + } } // check runs all checks and user-defined validators on the builder. @@ -251,6 +269,9 @@ func (bc *BouncerCreate) check() error { if _, ok := bc.mutation.AuthType(); !ok { return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)} } + if _, ok := bc.mutation.AutoCreated(); !ok { + return &ValidationError{Name: "auto_created", err: errors.New(`ent: missing required field "Bouncer.auto_created"`)} + } return nil } @@ -329,6 +350,10 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { _spec.SetField(bouncer.FieldFeatureflags, field.TypeString, value) _node.Featureflags = value } + if value, ok := bc.mutation.AutoCreated(); ok { + _spec.SetField(bouncer.FieldAutoCreated, field.TypeBool, value) + _node.AutoCreated = value + } return _node, _spec } diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 986f5bc8c67..dae248c7f38 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -74,6 +74,7 @@ var ( {Name: "osname", Type: field.TypeString, Nullable: true}, {Name: "osversion", Type: field.TypeString, Nullable: true}, {Name: "featureflags", Type: field.TypeString, Nullable: true}, + {Name: "auto_created", Type: field.TypeBool, Default: false}, } // BouncersTable holds the schema information for the "bouncers" table. BouncersTable = &schema.Table{ diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 5c6596f3db4..fa1ccb3da58 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -2471,6 +2471,7 @@ type BouncerMutation struct { osname *string osversion *string featureflags *string + auto_created *bool clearedFields map[string]struct{} done bool oldValue func(context.Context) (*Bouncer, error) @@ -3134,6 +3135,42 @@ func (m *BouncerMutation) ResetFeatureflags() { delete(m.clearedFields, bouncer.FieldFeatureflags) } +// SetAutoCreated sets the "auto_created" field. +func (m *BouncerMutation) SetAutoCreated(b bool) { + m.auto_created = &b +} + +// AutoCreated returns the value of the "auto_created" field in the mutation. +func (m *BouncerMutation) AutoCreated() (r bool, exists bool) { + v := m.auto_created + if v == nil { + return + } + return *v, true +} + +// OldAutoCreated returns the old "auto_created" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldAutoCreated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAutoCreated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAutoCreated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAutoCreated: %w", err) + } + return oldValue.AutoCreated, nil +} + +// ResetAutoCreated resets all changes to the "auto_created" field. +func (m *BouncerMutation) ResetAutoCreated() { + m.auto_created = nil +} + // Where appends a list predicates to the BouncerMutation builder. func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) @@ -3168,7 +3205,7 @@ func (m *BouncerMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BouncerMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, bouncer.FieldCreatedAt) } @@ -3208,6 +3245,9 @@ func (m *BouncerMutation) Fields() []string { if m.featureflags != nil { fields = append(fields, bouncer.FieldFeatureflags) } + if m.auto_created != nil { + fields = append(fields, bouncer.FieldAutoCreated) + } return fields } @@ -3242,6 +3282,8 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) { return m.Osversion() case bouncer.FieldFeatureflags: return m.Featureflags() + case bouncer.FieldAutoCreated: + return m.AutoCreated() } return nil, false } @@ -3277,6 +3319,8 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldOsversion(ctx) case bouncer.FieldFeatureflags: return m.OldFeatureflags(ctx) + case bouncer.FieldAutoCreated: + return m.OldAutoCreated(ctx) } return nil, fmt.Errorf("unknown Bouncer field %s", name) } @@ -3377,6 +3421,13 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetFeatureflags(v) return nil + case bouncer.FieldAutoCreated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAutoCreated(v) + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3510,6 +3561,9 @@ func (m *BouncerMutation) ResetField(name string) error { case bouncer.FieldFeatureflags: m.ResetFeatureflags() return nil + case bouncer.FieldAutoCreated: + m.ResetAutoCreated() + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index 15413490633..49921a17b03 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -76,6 +76,10 @@ func init() { bouncerDescAuthType := bouncerFields[9].Descriptor() // bouncer.DefaultAuthType holds the default value on creation for the auth_type field. bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string) + // bouncerDescAutoCreated is the schema descriptor for auto_created field. + bouncerDescAutoCreated := bouncerFields[13].Descriptor() + // bouncer.DefaultAutoCreated holds the default value on creation for the auto_created field. + bouncer.DefaultAutoCreated = bouncerDescAutoCreated.Default.(bool) configitemFields := schema.ConfigItem{}.Fields() _ = configitemFields // configitemDescCreatedAt is the schema descriptor for created_at field. diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go index 599c4c404fc..c176bf0f766 100644 --- a/pkg/database/ent/schema/bouncer.go +++ b/pkg/database/ent/schema/bouncer.go @@ -33,6 +33,8 @@ func (Bouncer) Fields() []ent.Field { field.String("osname").Optional(), field.String("osversion").Optional(), field.String("featureflags").Optional(), + // Old auto-created TLS bouncers will have a wrong value for this field + field.Bool("auto_created").StructTag(`json:"auto_created"`).Default(false).Immutable(), } } diff --git a/pkg/database/errors.go b/pkg/database/errors.go index 8e96f52d7ce..77f92707e51 100644 --- a/pkg/database/errors.go +++ b/pkg/database/errors.go @@ -13,8 +13,8 @@ var ( ItemNotFound = errors.New("object not found") ParseTimeFail = errors.New("unable to parse time") ParseDurationFail = errors.New("unable to parse duration") - MarshalFail = errors.New("unable to marshal") - UnmarshalFail = errors.New("unable to unmarshal") + MarshalFail = errors.New("unable to serialize") + UnmarshalFail = errors.New("unable to parse") BulkError = errors.New("unable to insert bulk") ParseType = errors.New("unable to parse type") InvalidIPOrRange = errors.New("invalid ip address / range") diff --git a/pkg/database/flush.go b/pkg/database/flush.go index 5d53d10c942..8f646ddc961 100644 --- a/pkg/database/flush.go +++ b/pkg/database/flush.go @@ -1,6 +1,7 @@ package database import ( + "context" "errors" "fmt" "time" @@ -26,7 +27,7 @@ const ( flushInterval = 1 * time.Minute ) -func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { +func (c *Client) StartFlushScheduler(ctx context.Context, config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) { maxItems := 0 maxAge := "" @@ -45,7 +46,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched // Init & Start cronjob every minute for alerts scheduler := gocron.NewScheduler(time.UTC) - job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, ctx, maxAge, maxItems) if err != nil { return nil, fmt.Errorf("while starting FlushAlerts scheduler: %w", err) } @@ -100,14 +101,14 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } } - baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + baJob, err := scheduler.Every(flushInterval).Do(c.FlushAgentsAndBouncers, ctx, config.AgentsGC, config.BouncersGC) if err != nil { return nil, fmt.Errorf("while starting FlushAgentsAndBouncers scheduler: %w", err) } baJob.SingletonMode() - metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, config.MetricsMaxAge) + metricsJob, err := scheduler.Every(flushInterval).Do(c.flushMetrics, ctx, config.MetricsMaxAge) if err != nil { return nil, fmt.Errorf("while starting flushMetrics scheduler: %w", err) } @@ -120,7 +121,7 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched } // flushMetrics deletes metrics older than maxAge, regardless if they have been pushed to CAPI or not -func (c *Client) flushMetrics(maxAge *time.Duration) { +func (c *Client) flushMetrics(ctx context.Context, maxAge *time.Duration) { if maxAge == nil { maxAge = ptr.Of(defaultMetricsMaxAge) } @@ -129,7 +130,7 @@ func (c *Client) flushMetrics(maxAge *time.Duration) { deleted, err := c.Ent.Metric.Delete().Where( metric.ReceivedAtLTE(time.Now().UTC().Add(-*maxAge)), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while flushing metrics: %s", err) return @@ -140,10 +141,10 @@ func (c *Client) flushMetrics(maxAge *time.Duration) { } } -func (c *Client) FlushOrphans() { +func (c *Client) FlushOrphans(ctx context.Context) { /* While it has only been linked to some very corner-case bug : https://github.com/crowdsecurity/crowdsec/issues/778 */ /* We want to take care of orphaned events for which the parent alert/decision has been deleted */ - eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(c.CTX) + eventsCount, err := c.Ent.Event.Delete().Where(event.Not(event.HasOwner())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan events: %s", err) return @@ -154,7 +155,7 @@ func (c *Client) FlushOrphans() { } eventsCount, err = c.Ent.Decision.Delete().Where( - decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(c.CTX) + decision.Not(decision.HasOwner())).Where(decision.UntilLTE(time.Now().UTC())).Exec(ctx) if err != nil { c.Log.Warningf("error while deleting orphan decisions: %s", err) return @@ -165,7 +166,7 @@ func (c *Client) FlushOrphans() { } } -func (c *Client) flushBouncers(authType string, duration *time.Duration) { +func (c *Client) flushBouncers(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -174,7 +175,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { bouncer.LastPullLTE(time.Now().UTC().Add(-*duration)), ).Where( bouncer.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired bouncers (%s): %s", authType, err) return @@ -185,7 +186,7 @@ func (c *Client) flushBouncers(authType string, duration *time.Duration) { } } -func (c *Client) flushAgents(authType string, duration *time.Duration) { +func (c *Client) flushAgents(ctx context.Context, authType string, duration *time.Duration) { if duration == nil { return } @@ -194,7 +195,7 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { machine.LastHeartbeatLTE(time.Now().UTC().Add(-*duration)), machine.Not(machine.HasAlerts()), machine.AuthTypeEQ(authType), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { c.Log.Errorf("while auto-deleting expired machines (%s): %s", authType, err) return @@ -205,23 +206,23 @@ func (c *Client) flushAgents(authType string, duration *time.Duration) { } } -func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { +func (c *Client) FlushAgentsAndBouncers(ctx context.Context, agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { log.Debug("starting FlushAgentsAndBouncers") if agentsCfg != nil { - c.flushAgents(types.TlsAuthType, agentsCfg.CertDuration) - c.flushAgents(types.PasswordAuthType, agentsCfg.LoginPasswordDuration) + c.flushAgents(ctx, types.TlsAuthType, agentsCfg.CertDuration) + c.flushAgents(ctx, types.PasswordAuthType, agentsCfg.LoginPasswordDuration) } if bouncersCfg != nil { - c.flushBouncers(types.TlsAuthType, bouncersCfg.CertDuration) - c.flushBouncers(types.ApiKeyAuthType, bouncersCfg.ApiDuration) + c.flushBouncers(ctx, types.TlsAuthType, bouncersCfg.CertDuration) + c.flushBouncers(ctx, types.ApiKeyAuthType, bouncersCfg.ApiDuration) } return nil } -func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { +func (c *Client) FlushAlerts(ctx context.Context, MaxAge string, MaxItems int) error { var ( deletedByAge int deletedByNbItem int @@ -235,10 +236,10 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { } c.Log.Debug("Flushing orphan alerts") - c.FlushOrphans() + c.FlushOrphans(ctx) c.Log.Debug("Done flushing orphan alerts") - totalAlerts, err = c.TotalAlerts() + totalAlerts, err = c.TotalAlerts(ctx) if err != nil { c.Log.Warningf("FlushAlerts (max items count): %s", err) return fmt.Errorf("unable to get alerts count: %w", err) @@ -251,7 +252,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { "created_before": {MaxAge}, } - nbDeleted, err := c.DeleteAlertWithFilter(filter) + nbDeleted, err := c.DeleteAlertWithFilter(ctx, filter) if err != nil { c.Log.Warningf("FlushAlerts (max age): %s", err) return fmt.Errorf("unable to flush alerts with filter until=%s: %w", MaxAge, err) @@ -267,7 +268,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { // This gives us the oldest alert that we want to keep // We then delete all the alerts with an id lower than this one // We can do this because the id is auto-increment, and the database won't reuse the same id twice - lastAlert, err := c.QueryAlertWithFilter(map[string][]string{ + lastAlert, err := c.QueryAlertWithFilter(ctx, map[string][]string{ "sort": {"DESC"}, "limit": {"1"}, // we do not care about fetching the edges, we just want the id @@ -287,7 +288,7 @@ func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { if maxid > 0 { // This may lead to orphan alerts (at least on MySQL), but the next time the flush job will run, they will be deleted - deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(c.CTX) + deletedByNbItem, err = c.Ent.Alert.Delete().Where(alert.IDLT(maxid)).Exec(ctx) if err != nil { c.Log.Errorf("FlushAlerts: Could not delete alerts: %s", err) return fmt.Errorf("could not delete alerts: %w", err) diff --git a/pkg/database/lock.go b/pkg/database/lock.go index d25b71870f0..474228a069c 100644 --- a/pkg/database/lock.go +++ b/pkg/database/lock.go @@ -1,6 +1,7 @@ package database import ( + "context" "time" "github.com/pkg/errors" @@ -16,40 +17,45 @@ const ( CapiPullLockName = "pullCAPI" ) -func (c *Client) AcquireLock(name string) error { +func (c *Client) AcquireLock(ctx context.Context, name string) error { log.Debugf("acquiring lock %s", name) _, err := c.Ent.Lock.Create(). SetName(name). SetCreatedAt(types.UtcNow()). - Save(c.CTX) + Save(ctx) + if ent.IsConstraintError(err) { return err } + if err != nil { return errors.Wrapf(InsertFail, "insert lock: %s", err) } + return nil } -func (c *Client) ReleaseLock(name string) error { +func (c *Client) ReleaseLock(ctx context.Context, name string) error { log.Debugf("releasing lock %s", name) - _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(c.CTX) + _, err := c.Ent.Lock.Delete().Where(lock.NameEQ(name)).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } + return nil } -func (c *Client) ReleaseLockWithTimeout(name string, timeout int) error { +func (c *Client) ReleaseLockWithTimeout(ctx context.Context, name string, timeout int) error { log.Debugf("releasing lock %s with timeout of %d minutes", name, timeout) + _, err := c.Ent.Lock.Delete().Where( lock.NameEQ(name), lock.CreatedAtLT(time.Now().UTC().Add(-time.Duration(timeout)*time.Minute)), - ).Exec(c.CTX) - + ).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } + return nil } @@ -57,23 +63,25 @@ func (c *Client) IsLocked(err error) bool { return ent.IsConstraintError(err) } -func (c *Client) AcquirePullCAPILock() error { - - /*delete orphan "old" lock if present*/ - err := c.ReleaseLockWithTimeout(CapiPullLockName, CAPIPullLockTimeout) +func (c *Client) AcquirePullCAPILock(ctx context.Context) error { + // delete orphan "old" lock if present + err := c.ReleaseLockWithTimeout(ctx, CapiPullLockName, CAPIPullLockTimeout) if err != nil { log.Errorf("unable to release pullCAPI lock: %s", err) } - return c.AcquireLock(CapiPullLockName) + + return c.AcquireLock(ctx, CapiPullLockName) } -func (c *Client) ReleasePullCAPILock() error { +func (c *Client) ReleasePullCAPILock(ctx context.Context) error { log.Debugf("deleting lock %s", CapiPullLockName) + _, err := c.Ent.Lock.Delete().Where( lock.NameEQ(CapiPullLockName), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return errors.Wrapf(DeleteFail, "delete lock: %s", err) } + return nil } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 75b0ee5fdaa..d8c02825312 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "strings" "time" @@ -29,13 +30,13 @@ func (e *MachineNotFoundError) Error() string { return fmt.Sprintf("'%s' does not exist", e.MachineID) } -func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { +func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error { os := baseMetrics.Os features := strings.Join(baseMetrics.FeatureFlags, ",") var heartbeat time.Time - if baseMetrics.Metrics == nil || len(baseMetrics.Metrics) == 0 { + if len(baseMetrics.Metrics) == 0 { heartbeat = time.Now().UTC() } else { heartbeat = time.Unix(*baseMetrics.Metrics[0].Meta.UtcNowTimestamp, 0) @@ -63,7 +64,7 @@ func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.B SetLastHeartbeat(heartbeat). SetHubstate(hubState). SetDatasources(datasources). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update base machine metrics in database: %w", err) } @@ -71,7 +72,7 @@ func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.B return nil } -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { +func (c *Client) CreateMachine(ctx context.Context, machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { c.Log.Warningf("CreateMachine: %s", err) @@ -81,20 +82,20 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA machineExist, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(*machineID)). - Select(machine.FieldMachineId).Strings(c.CTX) + Select(machine.FieldMachineId).Strings(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } if len(machineExist) > 0 { if force { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - machine, err := c.QueryMachineByID(*machineID) + machine, err := c.QueryMachineByID(ctx, *machineID) if err != nil { return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } @@ -112,7 +113,7 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA SetIpAddress(ipAddress). SetIsValidated(isValidated). SetAuthType(authType). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("CreateMachine : %s", err) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) @@ -121,11 +122,11 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA return machine, nil } -func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { +func (c *Client) QueryMachineByID(ctx context.Context, machineID string) (*ent.Machine, error) { machine, err := c.Ent.Machine. Query(). Where(machine.MachineIdEQ(machineID)). - Only(c.CTX) + Only(ctx) if err != nil { c.Log.Warningf("QueryMachineByID : %s", err) return &ent.Machine{}, errors.Wrapf(UserNotExists, "user '%s'", machineID) @@ -134,8 +135,8 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { return machine, nil } -func (c *Client) ListMachines() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().All(c.CTX) +func (c *Client) ListMachines(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().All(ctx) if err != nil { return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } @@ -143,8 +144,8 @@ func (c *Client) ListMachines() ([]*ent.Machine, error) { return machines, nil } -func (c *Client) ValidateMachine(machineID string) error { - rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(c.CTX) +func (c *Client) ValidateMachine(ctx context.Context, machineID string) error { + rets, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetIsValidated(true).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "validating machine: %s", err) } @@ -156,8 +157,8 @@ func (c *Client) ValidateMachine(machineID string) error { return nil } -func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { - machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) +func (c *Client) QueryPendingMachine(ctx context.Context) ([]*ent.Machine, error) { + machines, err := c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(ctx) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) @@ -166,11 +167,11 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { return machines, nil } -func (c *Client) DeleteWatcher(name string) error { +func (c *Client) DeleteWatcher(ctx context.Context, name string) error { nbDeleted, err := c.Ent.Machine. Delete(). Where(machine.MachineIdEQ(name)). - Exec(c.CTX) + Exec(ctx) if err != nil { return err } @@ -182,13 +183,13 @@ func (c *Client) DeleteWatcher(name string) error { return nil } -func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { +func (c *Client) BulkDeleteWatchers(ctx context.Context, machines []*ent.Machine) (int, error) { ids := make([]int, len(machines)) for i, b := range machines { ids[i] = b.ID } - nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(c.CTX) + nbDeleted, err := c.Ent.Machine.Delete().Where(machine.IDIn(ids...)).Exec(ctx) if err != nil { return nbDeleted, err } @@ -196,8 +197,8 @@ func (c *Client) BulkDeleteWatchers(machines []*ent.Machine) (int, error) { return nbDeleted, nil } -func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { - _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(c.CTX) +func (c *Client) UpdateMachineLastHeartBeat(ctx context.Context, machineID string) error { + _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(machineID)).SetLastHeartbeat(time.Now().UTC()).Save(ctx) if err != nil { return errors.Wrapf(UpdateFail, "updating machine last_heartbeat: %s", err) } @@ -205,11 +206,11 @@ func (c *Client) UpdateMachineLastHeartBeat(machineID string) error { return nil } -func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { +func (c *Client) UpdateMachineScenarios(ctx context.Context, scenarios string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetUpdatedAt(time.Now().UTC()). SetScenarios(scenarios). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine in database: %w", err) } @@ -217,10 +218,10 @@ func (c *Client) UpdateMachineScenarios(scenarios string, id int) error { return nil } -func (c *Client) UpdateMachineIP(ipAddr string, id int) error { +func (c *Client) UpdateMachineIP(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetIpAddress(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine IP in database: %w", err) } @@ -228,10 +229,10 @@ func (c *Client) UpdateMachineIP(ipAddr string, id int) error { return nil } -func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { +func (c *Client) UpdateMachineVersion(ctx context.Context, ipAddr string, id int) error { _, err := c.Ent.Machine.UpdateOneID(id). SetVersion(ipAddr). - Save(c.CTX) + Save(ctx) if err != nil { return fmt.Errorf("unable to update machine version in database: %w", err) } @@ -239,8 +240,8 @@ func (c *Client) UpdateMachineVersion(ipAddr string, id int) error { return nil } -func (c *Client) IsMachineRegistered(machineID string) (bool, error) { - exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(c.CTX) +func (c *Client) IsMachineRegistered(ctx context.Context, machineID string) (bool, error) { + exist, err := c.Ent.Machine.Query().Where().Select(machine.FieldMachineId).Strings(ctx) if err != nil { return false, err } @@ -256,11 +257,11 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { return false, nil } -func (c *Client) QueryMachinesInactiveSince(t time.Time) ([]*ent.Machine, error) { +func (c *Client) QueryMachinesInactiveSince(ctx context.Context, t time.Time) ([]*ent.Machine, error) { return c.Ent.Machine.Query().Where( machine.Or( machine.And(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)), machine.And(machine.LastHeartbeatIsNil(), machine.CreatedAtLT(t)), ), - ).All(c.CTX) + ).All(ctx) } diff --git a/pkg/database/metrics.go b/pkg/database/metrics.go index 7626c39f6f1..eb4c472821e 100644 --- a/pkg/database/metrics.go +++ b/pkg/database/metrics.go @@ -1,6 +1,7 @@ package database import ( + "context" "fmt" "time" @@ -8,15 +9,15 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/metric" ) -func (c *Client) CreateMetric(generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) { +func (c *Client) CreateMetric(ctx context.Context, generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) { metric, err := c.Ent.Metric. Create(). SetGeneratedType(generatedType). SetGeneratedBy(generatedBy). SetReceivedAt(receivedAt). SetPayload(payload). - Save(c.CTX) - if err != nil { + Save(ctx) + if err != nil { c.Log.Warningf("CreateMetric: %s", err) return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail) } @@ -24,14 +25,14 @@ func (c *Client) CreateMetric(generatedType metric.GeneratedType, generatedBy st return metric, nil } -func (c *Client) GetLPUsageMetricsByMachineID(machineId string) ([]*ent.Metric, error) { +func (c *Client) GetLPUsageMetricsByMachineID(ctx context.Context, machineId string) ([]*ent.Metric, error) { metrics, err := c.Ent.Metric.Query(). Where( metric.GeneratedTypeEQ(metric.GeneratedTypeLP), metric.GeneratedByEQ(machineId), metric.PushedAtIsNil(), ). - All(c.CTX) + All(ctx) if err != nil { c.Log.Warningf("GetLPUsageMetricsByOrigin: %s", err) return nil, fmt.Errorf("getting LP usage metrics by origin %s: %w", machineId, err) @@ -40,14 +41,14 @@ func (c *Client) GetLPUsageMetricsByMachineID(machineId string) ([]*ent.Metric, return metrics, nil } -func (c *Client) GetBouncerUsageMetricsByName(bouncerName string) ([]*ent.Metric, error) { +func (c *Client) GetBouncerUsageMetricsByName(ctx context.Context, bouncerName string) ([]*ent.Metric, error) { metrics, err := c.Ent.Metric.Query(). Where( metric.GeneratedTypeEQ(metric.GeneratedTypeRC), metric.GeneratedByEQ(bouncerName), metric.PushedAtIsNil(), ). - All(c.CTX) + All(ctx) if err != nil { c.Log.Warningf("GetBouncerUsageMetricsByName: %s", err) return nil, fmt.Errorf("getting bouncer usage metrics by name %s: %w", bouncerName, err) @@ -56,11 +57,11 @@ func (c *Client) GetBouncerUsageMetricsByName(bouncerName string) ([]*ent.Metric return metrics, nil } -func (c *Client) MarkUsageMetricsAsSent(ids []int) error { +func (c *Client) MarkUsageMetricsAsSent(ctx context.Context, ids []int) error { _, err := c.Ent.Metric.Update(). Where(metric.IDIn(ids...)). SetPushedAt(time.Now().UTC()). - Save(c.CTX) + Save(ctx) if err != nil { c.Log.Warningf("MarkUsageMetricsAsSent: %s", err) return fmt.Errorf("marking usage metrics as sent: %w", err) diff --git a/pkg/database/utils.go b/pkg/database/utils.go index f1c06565635..8148df56f24 100644 --- a/pkg/database/utils.go +++ b/pkg/database/utils.go @@ -42,7 +42,8 @@ func LastAddress(n *net.IPNet) net.IP { ip[6] | ^n.Mask[6], ip[7] | ^n.Mask[7], ip[8] | ^n.Mask[8], ip[9] | ^n.Mask[9], ip[10] | ^n.Mask[10], ip[11] | ^n.Mask[11], ip[12] | ^n.Mask[12], ip[13] | ^n.Mask[13], ip[14] | ^n.Mask[14], - ip[15] | ^n.Mask[15]} + ip[15] | ^n.Mask[15], + } } return net.IPv4( @@ -74,7 +75,7 @@ func ParseDuration(d string) (time.Duration, error) { if strings.HasSuffix(d, "d") { days := strings.Split(d, "d")[0] - if len(days) == 0 { + if days == "" { return 0, fmt.Errorf("'%s' can't be parsed as duration", d) } diff --git a/pkg/dumps/parser_dump.go b/pkg/dumps/parser_dump.go index d43f3cdc1b9..bc8f78dc203 100644 --- a/pkg/dumps/parser_dump.go +++ b/pkg/dumps/parser_dump.go @@ -259,7 +259,7 @@ func (t *tree) displayResults(opts DumpOpts) { } if updated > 0 { - if len(changeStr) > 0 { + if changeStr != "" { changeStr += " " } @@ -267,7 +267,7 @@ func (t *tree) displayResults(opts DumpOpts) { } if deleted > 0 { - if len(changeStr) > 0 { + if changeStr != "" { changeStr += " " } @@ -275,7 +275,7 @@ func (t *tree) displayResults(opts DumpOpts) { } if whitelisted { - if len(changeStr) > 0 { + if changeStr != "" { changeStr += " " } diff --git a/pkg/exprhelpers/debugger.go b/pkg/exprhelpers/debugger.go index 711aa491078..2e47af6d1de 100644 --- a/pkg/exprhelpers/debugger.go +++ b/pkg/exprhelpers/debugger.go @@ -53,9 +53,8 @@ type OpOutput struct { } func (o *OpOutput) String() string { - ret := fmt.Sprintf("%*c", o.CodeDepth, ' ') - if len(o.Code) != 0 { + if o.Code != "" { ret += fmt.Sprintf("[%s]", o.Code) } ret += " " @@ -70,7 +69,7 @@ func (o *OpOutput) String() string { indent = 0 } ret = fmt.Sprintf("%*cBLOCK_END [%s]", indent, ' ', o.Code) - if len(o.StrConditionResult) > 0 { + if o.StrConditionResult != "" { ret += fmt.Sprintf(" -> %s", o.StrConditionResult) } return ret diff --git a/pkg/exprhelpers/debugger_test.go b/pkg/exprhelpers/debugger_test.go index efdcbc1a769..32144454084 100644 --- a/pkg/exprhelpers/debugger_test.go +++ b/pkg/exprhelpers/debugger_test.go @@ -26,6 +26,7 @@ type ExprDbgTest struct { func UpperTwo(params ...any) (any, error) { s := params[0].(string) v := params[1].(string) + return strings.ToUpper(s) + strings.ToUpper(v), nil } @@ -33,6 +34,7 @@ func UpperThree(params ...any) (any, error) { s := params[0].(string) v := params[1].(string) x := params[2].(string) + return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x), nil } @@ -41,6 +43,7 @@ func UpperN(params ...any) (any, error) { v := params[1].(string) x := params[2].(string) y := params[3].(string) + return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x) + strings.ToUpper(y), nil } @@ -76,9 +79,9 @@ func TestBaseDbg(t *testing.T) { // use '%#v' to dump in golang syntax // use regexp to clear empty/default fields: // [a-z]+: (false|\[\]string\(nil\)|""), - //ConditionResult:(*bool) + // ConditionResult:(*bool) - //Missing multi parametes function + // Missing multi parametes function tests := []ExprDbgTest{ { Name: "nil deref", @@ -272,6 +275,7 @@ func TestBaseDbg(t *testing.T) { } logger := log.WithField("test", "exprhelpers") + for _, test := range tests { if test.LogLevel != 0 { log.SetLevel(test.LogLevel) @@ -308,10 +312,13 @@ func TestBaseDbg(t *testing.T) { t.Fatalf("test %s : unexpected compile error : %s", test.Name, err) } } + if test.Name == "nil deref" { test.Env["nilvar"] = nil } + outdbg, ret, err := RunWithDebug(prog, test.Env, logger) + if test.ExpectedFailRuntime { if err == nil { t.Fatalf("test %s : expected runtime error", test.Name) @@ -321,25 +328,30 @@ func TestBaseDbg(t *testing.T) { t.Fatalf("test %s : unexpected runtime error : %s", test.Name, err) } } + log.SetLevel(log.DebugLevel) DisplayExprDebug(prog, outdbg, logger, ret) + if len(outdbg) != len(test.ExpectedOutputs) { t.Errorf("failed test %s", test.Name) t.Errorf("%#v", outdbg) - //out, _ := yaml.Marshal(outdbg) - //fmt.Printf("%s", string(out)) + // out, _ := yaml.Marshal(outdbg) + // fmt.Printf("%s", string(out)) t.Fatalf("test %s : expected %d outputs, got %d", test.Name, len(test.ExpectedOutputs), len(outdbg)) - } + for i, out := range outdbg { - if !reflect.DeepEqual(out, test.ExpectedOutputs[i]) { - spew.Config.DisableMethods = true - t.Errorf("failed test %s", test.Name) - t.Errorf("expected : %#v", test.ExpectedOutputs[i]) - t.Errorf("got : %#v", out) - t.Fatalf("%d/%d : mismatch", i, len(outdbg)) + if reflect.DeepEqual(out, test.ExpectedOutputs[i]) { + // DisplayExprDebug(prog, outdbg, logger, ret) + continue } - //DisplayExprDebug(prog, outdbg, logger, ret) + + spew.Config.DisableMethods = true + + t.Errorf("failed test %s", test.Name) + t.Errorf("expected : %#v", test.ExpectedOutputs[i]) + t.Errorf("got : %#v", out) + t.Fatalf("%d/%d : mismatch", i, len(outdbg)) } } } diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 17ce468f623..9bc991a8f2d 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -2,6 +2,7 @@ package exprhelpers import ( "bufio" + "context" "encoding/base64" "errors" "fmt" @@ -128,7 +129,7 @@ func Init(databaseClient *database.Client) error { dataFileRegex = make(map[string][]*regexp.Regexp) dataFileRe2 = make(map[string][]*re2.Regexp) dbClient = databaseClient - + XMLCacheInit() return nil } @@ -213,7 +214,7 @@ func FileInit(fileFolder string, filename string, fileType string) error { if strings.HasPrefix(scanner.Text(), "#") { // allow comments continue } - if len(scanner.Text()) == 0 { //skip empty lines + if scanner.Text() == "" { //skip empty lines continue } @@ -254,7 +255,6 @@ func Distinct(params ...any) (any, error) { } } return ret, nil - } func FlattenDistinct(params ...any) (any, error) { @@ -280,6 +280,7 @@ func flatten(args []interface{}, v reflect.Value) []interface{} { return args } + func existsInFileMaps(filename string, ftype string) (bool, error) { ok := false var err error @@ -592,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) { return 0, nil } - count, err := dbClient.CountDecisionsByValue(value) + + ctx := context.TODO() + + count, err := dbClient.CountDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -613,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) { log.Errorf("Failed to parse since parameter '%s' : %s", since, err) return 0, nil } + + ctx := context.TODO() sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + + count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -628,7 +635,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsCount()") return 0, nil } - count, err := dbClient.CountActiveDecisionsByValue(value) + ctx := context.TODO() + count, err := dbClient.CountActiveDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions count from value '%s'", value) return 0, err @@ -642,7 +650,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsTimeLeft()") return 0, nil } - timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value) + ctx := context.TODO() + timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions time left from value '%s'", value) return 0, err @@ -765,7 +774,6 @@ func B64Decode(params ...any) (any, error) { } func ParseKV(params ...any) (any, error) { - blob := params[0].(string) target := params[1].(map[string]interface{}) prefix := params[2].(string) diff --git a/pkg/exprhelpers/xml.go b/pkg/exprhelpers/xml.go index 75758e18316..0b550bdb641 100644 --- a/pkg/exprhelpers/xml.go +++ b/pkg/exprhelpers/xml.go @@ -1,43 +1,103 @@ package exprhelpers import ( + "errors" + "sync" + "time" + "github.com/beevik/etree" + "github.com/bluele/gcache" + "github.com/cespare/xxhash/v2" log "github.com/sirupsen/logrus" ) -var pathCache = make(map[string]etree.Path) +var ( + pathCache = make(map[string]etree.Path) + rwMutex = sync.RWMutex{} + xmlDocumentCache gcache.Cache +) + +func compileOrGetPath(path string) (etree.Path, error) { + rwMutex.RLock() + compiledPath, ok := pathCache[path] + rwMutex.RUnlock() + + if !ok { + var err error + compiledPath, err = etree.CompilePath(path) + if err != nil { + return etree.Path{}, err + } + + rwMutex.Lock() + pathCache[path] = compiledPath + rwMutex.Unlock() + } + + return compiledPath, nil +} + +func getXMLDocumentFromCache(xmlString string) (*etree.Document, error) { + cacheKey := xxhash.Sum64String(xmlString) + cacheObj, err := xmlDocumentCache.Get(cacheKey) + + if err != nil && !errors.Is(err, gcache.KeyNotFoundError) { + return nil, err + } + + doc, ok := cacheObj.(*etree.Document) + if !ok || cacheObj == nil { + doc = etree.NewDocument() + if err := doc.ReadFromString(xmlString); err != nil { + return nil, err + } + if err := xmlDocumentCache.Set(cacheKey, doc); err != nil { + log.Warnf("Could not set XML document in cache: %s", err) + } + } + + return doc, nil +} + +func XMLCacheInit() { + gc := gcache.New(50) + // Short cache expiration because we each line we read is different, but we can call multiple times XML helpers on each of them + gc.Expiration(5 * time.Second) + gc = gc.LRU() + + xmlDocumentCache = gc.Build() +} // func XMLGetAttributeValue(xmlString string, path string, attributeName string) string { func XMLGetAttributeValue(params ...any) (any, error) { xmlString := params[0].(string) path := params[1].(string) attributeName := params[2].(string) - if _, ok := pathCache[path]; !ok { - compiledPath, err := etree.CompilePath(path) - if err != nil { - log.Errorf("Could not compile path %s: %s", path, err) - return "", nil - } - pathCache[path] = compiledPath + + compiledPath, err := compileOrGetPath(path) + if err != nil { + log.Errorf("Could not compile path %s: %s", path, err) + return "", nil } - compiledPath := pathCache[path] - doc := etree.NewDocument() - err := doc.ReadFromString(xmlString) + doc, err := getXMLDocumentFromCache(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) return "", nil } + elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) return "", nil } + attr := elem.SelectAttr(attributeName) if attr == nil { log.Debugf("Could not find attribute %s", attributeName) return "", nil } + return attr.Value, nil } @@ -45,26 +105,24 @@ func XMLGetAttributeValue(params ...any) (any, error) { func XMLGetNodeValue(params ...any) (any, error) { xmlString := params[0].(string) path := params[1].(string) - if _, ok := pathCache[path]; !ok { - compiledPath, err := etree.CompilePath(path) - if err != nil { - log.Errorf("Could not compile path %s: %s", path, err) - return "", nil - } - pathCache[path] = compiledPath + + compiledPath, err := compileOrGetPath(path) + if err != nil { + log.Errorf("Could not compile path %s: %s", path, err) + return "", nil } - compiledPath := pathCache[path] - doc := etree.NewDocument() - err := doc.ReadFromString(xmlString) + doc, err := getXMLDocumentFromCache(xmlString) if err != nil { log.Tracef("Could not parse XML: %s", err) return "", nil } + elem := doc.FindElementPath(compiledPath) if elem == nil { log.Debugf("Could not find element %s", path) return "", nil } + return elem.Text(), nil } diff --git a/pkg/fflag/features.go b/pkg/fflag/features.go index 3a106984a66..c8a3d7755ea 100644 --- a/pkg/fflag/features.go +++ b/pkg/fflag/features.go @@ -97,7 +97,7 @@ type FeatureRegister struct { features map[string]*Feature } -var featureNameRexp = regexp.MustCompile(`^[a-z0-9_\.]+$`) +var featureNameRexp = regexp.MustCompile(`^[a-z0-9_.]+$`) func validateFeatureName(featureName string) error { if featureName == "" { diff --git a/pkg/hubtest/coverage.go b/pkg/hubtest/coverage.go index 4156def06d7..e42c1e23455 100644 --- a/pkg/hubtest/coverage.go +++ b/pkg/hubtest/coverage.go @@ -57,7 +57,7 @@ func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { err = yaml.Unmarshal(yamlFile, configFileData) if err != nil { - return nil, fmt.Errorf("unmarshal: %v", err) + return nil, fmt.Errorf("parsing: %v", err) } for _, appsecRulesFile := range configFileData.AppsecRules { @@ -70,7 +70,7 @@ func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { err = yaml.Unmarshal(yamlFile, appsecRuleData) if err != nil { - return nil, fmt.Errorf("unmarshal: %v", err) + return nil, fmt.Errorf("parsing: %v", err) } appsecRuleName := appsecRuleData.Name diff --git a/pkg/hubtest/hubtest.go b/pkg/hubtest/hubtest.go index a4ca275c310..93f5abaa879 100644 --- a/pkg/hubtest/hubtest.go +++ b/pkg/hubtest/hubtest.go @@ -83,7 +83,7 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecT } if isAppsecTest { - HubTestPath := filepath.Join(hubPath, "./.appsec-tests/") + HubTestPath := filepath.Join(hubPath, ".appsec-tests") hubIndexFile := filepath.Join(hubPath, ".index.json") local := &csconfig.LocalHubCfg{ @@ -119,7 +119,7 @@ func NewHubTest(hubPath string, crowdsecPath string, cscliPath string, isAppsecT }, nil } - HubTestPath := filepath.Join(hubPath, "./.tests/") + HubTestPath := filepath.Join(hubPath, ".tests") hubIndexFile := filepath.Join(hubPath, ".index.json") diff --git a/pkg/hubtest/hubtest_item.go b/pkg/hubtest/hubtest_item.go index da4969ee8dd..bc9c8955d0d 100644 --- a/pkg/hubtest/hubtest_item.go +++ b/pkg/hubtest/hubtest_item.go @@ -111,7 +111,7 @@ func NewTest(name string, hubTest *HubTest) (*HubTestItem, error) { err = yaml.Unmarshal(yamlFile, configFileData) if err != nil { - return nil, fmt.Errorf("unmarshal: %w", err) + return nil, fmt.Errorf("parsing: %w", err) } parserAssertFilePath := filepath.Join(testPath, ParserAssertFileName) @@ -201,7 +201,7 @@ func (t *HubTestItem) InstallHub() error { b, err := yaml.Marshal(n) if err != nil { - return fmt.Errorf("unable to marshal overrides: %w", err) + return fmt.Errorf("unable to serialize overrides: %w", err) } tgtFilename := fmt.Sprintf("%s/parsers/s00-raw/00_overrides.yaml", t.RuntimePath) @@ -223,39 +223,30 @@ func (t *HubTestItem) InstallHub() error { ctx := context.Background() // install data for parsers if needed - ret := hub.GetItemMap(cwhub.PARSERS) - for parserName, item := range ret { - if item.State.Installed { - if err := item.DownloadDataIfNeeded(ctx, true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", parserName, err) - } - - log.Debugf("parser '%s' installed successfully in runtime environment", parserName) + for _, item := range hub.GetInstalledByType(cwhub.PARSERS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("parser '%s' installed successfully in runtime environment", item.Name) } // install data for scenarios if needed - ret = hub.GetItemMap(cwhub.SCENARIOS) - for scenarioName, item := range ret { - if item.State.Installed { - if err := item.DownloadDataIfNeeded(ctx, true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", scenarioName, err) - } - - log.Debugf("scenario '%s' installed successfully in runtime environment", scenarioName) + for _, item := range hub.GetInstalledByType(cwhub.SCENARIOS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("scenario '%s' installed successfully in runtime environment", item.Name) } // install data for postoverflows if needed - ret = hub.GetItemMap(cwhub.POSTOVERFLOWS) - for postoverflowName, item := range ret { - if item.State.Installed { - if err := item.DownloadDataIfNeeded(ctx, true); err != nil { - return fmt.Errorf("unable to download data for parser '%s': %+v", postoverflowName, err) - } - - log.Debugf("postoverflow '%s' installed successfully in runtime environment", postoverflowName) + for _, item := range hub.GetInstalledByType(cwhub.POSTOVERFLOWS, true) { + if err := item.DownloadDataIfNeeded(ctx, true); err != nil { + return fmt.Errorf("unable to download data for parser '%s': %+v", item.Name, err) } + + log.Debugf("postoverflow '%s' installed successfully in runtime environment", item.Name) } return nil diff --git a/pkg/hubtest/nucleirunner.go b/pkg/hubtest/nucleirunner.go index 0bf2013dd8d..32c81eb64d8 100644 --- a/pkg/hubtest/nucleirunner.go +++ b/pkg/hubtest/nucleirunner.go @@ -42,11 +42,11 @@ func (nc *NucleiConfig) RunNucleiTemplate(testName string, templatePath string, err := cmd.Run() - if err := os.WriteFile(outputPrefix+"_stdout.txt", out.Bytes(), 0644); err != nil { + if err := os.WriteFile(outputPrefix+"_stdout.txt", out.Bytes(), 0o644); err != nil { log.Warningf("Error writing stdout: %s", err) } - if err := os.WriteFile(outputPrefix+"_stderr.txt", outErr.Bytes(), 0644); err != nil { + if err := os.WriteFile(outputPrefix+"_stderr.txt", outErr.Bytes(), 0o644); err != nil { log.Warningf("Error writing stderr: %s", err) } @@ -56,7 +56,7 @@ func (nc *NucleiConfig) RunNucleiTemplate(testName string, templatePath string, log.Warningf("Stderr saved to %s", outputPrefix+"_stderr.txt") log.Warningf("Nuclei generated output saved to %s", outputPrefix+".json") return err - } else if len(out.String()) == 0 { + } else if out.String() == "" { log.Warningf("Stdout saved to %s", outputPrefix+"_stdout.txt") log.Warningf("Stderr saved to %s", outputPrefix+"_stderr.txt") log.Warningf("Nuclei generated output saved to %s", outputPrefix+".json") diff --git a/pkg/hubtest/regexp.go b/pkg/hubtest/regexp.go index f9165eae3d1..8b2fcc928dd 100644 --- a/pkg/hubtest/regexp.go +++ b/pkg/hubtest/regexp.go @@ -5,7 +5,7 @@ import ( ) var ( - variableRE = regexp.MustCompile(`(?P[^ =]+) == .*`) - parserResultRE = regexp.MustCompile(`^results\["[^"]+"\]\["(?P[^"]+)"\]\[[0-9]+\]\.Evt\..*`) + variableRE = regexp.MustCompile(`(?P[^ =]+) == .*`) + parserResultRE = regexp.MustCompile(`^results\["[^"]+"\]\["(?P[^"]+)"\]\[[0-9]+\]\.Evt\..*`) scenarioResultRE = regexp.MustCompile(`^results\[[0-9]+\].Overflow.Alert.GetScenario\(\) == "(?P[^"]+)"`) ) diff --git a/pkg/hubtest/utils.go b/pkg/hubtest/utils.go index a7373fcc0bf..b42a73461f3 100644 --- a/pkg/hubtest/utils.go +++ b/pkg/hubtest/utils.go @@ -91,7 +91,7 @@ func CopyDir(src string, dest string) error { return errors.New("Source " + file.Name() + " is not a directory!") } - err = os.MkdirAll(dest, 0755) + err = os.MkdirAll(dest, 0o755) if err != nil { return err } diff --git a/pkg/leakybucket/buckets_test.go b/pkg/leakybucket/buckets_test.go index 989e03944c3..1da906cb555 100644 --- a/pkg/leakybucket/buckets_test.go +++ b/pkg/leakybucket/buckets_test.go @@ -136,7 +136,7 @@ func testOneBucket(t *testing.T, hub *cwhub.Hub, dir string, tomb *tomb.Tomb) er } if err := yaml.UnmarshalStrict(out.Bytes(), &stages); err != nil { - t.Fatalf("failed unmarshaling %s : %s", stagecfg, err) + t.Fatalf("failed to parse %s : %s", stagecfg, err) } files := []string{} @@ -201,7 +201,7 @@ func testFile(t *testing.T, file string, bs string, holders []BucketFactory, res var ts time.Time if err := ts.UnmarshalText([]byte(in.MarshaledTime)); err != nil { - t.Fatalf("Failed to unmarshal time from input event : %s", err) + t.Fatalf("Failed to parse time from input event : %s", err) } if latest_ts.IsZero() { diff --git a/pkg/leakybucket/manager_load.go b/pkg/leakybucket/manager_load.go index ca2e4d17d99..b8310b8cb17 100644 --- a/pkg/leakybucket/manager_load.go +++ b/pkg/leakybucket/manager_load.go @@ -22,7 +22,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/alertcontext" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwhub" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -45,12 +45,12 @@ type BucketFactory struct { Debug bool `yaml:"debug"` // Debug, when set to true, will enable debugging for _this_ scenario specifically Labels map[string]interface{} `yaml:"labels"` // Labels is K:V list aiming at providing context the overflow Blackhole string `yaml:"blackhole,omitempty"` // Blackhole is a duration that, if present, will prevent same bucket partition to overflow more often than $duration - logger *log.Entry `yaml:"-"` // logger is bucket-specific logger (used by Debug as well) - Reprocess bool `yaml:"reprocess"` // Reprocess, if true, will for the bucket to be re-injected into processing chain - CacheSize int `yaml:"cache_size"` // CacheSize, if > 0, limits the size of in-memory cache of the bucket - Profiling bool `yaml:"profiling"` // Profiling, if true, will make the bucket record pours/overflows/etc. - OverflowFilter string `yaml:"overflow_filter"` // OverflowFilter if present, is a filter that must return true for the overflow to go through - ConditionalOverflow string `yaml:"condition"` // condition if present, is an expression that must return true for the bucket to overflow + logger *log.Entry // logger is bucket-specific logger (used by Debug as well) + Reprocess bool `yaml:"reprocess"` // Reprocess, if true, will for the bucket to be re-injected into processing chain + CacheSize int `yaml:"cache_size"` // CacheSize, if > 0, limits the size of in-memory cache of the bucket + Profiling bool `yaml:"profiling"` // Profiling, if true, will make the bucket record pours/overflows/etc. + OverflowFilter string `yaml:"overflow_filter"` // OverflowFilter if present, is a filter that must return true for the overflow to go through + ConditionalOverflow string `yaml:"condition"` // condition if present, is an expression that must return true for the bucket to overflow BayesianPrior float32 `yaml:"bayesian_prior"` BayesianThreshold float32 `yaml:"bayesian_threshold"` BayesianConditions []RawBayesianCondition `yaml:"bayesian_conditions"` // conditions for the bayesian bucket @@ -68,95 +68,136 @@ type BucketFactory struct { processors []Processor // processors is the list of hooks for pour/overflow/create (cf. uniq, blackhole etc.) output bool // ?? ScenarioVersion string `yaml:"version,omitempty"` - hash string `yaml:"-"` - Simulated bool `yaml:"simulated"` // Set to true if the scenario instantiating the bucket was in the exclusion list - tomb *tomb.Tomb `yaml:"-"` - wgPour *sync.WaitGroup `yaml:"-"` - wgDumpState *sync.WaitGroup `yaml:"-"` + hash string + Simulated bool `yaml:"simulated"` // Set to true if the scenario instantiating the bucket was in the exclusion list + tomb *tomb.Tomb + wgPour *sync.WaitGroup + wgDumpState *sync.WaitGroup orderEvent bool } // we use one NameGenerator for all the future buckets var seed namegenerator.Generator = namegenerator.NewNameGenerator(time.Now().UTC().UnixNano()) -func ValidateFactory(bucketFactory *BucketFactory) error { - if bucketFactory.Name == "" { - return errors.New("bucket must have name") +func validateLeakyType(bucketFactory *BucketFactory) error { + if bucketFactory.Capacity <= 0 { // capacity must be a positive int + return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) } - if bucketFactory.Description == "" { - return errors.New("description is mandatory") + if bucketFactory.LeakSpeed == "" { + return errors.New("leakspeed can't be empty for leaky") } - if bucketFactory.Type == "leaky" { - if bucketFactory.Capacity <= 0 { // capacity must be a positive int - return fmt.Errorf("bad capacity for leaky '%d'", bucketFactory.Capacity) - } + if bucketFactory.leakspeed == 0 { + return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) + } - if bucketFactory.LeakSpeed == "" { - return errors.New("leakspeed can't be empty for leaky") - } + return nil +} - if bucketFactory.leakspeed == 0 { - return fmt.Errorf("bad leakspeed for leaky '%s'", bucketFactory.LeakSpeed) - } - } else if bucketFactory.Type == "counter" { - if bucketFactory.Duration == "" { - return errors.New("duration can't be empty for counter") - } +func validateCounterType(bucketFactory *BucketFactory) error { + if bucketFactory.Duration == "" { + return errors.New("duration can't be empty for counter") + } - if bucketFactory.duration == 0 { - return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) - } + if bucketFactory.duration == 0 { + return fmt.Errorf("bad duration for counter bucket '%d'", bucketFactory.duration) + } - if bucketFactory.Capacity != -1 { - return errors.New("counter bucket must have -1 capacity") - } - } else if bucketFactory.Type == "trigger" { - if bucketFactory.Capacity != 0 { - return errors.New("trigger bucket must have 0 capacity") - } - } else if bucketFactory.Type == "conditional" { - if bucketFactory.ConditionalOverflow == "" { - return errors.New("conditional bucket must have a condition") - } + if bucketFactory.Capacity != -1 { + return errors.New("counter bucket must have -1 capacity") + } - if bucketFactory.Capacity != -1 { - bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") - } + return nil +} - if bucketFactory.LeakSpeed == "" { - return errors.New("leakspeed can't be empty for conditional bucket") - } +func validateTriggerType(bucketFactory *BucketFactory) error { + if bucketFactory.Capacity != 0 { + return errors.New("trigger bucket must have 0 capacity") + } - if bucketFactory.leakspeed == 0 { - return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) - } - } else if bucketFactory.Type == "bayesian" { - if bucketFactory.BayesianConditions == nil { - return errors.New("bayesian bucket must have bayesian conditions") - } + return nil +} - if bucketFactory.BayesianPrior == 0 { - return errors.New("bayesian bucket must have a valid, non-zero prior") - } +func validateConditionalType(bucketFactory *BucketFactory) error { + if bucketFactory.ConditionalOverflow == "" { + return errors.New("conditional bucket must have a condition") + } - if bucketFactory.BayesianThreshold == 0 { - return errors.New("bayesian bucket must have a valid, non-zero threshold") - } + if bucketFactory.Capacity != -1 { + bucketFactory.logger.Warnf("Using a value different than -1 as capacity for conditional bucket, this may lead to unexpected overflows") + } - if bucketFactory.BayesianPrior > 1 { - return errors.New("bayesian bucket must have a valid, non-zero prior") - } + if bucketFactory.LeakSpeed == "" { + return errors.New("leakspeed can't be empty for conditional bucket") + } - if bucketFactory.BayesianThreshold > 1 { - return errors.New("bayesian bucket must have a valid, non-zero threshold") - } + if bucketFactory.leakspeed == 0 { + return fmt.Errorf("bad leakspeed for conditional bucket '%s'", bucketFactory.LeakSpeed) + } + + return nil +} + +func validateBayesianType(bucketFactory *BucketFactory) error { + if bucketFactory.BayesianConditions == nil { + return errors.New("bayesian bucket must have bayesian conditions") + } + + if bucketFactory.BayesianPrior == 0 { + return errors.New("bayesian bucket must have a valid, non-zero prior") + } + + if bucketFactory.BayesianThreshold == 0 { + return errors.New("bayesian bucket must have a valid, non-zero threshold") + } + + if bucketFactory.BayesianPrior > 1 { + return errors.New("bayesian bucket must have a valid, non-zero prior") + } + + if bucketFactory.BayesianThreshold > 1 { + return errors.New("bayesian bucket must have a valid, non-zero threshold") + } + + if bucketFactory.Capacity != -1 { + return errors.New("bayesian bucket must have capacity -1") + } + + return nil +} + +func ValidateFactory(bucketFactory *BucketFactory) error { + if bucketFactory.Name == "" { + return errors.New("bucket must have name") + } + + if bucketFactory.Description == "" { + return errors.New("description is mandatory") + } - if bucketFactory.Capacity != -1 { - return errors.New("bayesian bucket must have capacity -1") + switch bucketFactory.Type { + case "leaky": + if err := validateLeakyType(bucketFactory); err != nil { + return err } - } else { + case "counter": + if err := validateCounterType(bucketFactory); err != nil { + return err + } + case "trigger": + if err := validateTriggerType(bucketFactory); err != nil { + return err + } + case "conditional": + if err := validateConditionalType(bucketFactory); err != nil { + return err + } + case "bayesian": + if err := validateBayesianType(bucketFactory); err != nil { + return err + } + default: return fmt.Errorf("unknown bucket type '%s'", bucketFactory.Type) } @@ -230,8 +271,8 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str err = dec.Decode(&bucketFactory) if err != nil { if !errors.Is(err, io.EOF) { - log.Errorf("Bad yaml in %s : %v", f, err) - return nil, nil, fmt.Errorf("bad yaml in %s : %v", f, err) + log.Errorf("Bad yaml in %s: %v", f, err) + return nil, nil, fmt.Errorf("bad yaml in %s: %w", f, err) } log.Tracef("End of yaml file") @@ -251,13 +292,13 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str bucketFactory.FormatVersion = "1.0" } - ok, err := cwversion.Satisfies(bucketFactory.FormatVersion, cwversion.Constraint_scenario) + ok, err := constraint.Satisfies(bucketFactory.FormatVersion, constraint.Scenario) if err != nil { return nil, nil, fmt.Errorf("failed to check version: %w", err) } if !ok { - log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, cwversion.Constraint_scenario) + log.Errorf("can't load %s : %s doesn't satisfy scenario format %s, skip", bucketFactory.Name, bucketFactory.FormatVersion, constraint.Scenario) continue } @@ -282,8 +323,8 @@ func LoadBuckets(cscfg *csconfig.CrowdsecServiceCfg, hub *cwhub.Hub, files []str err = LoadBucket(&bucketFactory, tomb) if err != nil { - log.Errorf("Failed to load bucket %s : %v", bucketFactory.Name, err) - return nil, nil, fmt.Errorf("loading of %s failed : %v", bucketFactory.Name, err) + log.Errorf("Failed to load bucket %s: %v", bucketFactory.Name, err) + return nil, nil, fmt.Errorf("loading of %s failed: %w", bucketFactory.Name, err) } bucketFactory.orderEvent = orderEvent @@ -326,7 +367,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.LeakSpeed != "" { if bucketFactory.leakspeed, err = time.ParseDuration(bucketFactory.LeakSpeed); err != nil { - return fmt.Errorf("bad leakspeed '%s' in %s : %v", bucketFactory.LeakSpeed, bucketFactory.Filename, err) + return fmt.Errorf("bad leakspeed '%s' in %s: %w", bucketFactory.LeakSpeed, bucketFactory.Filename, err) } } else { bucketFactory.leakspeed = time.Duration(0) @@ -334,7 +375,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { if bucketFactory.Duration != "" { if bucketFactory.duration, err = time.ParseDuration(bucketFactory.Duration); err != nil { - return fmt.Errorf("invalid Duration '%s' in %s : %v", bucketFactory.Duration, bucketFactory.Filename, err) + return fmt.Errorf("invalid Duration '%s' in %s: %w", bucketFactory.Duration, bucketFactory.Filename, err) } } @@ -345,13 +386,13 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { bucketFactory.RunTimeFilter, err = expr.Compile(bucketFactory.Filter, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("invalid filter '%s' in %s : %v", bucketFactory.Filter, bucketFactory.Filename, err) + return fmt.Errorf("invalid filter '%s' in %s: %w", bucketFactory.Filter, bucketFactory.Filename, err) } if bucketFactory.GroupBy != "" { bucketFactory.RunTimeGroupBy, err = expr.Compile(bucketFactory.GroupBy, exprhelpers.GetExprOptions(map[string]interface{}{"evt": &types.Event{}})...) if err != nil { - return fmt.Errorf("invalid groupby '%s' in %s : %v", bucketFactory.GroupBy, bucketFactory.Filename, err) + return fmt.Errorf("invalid groupby '%s' in %s: %w", bucketFactory.GroupBy, bucketFactory.Filename, err) } } @@ -370,7 +411,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { case "bayesian": bucketFactory.processors = append(bucketFactory.processors, &DumbProcessor{}) default: - return fmt.Errorf("invalid type '%s' in %s : %v", bucketFactory.Type, bucketFactory.Filename, err) + return fmt.Errorf("invalid type '%s' in %s: %w", bucketFactory.Type, bucketFactory.Filename, err) } if bucketFactory.Distinct != "" { @@ -435,7 +476,7 @@ func LoadBucket(bucketFactory *BucketFactory, tomb *tomb.Tomb) error { bucketFactory.output = false if err := ValidateFactory(bucketFactory); err != nil { - return fmt.Errorf("invalid bucket from %s : %v", bucketFactory.Filename, err) + return fmt.Errorf("invalid bucket from %s: %w", bucketFactory.Filename, err) } bucketFactory.tomb = tomb @@ -452,7 +493,7 @@ func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFac } if err := json.Unmarshal(body, &state); err != nil { - return fmt.Errorf("can't unmarshal state file %s: %w", file, err) + return fmt.Errorf("can't parse state file %s: %w", file, err) } for k, v := range state { @@ -468,37 +509,39 @@ func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFac found := false for _, h := range bucketFactories { - if h.Name == v.Name { - log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) - // check in which mode the bucket was - if v.Mode == types.TIMEMACHINE { - tbucket = NewTimeMachine(h) - } else if v.Mode == types.LIVE { - tbucket = NewLeaky(h) - } else { - log.Errorf("Unknown bucket type : %d", v.Mode) - } - /*Trying to restore queue state*/ - tbucket.Queue = v.Queue - /*Trying to set the limiter to the saved values*/ - tbucket.Limiter.Load(v.SerializedState) - tbucket.In = make(chan *types.Event) - tbucket.Mapkey = k - tbucket.Signal = make(chan bool, 1) - tbucket.First_ts = v.First_ts - tbucket.Last_ts = v.Last_ts - tbucket.Ovflw_ts = v.Ovflw_ts - tbucket.Total_count = v.Total_count - buckets.Bucket_map.Store(k, tbucket) - h.tomb.Go(func() error { - return LeakRoutine(tbucket) - }) - <-tbucket.Signal - - found = true + if h.Name != v.Name { + continue + } - break + log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description) + // check in which mode the bucket was + if v.Mode == types.TIMEMACHINE { + tbucket = NewTimeMachine(h) + } else if v.Mode == types.LIVE { + tbucket = NewLeaky(h) + } else { + log.Errorf("Unknown bucket type : %d", v.Mode) } + /*Trying to restore queue state*/ + tbucket.Queue = v.Queue + /*Trying to set the limiter to the saved values*/ + tbucket.Limiter.Load(v.SerializedState) + tbucket.In = make(chan *types.Event) + tbucket.Mapkey = k + tbucket.Signal = make(chan bool, 1) + tbucket.First_ts = v.First_ts + tbucket.Last_ts = v.Last_ts + tbucket.Ovflw_ts = v.Ovflw_ts + tbucket.Total_count = v.Total_count + buckets.Bucket_map.Store(k, tbucket) + h.tomb.Go(func() error { + return LeakRoutine(tbucket) + }) + <-tbucket.Signal + + found = true + + break } if !found { diff --git a/pkg/leakybucket/manager_run.go b/pkg/leakybucket/manager_run.go index 673b372d81e..2858d8b5635 100644 --- a/pkg/leakybucket/manager_run.go +++ b/pkg/leakybucket/manager_run.go @@ -132,7 +132,7 @@ func DumpBucketsStateAt(deadline time.Time, outputdir string, buckets *Buckets) }) bbuckets, err := json.MarshalIndent(serialized, "", " ") if err != nil { - return "", fmt.Errorf("failed to unmarshal buckets: %s", err) + return "", fmt.Errorf("failed to parse buckets: %s", err) } size, err := tmpFd.Write(bbuckets) if err != nil { @@ -203,7 +203,7 @@ func PourItemToBucket(bucket *Leaky, holder BucketFactory, buckets *Buckets, par var d time.Time err = d.UnmarshalText([]byte(parsed.MarshaledTime)) if err != nil { - holder.logger.Warningf("Failed unmarshaling event time (%s) : %v", parsed.MarshaledTime, err) + holder.logger.Warningf("Failed to parse event time (%s) : %v", parsed.MarshaledTime, err) } if d.After(lastTs.Add(bucket.Duration)) { bucket.logger.Tracef("bucket is expired (curr event: %s, bucket deadline: %s), kill", d, lastTs.Add(bucket.Duration)) @@ -298,7 +298,7 @@ func PourItemToHolders(parsed types.Event, holders []BucketFactory, buckets *Buc BucketPourCache["OK"] = append(BucketPourCache["OK"], evt.(types.Event)) } //find the relevant holders (scenarios) - for idx := range len(holders) { + for idx := range holders { //for idx, holder := range holders { //evaluate bucket's condition diff --git a/pkg/leakybucket/overflows.go b/pkg/leakybucket/overflows.go index 3ee067177ef..39b0e6a0ec4 100644 --- a/pkg/leakybucket/overflows.go +++ b/pkg/leakybucket/overflows.go @@ -19,66 +19,77 @@ import ( // SourceFromEvent extracts and formats a valid models.Source object from an Event func SourceFromEvent(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { - srcs := make(map[string]models.Source) /*if it's already an overflow, we have properly formatted sources. we can just twitch them to reflect the requested scope*/ if evt.Type == types.OVFLW { - for k, v := range evt.Overflow.Sources { - /*the scopes are already similar, nothing to do*/ - if leaky.scopeType.Scope == *v.Scope { - srcs[k] = v - continue - } + return overflowEventSources(evt, leaky) + } - /*The bucket requires a decision on scope Range */ - if leaky.scopeType.Scope == types.Range { - /*the original bucket was target IPs, check that we do have range*/ - if *v.Scope == types.Ip { - src := models.Source{} - src.AsName = v.AsName - src.AsNumber = v.AsNumber - src.Cn = v.Cn - src.Latitude = v.Latitude - src.Longitude = v.Longitude - src.Range = v.Range - src.Value = new(string) - src.Scope = new(string) - *src.Scope = leaky.scopeType.Scope - *src.Value = "" - - if v.Range != "" { - *src.Value = v.Range - } + return eventSources(evt, leaky) +} - if leaky.scopeType.RunTimeFilter != nil { - retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) - if err != nil { - return srcs, fmt.Errorf("while running scope filter: %w", err) - } +func overflowEventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { + srcs := make(map[string]models.Source) - value, ok := retValue.(string) - if !ok { - value = "" - } + for k, v := range evt.Overflow.Sources { + /*the scopes are already similar, nothing to do*/ + if leaky.scopeType.Scope == *v.Scope { + srcs[k] = v + continue + } - src.Value = &value + /*The bucket requires a decision on scope Range */ + if leaky.scopeType.Scope == types.Range { + /*the original bucket was target IPs, check that we do have range*/ + if *v.Scope == types.Ip { + src := models.Source{} + src.AsName = v.AsName + src.AsNumber = v.AsNumber + src.Cn = v.Cn + src.Latitude = v.Latitude + src.Longitude = v.Longitude + src.Range = v.Range + src.Value = new(string) + src.Scope = new(string) + *src.Scope = leaky.scopeType.Scope + *src.Value = "" + + if v.Range != "" { + *src.Value = v.Range + } + + if leaky.scopeType.RunTimeFilter != nil { + retValue, err := exprhelpers.Run(leaky.scopeType.RunTimeFilter, map[string]interface{}{"evt": &evt}, leaky.logger, leaky.BucketConfig.Debug) + if err != nil { + return srcs, fmt.Errorf("while running scope filter: %w", err) } - if *src.Value != "" { - srcs[*src.Value] = src - } else { - log.Warningf("bucket %s requires scope Range, but none was provided. It seems that the %s wasn't enriched to include its range.", leaky.Name, *v.Value) + value, ok := retValue.(string) + if !ok { + value = "" } + + src.Value = &value + } + + if *src.Value != "" { + srcs[*src.Value] = src } else { - log.Warningf("bucket %s requires scope Range, but can't extrapolate from %s (%s)", - leaky.Name, *v.Scope, *v.Value) + log.Warningf("bucket %s requires scope Range, but none was provided. It seems that the %s wasn't enriched to include its range.", leaky.Name, *v.Value) } + } else { + log.Warningf("bucket %s requires scope Range, but can't extrapolate from %s (%s)", + leaky.Name, *v.Scope, *v.Value) } } - - return srcs, nil } + return srcs, nil +} + +func eventSources(evt types.Event, leaky *Leaky) (map[string]models.Source, error) { + srcs := make(map[string]models.Source) + src := models.Source{} switch leaky.scopeType.Scope { @@ -220,7 +231,7 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { raw, err := evt.Time.MarshalText() if err != nil { - log.Warningf("while marshaling time '%s' : %s", evt.Time.String(), err) + log.Warningf("while serializing time '%s' : %s", evt.Time.String(), err) } else { *ovflwEvent.Timestamp = string(raw) } @@ -236,9 +247,10 @@ func EventsFromQueue(queue *types.Queue) []*models.Event { // alertFormatSource iterates over the queue to collect sources func alertFormatSource(leaky *Leaky, queue *types.Queue) (map[string]models.Source, string, error) { - var sources = make(map[string]models.Source) var source_type string + sources := make(map[string]models.Source) + log.Debugf("Formatting (%s) - scope Info : scope_type:%s / scope_filter:%s", leaky.Name, leaky.scopeType.Scope, leaky.scopeType.Filter) for _, evt := range queue.Queue { @@ -274,12 +286,12 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { */ start_at, err := leaky.First_ts.MarshalText() if err != nil { - log.Warningf("failed to marshal start ts %s : %s", leaky.First_ts.String(), err) + log.Warningf("failed to serialize start ts %s : %s", leaky.First_ts.String(), err) } stop_at, err := leaky.Ovflw_ts.MarshalText() if err != nil { - log.Warningf("failed to marshal ovflw ts %s : %s", leaky.First_ts.String(), err) + log.Warningf("failed to serialize ovflw ts %s : %s", leaky.First_ts.String(), err) } capacity := int32(leaky.Capacity) @@ -299,6 +311,7 @@ func NewAlert(leaky *Leaky, queue *types.Queue) (types.RuntimeAlert, error) { StopAt: &stopAt, Simulated: &leaky.Simulated, } + if leaky.BucketConfig == nil { return runtimeAlert, errors.New("leaky.BucketConfig is nil") } diff --git a/pkg/leakybucket/timemachine.go b/pkg/leakybucket/timemachine.go index e72bb1a464c..34073d1cc5c 100644 --- a/pkg/leakybucket/timemachine.go +++ b/pkg/leakybucket/timemachine.go @@ -24,7 +24,7 @@ func TimeMachinePour(l *Leaky, msg types.Event) { err = d.UnmarshalText([]byte(msg.MarshaledTime)) if err != nil { - log.Warningf("Failed unmarshaling event time (%s) : %v", msg.MarshaledTime, err) + log.Warningf("Failed to parse event time (%s) : %v", msg.MarshaledTime, err) return } diff --git a/pkg/leakybucket/trigger.go b/pkg/leakybucket/trigger.go index b6af1431888..d13e57856f9 100644 --- a/pkg/leakybucket/trigger.go +++ b/pkg/leakybucket/trigger.go @@ -16,25 +16,31 @@ func (t *Trigger) OnBucketPour(b *BucketFactory) func(types.Event, *Leaky) *type // Pour makes the bucket overflow all the time // TriggerPour unconditionally overflows return func(msg types.Event, l *Leaky) *types.Event { + now := time.Now().UTC() + if l.Mode == types.TIMEMACHINE { var d time.Time + err := d.UnmarshalText([]byte(msg.MarshaledTime)) if err != nil { - log.Warningf("Failed unmarshaling event time (%s) : %v", msg.MarshaledTime, err) - d = time.Now().UTC() + log.Warningf("Failed to parse event time (%s) : %v", msg.MarshaledTime, err) + + d = now } + l.logger.Debugf("yay timemachine overflow time : %s --> %s", d, msg.MarshaledTime) l.Last_ts = d l.First_ts = d l.Ovflw_ts = d } else { - l.Last_ts = time.Now().UTC() - l.First_ts = time.Now().UTC() - l.Ovflw_ts = time.Now().UTC() + l.Last_ts = now + l.First_ts = now + l.Ovflw_ts = now } + l.Total_count = 1 - l.logger.Infof("Bucket overflow") + l.logger.Debug("Bucket overflow") l.Queue.Add(msg) l.Out <- l.Queue diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index 9fa3b4b3f9a..5c395185b20 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -1,6 +1,7 @@ package longpollclient import ( + "context" "encoding/json" "errors" "fmt" @@ -50,7 +51,7 @@ var errUnauthorized = errors.New("user is not authorized to use PAPI") const timeoutMessage = "no events before timeout" -func (c *LongPollClient) doQuery() (*http.Response, error) { +func (c *LongPollClient) doQuery(ctx context.Context) (*http.Response, error) { logger := c.logger.WithField("method", "doQuery") query := c.url.Query() query.Set("since_time", fmt.Sprintf("%d", c.since)) @@ -59,7 +60,7 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { logger.Debugf("Query parameters: %s", c.url.RawQuery) - req, err := http.NewRequest(http.MethodGet, c.url.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url.String(), nil) if err != nil { logger.Errorf("failed to create request: %s", err) return nil, err @@ -73,12 +74,10 @@ func (c *LongPollClient) doQuery() (*http.Response, error) { return resp, nil } -func (c *LongPollClient) poll() error { - +func (c *LongPollClient) poll(ctx context.Context) error { logger := c.logger.WithField("method", "poll") - resp, err := c.doQuery() - + resp, err := c.doQuery(ctx) if err != nil { return err } @@ -95,7 +94,7 @@ func (c *LongPollClient) poll() error { logger.Errorf("failed to read response body: %s", err) return err } - logger.Errorf(string(bodyContent)) + logger.Error(string(bodyContent)) return errUnauthorized } return fmt.Errorf("unexpected status code: %d", resp.StatusCode) @@ -122,7 +121,7 @@ func (c *LongPollClient) poll() error { logger.Tracef("got response: %+v", pollResp) - if len(pollResp.ErrorMessage) > 0 { + if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { logger.Debugf("got timeout message") return nil @@ -148,7 +147,7 @@ func (c *LongPollClient) poll() error { } } -func (c *LongPollClient) pollEvents() error { +func (c *LongPollClient) pollEvents(ctx context.Context) error { for { select { case <-c.t.Dying(): @@ -156,7 +155,7 @@ func (c *LongPollClient) pollEvents() error { return nil default: c.logger.Debug("Polling PAPI") - err := c.poll() + err := c.poll(ctx) if err != nil { c.logger.Errorf("failed to poll: %s", err) if errors.Is(err, errUnauthorized) { @@ -170,12 +169,12 @@ func (c *LongPollClient) pollEvents() error { } } -func (c *LongPollClient) Start(since time.Time) chan Event { +func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event { c.logger.Infof("starting polling client") c.c = make(chan Event) c.since = since.Unix() * 1000 c.timeout = "45" - c.t.Go(c.pollEvents) + c.t.Go(func() error {return c.pollEvents(ctx)}) return c.c } @@ -184,11 +183,11 @@ func (c *LongPollClient) Stop() error { return nil } -func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { +func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event, error) { c.logger.Debug("Pulling PAPI once") c.since = since.Unix() * 1000 c.timeout = "1" - resp, err := c.doQuery() + resp, err := c.doQuery(ctx) if err != nil { return nil, err } @@ -209,7 +208,7 @@ func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { c.logger.Tracef("got response: %+v", pollResp) - if len(pollResp.ErrorMessage) > 0 { + if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { c.logger.Debugf("got timeout message") break diff --git a/pkg/metabase/api.go b/pkg/metabase/api.go index 387e8d151e0..08e10188678 100644 --- a/pkg/metabase/api.go +++ b/pkg/metabase/api.go @@ -9,7 +9,7 @@ import ( "github.com/dghubble/sling" log "github.com/sirupsen/logrus" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" ) type MBClient struct { @@ -38,7 +38,7 @@ var ( func NewMBClient(url string) (*MBClient, error) { httpClient := &http.Client{Timeout: 20 * time.Second} return &MBClient{ - CTX: sling.New().Client(httpClient).Base(url).Set("User-Agent", cwversion.UserAgent()), + CTX: sling.New().Client(httpClient).Base(url).Set("User-Agent", useragent.Default()), Client: httpClient, }, nil } diff --git a/pkg/metabase/metabase.go b/pkg/metabase/metabase.go index 837bab796d5..324a05666a1 100644 --- a/pkg/metabase/metabase.go +++ b/pkg/metabase/metabase.go @@ -70,12 +70,12 @@ func (m *Metabase) Init(containerName string, image string) error { switch m.Config.Database.Type { case "mysql": - return fmt.Errorf("'mysql' is not supported yet for cscli dashboard") + return errors.New("'mysql' is not supported yet for cscli dashboard") //DBConnectionURI = fmt.Sprintf("MB_DB_CONNECTION_URI=mysql://%s:%d/%s?user=%s&password=%s&allowPublicKeyRetrieval=true", remoteDBAddr, m.Config.Database.Port, m.Config.Database.DbName, m.Config.Database.User, m.Config.Database.Password) case "sqlite": m.InternalDBURL = metabaseSQLiteDBURL case "postgresql", "postgres", "pgsql": - return fmt.Errorf("'postgresql' is not supported yet by cscli dashboard") + return errors.New("'postgresql' is not supported yet by cscli dashboard") default: return fmt.Errorf("database '%s' not supported", m.Config.Database.Type) } diff --git a/pkg/models/generate.go b/pkg/models/generate.go index ccacc409ab5..502d6f3d2cf 100644 --- a/pkg/models/generate.go +++ b/pkg/models/generate.go @@ -1,4 +1,4 @@ package models -//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.30.5 generate model --spec=./localapi_swagger.yaml --target=../ +//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.31.0 generate model --spec=./localapi_swagger.yaml --target=../ diff --git a/pkg/modelscapi/add_signals_request.go b/pkg/modelscapi/add_signals_request.go index 62fe590cb79..7bfe6ae80e0 100644 --- a/pkg/modelscapi/add_signals_request.go +++ b/pkg/modelscapi/add_signals_request.go @@ -56,6 +56,11 @@ func (m AddSignalsRequest) ContextValidate(ctx context.Context, formats strfmt.R for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/add_signals_request_item.go b/pkg/modelscapi/add_signals_request_item.go index f9c865b4c68..5f63b542d5a 100644 --- a/pkg/modelscapi/add_signals_request_item.go +++ b/pkg/modelscapi/add_signals_request_item.go @@ -65,6 +65,9 @@ type AddSignalsRequestItem struct { // stop at // Required: true StopAt *string `json:"stop_at"` + + // UUID of the alert + UUID string `json:"uuid,omitempty"` } // Validate validates this add signals request item @@ -257,6 +260,11 @@ func (m *AddSignalsRequestItem) contextValidateContext(ctx context.Context, form for i := 0; i < len(m.Context); i++ { if m.Context[i] != nil { + + if swag.IsZero(m.Context[i]) { // not required + return nil + } + if err := m.Context[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("context" + "." + strconv.Itoa(i)) @@ -289,6 +297,7 @@ func (m *AddSignalsRequestItem) contextValidateDecisions(ctx context.Context, fo func (m *AddSignalsRequestItem) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/modelscapi/add_signals_request_item_decisions.go b/pkg/modelscapi/add_signals_request_item_decisions.go index 54e123ab3f8..11ed27a496d 100644 --- a/pkg/modelscapi/add_signals_request_item_decisions.go +++ b/pkg/modelscapi/add_signals_request_item_decisions.go @@ -54,6 +54,11 @@ func (m AddSignalsRequestItemDecisions) ContextValidate(ctx context.Context, for for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/add_signals_request_item_decisions_item.go b/pkg/modelscapi/add_signals_request_item_decisions_item.go index 34dfeb5bce5..797c517e33f 100644 --- a/pkg/modelscapi/add_signals_request_item_decisions_item.go +++ b/pkg/modelscapi/add_signals_request_item_decisions_item.go @@ -49,6 +49,9 @@ type AddSignalsRequestItemDecisionsItem struct { // until Until string `json:"until,omitempty"` + // UUID of the decision + UUID string `json:"uuid,omitempty"` + // the value of the decision scope : an IP, a range, a username, etc // Required: true Value *string `json:"value"` diff --git a/pkg/modelscapi/centralapi_swagger.yaml b/pkg/modelscapi/centralapi_swagger.yaml new file mode 100644 index 00000000000..c75233809c8 --- /dev/null +++ b/pkg/modelscapi/centralapi_swagger.yaml @@ -0,0 +1,888 @@ +swagger: "2.0" +info: + description: + "API to manage machines using [crowdsec](https://github.com/crowdsecurity/crowdsec)\ + \ and bouncers.\n" + version: "2023-01-23T11:16:39Z" + title: "prod-capi-v3" + contact: + name: "Crowdsec team" + url: "https://github.com/crowdsecurity/crowdsec" + email: "support@crowdsec.net" +host: "api.crowdsec.net" +basePath: "/v3" +tags: + - name: "watchers" + description: "Operations about watchers: crowdsec & cscli" + - name: "bouncers" + description: "Operations about decisions : bans, captcha, rate-limit etc." +schemes: + - "https" +paths: + /decisions/delete: + post: + tags: + - "watchers" + summary: "delete decisions" + description: "delete provided decisions" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "DecisionsDeleteRequest" + required: true + schema: + $ref: "#/definitions/DecisionsDeleteRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /decisions/stream: + get: + tags: + - "bouncers" + - "watchers" + summary: "returns list of top decisions" + description: "returns list of top decisions to add or delete" + produces: + - "application/json" + parameters: + - in: query + name: "community_pull" + type: "boolean" + default: true + required: false + description: "Fetch the community blocklist content" + - in: query + name: "additional_pull" + type: "boolean" + default: true + required: false + description: "Fetch additional blocklists content" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/GetDecisionsStreamResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "404": + description: "404 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + options: + consumes: + - "application/json" + produces: + - "application/json" + responses: + "200": + description: "200 response" + headers: + Access-Control-Allow-Origin: + type: "string" + Access-Control-Allow-Methods: + type: "string" + Access-Control-Allow-Headers: + type: "string" + /decisions/sync: + post: + tags: + - "watchers" + summary: "sync decisions" + description: "sync provided decisions" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "DecisionsSyncRequest" + required: true + schema: + $ref: "#/definitions/DecisionsSyncRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /metrics: + post: + tags: + - "watchers" + summary: "receive metrics about enrolled machines and bouncers in APIL" + description: "receive metrics about enrolled machines and bouncers in APIL" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "MetricsRequest" + required: true + schema: + $ref: "#/definitions/MetricsRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /signals: + post: + tags: + - "watchers" + summary: "Push signals" + description: "to push signals" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "AddSignalsRequest" + required: true + schema: + $ref: "#/definitions/AddSignalsRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /watchers: + post: + tags: + - "watchers" + summary: "Register watcher" + description: "Register a watcher" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "RegisterRequest" + required: true + schema: + $ref: "#/definitions/RegisterRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + /watchers/enroll: + post: + tags: + - "watchers" + summary: "watcher enrollment" + description: "watcher enrollment : enroll watcher to crowdsec backoffice account" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "EnrollRequest" + required: true + schema: + $ref: "#/definitions/EnrollRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + security: + - UserPoolAuthorizer: [] + /watchers/login: + post: + tags: + - "watchers" + summary: "watcher login" + description: "Sign-in to get a valid token" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "LoginRequest" + required: true + schema: + $ref: "#/definitions/LoginRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/LoginResponse" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + /watchers/reset: + post: + tags: + - "watchers" + summary: "Reset Password" + description: "to reset a watcher password" + consumes: + - "application/json" + produces: + - "application/json" + parameters: + - in: "body" + name: "ResetPasswordRequest" + required: true + schema: + $ref: "#/definitions/ResetPasswordRequest" + responses: + "200": + description: "200 response" + schema: + $ref: "#/definitions/SuccessResponse" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + "400": + description: "400 response" + schema: + $ref: "#/definitions/ErrorResponse" + "500": + description: "500 response" + schema: + $ref: "#/definitions/ErrorResponse" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + "403": + description: "403 response" + schema: + $ref: "#/definitions/ErrorResponse" + "404": + description: "404 response" + headers: + Content-type: + type: "string" + Access-Control-Allow-Origin: + type: "string" + options: + consumes: + - "application/json" + produces: + - "application/json" + responses: + "200": + description: "200 response" + headers: + Access-Control-Allow-Origin: + type: "string" + Access-Control-Allow-Methods: + type: "string" + Access-Control-Allow-Headers: + type: "string" +securityDefinitions: + UserPoolAuthorizer: + type: "apiKey" + name: "Authorization" + in: "header" + x-amazon-apigateway-authtype: "cognito_user_pools" +definitions: + DecisionsDeleteRequest: + title: "delete decisions" + type: "array" + description: "delete decision model" + items: + $ref: "#/definitions/DecisionsDeleteRequestItem" + DecisionsSyncRequestItem: + type: "object" + required: + - "message" + - "scenario" + - "scenario_hash" + - "scenario_version" + - "source" + - "start_at" + - "stop_at" + properties: + scenario_trust: + type: "string" + scenario_hash: + type: "string" + scenario: + type: "string" + alert_id: + type: "integer" + created_at: + type: "string" + machine_id: + type: "string" + decisions: + $ref: "#/definitions/DecisionsSyncRequestItemDecisions" + source: + $ref: "#/definitions/DecisionsSyncRequestItemSource" + scenario_version: + type: "string" + message: + type: "string" + description: "a human readable message" + start_at: + type: "string" + stop_at: + type: "string" + title: "Signal" + AddSignalsRequestItem: + type: "object" + required: + - "message" + - "scenario" + - "scenario_hash" + - "scenario_version" + - "source" + - "start_at" + - "stop_at" + properties: + created_at: + type: "string" + machine_id: + type: "string" + source: + $ref: "#/definitions/AddSignalsRequestItemSource" + scenario_version: + type: "string" + message: + type: "string" + description: "a human readable message" + uuid: + type: "string" + description: "UUID of the alert" + start_at: + type: "string" + scenario_trust: + type: "string" + scenario_hash: + type: "string" + scenario: + type: "string" + alert_id: + type: "integer" + context: + type: "array" + items: + type: "object" + properties: + value: + type: "string" + key: + type: "string" + decisions: + $ref: "#/definitions/AddSignalsRequestItemDecisions" + stop_at: + type: "string" + title: "Signal" + DecisionsSyncRequest: + title: "sync decisions request" + type: "array" + description: "sync decision model" + items: + $ref: "#/definitions/DecisionsSyncRequestItem" + LoginRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + minLength: 48 + maxLength: 48 + pattern: "^[a-zA-Z0-9]+$" + scenarios: + type: "array" + description: "all scenarios installed" + items: + type: "string" + title: "login request" + description: "Login request model" + GetDecisionsStreamResponseNewItem: + type: "object" + required: + - "scenario" + - "scope" + - "decisions" + properties: + scenario: + type: "string" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + decisions: + type: array + items: + type: object + required: + - value + - duration + properties: + duration: + type: "string" + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "New Decisions" + GetDecisionsStreamResponseDeletedItem: + type: object + required: + - scope + - decisions + properties: + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + decisions: + type: array + items: + type: string + BlocklistLink: + type: object + required: + - name + - url + - remediation + - scope + - duration + properties: + name: + type: string + description: "the name of the blocklist" + url: + type: string + description: "the url from which the blocklist content can be downloaded" + remediation: + type: string + description: "the remediation that should be used for the blocklist" + scope: + type: string + description: "the scope of decisions in the blocklist" + duration: + type: string + AddSignalsRequestItemDecisionsItem: + type: "object" + required: + - "duration" + - "id" + - "origin" + - "scenario" + - "scope" + - "type" + - "value" + properties: + duration: + type: "string" + uuid: + type: "string" + description: "UUID of the decision" + scenario: + type: "string" + origin: + type: "string" + description: "the origin of the decision : cscli, crowdsec" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + simulated: + type: "boolean" + until: + type: "string" + id: + type: "integer" + description: "(only relevant for GET ops) the unique id" + type: + type: "string" + description: + "the type of decision, might be 'ban', 'captcha' or something\ + \ custom. Ignored when watcher (cscli/crowdsec) is pushing to APIL." + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "Decision" + EnrollRequest: + type: "object" + required: + - "attachment_key" + properties: + name: + type: "string" + description: "The name that will be display in the console for the instance" + overwrite: + type: "boolean" + description: "To force enroll the instance" + attachment_key: + type: "string" + description: + "attachment_key is generated in your crowdsec backoffice account\ + \ and allows you to enroll your machines to your BO account" + pattern: "^[a-zA-Z0-9]+$" + tags: + type: "array" + description: "Tags to apply on the console for the instance" + items: + type: "string" + title: "enroll request" + description: "enroll request model" + ResetPasswordRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + minLength: 48 + maxLength: 48 + pattern: "^[a-zA-Z0-9]+$" + title: "resetPassword" + description: "ResetPassword request model" + MetricsRequestBouncersItem: + type: "object" + properties: + last_pull: + type: "string" + description: "last bouncer pull date" + custom_name: + type: "string" + description: "bouncer name" + name: + type: "string" + description: "bouncer type (firewall, php...)" + version: + type: "string" + description: "bouncer version" + title: "MetricsBouncerInfo" + AddSignalsRequestItemSource: + type: "object" + required: + - "scope" + - "value" + properties: + scope: + type: "string" + description: "the scope of a source : ip,range,username,etc" + ip: + type: "string" + description: "provided as a convenience when the source is an IP" + latitude: + type: "number" + format: "float" + as_number: + type: "string" + description: "provided as a convenience when the source is an IP" + range: + type: "string" + description: "provided as a convenience when the source is an IP" + cn: + type: "string" + value: + type: "string" + description: "the value of a source : the ip, the range, the username,etc" + as_name: + type: "string" + description: "provided as a convenience when the source is an IP" + longitude: + type: "number" + format: "float" + title: "Source" + DecisionsSyncRequestItemDecisions: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/DecisionsSyncRequestItemDecisionsItem" + RegisterRequest: + type: "object" + required: + - "machine_id" + - "password" + properties: + password: + type: "string" + description: "Password, should respect the password policy (link to add)" + machine_id: + type: "string" + description: "machine_id is a (username) generated by crowdsec" + pattern: "^[a-zA-Z0-9]+$" + title: "register request" + description: "Register request model" + SuccessResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "message" + title: "success response" + description: "success response return by the API" + LoginResponse: + type: "object" + properties: + code: + type: "integer" + expire: + type: "string" + token: + type: "string" + title: "login response" + description: "Login request model" + DecisionsSyncRequestItemDecisionsItem: + type: "object" + required: + - "duration" + - "id" + - "origin" + - "scenario" + - "scope" + - "type" + - "value" + properties: + duration: + type: "string" + scenario: + type: "string" + origin: + type: "string" + description: "the origin of the decision : cscli, crowdsec" + scope: + type: "string" + description: + "the scope of decision : does it apply to an IP, a range, a username,\ + \ etc" + simulated: + type: "boolean" + until: + type: "string" + id: + type: "integer" + description: "(only relevant for GET ops) the unique id" + type: + type: "string" + description: + "the type of decision, might be 'ban', 'captcha' or something\ + \ custom. Ignored when watcher (cscli/crowdsec) is pushing to APIL." + value: + type: "string" + description: + "the value of the decision scope : an IP, a range, a username,\ + \ etc" + title: "Decision" + GetDecisionsStreamResponse: + type: "object" + properties: + new: + $ref: "#/definitions/GetDecisionsStreamResponseNew" + deleted: + $ref: "#/definitions/GetDecisionsStreamResponseDeleted" + links: + $ref: "#/definitions/GetDecisionsStreamResponseLinks" + title: "get decisions stream response" + description: "get decision response model" + DecisionsSyncRequestItemSource: + type: "object" + required: + - "scope" + - "value" + properties: + scope: + type: "string" + description: "the scope of a source : ip,range,username,etc" + ip: + type: "string" + description: "provided as a convenience when the source is an IP" + latitude: + type: "number" + format: "float" + as_number: + type: "string" + description: "provided as a convenience when the source is an IP" + range: + type: "string" + description: "provided as a convenience when the source is an IP" + cn: + type: "string" + value: + type: "string" + description: "the value of a source : the ip, the range, the username,etc" + as_name: + type: "string" + description: "provided as a convenience when the source is an IP" + longitude: + type: "number" + format: "float" + title: "Source" + AddSignalsRequestItemDecisions: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/AddSignalsRequestItemDecisionsItem" + MetricsRequestMachinesItem: + type: "object" + properties: + last_update: + type: "string" + description: "last agent update date" + name: + type: "string" + description: "agent name" + last_push: + type: "string" + description: "last agent push date" + version: + type: "string" + description: "agent version" + title: "MetricsAgentInfo" + MetricsRequest: + type: "object" + required: + - "bouncers" + - "machines" + properties: + bouncers: + type: "array" + items: + $ref: "#/definitions/MetricsRequestBouncersItem" + machines: + type: "array" + items: + $ref: "#/definitions/MetricsRequestMachinesItem" + title: "metrics" + description: "push metrics model" + ErrorResponse: + type: "object" + required: + - "message" + properties: + message: + type: "string" + description: "Error message" + errors: + type: "string" + description: "more detail on individual errors" + title: "error response" + description: "error response return by the API" + AddSignalsRequest: + title: "add signals request" + type: "array" + description: "All signals request model" + items: + $ref: "#/definitions/AddSignalsRequestItem" + DecisionsDeleteRequestItem: + type: "string" + title: "decisionsIDs" + GetDecisionsStreamResponseNew: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/GetDecisionsStreamResponseNewItem" + GetDecisionsStreamResponseDeleted: + title: "Decisions list" + type: "array" + items: + $ref: "#/definitions/GetDecisionsStreamResponseDeletedItem" + GetDecisionsStreamResponseLinks: + title: "Decisions list" + type: "object" + properties: + blocklists: + type: array + items: + $ref: "#/definitions/BlocklistLink" + diff --git a/pkg/modelscapi/decisions_delete_request.go b/pkg/modelscapi/decisions_delete_request.go index e8718835027..0c93558adf1 100644 --- a/pkg/modelscapi/decisions_delete_request.go +++ b/pkg/modelscapi/decisions_delete_request.go @@ -11,6 +11,7 @@ import ( "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" ) // DecisionsDeleteRequest delete decisions @@ -49,6 +50,10 @@ func (m DecisionsDeleteRequest) ContextValidate(ctx context.Context, formats str for i := 0; i < len(m); i++ { + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/decisions_sync_request.go b/pkg/modelscapi/decisions_sync_request.go index e3a95162519..c087d39ff62 100644 --- a/pkg/modelscapi/decisions_sync_request.go +++ b/pkg/modelscapi/decisions_sync_request.go @@ -56,6 +56,11 @@ func (m DecisionsSyncRequest) ContextValidate(ctx context.Context, formats strfm for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/decisions_sync_request_item.go b/pkg/modelscapi/decisions_sync_request_item.go index 5139ea2de4b..460fe4d430e 100644 --- a/pkg/modelscapi/decisions_sync_request_item.go +++ b/pkg/modelscapi/decisions_sync_request_item.go @@ -231,6 +231,7 @@ func (m *DecisionsSyncRequestItem) contextValidateDecisions(ctx context.Context, func (m *DecisionsSyncRequestItem) contextValidateSource(ctx context.Context, formats strfmt.Registry) error { if m.Source != nil { + if err := m.Source.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("source") diff --git a/pkg/modelscapi/decisions_sync_request_item_decisions.go b/pkg/modelscapi/decisions_sync_request_item_decisions.go index 76316e43c5e..bdc8e77e2b6 100644 --- a/pkg/modelscapi/decisions_sync_request_item_decisions.go +++ b/pkg/modelscapi/decisions_sync_request_item_decisions.go @@ -54,6 +54,11 @@ func (m DecisionsSyncRequestItemDecisions) ContextValidate(ctx context.Context, for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/generate.go b/pkg/modelscapi/generate.go new file mode 100644 index 00000000000..66dc2a34b7e --- /dev/null +++ b/pkg/modelscapi/generate.go @@ -0,0 +1,4 @@ +package modelscapi + +//go:generate go run -mod=mod github.com/go-swagger/go-swagger/cmd/swagger@v0.31.0 generate model --spec=./centralapi_swagger.yaml --target=../ --model-package=modelscapi + diff --git a/pkg/modelscapi/get_decisions_stream_response.go b/pkg/modelscapi/get_decisions_stream_response.go index af19b85c4d3..5ebf29c5d93 100644 --- a/pkg/modelscapi/get_decisions_stream_response.go +++ b/pkg/modelscapi/get_decisions_stream_response.go @@ -144,6 +144,11 @@ func (m *GetDecisionsStreamResponse) contextValidateDeleted(ctx context.Context, func (m *GetDecisionsStreamResponse) contextValidateLinks(ctx context.Context, formats strfmt.Registry) error { if m.Links != nil { + + if swag.IsZero(m.Links) { // not required + return nil + } + if err := m.Links.ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("links") diff --git a/pkg/modelscapi/get_decisions_stream_response_deleted.go b/pkg/modelscapi/get_decisions_stream_response_deleted.go index d218bf87e4e..78292860f22 100644 --- a/pkg/modelscapi/get_decisions_stream_response_deleted.go +++ b/pkg/modelscapi/get_decisions_stream_response_deleted.go @@ -54,6 +54,11 @@ func (m GetDecisionsStreamResponseDeleted) ContextValidate(ctx context.Context, for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_links.go b/pkg/modelscapi/get_decisions_stream_response_links.go index 85cc9af9b48..6b9054574f1 100644 --- a/pkg/modelscapi/get_decisions_stream_response_links.go +++ b/pkg/modelscapi/get_decisions_stream_response_links.go @@ -82,6 +82,11 @@ func (m *GetDecisionsStreamResponseLinks) contextValidateBlocklists(ctx context. for i := 0; i < len(m.Blocklists); i++ { if m.Blocklists[i] != nil { + + if swag.IsZero(m.Blocklists[i]) { // not required + return nil + } + if err := m.Blocklists[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("blocklists" + "." + strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_new.go b/pkg/modelscapi/get_decisions_stream_response_new.go index e9525bf6fa7..8e09f1b20e7 100644 --- a/pkg/modelscapi/get_decisions_stream_response_new.go +++ b/pkg/modelscapi/get_decisions_stream_response_new.go @@ -54,6 +54,11 @@ func (m GetDecisionsStreamResponseNew) ContextValidate(ctx context.Context, form for i := 0; i < len(m); i++ { if m[i] != nil { + + if swag.IsZero(m[i]) { // not required + return nil + } + if err := m[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName(strconv.Itoa(i)) diff --git a/pkg/modelscapi/get_decisions_stream_response_new_item.go b/pkg/modelscapi/get_decisions_stream_response_new_item.go index a3592d0ab61..77cc06732ce 100644 --- a/pkg/modelscapi/get_decisions_stream_response_new_item.go +++ b/pkg/modelscapi/get_decisions_stream_response_new_item.go @@ -119,6 +119,11 @@ func (m *GetDecisionsStreamResponseNewItem) contextValidateDecisions(ctx context for i := 0; i < len(m.Decisions); i++ { if m.Decisions[i] != nil { + + if swag.IsZero(m.Decisions[i]) { // not required + return nil + } + if err := m.Decisions[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("decisions" + "." + strconv.Itoa(i)) diff --git a/pkg/modelscapi/metrics_request.go b/pkg/modelscapi/metrics_request.go index d5b7d058fc1..5d663cf1750 100644 --- a/pkg/modelscapi/metrics_request.go +++ b/pkg/modelscapi/metrics_request.go @@ -126,6 +126,11 @@ func (m *MetricsRequest) contextValidateBouncers(ctx context.Context, formats st for i := 0; i < len(m.Bouncers); i++ { if m.Bouncers[i] != nil { + + if swag.IsZero(m.Bouncers[i]) { // not required + return nil + } + if err := m.Bouncers[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("bouncers" + "." + strconv.Itoa(i)) @@ -146,6 +151,11 @@ func (m *MetricsRequest) contextValidateMachines(ctx context.Context, formats st for i := 0; i < len(m.Machines); i++ { if m.Machines[i] != nil { + + if swag.IsZero(m.Machines[i]) { // not required + return nil + } + if err := m.Machines[i].ContextValidate(ctx, formats); err != nil { if ve, ok := err.(*errors.Validation); ok { return ve.ValidateName("machines" + "." + strconv.Itoa(i)) diff --git a/pkg/parser/enrich_date.go b/pkg/parser/enrich_date.go index 748a466d7c3..40c8de39da5 100644 --- a/pkg/parser/enrich_date.go +++ b/pkg/parser/enrich_date.go @@ -18,7 +18,7 @@ func parseDateWithFormat(date, format string) (string, time.Time) { } retstr, err := t.MarshalText() if err != nil { - log.Warningf("Failed marshaling '%v'", t) + log.Warningf("Failed to serialize '%v'", t) return "", time.Time{} } return string(retstr), t @@ -98,7 +98,7 @@ func ParseDate(in string, p *types.Event, plog *log.Entry) (map[string]string, e now := time.Now().UTC() retstr, err := now.MarshalText() if err != nil { - plog.Warning("Failed marshaling current time") + plog.Warning("Failed to serialize current time") return ret, err } ret["MarshaledTime"] = string(retstr) diff --git a/pkg/parser/enrich_unmarshal.go b/pkg/parser/enrich_unmarshal.go index 7ff91b70aea..dbdd9d3f583 100644 --- a/pkg/parser/enrich_unmarshal.go +++ b/pkg/parser/enrich_unmarshal.go @@ -11,7 +11,7 @@ import ( func unmarshalJSON(field string, p *types.Event, plog *log.Entry) (map[string]string, error) { err := json.Unmarshal([]byte(p.Line.Raw), &p.Unmarshaled) if err != nil { - plog.Errorf("could not unmarshal JSON: %s", err) + plog.Errorf("could not parse JSON: %s", err) return nil, err } plog.Tracef("unmarshaled JSON: %+v", p.Unmarshaled) diff --git a/pkg/parser/parsing_test.go b/pkg/parser/parsing_test.go index 0542c69c049..269d51a1ba2 100644 --- a/pkg/parser/parsing_test.go +++ b/pkg/parser/parsing_test.go @@ -132,7 +132,7 @@ func testOneParser(pctx *UnixParserCtx, ectx EnricherCtx, dir string, b *testing } if err = yaml.UnmarshalStrict(out.Bytes(), &parser_configs); err != nil { - return fmt.Errorf("failed unmarshaling %s: %w", parser_cfg_file, err) + return fmt.Errorf("failed to parse %s: %w", parser_cfg_file, err) } pnodes, err = LoadStages(parser_configs, pctx, ectx) diff --git a/pkg/parser/stage.go b/pkg/parser/stage.go index fe538023b61..b98db350254 100644 --- a/pkg/parser/stage.go +++ b/pkg/parser/stage.go @@ -21,7 +21,7 @@ import ( log "github.com/sirupsen/logrus" yaml "gopkg.in/yaml.v2" - "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/crowdsecurity/crowdsec/pkg/cwversion/constraint" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" ) @@ -85,12 +85,12 @@ func LoadStages(stageFiles []Stagefile, pctx *UnixParserCtx, ectx EnricherCtx) ( log.Tracef("no version in %s, assuming '1.0'", node.Name) node.FormatVersion = "1.0" } - ok, err := cwversion.Satisfies(node.FormatVersion, cwversion.Constraint_parser) + ok, err := constraint.Satisfies(node.FormatVersion, constraint.Parser) if err != nil { return nil, fmt.Errorf("failed to check version : %s", err) } if !ok { - log.Errorf("%s : %s doesn't satisfy parser format %s, skip", node.Name, node.FormatVersion, cwversion.Constraint_parser) + log.Errorf("%s : %s doesn't satisfy parser format %s, skip", node.Name, node.FormatVersion, constraint.Parser) continue } diff --git a/pkg/parser/unix_parser.go b/pkg/parser/unix_parser.go index 280d122ecc1..f0f26a06645 100644 --- a/pkg/parser/unix_parser.go +++ b/pkg/parser/unix_parser.go @@ -43,7 +43,7 @@ func Init(c map[string]interface{}) (*UnixParserCtx, error) { } r.DataFolder = c["data"].(string) for _, f := range files { - if strings.Contains(f.Name(), ".") { + if strings.Contains(f.Name(), ".") || f.IsDir() { continue } if err := r.Grok.AddFromFile(filepath.Join(c["patterns"].(string), f.Name())); err != nil { @@ -66,21 +66,20 @@ func NewParsers(hub *cwhub.Hub) *Parsers { } for _, itemType := range []string{cwhub.PARSERS, cwhub.POSTOVERFLOWS} { - for _, hubParserItem := range hub.GetItemMap(itemType) { - if hubParserItem.State.Installed { - stagefile := Stagefile{ - Filename: hubParserItem.State.LocalPath, - Stage: hubParserItem.Stage, - } - if itemType == cwhub.PARSERS { - parsers.StageFiles = append(parsers.StageFiles, stagefile) - } - if itemType == cwhub.POSTOVERFLOWS { - parsers.PovfwStageFiles = append(parsers.PovfwStageFiles, stagefile) - } + for _, hubParserItem := range hub.GetInstalledByType(itemType, false) { + stagefile := Stagefile{ + Filename: hubParserItem.State.LocalPath, + Stage: hubParserItem.Stage, + } + if itemType == cwhub.PARSERS { + parsers.StageFiles = append(parsers.StageFiles, stagefile) + } + if itemType == cwhub.POSTOVERFLOWS { + parsers.PovfwStageFiles = append(parsers.PovfwStageFiles, stagefile) } } } + if parsers.StageFiles != nil { sort.Slice(parsers.StageFiles, func(i, j int) bool { return parsers.StageFiles[i].Filename < parsers.StageFiles[j].Filename @@ -101,13 +100,17 @@ func LoadParsers(cConfig *csconfig.Config, parsers *Parsers) (*Parsers, error) { patternsDir := cConfig.ConfigPaths.PatternDir log.Infof("Loading grok library %s", patternsDir) /* load base regexps for two grok parsers */ - parsers.Ctx, err = Init(map[string]interface{}{"patterns": patternsDir, - "data": cConfig.ConfigPaths.DataDir}) + parsers.Ctx, err = Init(map[string]interface{}{ + "patterns": patternsDir, + "data": cConfig.ConfigPaths.DataDir, + }) if err != nil { return parsers, fmt.Errorf("failed to load parser patterns : %v", err) } - parsers.Povfwctx, err = Init(map[string]interface{}{"patterns": patternsDir, - "data": cConfig.ConfigPaths.DataDir}) + parsers.Povfwctx, err = Init(map[string]interface{}{ + "patterns": patternsDir, + "data": cConfig.ConfigPaths.DataDir, + }) if err != nil { return parsers, fmt.Errorf("failed to load postovflw parser patterns : %v", err) } diff --git a/pkg/protobufs/generate.go b/pkg/protobufs/generate.go new file mode 100644 index 00000000000..0e90d65b643 --- /dev/null +++ b/pkg/protobufs/generate.go @@ -0,0 +1,14 @@ +package protobufs + +// Dependencies: +// +// apt install protobuf-compiler +// +// keep this in sync with go.mod +// go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 +// +// Not the same versions as google.golang.org/grpc +// go list -m -versions google.golang.org/grpc/cmd/protoc-gen-go-grpc +// go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative notifier.proto diff --git a/pkg/protobufs/notifier.pb.go b/pkg/protobufs/notifier.pb.go index b5dc8113568..8c4754da773 100644 --- a/pkg/protobufs/notifier.pb.go +++ b/pkg/protobufs/notifier.pb.go @@ -1,16 +1,12 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v3.12.4 +// protoc-gen-go v1.34.2 +// protoc v3.21.12 // source: notifier.proto package protobufs import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -198,7 +194,7 @@ func file_notifier_proto_rawDescGZIP() []byte { } var file_notifier_proto_msgTypes = make([]protoimpl.MessageInfo, 3) -var file_notifier_proto_goTypes = []interface{}{ +var file_notifier_proto_goTypes = []any{ (*Notification)(nil), // 0: proto.Notification (*Config)(nil), // 1: proto.Config (*Empty)(nil), // 2: proto.Empty @@ -221,7 +217,7 @@ func file_notifier_proto_init() { return } if !protoimpl.UnsafeEnabled { - file_notifier_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*Notification); i { case 0: return &v.state @@ -233,7 +229,7 @@ func file_notifier_proto_init() { return nil } } - file_notifier_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*Config); i { case 0: return &v.state @@ -245,7 +241,7 @@ func file_notifier_proto_init() { return nil } } - file_notifier_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + file_notifier_proto_msgTypes[2].Exporter = func(v any, i int) any { switch v := v.(*Empty); i { case 0: return &v.state @@ -277,119 +273,3 @@ func file_notifier_proto_init() { file_notifier_proto_goTypes = nil file_notifier_proto_depIdxs = nil } - -// Reference imports to suppress errors if they are not otherwise used. -var _ context.Context -var _ grpc.ClientConnInterface - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -const _ = grpc.SupportPackageIsVersion6 - -// NotifierClient is the client API for Notifier service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. -type NotifierClient interface { - Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) - Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) -} - -type notifierClient struct { - cc grpc.ClientConnInterface -} - -func NewNotifierClient(cc grpc.ClientConnInterface) NotifierClient { - return ¬ifierClient{cc} -} - -func (c *notifierClient) Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := c.cc.Invoke(ctx, "/proto.Notifier/Notify", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *notifierClient) Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) { - out := new(Empty) - err := c.cc.Invoke(ctx, "/proto.Notifier/Configure", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// NotifierServer is the server API for Notifier service. -type NotifierServer interface { - Notify(context.Context, *Notification) (*Empty, error) - Configure(context.Context, *Config) (*Empty, error) -} - -// UnimplementedNotifierServer can be embedded to have forward compatible implementations. -type UnimplementedNotifierServer struct { -} - -func (*UnimplementedNotifierServer) Notify(context.Context, *Notification) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Notify not implemented") -} -func (*UnimplementedNotifierServer) Configure(context.Context, *Config) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Configure not implemented") -} - -func RegisterNotifierServer(s *grpc.Server, srv NotifierServer) { - s.RegisterService(&_Notifier_serviceDesc, srv) -} - -func _Notifier_Notify_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Notification) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NotifierServer).Notify(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/proto.Notifier/Notify", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotifierServer).Notify(ctx, req.(*Notification)) - } - return interceptor(ctx, in, info, handler) -} - -func _Notifier_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(Config) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(NotifierServer).Configure(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/proto.Notifier/Configure", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(NotifierServer).Configure(ctx, req.(*Config)) - } - return interceptor(ctx, in, info, handler) -} - -var _Notifier_serviceDesc = grpc.ServiceDesc{ - ServiceName: "proto.Notifier", - HandlerType: (*NotifierServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Notify", - Handler: _Notifier_Notify_Handler, - }, - { - MethodName: "Configure", - Handler: _Notifier_Configure_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "notifier.proto", -} diff --git a/pkg/protobufs/notifier_grpc.pb.go b/pkg/protobufs/notifier_grpc.pb.go new file mode 100644 index 00000000000..5141e83f98b --- /dev/null +++ b/pkg/protobufs/notifier_grpc.pb.go @@ -0,0 +1,159 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v3.21.12 +// source: notifier.proto + +package protobufs + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Notifier_Notify_FullMethodName = "/proto.Notifier/Notify" + Notifier_Configure_FullMethodName = "/proto.Notifier/Configure" +) + +// NotifierClient is the client API for Notifier service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type NotifierClient interface { + Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) + Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) +} + +type notifierClient struct { + cc grpc.ClientConnInterface +} + +func NewNotifierClient(cc grpc.ClientConnInterface) NotifierClient { + return ¬ifierClient{cc} +} + +func (c *notifierClient) Notify(ctx context.Context, in *Notification, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, Notifier_Notify_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *notifierClient) Configure(ctx context.Context, in *Config, opts ...grpc.CallOption) (*Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Empty) + err := c.cc.Invoke(ctx, Notifier_Configure_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// NotifierServer is the server API for Notifier service. +// All implementations must embed UnimplementedNotifierServer +// for forward compatibility. +type NotifierServer interface { + Notify(context.Context, *Notification) (*Empty, error) + Configure(context.Context, *Config) (*Empty, error) + mustEmbedUnimplementedNotifierServer() +} + +// UnimplementedNotifierServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedNotifierServer struct{} + +func (UnimplementedNotifierServer) Notify(context.Context, *Notification) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Notify not implemented") +} +func (UnimplementedNotifierServer) Configure(context.Context, *Config) (*Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Configure not implemented") +} +func (UnimplementedNotifierServer) mustEmbedUnimplementedNotifierServer() {} +func (UnimplementedNotifierServer) testEmbeddedByValue() {} + +// UnsafeNotifierServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to NotifierServer will +// result in compilation errors. +type UnsafeNotifierServer interface { + mustEmbedUnimplementedNotifierServer() +} + +func RegisterNotifierServer(s grpc.ServiceRegistrar, srv NotifierServer) { + // If the following call pancis, it indicates UnimplementedNotifierServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Notifier_ServiceDesc, srv) +} + +func _Notifier_Notify_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Notification) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(NotifierServer).Notify(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Notifier_Notify_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(NotifierServer).Notify(ctx, req.(*Notification)) + } + return interceptor(ctx, in, info, handler) +} + +func _Notifier_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Config) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(NotifierServer).Configure(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Notifier_Configure_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(NotifierServer).Configure(ctx, req.(*Config)) + } + return interceptor(ctx, in, info, handler) +} + +// Notifier_ServiceDesc is the grpc.ServiceDesc for Notifier service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Notifier_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "proto.Notifier", + HandlerType: (*NotifierServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Notify", + Handler: _Notifier_Notify_Handler, + }, + { + MethodName: "Configure", + Handler: _Notifier_Configure_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "notifier.proto", +} diff --git a/pkg/protobufs/plugin_interface.go b/pkg/protobufs/plugin_interface.go deleted file mode 100644 index fc89b2fa009..00000000000 --- a/pkg/protobufs/plugin_interface.go +++ /dev/null @@ -1,47 +0,0 @@ -package protobufs - -import ( - "context" - - plugin "github.com/hashicorp/go-plugin" - "google.golang.org/grpc" -) - -type Notifier interface { - Notify(ctx context.Context, notification *Notification) (*Empty, error) - Configure(ctx context.Context, config *Config) (*Empty, error) -} - -// This is the implementation of plugin.NotifierPlugin so we can serve/consume this. -type NotifierPlugin struct { - // GRPCPlugin must still implement the Plugin interface - plugin.Plugin - // Concrete implementation, written in Go. This is only used for plugins - // that are written in Go. - Impl Notifier -} - -type GRPCClient struct{ client NotifierClient } - -func (m *GRPCClient) Notify(ctx context.Context, notification *Notification) (*Empty, error) { - _, err := m.client.Notify(context.Background(), notification) - return &Empty{}, err -} - -func (m *GRPCClient) Configure(ctx context.Context, config *Config) (*Empty, error) { - _, err := m.client.Configure(context.Background(), config) - return &Empty{}, err -} - -type GRPCServer struct { - Impl Notifier -} - -func (p *NotifierPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { - RegisterNotifierServer(s, p.Impl) - return nil -} - -func (p *NotifierPlugin) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &GRPCClient{client: NewNotifierClient(c)}, nil -} diff --git a/pkg/setup/detect.go b/pkg/setup/detect.go index 55af951bf89..073b221b10c 100644 --- a/pkg/setup/detect.go +++ b/pkg/setup/detect.go @@ -73,9 +73,9 @@ func validateDataSource(opaqueDS DataSourceItem) error { // source must be known - ds := acquisition.GetDataSourceIface(commonDS.Source) - if ds == nil { - return fmt.Errorf("unknown source '%s'", commonDS.Source) + ds, err := acquisition.GetDataSourceIface(commonDS.Source) + if err != nil { + return err } // unmarshal and validate the rest with the specific implementation @@ -545,7 +545,7 @@ func Detect(detectReader io.Reader, opts DetectOptions) (Setup, error) { // } // err = yaml.Unmarshal(svc.AcquisYAML, svc.DataSource) // if err != nil { - // return Setup{}, fmt.Errorf("while unmarshaling datasource for service %s: %w", name, err) + // return Setup{}, fmt.Errorf("while parsing datasource for service %s: %w", name, err) // } // } diff --git a/pkg/setup/detect_test.go b/pkg/setup/detect_test.go index c744e7d6796..588e74dab54 100644 --- a/pkg/setup/detect_test.go +++ b/pkg/setup/detect_test.go @@ -184,7 +184,6 @@ func TestNormalizeVersion(t *testing.T) { } for _, tc := range tests { - tc := tc t.Run(tc.version, func(t *testing.T) { t.Parallel() actual := setup.NormalizeVersion(tc.version) @@ -871,7 +870,7 @@ func TestDetectDatasourceValidation(t *testing.T) { datasource: source: wombat`, expected: setup.Setup{Setup: []setup.ServiceSetup{}}, - expectedErr: "invalid datasource for foobar: unknown source 'wombat'", + expectedErr: "invalid datasource for foobar: unknown data source wombat", }, { name: "source is misplaced", config: ` diff --git a/pkg/setup/install.go b/pkg/setup/install.go index fc5bd380fd9..d63a1ee1775 100644 --- a/pkg/setup/install.go +++ b/pkg/setup/install.go @@ -40,7 +40,7 @@ func decodeSetup(input []byte, fancyErrors bool) (Setup, error) { dec2.KnownFields(true) if err := dec2.Decode(&ret); err != nil { - return ret, fmt.Errorf("while unmarshaling setup file: %w", err) + return ret, fmt.Errorf("while parsing setup file: %w", err) } return ret, nil diff --git a/pkg/setup/units.go b/pkg/setup/units.go index ab1eec6f33e..861513d3f1d 100644 --- a/pkg/setup/units.go +++ b/pkg/setup/units.go @@ -35,7 +35,7 @@ func systemdUnitList() ([]string, error) { for scanner.Scan() { line := scanner.Text() - if len(line) == 0 { + if line == "" { break // the rest of the output is footer } diff --git a/pkg/types/appsec_event.go b/pkg/types/appsec_event.go index dc81c63b344..11d70ad368d 100644 --- a/pkg/types/appsec_event.go +++ b/pkg/types/appsec_event.go @@ -18,7 +18,9 @@ len(evt.Waf.ByTagRx("*CVE*").ByConfidence("high").ByAction("block")) > 1 */ -type MatchedRules []map[string]interface{} +type MatchedRules []MatchedRule + +type MatchedRule map[string]interface{} type AppsecEvent struct { HasInBandMatches, HasOutBandMatches bool @@ -45,6 +47,10 @@ const ( Kind Field = "kind" ) +func NewMatchedRule() *MatchedRule { + return &MatchedRule{} +} + func (w AppsecEvent) GetVar(varName string) string { if w.Vars == nil { return "" diff --git a/pkg/types/event.go b/pkg/types/event.go index 76a447bdc8c..9300626b927 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -2,6 +2,7 @@ package types import ( "net" + "strings" "time" "github.com/expr-lang/expr/vm" @@ -19,11 +20,11 @@ const ( // Event is the structure representing a runtime event (log or overflow) type Event struct { /* is it a log or an overflow */ - Type int `yaml:"Type,omitempty" json:"Type,omitempty"` //Can be types.LOG (0) or types.OVFLOW (1) - ExpectMode int `yaml:"ExpectMode,omitempty" json:"ExpectMode,omitempty"` //how to buckets should handle event : types.TIMEMACHINE or types.LIVE + Type int `yaml:"Type,omitempty" json:"Type,omitempty"` // Can be types.LOG (0) or types.OVFLOW (1) + ExpectMode int `yaml:"ExpectMode,omitempty" json:"ExpectMode,omitempty"` // how to buckets should handle event : types.TIMEMACHINE or types.LIVE Whitelisted bool `yaml:"Whitelisted,omitempty" json:"Whitelisted,omitempty"` WhitelistReason string `yaml:"WhitelistReason,omitempty" json:"whitelist_reason,omitempty"` - //should add whitelist reason ? + // should add whitelist reason ? /* the current stage of the line being parsed */ Stage string `yaml:"Stage,omitempty" json:"Stage,omitempty"` /* original line (produced by acquisition) */ @@ -36,21 +37,39 @@ type Event struct { Unmarshaled map[string]interface{} `yaml:"Unmarshaled,omitempty" json:"Unmarshaled,omitempty"` /* Overflow */ Overflow RuntimeAlert `yaml:"Overflow,omitempty" json:"Alert,omitempty"` - Time time.Time `yaml:"Time,omitempty" json:"Time,omitempty"` //parsed time `json:"-"` `` + Time time.Time `yaml:"Time,omitempty" json:"Time,omitempty"` // parsed time `json:"-"` `` StrTime string `yaml:"StrTime,omitempty" json:"StrTime,omitempty"` StrTimeFormat string `yaml:"StrTimeFormat,omitempty" json:"StrTimeFormat,omitempty"` MarshaledTime string `yaml:"MarshaledTime,omitempty" json:"MarshaledTime,omitempty"` - Process bool `yaml:"Process,omitempty" json:"Process,omitempty"` //can be set to false to avoid processing line + Process bool `yaml:"Process,omitempty" json:"Process,omitempty"` // can be set to false to avoid processing line Appsec AppsecEvent `yaml:"Appsec,omitempty" json:"Appsec,omitempty"` /* Meta is the only part that will make it to the API - it should be normalized */ Meta map[string]string `yaml:"Meta,omitempty" json:"Meta,omitempty"` } +func MakeEvent(timeMachine bool, evtType int, process bool) Event { + evt := Event{ + Parsed: make(map[string]string), + Meta: make(map[string]string), + Unmarshaled: make(map[string]interface{}), + Enriched: make(map[string]string), + ExpectMode: LIVE, + Process: process, + Type: evtType, + } + if timeMachine { + evt.ExpectMode = TIMEMACHINE + } + return evt +} + func (e *Event) SetMeta(key string, value string) bool { if e.Meta == nil { e.Meta = make(map[string]string) } + e.Meta[key] = value + return true } @@ -58,7 +77,9 @@ func (e *Event) SetParsed(key string, value string) bool { if e.Parsed == nil { e.Parsed = make(map[string]string) } + e.Parsed[key] = value + return true } @@ -90,11 +111,13 @@ func (e *Event) GetMeta(key string) string { } } } + return "" } func (e *Event) ParseIPSources() []net.IP { var srcs []net.IP + switch e.Type { case LOG: if _, ok := e.Meta["source_ip"]; ok { @@ -105,6 +128,7 @@ func (e *Event) ParseIPSources() []net.IP { srcs = append(srcs, net.ParseIP(k)) } } + return srcs } @@ -131,8 +155,8 @@ type RuntimeAlert struct { Whitelisted bool `yaml:"Whitelisted,omitempty" json:"Whitelisted,omitempty"` Reprocess bool `yaml:"Reprocess,omitempty" json:"Reprocess,omitempty"` Sources map[string]models.Source `yaml:"Sources,omitempty" json:"Sources,omitempty"` - Alert *models.Alert `yaml:"Alert,omitempty" json:"Alert,omitempty"` //this one is a pointer to APIAlerts[0] for convenience. - //APIAlerts will be populated at the end when there is more than one source + Alert *models.Alert `yaml:"Alert,omitempty" json:"Alert,omitempty"` // this one is a pointer to APIAlerts[0] for convenience. + // APIAlerts will be populated at the end when there is more than one source APIAlerts []models.Alert `yaml:"APIAlerts,omitempty" json:"APIAlerts,omitempty"` } @@ -141,5 +165,21 @@ func (r RuntimeAlert) GetSources() []string { for key := range r.Sources { ret = append(ret, key) } + return ret } + +func NormalizeScope(scope string) string { + switch strings.ToLower(scope) { + case "ip": + return Ip + case "range": + return Range + case "as": + return AS + case "country": + return Country + default: + return scope + } +} diff --git a/rpm/SPECS/crowdsec.spec b/rpm/SPECS/crowdsec.spec index ab71b650d11..ac438ad0c14 100644 --- a/rpm/SPECS/crowdsec.spec +++ b/rpm/SPECS/crowdsec.spec @@ -12,7 +12,7 @@ Patch0: user.patch BuildRoot: %{_tmppath}/%{name}-%{version}-%{release}-root-%(%{__id_u} -n) BuildRequires: systemd -Requires: crontabs +Requires: (crontabs or cron) %{?fc33:BuildRequires: systemd-rpm-macros} %{?fc34:BuildRequires: systemd-rpm-macros} %{?fc35:BuildRequires: systemd-rpm-macros} diff --git a/test/ansible/vagrant/fedora-40/Vagrantfile b/test/ansible/vagrant/fedora-40/Vagrantfile index ec03661fe39..5541d453acf 100644 --- a/test/ansible/vagrant/fedora-40/Vagrantfile +++ b/test/ansible/vagrant/fedora-40/Vagrantfile @@ -1,7 +1,7 @@ # frozen_string_literal: true Vagrant.configure('2') do |config| - config.vm.box = "fedora/39-cloud-base" + config.vm.box = "fedora/40-cloud-base" config.vm.provision "shell", inline: <<-SHELL SHELL end diff --git a/test/ansible/vagrant/fedora-41/Vagrantfile b/test/ansible/vagrant/fedora-41/Vagrantfile new file mode 100644 index 00000000000..3f905f51671 --- /dev/null +++ b/test/ansible/vagrant/fedora-41/Vagrantfile @@ -0,0 +1,13 @@ +# frozen_string_literal: true + +Vagrant.configure('2') do |config| + config.vm.box = "fedora/40-cloud-base" + config.vm.provision "shell", inline: <<-SHELL + SHELL + config.vm.provision "shell" do |s| + s.inline = "sudo dnf upgrade --refresh -y && sudo dnf install dnf-plugin-system-upgrade -y && sudo dnf system-upgrade download --releasever=41 -y && sudo dnf system-upgrade reboot -y" + end +end + +common = '../common' +load common if File.exist?(common) diff --git a/test/ansible/vagrant/fedora-41/skip b/test/ansible/vagrant/fedora-41/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/fedora-41/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/ansible/vagrant/opensuse-leap-15/Vagrantfile b/test/ansible/vagrant/opensuse-leap-15/Vagrantfile new file mode 100644 index 00000000000..d10e68a50a7 --- /dev/null +++ b/test/ansible/vagrant/opensuse-leap-15/Vagrantfile @@ -0,0 +1,10 @@ +# frozen_string_literal: true + +Vagrant.configure('2') do |config| + config.vm.box = "opensuse/Leap-15.6.x86_64" + config.vm.provision "shell", inline: <<-SHELL + SHELL +end + +common = '../common' +load common if File.exist?(common) diff --git a/test/ansible/vagrant/opensuse-leap-15/skip b/test/ansible/vagrant/opensuse-leap-15/skip new file mode 100644 index 00000000000..4f1a9063d2b --- /dev/null +++ b/test/ansible/vagrant/opensuse-leap-15/skip @@ -0,0 +1,9 @@ +#!/bin/sh + +die() { + echo "$@" >&2 + exit 1 +} + +[ "${DB_BACKEND}" = "mysql" ] && die "mysql role does not support this distribution" +exit 0 diff --git a/test/bats.mk b/test/bats.mk index 8f507cb659b..72ac8863f72 100644 --- a/test/bats.mk +++ b/test/bats.mk @@ -38,6 +38,7 @@ define ENV := export TEST_DIR="$(TEST_DIR)" export LOCAL_DIR="$(LOCAL_DIR)" export BIN_DIR="$(BIN_DIR)" +# append .min to the binary names to use the minimal profile export CROWDSEC="$(CROWDSEC)" export CSCLI="$(CSCLI)" export CONFIG_YAML="$(CONFIG_DIR)/config.yaml" @@ -66,8 +67,8 @@ bats-check-requirements: ## Check dependencies for functional tests @$(TEST_DIR)/bin/check-requirements bats-update-tools: ## Install/update tools required for functional tests - # yq v4.43.1 - GOBIN=$(TEST_DIR)/tools go install github.com/mikefarah/yq/v4@c35ec752e38ea0c096d3c44e13cfc0797ac394d8 + # yq v4.44.3 + GOBIN=$(TEST_DIR)/tools go install github.com/mikefarah/yq/v4@bbdd97482f2d439126582a59689eb1c855944955 # cfssl v1.6.5 GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssl@96259aa29c9cc9b2f4e04bad7d4bc152e5405dda GOBIN=$(TEST_DIR)/tools go install github.com/cloudflare/cfssl/cmd/cfssljson@96259aa29c9cc9b2f4e04bad7d4bc152e5405dda @@ -75,6 +76,11 @@ bats-update-tools: ## Install/update tools required for functional tests # Build and installs crowdsec in a local directory. Rebuilds if already exists. bats-build: bats-environment ## Build binaries for functional tests @$(MKDIR) $(BIN_DIR) $(LOG_DIR) $(PID_DIR) $(BATS_PLUGIN_DIR) + # minimal profile + @$(MAKE) build DEBUG=1 TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) BUILD_PROFILE=minimal + @install -m 0755 cmd/crowdsec/crowdsec $(BIN_DIR)/crowdsec.min + @install -m 0755 cmd/crowdsec-cli/cscli $(BIN_DIR)/cscli.min + # default profile @$(MAKE) build DEBUG=1 TEST_COVERAGE=$(TEST_COVERAGE) DEFAULT_CONFIGDIR=$(CONFIG_DIR) DEFAULT_DATADIR=$(DATA_DIR) @install -m 0755 cmd/crowdsec/crowdsec cmd/crowdsec-cli/cscli $(BIN_DIR)/ @install -m 0755 cmd/notification-*/notification-* $(BATS_PLUGIN_DIR)/ diff --git a/test/bats/01_crowdsec.bats b/test/bats/01_crowdsec.bats index 83072b0f159..aa5830a6bae 100644 --- a/test/bats/01_crowdsec.bats +++ b/test/bats/01_crowdsec.bats @@ -199,7 +199,42 @@ teardown() { assert_stderr --partial "crowdsec init: while loading acquisition config: no datasource enabled" } -@test "crowdsec (disabled datasources)" { +@test "crowdsec (datasource not built)" { + config_set '.common.log_media="stdout"' + + # a datasource cannot run - it's not built in the log processor executable + + ACQUIS_DIR=$(config_get '.crowdsec_service.acquisition_dir') + mkdir -p "$ACQUIS_DIR" + cat >"$ACQUIS_DIR"/foo.yaml <<-EOT + source: journalctl + journalctl_filter: + - "_SYSTEMD_UNIT=ssh.service" + labels: + type: syslog + EOT + + #shellcheck disable=SC2016 + rune -1 wait-for \ + --err "crowdsec init: while loading acquisition config: in file $ACQUIS_DIR/foo.yaml (position: 0) - data source journalctl is not built in this version of crowdsec" \ + env PATH='' "$CROWDSEC".min + + # auto-detection of journalctl_filter still works + cat >"$ACQUIS_DIR"/foo.yaml <<-EOT + source: whatever + journalctl_filter: + - "_SYSTEMD_UNIT=ssh.service" + labels: + type: syslog + EOT + + #shellcheck disable=SC2016 + rune -1 wait-for \ + --err "crowdsec init: while loading acquisition config: in file $ACQUIS_DIR/foo.yaml (position: 0) - data source journalctl is not built in this version of crowdsec" \ + env PATH='' "$CROWDSEC".min +} + +@test "crowdsec (disabled datasource)" { if is_package_testing; then # we can't hide journalctl in package testing # because crowdsec is run from systemd diff --git a/test/bats/01_cscli_lapi.bats b/test/bats/01_cscli_lapi.bats index a503dfff8cf..6e876576a6e 100644 --- a/test/bats/01_cscli_lapi.bats +++ b/test/bats/01_cscli_lapi.bats @@ -29,9 +29,9 @@ teardown() { rune -0 ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "Loaded credentials from" - assert_stderr --partial "Trying to authenticate with username" - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "Loaded credentials from" + assert_output --partial "Trying to authenticate with username" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli - missing LAPI credentials file" { @@ -76,7 +76,7 @@ teardown() { rune -0 ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" rm "$LOCAL_API_CREDENTIALS".local @@ -88,7 +88,7 @@ teardown() { config_set "$LOCAL_API_CREDENTIALS" '.password="$PASSWORD"' rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" # but if a variable is not defined, there is no specific error message unset URL @@ -115,7 +115,7 @@ teardown() { rune -1 cscli lapi status -o json rune -0 jq -r '.msg' <(stderr) - assert_output 'failed to authenticate to Local API (LAPI): parsing api url: parse "http://127.0.0.1:-80/": invalid port ":-80" after host' + assert_output 'failed to authenticate to Local API (LAPI): parse "http://127.0.0.1:-80/": invalid port ":-80" after host' } @test "cscli - bad LAPI password" { diff --git a/test/bats/03_noagent.bats b/test/bats/03_noagent.bats index 60731b90713..6be5101cee2 100644 --- a/test/bats/03_noagent.bats +++ b/test/bats/03_noagent.bats @@ -76,7 +76,7 @@ teardown() { config_disable_agent ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli metrics" { diff --git a/test/bats/04_capi.bats b/test/bats/04_capi.bats index d5154c1a0d7..7ba6bfa4428 100644 --- a/test/bats/04_capi.bats +++ b/test/bats/04_capi.bats @@ -46,19 +46,32 @@ setup() { assert_stderr --regexp "no configuration for Central API \(CAPI\) in '$(echo $CONFIG_YAML|sed s#//#/#g)'" } -@test "cscli capi status" { +@test "cscli {capi,papi} status" { ./instance-data load config_enable_capi + + # should not panic with no credentials, but return an error + rune -1 cscli papi status + assert_stderr --partial "the Central API (CAPI) must be configured with 'cscli capi register'" + rune -0 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX rune -1 cscli capi status - assert_stderr --partial "no scenarios installed, abort" + assert_stderr --partial "no scenarios or appsec-rules installed, abort" + + rune -1 cscli papi status + assert_stderr --partial "no PAPI URL in configuration" + + rune -0 cscli console enable console_management + rune -1 cscli papi status + assert_stderr --partial "unable to get PAPI permissions" + assert_stderr --partial "Forbidden for plan" rune -0 cscli scenarios install crowdsecurity/ssh-bf rune -0 cscli capi status - assert_stderr --partial "Loaded credentials from" - assert_stderr --partial "Trying to authenticate with username" - assert_stderr --partial " on https://api.crowdsec.net/" - assert_stderr --partial "You can successfully interact with Central API (CAPI)" + assert_output --partial "Loaded credentials from" + assert_output --partial "Trying to authenticate with username" + assert_output --partial " on https://api.crowdsec.net/" + assert_output --partial "You can successfully interact with Central API (CAPI)" } @test "cscli alerts list: receive a community pull when capi is enabled" { @@ -85,7 +98,7 @@ setup() { config_disable_agent ./instance-crowdsec start rune -0 cscli capi status - assert_stderr --partial "You can successfully interact with Central API (CAPI)" + assert_output --partial "You can successfully interact with Central API (CAPI)" } @test "capi register must be run from lapi" { diff --git a/test/bats/04_nocapi.bats b/test/bats/04_nocapi.bats index c02a75810b9..d22a6f0a953 100644 --- a/test/bats/04_nocapi.bats +++ b/test/bats/04_nocapi.bats @@ -66,7 +66,7 @@ teardown() { config_disable_capi ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli metrics" { diff --git a/test/bats/07_setup.bats b/test/bats/07_setup.bats index 2106d3ab6b2..f832ac572d2 100644 --- a/test/bats/07_setup.bats +++ b/test/bats/07_setup.bats @@ -819,6 +819,6 @@ update-notifier-motd.timer enabled enabled setup: alsdk al; sdf EOT - assert_output "while unmarshaling setup file: yaml: line 2: could not find expected ':'" + assert_output "while parsing setup file: yaml: line 2: could not find expected ':'" assert_stderr --partial "invalid setup file" } diff --git a/test/bats/09_socket.bats b/test/bats/09_socket.bats index f770abaad2e..f861d8a40dc 100644 --- a/test/bats/09_socket.bats +++ b/test/bats/09_socket.bats @@ -37,22 +37,22 @@ teardown() { ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --regexp "Trying to authenticate with username .* on $socket" - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --regexp "Trying to authenticate with username .* on $socket" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "crowdsec - listen on both socket and TCP" { ./instance-crowdsec start rune -0 cscli lapi status - assert_stderr --regexp "Trying to authenticate with username .* on http://127.0.0.1:8080/" - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --regexp "Trying to authenticate with username .* on http://127.0.0.1:8080/" + assert_output --partial "You can successfully interact with Local API (LAPI)" config_set "$LOCAL_API_CREDENTIALS" ".url=strenv(socket)" rune -0 cscli lapi status - assert_stderr --regexp "Trying to authenticate with username .* on $socket" - assert_stderr --partial "You can successfully interact with Local API (LAPI)" + assert_output --regexp "Trying to authenticate with username .* on $socket" + assert_output --partial "You can successfully interact with Local API (LAPI)" } @test "cscli - authenticate new machine with socket" { diff --git a/test/bats/10_bouncers.bats b/test/bats/10_bouncers.bats index f99913dcee5..b1c90116dd2 100644 --- a/test/bats/10_bouncers.bats +++ b/test/bats/10_bouncers.bats @@ -63,7 +63,7 @@ teardown() { @test "delete non-existent bouncer" { # this is a fatal error, which is not consistent with "machines delete" rune -1 cscli bouncers delete something - assert_stderr --partial "unable to delete bouncer: 'something' does not exist" + assert_stderr --partial "unable to delete bouncer something: ent: bouncer not found" rune -0 cscli bouncers delete something --ignore-missing refute_stderr } @@ -144,3 +144,56 @@ teardown() { rune -0 cscli bouncers prune assert_output 'No bouncers to prune.' } + +curl_localhost() { + [[ -z "$API_KEY" ]] && { fail "${FUNCNAME[0]}: missing API_KEY"; } + local path=$1 + shift + curl "localhost:8080$path" -sS --fail-with-body -H "X-Api-Key: $API_KEY" "$@" +} + +# We can't use curl-with-key here, as we want to query localhost, not 127.0.0.1 +@test "multiple bouncers sharing api key" { + export API_KEY=bouncerkey + + # crowdsec needs to listen on all interfaces + rune -0 ./instance-crowdsec stop + rune -0 config_set 'del(.api.server.listen_socket) | del(.api.server.listen_uri)' + echo "{'api':{'server':{'listen_uri':0.0.0.0:8080}}}" >"${CONFIG_YAML}.local" + + rune -0 ./instance-crowdsec start + + # add a decision for our bouncers + rune -0 cscli decisions add -i '1.2.3.5' + + rune -0 cscli bouncers add test-auto -k "$API_KEY" + + # query with 127.0.0.1 as source ip + rune -0 curl_localhost "/v1/decisions/stream" -4 + rune -0 jq -r '.new' <(output) + assert_output --partial '1.2.3.5' + + # now with ::1, we should get the same IP, even though we are using the same key + rune -0 curl_localhost "/v1/decisions/stream" -6 + rune -0 jq -r '.new' <(output) + assert_output --partial '1.2.3.5' + + rune -0 cscli bouncers list -o json + rune -0 jq -c '[.[] | [.name,.revoked,.ip_address,.auto_created]]' <(output) + assert_json '[["test-auto",false,"127.0.0.1",false],["test-auto@::1",false,"::1",true]]' + + # check the 2nd bouncer was created automatically + rune -0 cscli bouncers inspect "test-auto@::1" -o json + rune -0 jq -r '.ip_address' <(output) + assert_output --partial '::1' + + # attempt to delete the auto-created bouncer, it should fail + rune -0 cscli bouncers delete 'test-auto@::1' + assert_stderr --partial 'cannot be deleted' + + # delete the "real" bouncer, it should delete both + rune -0 cscli bouncers delete 'test-auto' + + rune -0 cscli bouncers list -o json + assert_json [] +} diff --git a/test/bats/20_hub_items.bats b/test/bats/20_hub_items.bats index 214d07d927f..4b390c90ed4 100644 --- a/test/bats/20_hub_items.bats +++ b/test/bats/20_hub_items.bats @@ -176,7 +176,7 @@ teardown() { rune -0 mkdir -p "$CONFIG_DIR/collections" rune -0 ln -s /this/does/not/exist.yaml "$CONFIG_DIR/collections/foobar.yaml" rune -0 cscli hub list - assert_stderr --partial "link target does not exist: $CONFIG_DIR/collections/foobar.yaml -> /this/does/not/exist.yaml" + assert_stderr --partial "Ignoring file $CONFIG_DIR/collections/foobar.yaml: lstat /this/does/not/exist.yaml: no such file or directory" rune -0 cscli hub list -o json rune -0 jq '.collections' <(output) assert_json '[]' @@ -194,9 +194,89 @@ teardown() { assert_output 'false' } -@test "skip files if we can't guess their type" { - rune -0 mkdir -p "$CONFIG_DIR/scenarios/foo" - rune -0 touch "$CONFIG_DIR/scenarios/foo/bar.yaml" - rune -0 cscli hub list - assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/foo/bar.yaml: unknown configuration type" +@test "don't traverse hidden directories (starting with a dot)" { + rune -0 mkdir -p "$CONFIG_DIR/scenarios/.foo" + rune -0 touch "$CONFIG_DIR/scenarios/.foo/bar.yaml" + rune -0 cscli hub list --trace + assert_stderr --partial "skipping hidden directory $CONFIG_DIR/scenarios/.foo" +} + +@test "allow symlink to target inside a hidden directory" { + # k8s config maps use hidden directories and links when mounted + rune -0 mkdir -p "$CONFIG_DIR/scenarios/.foo" + + # ignored + rune -0 touch "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # real file + rune -0 touch "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 1 + + rune -0 rm "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # link to ignored is not ignored, and the name comes from the link + rune -0 ln -s "$CONFIG_DIR/scenarios/.foo/hidden.yaml" "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -c '[.scenarios[].name] | sort' <(output) + assert_json '["myfoo.yaml"]' +} + +@test "item files can be links to links" { + rune -0 mkdir -p "$CONFIG_DIR"/scenarios/{.foo,.bar} + + rune -0 ln -s "$CONFIG_DIR/scenarios/.foo/hidden.yaml" "$CONFIG_DIR/scenarios/.bar/hidden.yaml" + + # link to a danling link + rune -0 ln -s "$CONFIG_DIR/scenarios/.bar/hidden.yaml" "$CONFIG_DIR/scenarios/myfoo.yaml" + rune -0 cscli scenarios list + assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/myfoo.yaml: lstat $CONFIG_DIR/scenarios/.foo/hidden.yaml: no such file or directory" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 0 + + # detect link loops + rune -0 ln -s "$CONFIG_DIR/scenarios/.bar/hidden.yaml" "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list + assert_stderr --partial "Ignoring file $CONFIG_DIR/scenarios/myfoo.yaml: too many levels of symbolic links" + + rune -0 rm "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 touch "$CONFIG_DIR/scenarios/.foo/hidden.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq '.scenarios | length' <(output) + assert_output 1 +} + +@test "item files can be in a subdirectory" { + rune -0 mkdir -p "$CONFIG_DIR/scenarios/sub/sub2/sub3" + rune -0 touch "$CONFIG_DIR/scenarios/sub/imlocal.yaml" + # subdir name is now part of the item name + rune -0 cscli scenarios inspect sub/imlocal.yaml -o json + rune -0 jq -e '[.tainted,.local==false,true]' <(output) + rune -0 rm "$CONFIG_DIR/scenarios/sub/imlocal.yaml" + + rune -0 ln -s "$HUB_DIR/scenarios/crowdsecurity/smb-bf.yaml" "$CONFIG_DIR/scenarios/sub/smb-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/smb-bf -o json + rune -0 jq -e '[.tainted,.local==false,false]' <(output) + rune -0 rm "$CONFIG_DIR/scenarios/sub/smb-bf.yaml" + + rune -0 ln -s "$HUB_DIR/scenarios/crowdsecurity/smb-bf.yaml" "$CONFIG_DIR/scenarios/sub/sub2/sub3/smb-bf.yaml" + rune -0 cscli scenarios inspect crowdsecurity/smb-bf -o json + rune -0 jq -e '[.tainted,.local==false,false]' <(output) +} + +@test "same file name for local items in different subdirectories" { + rune -0 mkdir -p "$CONFIG_DIR"/scenarios/{foo,bar} + rune -0 touch "$CONFIG_DIR/scenarios/foo/local.yaml" + rune -0 touch "$CONFIG_DIR/scenarios/bar/local.yaml" + rune -0 cscli scenarios list -o json + rune -0 jq -c '[.scenarios[].name] | sort' <(output) + assert_json '["bar/local.yaml","foo/local.yaml"]' } diff --git a/test/bats/90_decisions.bats b/test/bats/90_decisions.bats index c7ed214ffc9..8601414db48 100644 --- a/test/bats/90_decisions.bats +++ b/test/bats/90_decisions.bats @@ -78,13 +78,13 @@ teardown() { # invalid defaults rune -1 cscli decisions import --duration "" -i - <<<'value\n5.6.7.8' --format csv - assert_stderr --partial "--duration cannot be empty" + assert_stderr --partial "default duration cannot be empty" rune -1 cscli decisions import --scope "" -i - <<<'value\n5.6.7.8' --format csv - assert_stderr --partial "--scope cannot be empty" + assert_stderr --partial "default scope cannot be empty" rune -1 cscli decisions import --reason "" -i - <<<'value\n5.6.7.8' --format csv - assert_stderr --partial "--reason cannot be empty" + assert_stderr --partial "default reason cannot be empty" rune -1 cscli decisions import --type "" -i - <<<'value\n5.6.7.8' --format csv - assert_stderr --partial "--type cannot be empty" + assert_stderr --partial "default type cannot be empty" #---------- # JSON @@ -108,12 +108,12 @@ teardown() { # invalid json rune -1 cscli decisions import -i - <<<'{"blah":"blah"}' --format json assert_stderr --partial 'Parsing json' - assert_stderr --partial 'json: cannot unmarshal object into Go value of type []main.decisionRaw' + assert_stderr --partial 'json: cannot unmarshal object into Go value of type []clidecision.decisionRaw' # json with extra data rune -1 cscli decisions import -i - <<<'{"values":"1.2.3.4","blah":"blah"}' --format json assert_stderr --partial 'Parsing json' - assert_stderr --partial 'json: cannot unmarshal object into Go value of type []main.decisionRaw' + assert_stderr --partial 'json: cannot unmarshal object into Go value of type []clidecision.decisionRaw' #---------- # CSV diff --git a/test/instance-data b/test/instance-data index e4e76d3980a..e7fd05a9e54 100755 --- a/test/instance-data +++ b/test/instance-data @@ -1,16 +1,26 @@ #!/usr/bin/env bash +set -eu + +die() { + echo >&2 "$@" + exit 1 +} + #shellcheck disable=SC1007 THIS_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) cd "$THIS_DIR" || exit 1 # shellcheck disable=SC1091 . ./.environment.sh +if [[ -f "$LOCAL_INIT_DIR/.lock" ]] && [[ "$1" != "unlock" ]]; then + die "init data is locked: are you doing some manual test? if so, please finish what you are doing, run 'instance-data unlock' and retry" +fi + backend_script="./lib/config/config-${CONFIG_BACKEND}" if [[ ! -x "$backend_script" ]]; then - echo "unknown config backend '${CONFIG_BACKEND}'" >&2 - exit 1 + die "unknown config backend '${CONFIG_BACKEND}'" fi exec "$backend_script" "$@" diff --git a/test/lib/init/crowdsec-daemon b/test/lib/init/crowdsec-daemon index a232f344b6a..ba8e98992db 100755 --- a/test/lib/init/crowdsec-daemon +++ b/test/lib/init/crowdsec-daemon @@ -51,7 +51,11 @@ stop() { PGID="$(ps -o pgid= -p "$(cat "${DAEMON_PID}")" | tr -d ' ')" # ps above should work on linux, freebsd, busybox.. if [[ -n "${PGID}" ]]; then - kill -- "-${PGID}" + kill -- "-${PGID}" + + while pgrep -g "${PGID}" >/dev/null; do + sleep .05 + done fi rm -f -- "${DAEMON_PID}" diff --git a/test/run-tests b/test/run-tests index 6fe3bd004e2..957eb663b9c 100755 --- a/test/run-tests +++ b/test/run-tests @@ -10,12 +10,12 @@ die() { # shellcheck disable=SC1007 TEST_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) # shellcheck source=./.environment.sh -. "${TEST_DIR}/.environment.sh" +. "$TEST_DIR/.environment.sh" -"${TEST_DIR}/bin/check-requirements" +"$TEST_DIR/bin/check-requirements" echo "Running tests..." -echo "DB_BACKEND: ${DB_BACKEND}" +echo "DB_BACKEND: $DB_BACKEND" if [[ -z "$TEST_COVERAGE" ]]; then echo "Coverage report: no" else @@ -24,23 +24,23 @@ fi [[ -f "$LOCAL_INIT_DIR/.lock" ]] && die "init data is locked: are you doing some manual test? if so, please finish what you are doing, run 'instance-data unlock' and retry" -dump_backend="$(cat "${LOCAL_INIT_DIR}/.backend")" +dump_backend="$(cat "$LOCAL_INIT_DIR/.backend")" if [[ "$DB_BACKEND" != "$dump_backend" ]]; then - die "Can't run with backend '${DB_BACKEND}' because the test data was build with '${dump_backend}'" + die "Can't run with backend '$DB_BACKEND' because the test data was build with '$dump_backend'" fi if [[ $# -ge 1 ]]; then echo "test files: $*" - "${TEST_DIR}/lib/bats-core/bin/bats" \ + "$TEST_DIR/lib/bats-core/bin/bats" \ --jobs 1 \ --timing \ --print-output-on-failure \ "$@" else - echo "test files: ${TEST_DIR}/bats ${TEST_DIR}/dyn-bats" - "${TEST_DIR}/lib/bats-core/bin/bats" \ + echo "test files: $TEST_DIR/bats $TEST_DIR/dyn-bats" + "$TEST_DIR/lib/bats-core/bin/bats" \ --jobs 1 \ --timing \ --print-output-on-failure \ - "${TEST_DIR}/bats" "${TEST_DIR}/dyn-bats" + "$TEST_DIR/bats" "$TEST_DIR/dyn-bats" fi