diff --git a/cmd/osv-scanner/main.go b/cmd/osv-scanner/main.go index 8fc7989a72..ea0569e361 100644 --- a/cmd/osv-scanner/main.go +++ b/cmd/osv-scanner/main.go @@ -6,6 +6,7 @@ import ( "io" "os" + "github.com/google/osv-scanner/pkg/osv" "github.com/google/osv-scanner/pkg/osvscanner" "github.com/google/osv-scanner/pkg/reporter" @@ -27,6 +28,8 @@ func run(args []string, stdout, stderr io.Writer) int { r.PrintText(fmt.Sprintf("osv-scanner version: %s\ncommit: %s\nbuilt at: %s\n", ctx.App.Version, commit, date)) } + osv.RequestUserAgent = "osv-scanner/" + version + app := &cli.App{ Name: "osv-scanner", Version: version, diff --git a/pkg/osv/osv.go b/pkg/osv/osv.go index 134d09428f..8e7e503f74 100644 --- a/pkg/osv/osv.go +++ b/pkg/osv/osv.go @@ -28,6 +28,8 @@ const ( maxConcurrentRequests = 25 ) +var RequestUserAgent = "" + // Package represents a package identifier for OSV. type Package struct { PURL string `json:"purl,omitempty"` @@ -146,7 +148,16 @@ func MakeRequestWithClient(request BatchedQuery, client *http.Client) (*BatchedR resp, err := makeRetryRequest(func() (*http.Response, error) { // We do not need a specific context //nolint:noctx - return client.Post(QueryEndpoint, "application/json", requestBuf) + req, err := http.NewRequest(http.MethodPost, QueryEndpoint, requestBuf) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + if RequestUserAgent != "" { + req.Header.Set("User-Agent", RequestUserAgent) + } + + return client.Do(req) }) if err != nil { return nil, err @@ -179,8 +190,17 @@ func Get(id string) (*models.Vulnerability, error) { // client. func GetWithClient(id string, client *http.Client) (*models.Vulnerability, error) { resp, err := makeRetryRequest(func() (*http.Response, error) { + // We do not need a specific context //nolint:noctx - return client.Get(GetEndpoint + "/" + id) + req, err := http.NewRequest(http.MethodGet, GetEndpoint+"/"+id, nil) + if err != nil { + return nil, err + } + if RequestUserAgent != "" { + req.Header.Set("User-Agent", RequestUserAgent) + } + + return client.Do(req) }) if err != nil { return nil, err diff --git a/pkg/osvscanner/osvscanner.go b/pkg/osvscanner/osvscanner.go index 5b667a57f1..586b258889 100644 --- a/pkg/osvscanner/osvscanner.go +++ b/pkg/osvscanner/osvscanner.go @@ -536,6 +536,10 @@ func DoScan(actions ScannerActions, r reporter.Reporter) (models.VulnerabilityRe return models.VulnerabilityResults{}, NoPackagesFoundErr } + if osv.RequestUserAgent == "" { + osv.RequestUserAgent = "osv-scanner-api" + } + resp, err := osv.MakeRequest(query) if err != nil { return models.VulnerabilityResults{}, fmt.Errorf("scan failed %w", err)