Skip to content

Commit

Permalink
internal/frontend, internal/vuln: replace getVulnEntries with vuln.Cl…
Browse files Browse the repository at this point in the history
…ient

Instead of passing around a function, getVulnEntries, pass the actual
vuln client and call it directly.

Update the TestClient to implement the GetByModules function so that
tests can use it.

The purpose of this change is to further isolate calls to the vulndb
Client to the internal/vuln package, and to make the code easier to
understand by removing a function parameter.

For golang/go#58928

Change-Id: I8bef528034a1caa44b99da2f185990338ec9cd5f
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/474537
Reviewed-by: Jamal Carvalho <jamal@golang.org>
Run-TryBot: Tatiana Bradley <tatianabradley@google.com>
TryBot-Result: kokoro <noreply+kokoro@google.com>
  • Loading branch information
tatianab committed Mar 9, 2023
1 parent 0817681 commit 8e09d06
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 73 deletions.
20 changes: 8 additions & 12 deletions internal/frontend/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ func determineSearchAction(r *http.Request, ds internal.DataSource, vulnClient *
if len(filters) > 0 {
symbol = filters[0]
}
var getVulnEntries vuln.VulnEntriesFunc
if vulnClient != nil {
getVulnEntries = vulnClient.ByModule
}
page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, getVulnEntries)
page, err := fetchSearchPage(ctx, db, cq, symbol, pageParams, mode == searchModeSymbol, vulnClient)
if err != nil {
// Instead of returning a 500, return a 408, since symbol searches may
// timeout for very popular symbols.
Expand Down Expand Up @@ -236,7 +232,7 @@ type subResult struct {
// fetchSearchPage fetches data matching the search query from the database and
// returns a SearchPage.
func fetchSearchPage(ctx context.Context, db *postgres.DB, cq, symbol string,
pageParams paginationParams, searchSymbols bool, getVulnEntries vuln.VulnEntriesFunc) (*SearchPage, error) {
pageParams paginationParams, searchSymbols bool, vulnClient *vuln.Client) (*SearchPage, error) {
maxResultCount := maxSearchOffset + pageParams.limit

// Pageless search: always start from the beginning.
Expand All @@ -258,8 +254,8 @@ func fetchSearchPage(ctx context.Context, db *postgres.DB, cq, symbol string,
results = append(results, sr)
}

if getVulnEntries != nil {
addVulns(ctx, results, getVulnEntries)
if vulnClient != nil {
addVulns(ctx, results, vulnClient)
}

var numResults int
Expand Down Expand Up @@ -400,13 +396,13 @@ EntryLoop:
}, nil
}

func searchVulnAlias(ctx context.Context, mode, cq string, vulnClient *vuln.Client) (_ *searchAction, err error) {
func searchVulnAlias(ctx context.Context, mode, cq string, vc *vuln.Client) (_ *searchAction, err error) {
defer derrors.Wrap(&err, "searchVulnAlias(%q, %q)", mode, cq)

if mode != searchModeVuln || !isVulnAlias(cq) {
return nil, nil
}
aliasEntries, err := vulnClient.ByAlias(ctx, cq)
aliasEntries, err := vc.ByAlias(ctx, cq)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -607,7 +603,7 @@ func elapsedTime(date time.Time) string {

// addVulns adds vulnerability information to search results by consulting the
// vulnerability database.
func addVulns(ctx context.Context, rs []*SearchResult, getVulnEntries vuln.VulnEntriesFunc) {
func addVulns(ctx context.Context, rs []*SearchResult, vc *vuln.Client) {
// Get all vulns concurrently.
var wg sync.WaitGroup
// TODO(golang/go#48223): throttle concurrency?
Expand All @@ -616,7 +612,7 @@ func addVulns(ctx context.Context, rs []*SearchResult, getVulnEntries vuln.VulnE
wg.Add(1)
go func() {
defer wg.Done()
r.Vulns = vuln.VulnsForPackage(ctx, r.ModulePath, r.Version, r.PackagePath, getVulnEntries)
r.Vulns = vuln.VulnsForPackage(ctx, r.ModulePath, r.Version, r.PackagePath, vc)
}()
}
wg.Wait()
Expand Down
9 changes: 2 additions & 7 deletions internal/frontend/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,7 @@ func TestFetchSearchPage(t *testing.T) {
}},
}}

getVulnEntries = func(_ context.Context, modulePath string) ([]*osv.Entry, error) {
if modulePath == moduleFoo.ModulePath {
return vulnEntries, nil
}
return nil, nil
}
vc = vuln.NewTestClient(vulnEntries)
)

for _, m := range []*internal.Module{moduleFoo, moduleBar} {
Expand Down Expand Up @@ -392,7 +387,7 @@ func TestFetchSearchPage(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, getVulnEntries)
got, err := fetchSearchPage(ctx, testDB, test.query, "", paginationParams{limit: 20, page: 1}, false, vc)
if err != nil {
t.Fatalf("fetchSearchPage(db, %q): %v", test.query, err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/frontend/tabs.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ func init() {
// handler.
func fetchDetailsForUnit(ctx context.Context, r *http.Request, tab string, ds internal.DataSource, um *internal.UnitMeta,
requestedVersion string, bc internal.BuildContext,
getVulnEntries vuln.VulnEntriesFunc) (_ any, err error) {
vc *vuln.Client) (_ any, err error) {
defer derrors.Wrap(&err, "fetchDetailsForUnit(r, %q, ds, um=%q,%q,%q)", tab, um.Path, um.ModulePath, um.Version)
switch tab {
case tabMain:
_, expandReadme := r.URL.Query()["readme"]
return fetchMainDetails(ctx, ds, um, requestedVersion, expandReadme, bc)
case tabVersions:
return fetchVersionsDetails(ctx, ds, um, getVulnEntries)
return fetchVersionsDetails(ctx, ds, um, vc)
case tabImports:
return fetchImportsDetails(ctx, ds, um.Path, um.ModulePath, um.Version)
case tabImportedBy:
Expand Down
11 changes: 3 additions & 8 deletions internal/frontend/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,7 @@ func (s *Server) serveUnitPage(ctx context.Context, w http.ResponseWriter, r *ht
// It's also okay to provide just one (e.g. GOOS=windows), which will select
// the first doc with that value, ignoring the other one.
bc := internal.BuildContext{GOOS: r.FormValue("GOOS"), GOARCH: r.FormValue("GOARCH")}
var getVulnEntries vuln.VulnEntriesFunc
if s.vulnClient != nil {
getVulnEntries = s.vulnClient.ByModule
}
d, err := fetchDetailsForUnit(ctx, r, tab, ds, um, info.requestedVersion, bc, getVulnEntries)
d, err := fetchDetailsForUnit(ctx, r, tab, ds, um, info.requestedVersion, bc, s.vulnClient)
if err != nil {
return err
}
Expand Down Expand Up @@ -240,9 +236,8 @@ func (s *Server) serveUnitPage(ctx context.Context, w http.ResponseWriter, r *ht
}

// Get vulnerability information.
if s.vulnClient != nil {
page.Vulns = vuln.VulnsForPackage(ctx, um.ModulePath, um.Version, um.Path, s.vulnClient.ByModule)
}
page.Vulns = vuln.VulnsForPackage(ctx, um.ModulePath, um.Version, um.Path, s.vulnClient)

s.servePage(ctx, w, tabSettings.TemplateName, page)
return nil
}
Expand Down
8 changes: 4 additions & 4 deletions internal/frontend/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type VersionSummary struct {
Vulns []vuln.Vuln
}

func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta, getVulnEntries vuln.VulnEntriesFunc) (*VersionsDetails, error) {
func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta, vc *vuln.Client) (*VersionsDetails, error) {
db, ok := ds.(*postgres.DB)
if !ok {
// The proxydatasource does not support the imported by page.
Expand Down Expand Up @@ -114,7 +114,7 @@ func fetchVersionsDetails(ctx context.Context, ds internal.DataSource, um *inter
}
return constructUnitURL(versionPath, mi.ModulePath, linkVersion(mi.ModulePath, mi.Version, mi.Version))
}
return buildVersionDetails(ctx, um.ModulePath, um.Path, versions, sh, linkify, getVulnEntries), nil
return buildVersionDetails(ctx, um.ModulePath, um.Path, versions, sh, linkify, vc), nil
}

// pathInVersion constructs the full import path of the package corresponding
Expand Down Expand Up @@ -146,7 +146,7 @@ func buildVersionDetails(ctx context.Context, currentModulePath, packagePath str
modInfos []*internal.ModuleInfo,
sh *internal.SymbolHistory,
linkify func(v *internal.ModuleInfo) string,
getVulnEntries vuln.VulnEntriesFunc,
vc *vuln.Client,
) *VersionsDetails {
// lists organizes versions by VersionListKey.
lists := make(map[VersionListKey]*VersionList)
Expand Down Expand Up @@ -201,7 +201,7 @@ func buildVersionDetails(ctx context.Context, currentModulePath, packagePath str
if mi.ModulePath == stdlib.ModulePath {
pkg = packagePath
}
vs.Vulns = vuln.VulnsForPackage(ctx, mi.ModulePath, mi.Version, pkg, getVulnEntries)
vs.Vulns = vuln.VulnsForPackage(ctx, mi.ModulePath, mi.Version, pkg, vc)
vl := lists[key]
if vl == nil {
seenLists = append(seenLists, key)
Expand Down
9 changes: 2 additions & 7 deletions internal/frontend/versions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,7 @@ func TestFetchPackageVersionsDetails(t *testing.T) {
},
}},
}
getVulnEntries := func(_ context.Context, m string) ([]*osv.Entry, error) {
if m == modulePath1 {
return []*osv.Entry{vulnEntry}, nil
}
return nil, nil
}
vc := vuln.NewTestClient([]*osv.Entry{vulnEntry})

for _, tc := range []struct {
name string
Expand Down Expand Up @@ -201,7 +196,7 @@ func TestFetchPackageVersionsDetails(t *testing.T) {
postgres.MustInsertModule(ctx, t, testDB, v)
}

got, err := fetchVersionsDetails(ctx, testDB, &tc.pkg.UnitMeta, getVulnEntries)
got, err := fetchVersionsDetails(ctx, testDB, &tc.pkg.UnitMeta, vc)
if err != nil {
t.Fatalf("fetchVersionsDetails(ctx, db, %q, %q): %v", tc.pkg.Path, tc.pkg.ModulePath, err)
}
Expand Down
19 changes: 12 additions & 7 deletions internal/vuln/test_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,38 @@ package vuln

import (
"context"
"errors"

vulnc "golang.org/x/vuln/client"
"golang.org/x/vuln/osv"
)

// NewTestClient creates an in-memory client for use in tests.
func NewTestClient(entries []*osv.Entry) *Client {
c := &vulndbTestClient{
entries: entries,
aliasToIDs: map[string][]string{},
entries: entries,
aliasToIDs: map[string][]string{},
modulesToEntries: map[string][]*osv.Entry{},
}
for _, e := range entries {
for _, a := range e.Aliases {
c.aliasToIDs[a] = append(c.aliasToIDs[a], e.ID)
}
for _, affected := range e.Affected {
c.modulesToEntries[affected.Package.Name] = append(c.modulesToEntries[affected.Package.Name], e)
}
}
return &Client{c: c}
}

type vulndbTestClient struct {
vulnc.Client
entries []*osv.Entry
aliasToIDs map[string][]string
entries []*osv.Entry
aliasToIDs map[string][]string
modulesToEntries map[string][]*osv.Entry
}

func (c *vulndbTestClient) GetByModule(context.Context, string) ([]*osv.Entry, error) {
return nil, errors.New("unimplemented")
func (c *vulndbTestClient) GetByModule(_ context.Context, module string) ([]*osv.Entry, error) {
return c.modulesToEntries[module], nil
}

func (c *vulndbTestClient) GetByID(_ context.Context, id string) (*osv.Entry, error) {
Expand Down
21 changes: 9 additions & 12 deletions internal/vuln/vulns.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,24 @@ type Vuln struct {
Details string
}

type VulnEntriesFunc func(context.Context, string) ([]*osv.Entry, error)

// VulnsForPackage obtains vulnerability information for the given package.
// If packagePath is empty, it returns all entries for the module at version.
// The getVulnEntries function should retrieve all entries for the given module path.
// It is passed to facilitate testing.
// If there is an error, VulnsForPackage returns a single Vuln that describes the error.
func VulnsForPackage(ctx context.Context, modulePath, version, packagePath string, getVulnEntries VulnEntriesFunc) []Vuln {
vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, getVulnEntries)
func VulnsForPackage(ctx context.Context, modulePath, version, packagePath string, vc *Client) []Vuln {
if vc == nil {
return nil
}

vs, err := vulnsForPackage(ctx, modulePath, version, packagePath, vc)
if err != nil {
return []Vuln{{Details: fmt.Sprintf("could not get vulnerability data: %v", err)}}
}
return vs
}

func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, getVulnEntries VulnEntriesFunc) (_ []Vuln, err error) {
defer derrors.Wrap(&err, "vulns(%q, %q, %q)", modulePath, vers, packagePath)
func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string, vc *Client) (_ []Vuln, err error) {
defer derrors.Wrap(&err, "vulnsForPackage(%q, %q, %q)", modulePath, vers, packagePath)

if getVulnEntries == nil {
return nil, nil
}
// Stdlib pages requested at master will map to a pseudo version that puts
// all vulns in range. We can't really tell you're at master so version.IsPseudo
// is the best we can do. The result is vulns won't be reported for a pseudoversion
Expand All @@ -68,7 +65,7 @@ func vulnsForPackage(ctx context.Context, modulePath, vers, packagePath string,
modulePath = vulnStdlibModulePath
}
// Get all the vulns for this module.
entries, err := getVulnEntries(ctx, modulePath)
entries, err := vc.ByModule(ctx, modulePath)
if err != nil {
return nil, err
}
Expand Down
16 changes: 2 additions & 14 deletions internal/vuln/vulns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package vuln

import (
"context"
"fmt"
"reflect"
"testing"

Expand Down Expand Up @@ -60,18 +59,7 @@ func TestVulnsForPackage(t *testing.T) {
}},
}

get := func(_ context.Context, modulePath string) ([]*osv.Entry, error) {
switch modulePath {
case "good.com":
return nil, nil
case "bad.com", "unfixable.com":
return []*osv.Entry{&e}, nil
case "stdlib":
return []*osv.Entry{&stdlib}, nil
default:
return nil, fmt.Errorf("unknown module %q", modulePath)
}
}
vc := NewTestClient([]*osv.Entry{&e, &stdlib})

testCases := []struct {
mod, pkg, version string
Expand Down Expand Up @@ -118,7 +106,7 @@ func TestVulnsForPackage(t *testing.T) {
},
}
for _, tc := range testCases {
got := VulnsForPackage(ctx, tc.mod, tc.version, tc.pkg, get)
got := VulnsForPackage(ctx, tc.mod, tc.version, tc.pkg, vc)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("VulnsForPackage(%q, %q, %q) = %+v, mismatch (-want, +got):\n%s", tc.mod, tc.version, tc.pkg, tc.want, diff)
}
Expand Down

0 comments on commit 8e09d06

Please sign in to comment.