diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3a03574..197ebe5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,6 +11,10 @@ jobs: name: lint runs-on: ubuntu-latest steps: + - name: setup go + uses: actions/setup-go@v3 + with: + go-version: 1.19 - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b291078..0a2dd1e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -5,7 +5,7 @@ jobs: test: strategy: matrix: - go-version: [1.17, 1.18, 1.19] + go-version: [1.18, 1.19] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: diff --git a/Dockerfile b/Dockerfile index b35054e..edbfda7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.18-alpine AS build +FROM golang:1.19-alpine AS build COPY . /go/src/jacobbednarz/go-csp-collector WORKDIR /go/src/jacobbednarz/go-csp-collector RUN set -ex \ diff --git a/README.md b/README.md index 95900e9..a160468 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,10 @@ $ CGO_ENABLED=0 go build csp_collector.go |port |Port to run on, default 8080| |filter-file|Reads the blocked URI filter list from the specified file. Note one filter per line| |health-check-path|Sets path for health checkers to use, default \/_healthcheck| +|log-client-ip|Include a field in the log with the IP delivering the report, or the value of the `X-Forwarded-For` header, if present.| +|log-truncated-client-ip|Include a field in the log with the truncated IP (to /24 for IPv4, /64 for IPv6) delivering the report, or the value of the `X-Forwarded-For` header, if present. Conflicts with `log-client-ip`. +|truncated-query-fragment|Remove all query strings and fragments (if set) from all URLs transmitted by the client| +|query-params-metadata|Log all query parameters of the report URL as a map in the `metadata` field| See the sample.filterlist.txt file as an example of the filter list in a file @@ -54,6 +58,11 @@ logged report. For example a report sent to `https://collector.example.com/?metadata=foobar` will include field `metadata` with value `foobar`. +If `query-params-metadata` is set, instead all query parameters are logged as a +map, e.g. `https://collector.example.com/?env=production&mode=enforce` will +result in `"metadata": {"env": "production", "mode": "enforce"}` in JSON +format, and `metadata="map[env:production mode:enforce]"` in default format. + ### Output formats The output format can be controlled by passing `--output-format ` diff --git a/csp_collector.go b/csp_collector.go index a9f4a25..79cb270 100644 --- a/csp_collector.go +++ b/csp_collector.go @@ -4,8 +4,8 @@ import ( "encoding/json" "flag" "fmt" - "io/ioutil" "net/http" + "net/netip" "os" "strconv" "strings" @@ -32,20 +32,16 @@ type CSPReportBody struct { StatusCode interface{} `json:"status-code"` } +const ( + // Default health check url. + defaultHealthCheckPath = "/_healthcheck" +) + var ( // Rev is set at build time and holds the revision that the package // was created at. Rev = "dev" - // Flag for toggling verbose output. - debugFlag bool - - // Flag for toggling output format. - outputFormat string - - // Flag for health check url. - healthCheckPath = "/_healthcheck" - // Shared defaults for the logger output. This ensures that we are // using the same keys for the `FieldKey` values across both formatters. logFieldMapDefaults = log.FieldMap{ @@ -54,11 +50,8 @@ var ( log.FieldKeyMsg: "message", } - // Path to file which has blocked URI's per line. - blockedURIfile string - // Default URI Filter list. - ignoredBlockedURIs = []string{ + defaultIgnoredBlockedURIs = []string{ "resource://", "chromenull://", "chrome-extension://", @@ -84,9 +77,6 @@ var ( "nativebaiduhd://adblock", "bdvideo://error", } - - // TCP Port to listen on. - listenPort int ) func init() { @@ -113,11 +103,17 @@ func trimEmptyAndComments(s []string) []string { func main() { version := flag.Bool("version", false, "Display the version") - flag.BoolVar(&debugFlag, "debug", false, "Output additional logging for debugging") - flag.StringVar(&outputFormat, "output-format", "text", "Define how the violation reports are formatted for output.\nDefaults to 'text'. Valid options are 'text' or 'json'") - flag.StringVar(&blockedURIfile, "filter-file", "", "Blocked URI Filter file") - flag.IntVar(&listenPort, "port", 8080, "Port to listen on") - flag.StringVar(&healthCheckPath, "health-check-path", healthCheckPath, "Health checker path") + debugFlag := flag.Bool("debug", false, "Output additional logging for debugging") + outputFormat := flag.String("output-format", "text", "Define how the violation reports are formatted for output.\nDefaults to 'text'. Valid options are 'text' or 'json'") + blockedURIFile := flag.String("filter-file", "", "Blocked URI Filter file") + listenPort := flag.Int("port", 8080, "Port to listen on") + healthCheckPath := flag.String("health-check-path", defaultHealthCheckPath, "Health checker path") + truncateQueryStringFragment := flag.Bool("truncate-query-fragment", false, "Truncate query string and fragment from document-uri, referrer and blocked-uri before logging (to reduce chances of accidentally logging sensitive data)") + + logClientIP := flag.Bool("log-client-ip", false, "Log the reporting client IP address") + logTruncatedClientIP := flag.Bool("log-truncated-client-ip", false, "Log the truncated client IP address (IPv4: /24, IPv6: /64") + + metadataObject := flag.Bool("query-params-metadata", false, "Write query parameters of the report URI as JSON object under metadata instead of the single metadata string") flag.Parse() @@ -126,19 +122,11 @@ func main() { os.Exit(0) } - if blockedURIfile != "" { - content, err := ioutil.ReadFile(blockedURIfile) - if err != nil { - fmt.Printf("Error reading Blocked File list: %s", blockedURIfile) - } - ignoredBlockedURIs = trimEmptyAndComments(strings.Split(string(content), "\n")) - } - - if debugFlag { + if *debugFlag { log.SetLevel(log.DebugLevel) } - if outputFormat == "json" { + if *outputFormat == "json" { log.SetFormatter(&log.JSONFormatter{ FieldMap: logFieldMapDefaults, }) @@ -153,29 +141,58 @@ func main() { } log.Debug("Starting up...") - if blockedURIfile != "" { - log.Debugf("Using Filter list from file at: %s\n", blockedURIfile) + ignoredBlockedURIs := defaultIgnoredBlockedURIs + if *blockedURIFile != "" { + log.Debugf("Using Filter list from file at: %s\n", *blockedURIFile) + + content, err := os.ReadFile(*blockedURIFile) + if err != nil { + log.Fatalf("Error reading Blocked File list: %s", *blockedURIFile) + } + ignoredBlockedURIs = trimEmptyAndComments(strings.Split(string(content), "\n")) } else { log.Debug("Using Filter list from internal list") } + log.Debugf("Blocked URI List: %s", ignoredBlockedURIs) - log.Debugf("Listening on TCP Port: %s", strconv.Itoa(listenPort)) + log.Debugf("Listening on TCP Port: %s", strconv.Itoa(*listenPort)) - http.HandleFunc("/", handleViolationReport) - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", strconv.Itoa(listenPort)), nil)) -} + http.HandleFunc(*healthCheckPath, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } -func handleViolationReport(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" && r.URL.Path == healthCheckPath { w.WriteHeader(http.StatusOK) - return - } + }) + + http.Handle("/", &violationReportHandler{ + blockedURIs: ignoredBlockedURIs, + truncateQueryStringFragment: *truncateQueryStringFragment, + + logClientIP: *logClientIP, + logTruncatedClientIP: *logTruncatedClientIP, + metadataObject: *metadataObject, + }) + log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", strconv.Itoa(*listenPort)), nil)) +} - if r.Method != "POST" { +type violationReportHandler struct { + truncateQueryStringFragment bool + blockedURIs []string + + logClientIP bool + logTruncatedClientIP bool + metadataObject bool +} + +func (vrh *violationReportHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) log.WithFields(log.Fields{ "http_method": r.Method, }).Debug("Received invalid HTTP method") + return } @@ -185,25 +202,36 @@ func handleViolationReport(w http.ResponseWriter, r *http.Request) { err := decoder.Decode(&report) if err != nil { w.WriteHeader(http.StatusUnprocessableEntity) - log.Debug(fmt.Sprintf("Unable to decode invalid JSON payload: %s", err)) + log.Debugf("Unable to decode invalid JSON payload: %s", err) return } defer r.Body.Close() - reportValidation := validateViolation(report) + reportValidation := vrh.validateViolation(report) if reportValidation != nil { http.Error(w, reportValidation.Error(), http.StatusBadRequest) - log.Debug(fmt.Sprintf("Received invalid payload: %s", reportValidation.Error())) + log.Debugf("Received invalid payload: %s", reportValidation.Error()) return } - metadatas, gotMetadata := r.URL.Query()["metadata"] - var metadata string - if gotMetadata { - metadata = metadatas[0] + var metadata interface{} + if vrh.metadataObject { + metadataMap := make(map[string]string) + query := r.URL.Query() + + for k, v := range query { + metadataMap[k] = v[0] + } + + metadata = metadataMap + } else { + metadatas, gotMetadata := r.URL.Query()["metadata"] + if gotMetadata { + metadata = metadatas[0] + } } - log.WithFields(log.Fields{ + lf := log.Fields{ "document_uri": report.Body.DocumentURI, "referrer": report.Body.Referrer, "blocked_uri": report.Body.BlockedURI, @@ -214,11 +242,36 @@ func handleViolationReport(w http.ResponseWriter, r *http.Request) { "script_sample": report.Body.ScriptSample, "status_code": report.Body.StatusCode, "metadata": metadata, - }).Info() + "path": r.URL.Path, + } + + if vrh.truncateQueryStringFragment { + lf["document_uri"] = truncateQueryStringFragment(report.Body.DocumentURI) + lf["referrer"] = truncateQueryStringFragment(report.Body.Referrer) + lf["blocked_uri"] = truncateQueryStringFragment(report.Body.BlockedURI) + } + + if vrh.logClientIP { + ip, err := getClientIP(r) + if err != nil { + log.Warnf("unable to parse client ip: %s", err) + } + lf["client_ip"] = ip.String() + } + + if vrh.logTruncatedClientIP { + ip, err := getClientIP(r) + if err != nil { + log.Warnf("unable to parse client ip: %s", err) + } + lf["client_ip"] = truncateClientIP(ip) + } + + log.WithFields(lf).Info() } -func validateViolation(r CSPReport) error { - for _, value := range ignoredBlockedURIs { +func (vrh *violationReportHandler) validateViolation(r CSPReport) error { + for _, value := range vrh.blockedURIs { if strings.HasPrefix(r.Body.BlockedURI, value) { err := fmt.Errorf("blocked URI ('%s') is an invalid resource", value) return err @@ -231,3 +284,45 @@ func validateViolation(r CSPReport) error { return nil } + +func truncateQueryStringFragment(uri string) string { + idx := strings.IndexAny(uri, "#?") + if idx != -1 { + return uri[:idx] + } + + return uri +} + +func truncateClientIP(addr netip.Addr) string { + // Ignoring the error is statically safe, as there are always enough bits. + if addr.Is4() { + p, _ := addr.Prefix(24) + return p.String() + } + + if addr.Is6() { + p, _ := addr.Prefix(64) + return p.String() + } + + return "unknown-address" +} + +func getClientIP(r *http.Request) (netip.Addr, error) { + if s := r.Header.Get("X-Forwarded-For"); s != "" { + addr, err := netip.ParseAddr(s) + if err != nil { + return netip.Addr{}, fmt.Errorf("unable to parse address from X-Forwarded-For=%s: %w", s, err) + } + + return addr, nil + } + + addrp, err := netip.ParseAddrPort(r.RemoteAddr) + if err != nil { + return netip.Addr{}, fmt.Errorf("unable to parse remote address %s: %w", r.RemoteAddr, err) + } + + return addrp.Addr(), nil +} diff --git a/csp_collector_test.go b/csp_collector_test.go index 81a420c..258d402 100644 --- a/csp_collector_test.go +++ b/csp_collector_test.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -13,6 +13,11 @@ import ( log "github.com/sirupsen/logrus" ) +var defaultViolationReportHandler = violationReportHandler{ + blockedURIs: defaultIgnoredBlockedURIs, + truncateQueryStringFragment: false, +} + func TestHandlerForDisallowedMethods(t *testing.T) { disallowedMethods := []string{"GET", "DELETE", "PUT", "TRACE", "PATCH"} randomUrls := []string{"/", "/blah"} @@ -25,7 +30,7 @@ func TestHandlerForDisallowedMethods(t *testing.T) { t.Fatalf("failed to create request: %v", err) } recorder := httptest.NewRecorder() - handleViolationReport(recorder, request) + defaultViolationReportHandler.ServeHTTP(recorder, request) response := recorder.Result() defer response.Body.Close() @@ -38,23 +43,6 @@ func TestHandlerForDisallowedMethods(t *testing.T) { } } -func TestHandlerForAllowingHealthcheck(t *testing.T) { - request, err := http.NewRequest("GET", "/_healthcheck", nil) - if err != nil { - t.Fatalf("failed to create request: %v", err) - } - recorder := httptest.NewRecorder() - - handleViolationReport(recorder, request) - - response := recorder.Result() - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - t.Errorf("expected HTTP status %v; got %v", http.StatusOK, response.StatusCode) - } -} - func TestHandlerWithMetadata(t *testing.T) { csp := CSPReport{ CSPReportBody{ @@ -80,7 +68,7 @@ func TestHandlerWithMetadata(t *testing.T) { } recorder := httptest.NewRecorder() - handleViolationReport(recorder, request) + defaultViolationReportHandler.ServeHTTP(recorder, request) response := recorder.Result() defer response.Body.Close() @@ -99,6 +87,42 @@ func TestHandlerWithMetadata(t *testing.T) { } } +func TestHandlerWithMetadataObject(t *testing.T) { + csp := CSPReport{ + CSPReportBody{ + DocumentURI: "http://example.com", + BlockedURI: "http://example.com", + }, + } + + payload, _ := json.Marshal(csp) + + var logBuffer bytes.Buffer + log.SetOutput(&logBuffer) + + request, err := http.NewRequest("POST", "/path?a=b&c=d", bytes.NewBuffer(payload)) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + recorder := httptest.NewRecorder() + + objectHandler := defaultViolationReportHandler + objectHandler.metadataObject = true + objectHandler.ServeHTTP(recorder, request) + + response := recorder.Result() + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + t.Errorf("expected HTTP status %v; got %v", http.StatusOK, response.StatusCode) + } + + log := logBuffer.String() + if !strings.Contains(log, "metadata=\"map[a:b c:d]\"") { + t.Fatalf("Logged result should contain metadata map '%s'", log) + } +} + func TestValidateViolationWithInvalidBlockedURIs(t *testing.T) { invalidBlockedURIs := []string{ "resource://", @@ -130,7 +154,7 @@ func TestValidateViolationWithInvalidBlockedURIs(t *testing.T) { testName := strings.Replace(blockedURI, "://", "", -1) t.Run(testName, func(t *testing.T) { - var rawReport = []byte(fmt.Sprintf(`{ + rawReport := []byte(fmt.Sprintf(`{ "csp-report": { "document-uri": "https://example.com", "blocked-uri": "%s" @@ -143,7 +167,7 @@ func TestValidateViolationWithInvalidBlockedURIs(t *testing.T) { fmt.Println("error:", jsonErr) } - validateErr := validateViolation(report) + validateErr := defaultViolationReportHandler.validateViolation(report) if validateErr == nil { t.Errorf("expected error to be raised but it didn't") } @@ -156,7 +180,7 @@ func TestValidateViolationWithInvalidBlockedURIs(t *testing.T) { } func TestValidateViolationWithValidBlockedURIs(t *testing.T) { - var rawReport = []byte(`{ + rawReport := []byte(`{ "csp-report": { "document-uri": "https://example.com", "blocked-uri": "https://google.com/example.css" @@ -169,20 +193,21 @@ func TestValidateViolationWithValidBlockedURIs(t *testing.T) { fmt.Println("error:", jsonErr) } - validateErr := validateViolation(report) + validateErr := defaultViolationReportHandler.validateViolation(report) if validateErr != nil { t.Errorf("expected error not be raised") } } func TestValidateNonHttpDocumentURI(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) report := CSPReport{Body: CSPReportBody{ BlockedURI: "http://example.com/", DocumentURI: "about", }} - validateErr := validateViolation(report) + + validateErr := defaultViolationReportHandler.validateViolation(report) if validateErr.Error() != "document URI ('about') is invalid" { t.Errorf("expected error to include correct message string but it didn't") } @@ -190,7 +215,7 @@ func TestValidateNonHttpDocumentURI(t *testing.T) { func TestHandleViolationReportMultipleTypeStatusCode(t *testing.T) { // Discard the output we create from the calls here. - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) statusCodeValues := []interface{}{"200", 200} @@ -214,7 +239,7 @@ func TestHandleViolationReportMultipleTypeStatusCode(t *testing.T) { } recorder := httptest.NewRecorder() - handleViolationReport(recorder, request) + defaultViolationReportHandler.ServeHTTP(recorder, request) response := recorder.Result() defer response.Body.Close() @@ -228,7 +253,7 @@ func TestHandleViolationReportMultipleTypeStatusCode(t *testing.T) { func TestFilterListProcessing(t *testing.T) { // Discard the output we create from the calls here. - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) blockList := []string{ "resource://", @@ -250,3 +275,65 @@ func TestFilterListProcessing(t *testing.T) { t.Errorf("unexpected list entry; got %v", trimmed[1]) } } + +func TestLogsPath(t *testing.T) { + var logBuffer bytes.Buffer + log.SetOutput(&logBuffer) + + csp := CSPReport{ + CSPReportBody{ + DocumentURI: "http://example.com", + BlockedURI: "http://example.com", + }, + } + + payload, _ := json.Marshal(csp) + + url := "/deep/link" + + request, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + recorder := httptest.NewRecorder() + + defaultViolationReportHandler.ServeHTTP(recorder, request) + + response := recorder.Result() + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + t.Errorf("expected HTTP status %v; got %v", http.StatusOK, response.StatusCode) + } + + log := logBuffer.String() + if !strings.Contains(log, "path=/deep/link") { + t.Fatalf("Logged result should contain path value in '%s'", log) + } +} + +func TestTruncateQueryStringFragment(t *testing.T) { + t.Parallel() + + cases := []struct { + original string + expected string + }{ + {"http://localhost.com/?test#anchor", "http://localhost.com/"}, + {"http://example.invalid", "http://example.invalid"}, + {"http://example.invalid#a", "http://example.invalid"}, + {"http://example.invalid?a", "http://example.invalid"}, + {"http://example.invalid#b?a", "http://example.invalid"}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.original, func(t *testing.T) { + t.Parallel() + actual := truncateQueryStringFragment(tc.original) + if actual != tc.expected { + t.Errorf("truncating '%s' yielded '%s', expected '%s'", tc.original, actual, tc.expected) + } + }) + } +}