Skip to content

Commit

Permalink
chore: Improved testing with more coverage (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsumners-nr authored May 6, 2024
1 parent 718d51f commit 8e8f9f7
Show file tree
Hide file tree
Showing 8 changed files with 429 additions and 38 deletions.
2 changes: 2 additions & 0 deletions flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ type StringEnumValue struct {
value string
}

// NewStringEnumValue creates a new [StringEnumValue] with a defined set of
// allowed values and a sets the initial value to a default value (`def`).
func NewStringEnumValue(allowed []string, def string) *StringEnumValue {
return &StringEnumValue{
allowed: allowed,
Expand Down
35 changes: 35 additions & 0 deletions flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package main

import (
"github.com/stretchr/testify/assert"
"testing"
)

func Test_StringEnumValue(t *testing.T) {
t.Run("only allows specified values", func(t *testing.T) {
sev := NewStringEnumValue([]string{"foo", "bar"}, "foo")
expected := &StringEnumValue{
allowed: []string{"foo", "bar"},
value: "foo",
}
assert.Equal(t, expected, sev)

err := sev.Set("baz")
assert.ErrorContains(t, err, "baz is not an allowed value")

err = sev.Set("bar")
assert.Nil(t, err)
expected.value = "bar"
assert.Equal(t, expected, sev)
})

t.Run("supports interface methods", func(t *testing.T) {
sev := NewStringEnumValue([]string{"foo", "bar"}, "foo")
assert.Equal(t, "string", sev.Type())
assert.Equal(t, "foo", sev.String())

err := sev.Set("bar")
assert.Nil(t, err)
assert.Equal(t, "bar", sev.String())
})
}
20 changes: 18 additions & 2 deletions npm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

type NpmClient struct {
baseUrl string
http *http.Client
}

type NpmPackage struct {
Expand All @@ -32,6 +33,7 @@ type NpmClientOption func(*NpmClient)
func NewNpmClient(options ...NpmClientOption) *NpmClient {
client := &NpmClient{
baseUrl: "https://registry.npmjs.com",
http: http.DefaultClient,
}

for _, opt := range options {
Expand All @@ -50,6 +52,12 @@ func WithBaseUrl(url string) NpmClientOption {
}
}

func WithHttpClient(c *http.Client) NpmClientOption {
return func(client *NpmClient) {
client.http = c
}
}

// GetDetailedInfo gets the full detailed information about a package from the
// NPM registry.
func (nc *NpmClient) GetDetailedInfo(packageName string) (*NpmDetailedPackage, error) {
Expand All @@ -62,12 +70,16 @@ func (nc *NpmClient) GetDetailedInfo(packageName string) (*NpmDetailedPackage, e
return nil, err
}

res, err := http.DefaultClient.Do(req)
res, err := nc.http.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()

if res.StatusCode != 200 {
return nil, fmt.Errorf("expected response code 200 but got %d: %s", res.StatusCode, res.Status)
}

var body NpmDetailedPackage
err = json.NewDecoder(res.Body).Decode(&body)
if err != nil {
Expand All @@ -88,12 +100,16 @@ func (nc *NpmClient) GetLatest(packageName string) (string, error) {
return "", err
}

res, err := http.DefaultClient.Do(req)
res, err := nc.http.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()

if res.StatusCode != 200 {
return "", fmt.Errorf("expected response code 200 but got %d: %s", res.StatusCode, res.Status)
}

var body NpmPackage
err = json.NewDecoder(res.Body).Decode(&body)
if err != nil {
Expand Down
145 changes: 145 additions & 0 deletions npm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package main

import (
"errors"
"github.com/jsumners/go-rfc3339"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
)

type RequestErrorTripper struct{}

func (ret *RequestErrorTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, errors.New("bad request")
}

func Test_WithBaseUrl(t *testing.T) {
npm := NewNpmClient(WithBaseUrl("http://127.0.0.1/"))
assert.Equal(t, "http://127.0.0.1", npm.baseUrl)
}

func Test_GetDetailedInfo(t *testing.T) {
t.Run("returns error for bad request construction", func(t *testing.T) {
npm := NewNpmClient(WithBaseUrl("http://127.0.0.1"))
result, err := npm.GetDetailedInfo("foo#%0x24")
assert.Empty(t, result)
assert.ErrorContains(t, err, "invalid URL escape")
})

t.Run("returns error for server error", func(t *testing.T) {
client := &http.Client{
Transport: &RequestErrorTripper{},
}
npm := NewNpmClient(WithBaseUrl("http://127.0.0.1"), WithHttpClient(client))

result, err := npm.GetDetailedInfo("anything")
assert.Empty(t, result)
assert.ErrorContains(t, err, "bad request")
})

t.Run("returns error for bad payload", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
io.WriteString(res, `{"foo":"bar"`)
}))
npm := NewNpmClient(WithBaseUrl(ts.URL))

result, err := npm.GetDetailedInfo("anything")
assert.Empty(t, result)
assert.ErrorContains(t, err, "unexpected EOF")
})

t.Run("handles error codes", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(500)
io.WriteString(res, `failed`)
}))
npm := NewNpmClient(WithBaseUrl(ts.URL))

result, err := npm.GetDetailedInfo("anything")
assert.Empty(t, result)
assert.ErrorContains(t, err, "expected response code 200 but got 500: 500")
})

t.Run("returns a success response", func(t *testing.T) {
payload := `{
"versions": {
"1.0.0": {}
},
"time": {
"1.0.0": "2024-05-03T13:00:00.000-04:00"
}
}`
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
io.WriteString(res, payload)
}))
npm := NewNpmClient(WithBaseUrl(ts.URL))

dt, _ := rfc3339.NewDateTimeFromString("2024-05-03T13:00:00.000-04:00")
expected := &NpmDetailedPackage{
Versions: map[string]any{"1.0.0": make(map[string]any)},
Time: map[string]rfc3339.DateTime{
"1.0.0": dt,
},
}
result, err := npm.GetDetailedInfo("anything")
assert.Nil(t, err)
assert.Equal(t, expected, result)
})
}

func Test_GetLatest(t *testing.T) {
t.Run("returns error for bad request construction", func(t *testing.T) {
npm := NewNpmClient(WithBaseUrl("http://127.0.0.1"))
result, err := npm.GetLatest("foo#%0x24")
assert.Empty(t, result)
assert.ErrorContains(t, err, "invalid URL escape")
})

t.Run("returns error for server error", func(t *testing.T) {
client := &http.Client{
Transport: &RequestErrorTripper{},
}
npm := NewNpmClient(WithBaseUrl("http://127.0.0.1"), WithHttpClient(client))

result, err := npm.GetLatest("anything")
assert.Empty(t, result)
assert.ErrorContains(t, err, "bad request")
})

t.Run("returns error for bad payload", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
io.WriteString(res, `{"foo":"bar"`)
}))
npm := NewNpmClient(WithBaseUrl(ts.URL))

result, err := npm.GetLatest("anything")
assert.Empty(t, result)
assert.ErrorContains(t, err, "unexpected EOF")
})

t.Run("handles error codes", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(500)
io.WriteString(res, `failed`)
}))
npm := NewNpmClient(WithBaseUrl(ts.URL))

result, err := npm.GetLatest("anything")
assert.Empty(t, result)
assert.ErrorContains(t, err, "expected response code 200 but got 500: 500")
})

t.Run("returns a success response", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
io.WriteString(res, `{"version":"1.0.0"}`)
}))
npm := NewNpmClient(WithBaseUrl(ts.URL))

result, err := npm.GetLatest("anything")
assert.Nil(t, err)
assert.Equal(t, "1.0.0", result)
})
}
66 changes: 44 additions & 22 deletions parse-package.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ type PkgInfo struct {
MinAgentVersion string
}

// parsePackage parses a versioned test `package.json` into the components
// required for the target's inclusion in the compatibility list. Which is
// to say, it pulls out the target module name, the minimum supported version
// of that module, and the minimum version of the agent that supports the
// module.
func parsePackage(pkg *VersionedTestPackageJson) ([]PkgInfo, error) {
var lastVersion *semver.Range
targets := pkg.Targets
Expand All @@ -41,37 +46,20 @@ func parsePackage(pkg *VersionedTestPackageJson) ([]PkgInfo, error) {
continue
}

var currentVersion semver.Range

// The semver library does not parse strings like `>1.0.0 <2.0.0 || >3.0.0`.
// So we need to split it up and normalize the pieces into range strings
// it can understand.
var currentVersion semver.Range
rangeStrings := strings.Split(val.Versions, "||")
for k, v := range rangeStrings {
// Oh, Go, why no slices.Map?
rangeStrings[k] = normalizeRangeString(v)
}

if len(rangeStrings) == 1 {
r, err := semver.NewRange([]byte(rangeStrings[0]))
if err != nil {
return nil, fmt.Errorf("failed to parse version string `%s` (from `%s`) for `%s`: %w", rangeStrings[0], val.Versions, targets, err)
}
currentVersion = r
} else {
ranges := make([]semver.Range, 0)
for _, rangeString := range rangeStrings {
r, err := semver.NewRange([]byte(rangeString))
if err != nil {
return nil, fmt.Errorf("failed to parse version string `%s` (from `%s`) for `%s`: %w", rangeString, val.Versions, targets, err)
}
ranges = append(ranges, r)
}
currentVersion = ranges[0]
for _, r := range ranges[1:] {
if isRangeLower(r, currentVersion) == true {
currentVersion = r
}
}
currentVersion, err := processRangeStrings(rangeStrings)
if err != nil {
return nil, fmt.Errorf("`%s` => `%s`: %w", target, val.Versions, err)
}

if lastVersion == nil {
Expand Down Expand Up @@ -107,6 +95,40 @@ func parsePackage(pkg *VersionedTestPackageJson) ([]PkgInfo, error) {
return results, nil
}

// processRangeStrings iterates a slice of semver range strings and returns
// the range with the lowest minimum version. The provided range strings
// should be normalized.
func processRangeStrings(rangeStrings []string) (semver.Range, error) {
var result semver.Range

if len(rangeStrings) == 1 {
r, err := semver.NewRange([]byte(rangeStrings[0]))
if err != nil {
return result, fmt.Errorf("failed to parse version string `%s`: %w", rangeStrings[0], err)
}
result = r
return result, nil
}

ranges := make([]semver.Range, 0)
for _, rangeString := range rangeStrings {
r, err := semver.NewRange([]byte(rangeString))
if err != nil {
return result, fmt.Errorf("failed to parse version string `%s`: %w", rangeString, err)
}
ranges = append(ranges, r)
}

result = ranges[0]
for _, r := range ranges[1:] {
if isRangeLower(r, result) == true {
result = r
}
}

return result, nil
}

// normalizeRangeString massages range strings into a format that the
// semver library recognizes as a valid range string.
func normalizeRangeString(input string) string {
Expand Down
Loading

0 comments on commit 8e8f9f7

Please sign in to comment.