diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a7baee5..e9fdfe70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add bucket validation. [#114](https://github.com/xmidt-org/argus/pull/114) - Remove stored loggers. [#118](https://github.com/xmidt-org/argus/pull/118) - Drop use of admin token headers from client. [#118](https://github.com/xmidt-org/argus/pull/118) +- Refactor client code and add unit tests around item CRUD operations [#119](https://github.com/xmidt-org/argus/pull/119) ### Fixed - Fix behavior in which the owner of an existing item was overwritten in super user mode. [#116](https://github.com/xmidt-org/argus/pull/116) diff --git a/chrysom/client.go b/chrysom/client.go index 120d894f..7c6e83a8 100644 --- a/chrysom/client.go +++ b/chrysom/client.go @@ -23,12 +23,15 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "net/http" + "strings" "time" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" + "github.com/go-kit/kit/metrics" "github.com/go-kit/kit/metrics/provider" "github.com/xmidt-org/argus/model" "github.com/xmidt-org/argus/store" @@ -36,6 +39,39 @@ import ( "github.com/xmidt-org/themis/xlog" ) +const ( + storeAPIPath = "/api/v1/store" + errWrappedFmt = "%w: %s" + errStatusCodeFmt = "statusCode %v: %w" +) + +// Errors that can be returned by this package. Since some of these errors are returned wrapped, it +// is safest to use errors.Is() to check for them. +// Some internal errors might be unwrapped from output errors but unless these errors become exported, +// they are not part of the library API and may change in future versions. +var ( + ErrAddressEmpty = errors.New("argus address is required") + ErrBucketEmpty = errors.New("bucket name is required") + ErrItemIDEmpty = errors.New("item ID is required") + ErrItemIDMismatch = errors.New("item ID must match that in payload") + ErrItemDataEmpty = errors.New("data field in item is required") + ErrUndefinedMetricsProvider = errors.New("a metrics provider is required") + ErrUndefinedIntervalTicker = errors.New("interval ticker is nil. Can't listen for updates") + ErrAuthAcquirerFailure = errors.New("failed acquiring auth token") + + ErrFailedAuthentication = errors.New("failed to authentication with argus") + ErrBadRequest = errors.New("argus rejected the request as invalid") +) + +var ( + errNonSuccessResponse = errors.New("argus responded with a non-success status code") + errNewRequestFailure = errors.New("failed creating an HTTP request") + errDoRequestFailure = errors.New("http client failed while sending request") + errReadingBodyFailure = errors.New("failed while reading http response body") + errJSONUnmarshal = errors.New("failed unmarshaling JSON response payload") + errJSONMarshal = errors.New("failed marshaling item as JSON payload") +) + // PushResult is a simple type to indicate the result type for the // PushItem operation. type PushResult string @@ -47,14 +83,43 @@ const ( ) type ClientConfig struct { - HTTPClient *http.Client - Bucket string - PullInterval time.Duration - Address string - Auth Auth + // HTTPClient refers to the client that will be used to send + // HTTP requests. + // (Optional) http.DefaultClient is used if left empty. + HTTPClient *http.Client + + // Address is the Argus URL (i.e. https://example-argus.io:8090) + Address string + + // Auth provides the mechanism to add auth headers to outgoing + // requests + // (Optional) If not provided, no auth headers are added. + Auth Auth + + // MetricsProvider allows measures updated by the client to be collected. MetricsProvider provider.Provider - Logger log.Logger - Listener Listener + + Logger log.Logger + + // Listener is the component that consumes the latest list of owned items in a + // bucket. + Listener Listener + + // PullInterval is how often listeners should get updates. + PullInterval time.Duration + + // Bucket to be used in listener requests. + Bucket string + + // Owner to be used in listener requests. + // (Optional) If left empty, items without an owner will be watched. + Owner string +} + +type response struct { + Body []byte + ArgusErrorHeader string + Code int } type Auth struct { @@ -62,62 +127,82 @@ type Auth struct { Basic string } +type Items []model.Item + type Client struct { - client *http.Client - ticker *time.Ticker - auth acquire.Acquirer - metrics *measures - listener Listener - bucketName string - remoteStoreAddress string - logger log.Logger + client *http.Client + auth acquire.Acquirer + storeBaseURL string + logger log.Logger + observer *listenerConfig } -func initMetrics(p provider.Provider) *measures { - return &measures{ - pollCount: p.NewCounter(PollCounter), - } +type listenerConfig struct { + listener Listener + ticker *time.Ticker + pollCount metrics.Counter + bucket string + owner string } -func CreateClient(config ClientConfig) (*Client, error) { +func NewClient(config ClientConfig) (*Client, error) { err := validateConfig(&config) if err != nil { return nil, err } - auth, err := determineTokenAcquirer(config) + tokenAcquirer, err := buildTokenAcquirer(&config.Auth) if err != nil { return nil, err } + clientStore := &Client{ - client: config.HTTPClient, - ticker: time.NewTicker(config.PullInterval), - auth: auth, - metrics: initMetrics(config.MetricsProvider), - logger: config.Logger, - listener: config.Listener, - remoteStoreAddress: config.Address, - bucketName: config.Bucket, + client: config.HTTPClient, + auth: tokenAcquirer, + logger: config.Logger, + observer: createObserver(config.Logger, config), + storeBaseURL: config.Address + storeAPIPath, } - if config.PullInterval > 0 { - clientStore.ticker = time.NewTicker(config.PullInterval) + return clientStore, nil +} + +// translateNonSuccessStatusCode returns as specific error +// for known Argus status codes. +func translateNonSuccessStatusCode(code int) error { + switch code { + case http.StatusBadRequest: + return ErrBadRequest + case http.StatusUnauthorized, http.StatusForbidden: + return ErrFailedAuthentication + default: + return errNonSuccessResponse } +} - return clientStore, nil +func createObserver(logger log.Logger, config ClientConfig) *listenerConfig { + if config.Listener == nil { + return nil + } + return &listenerConfig{ + listener: config.Listener, + ticker: time.NewTicker(config.PullInterval), + pollCount: config.MetricsProvider.NewCounter(PollCounter), + bucket: config.Bucket, + owner: config.Owner, + } } func validateConfig(config *ClientConfig) error { if config.HTTPClient == nil { config.HTTPClient = http.DefaultClient } + if config.Address == "" { - return errors.New("address can't be empty") - } - if config.Bucket == "" { - config.Bucket = "testing" + return ErrAddressEmpty } + if config.MetricsProvider == nil { - return errors.New("a metrics provider is required") + return ErrUndefinedMetricsProvider } if config.PullInterval == 0 { @@ -130,144 +215,157 @@ func validateConfig(config *ClientConfig) error { return nil } -func determineTokenAcquirer(config ClientConfig) (acquire.Acquirer, error) { - defaultAcquirer := &acquire.DefaultAcquirer{} - if config.Auth.JWT.AuthURL != "" && config.Auth.JWT.Buffer != 0 && config.Auth.JWT.Timeout != 0 { - return acquire.NewRemoteBearerTokenAcquirer(config.Auth.JWT) - } +func isEmpty(options acquire.RemoteBearerTokenAcquirerOptions) bool { + return len(options.AuthURL) < 1 || options.Buffer == 0 || options.Timeout == 0 +} - if config.Auth.Basic != "" { - return acquire.NewFixedAuthAcquirer(config.Auth.Basic) +func buildTokenAcquirer(auth *Auth) (acquire.Acquirer, error) { + if !isEmpty(auth.JWT) { + return acquire.NewRemoteBearerTokenAcquirer(auth.JWT) + } else if len(auth.Basic) > 0 { + return acquire.NewFixedAuthAcquirer(auth.Basic) } - - return defaultAcquirer, nil + return &acquire.DefaultAcquirer{}, nil } -func (c *Client) GetItems(owner string) ([]model.Item, error) { - request, err := http.NewRequest("GET", fmt.Sprintf("%s/api/v1/store/%s", c.remoteStoreAddress, c.bucketName), nil) +func (c Client) sendRequest(owner, method, url string, body io.Reader) (response, error) { + r, err := http.NewRequest(method, url, body) if err != nil { - return nil, err + return response{}, fmt.Errorf(errWrappedFmt, errNewRequestFailure, err.Error()) } - err = acquire.AddAuth(request, c.auth) + err = acquire.AddAuth(r, c.auth) if err != nil { - return nil, err + return response{}, fmt.Errorf(errWrappedFmt, ErrAuthAcquirerFailure, err.Error()) } + if len(owner) > 0 { + r.Header.Set(store.ItemOwnerHeaderKey, owner) + } + resp, err := c.client.Do(r) + if err != nil { + return response{}, fmt.Errorf(errWrappedFmt, errDoRequestFailure, err.Error()) + } + defer resp.Body.Close() - request.Header.Set(store.ItemOwnerHeaderKey, owner) - - response, err := c.client.Do(request) + var sqResp = response{ + Code: resp.StatusCode, + ArgusErrorHeader: resp.Header.Get(store.XmidtErrorHeaderKey), + } + bodyBytes, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, err + return sqResp, fmt.Errorf(errWrappedFmt, errReadingBodyFailure, err.Error()) } - defer response.Body.Close() + sqResp.Body = bodyBytes + return sqResp, nil +} - if response.StatusCode != http.StatusOK { - level.Error(c.logger).Log(xlog.MessageKey(), "Argus responded with non-200 response for GetItems request", "code", response.StatusCode) - return nil, errors.New("failed to get items, non 200 statuscode") +// GetItems fetches all items in a bucket that belong to a given owner. +func (c *Client) GetItems(bucket, owner string) (Items, error) { + if len(bucket) < 1 { + return nil, ErrBucketEmpty } - items := []model.Item{} - err = json.NewDecoder(response.Body).Decode(&items) + response, err := c.sendRequest(owner, http.MethodGet, fmt.Sprintf("%s/%s", c.storeBaseURL, bucket), nil) if err != nil { return nil, err } - return items, nil -} - -func (c *Client) Push(item model.Item, owner string) (PushResult, error) { - if item.ID == "" { - return "", errors.New("id can't be empty") + if response.Code != http.StatusOK { + level.Error(c.logger).Log(xlog.MessageKey(), "Argus responded with non-200 response for GetItems request", + "code", response.Code, "ErrorHeader", response.ArgusErrorHeader) + return nil, fmt.Errorf(errStatusCodeFmt, response.Code, translateNonSuccessStatusCode(response.Code)) } - if item.TTL != nil && *item.TTL < 1 { - return "", errors.New("when provided, TTL must be > 0") - } + var items Items - data, err := json.Marshal(&item) + err = json.Unmarshal(response.Body, &items) if err != nil { - return "", err + return nil, fmt.Errorf("GetItems: %w: %s", errJSONUnmarshal, err.Error()) } - request, err := http.NewRequest("PUT", fmt.Sprintf("%s/api/v1/store/%s/%s", c.remoteStoreAddress, c.bucketName, item.ID), bytes.NewReader(data)) + + return items, nil +} + +// PushItem creates a new item if one doesn't already exist at +// the resource path '{BUCKET}/{ID}'. If an item exists and the ownership matches, +// the item is simply updated. +func (c *Client) PushItem(id, bucket, owner string, item model.Item) (PushResult, error) { + err := validatePushItemInput(bucket, owner, id, item) if err != nil { return "", err } - err = acquire.AddAuth(request, c.auth) + + data, err := json.Marshal(item) if err != nil { - return "", err + return "", fmt.Errorf(errWrappedFmt, errJSONMarshal, err.Error()) } - request.Header.Add(store.ItemOwnerHeaderKey, owner) - response, err := c.client.Do(request) + response, err := c.sendRequest(owner, http.MethodPut, fmt.Sprintf("%s/%s/%s", c.storeBaseURL, bucket, id), bytes.NewReader(data)) if err != nil { return "", err } - defer response.Body.Close() - - switch response.StatusCode { - case http.StatusCreated: + if response.Code == http.StatusCreated { return CreatedPushResult, nil - case http.StatusOK: + } + + if response.Code == http.StatusOK { return UpdatedPushResult, nil } - level.Error(c.logger).Log(xlog.MessageKey(), "Argus responded with a non-successful status code for a Push request", "code", response.StatusCode) - return "", errors.New("Failed to set item as DB responded with non-success statuscode") + + level.Error(c.logger).Log(xlog.MessageKey(), "Argus responded with a non-successful status code for a PushItem request", + "code", response.Code, "ErrorHeader", response.ArgusErrorHeader) + + return "", fmt.Errorf(errStatusCodeFmt, response.Code, translateNonSuccessStatusCode(response.Code)) } -func (c *Client) Remove(id string, owner string) (model.Item, error) { - if id == "" { - return model.Item{}, errors.New("id can't be empty") - } - request, err := http.NewRequest("DELETE", fmt.Sprintf("%s/api/v1/store/%s/%s", c.remoteStoreAddress, c.bucketName, id), nil) - if err != nil { - return model.Item{}, err - } - err = acquire.AddAuth(request, c.auth) +// RemoveItem removes the item if it exists and returns the data associated to it. +func (c *Client) RemoveItem(id, bucket, owner string) (model.Item, error) { + err := validateRemoveItemInput(bucket, id) if err != nil { return model.Item{}, err } - request.Header.Add(store.ItemOwnerHeaderKey, owner) - - response, err := c.client.Do(request) + resp, err := c.sendRequest(owner, http.MethodDelete, fmt.Sprintf("%s/%s/%s", c.storeBaseURL, bucket, id), nil) if err != nil { return model.Item{}, err } - if response.StatusCode != 200 { - return model.Item{}, errors.New("failed to delete item, non 200 statuscode") + + if resp.Code != http.StatusOK { + level.Error(c.logger).Log(xlog.MessageKey(), "Argus responded with a non-successful status code for a RemoveItem request", + "code", resp.Code, "ErrorHeader", resp.ArgusErrorHeader) + return model.Item{}, fmt.Errorf(errStatusCodeFmt, resp.Code, translateNonSuccessStatusCode(resp.Code)) } - defer response.Body.Close() - responsePayload, _ := ioutil.ReadAll(response.Body) - item := model.Item{} - err = json.Unmarshal(responsePayload, &item) + + var item model.Item + err = json.Unmarshal(resp.Body, &item) if err != nil { - return model.Item{}, err + return item, fmt.Errorf("RemoveItem: %w: %s", errJSONUnmarshal, err.Error()) } return item, nil } func (c *Client) Start(ctx context.Context) error { - if c.ticker == nil { - return errors.New("interval ticker is nil") + if c.observer == nil { + level.Warn(c.logger).Log(xlog.MessageKey(), "No listener was setup to receive updates.") + return nil } - if c.listener == nil { - level.Info(c.logger).Log(xlog.MessageKey(), "No listener was setup to receive updates.") - return nil + if c.observer.ticker == nil { + return ErrUndefinedIntervalTicker } go func() { - for range c.ticker.C { + observer := c.observer + for range observer.ticker.C { outcome := SuccessOutcome - items, err := c.GetItems("") + items, err := c.GetItems(observer.bucket, observer.owner) if err == nil { - c.listener.Update(items) + observer.listener.Update(items) } else { outcome = FailureOutcome level.Error(c.logger).Log(xlog.MessageKey(), "Failed to get items for listeners", xlog.ErrorKey(), err) } - c.metrics.pollCount.With(OutcomeLabel, outcome).Add(1) + observer.pollCount.With(OutcomeLabel, outcome).Add(1) } }() @@ -275,8 +373,39 @@ func (c *Client) Start(ctx context.Context) error { } func (c *Client) Stop(ctx context.Context) error { - if c.ticker != nil { - c.ticker.Stop() + if c.observer != nil && c.observer.ticker != nil { + c.observer.ticker.Stop() + } + return nil +} + +func validatePushItemInput(bucket, owner, id string, item model.Item) error { + if len(bucket) < 1 { + return ErrBucketEmpty + } + + if len(id) < 1 || len(item.ID) < 1 { + return ErrItemIDEmpty + } + + if !strings.EqualFold(id, item.ID) { + return ErrItemIDMismatch + } + + if len(item.Data) < 1 { + return ErrItemDataEmpty + } + + return nil +} + +func validateRemoveItemInput(bucket, id string) error { + if len(bucket) < 1 { + return ErrBucketEmpty + } + + if len(id) < 1 { + return ErrItemIDEmpty } return nil } diff --git a/chrysom/client_test.go b/chrysom/client_test.go index 03aa1622..5f399494 100644 --- a/chrysom/client_test.go +++ b/chrysom/client_test.go @@ -1,17 +1,30 @@ package chrysom import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/metrics/provider" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xmidt-org/argus/model" + "github.com/xmidt-org/argus/store" ) +const failingURL = "nowhere://" + func TestInterface(t *testing.T) { assert := assert.New(t) - var ( - client interface{} - ) - client = &Client{} + var client interface{} = &Client{} _, ok := client.(Pusher) assert.True(ok, "not a pusher") _, ok = client.(Reader) @@ -19,3 +32,644 @@ func TestInterface(t *testing.T) { _, ok = client.(PushReader) assert.True(ok, "not a PushReader") } + +func TestValidateConfig(t *testing.T) { + type testCase struct { + Description string + Input *ClientConfig + ExpectedErr error + ExpectedConfig *ClientConfig + } + + allDefaultsCaseConfig := &ClientConfig{ + HTTPClient: http.DefaultClient, + PullInterval: time.Second * 5, + Logger: log.NewNopLogger(), + Address: "http://awesome-argus-hostname.io", + MetricsProvider: provider.NewDiscardProvider(), + } + + myAmazingClient := &http.Client{Timeout: time.Hour} + allDefinedCaseConfig := &ClientConfig{ + HTTPClient: myAmazingClient, + PullInterval: time.Hour * 24, + Address: "http://legit-argus-hostname.io", + Auth: Auth{}, + MetricsProvider: provider.NewDiscardProvider(), + Logger: log.NewJSONLogger(ioutil.Discard), + } + + tcs := []testCase{ + { + Description: "All default values", + Input: &ClientConfig{ + Address: "http://awesome-argus-hostname.io", + MetricsProvider: provider.NewDiscardProvider(), + }, + ExpectedConfig: allDefaultsCaseConfig, + }, + + { + Description: "No metrics provider", + Input: &ClientConfig{ + Address: "http://awesome-argus-hostname.io", + }, + ExpectedErr: ErrUndefinedMetricsProvider, + }, + { + Description: "No address", + Input: &ClientConfig{ + MetricsProvider: provider.NewDiscardProvider(), + }, + ExpectedErr: ErrAddressEmpty, + }, + + { + Description: "All defined", + Input: &ClientConfig{ + MetricsProvider: provider.NewDiscardProvider(), + Address: "http://legit-argus-hostname.io", + HTTPClient: myAmazingClient, + PullInterval: time.Hour * 24, + Logger: log.NewJSONLogger(ioutil.Discard), + }, + ExpectedConfig: allDefinedCaseConfig, + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + assert := assert.New(t) + err := validateConfig(tc.Input) + assert.Equal(tc.ExpectedErr, err) + if tc.ExpectedErr == nil { + assert.Equal(tc.ExpectedConfig, tc.Input) + } + }) + } +} + +func TestSendRequest(t *testing.T) { + type testCase struct { + Description string + Owner string + Method string + URL string + Body []byte + AcquirerFails bool + ClientDoFails bool + ExpectedResponse response + ExpectedErr error + } + + tcs := []testCase{ + { + Description: "New Request fails", + Method: "what method?", + URL: "http://argus-hostname.io", + ExpectedErr: errNewRequestFailure, + }, + { + Description: "Auth acquirer fails", + Method: http.MethodGet, + URL: "http://argus-hostname.io", + AcquirerFails: true, + ExpectedErr: ErrAuthAcquirerFailure, + }, + { + Description: "Client Do fails", + Method: http.MethodPut, + ClientDoFails: true, + ExpectedErr: errDoRequestFailure, + }, + { + Description: "Happy path", + Method: http.MethodPut, + URL: "http://argus-hostname.io", + Body: []byte("testing"), + Owner: "HappyCaseOwner", + ExpectedResponse: response{ + Code: http.StatusOK, + Body: []byte("testing"), + }, + }, + } + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + echoHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.Equal(tc.Owner, r.Header.Get(store.ItemOwnerHeaderKey)) + rw.WriteHeader(http.StatusOK) + bodyBytes, err := ioutil.ReadAll(r.Body) + require.Nil(err) + rw.Write(bodyBytes) + }) + + server := httptest.NewServer(echoHandler) + defer server.Close() + + client, err := NewClient(ClientConfig{ + HTTPClient: server.Client(), + Address: "http://argus-hostname.io", + MetricsProvider: provider.NewDiscardProvider(), + }) + + if tc.AcquirerFails { + client.auth = acquirerFunc(failAcquirer) + } + + var URL = server.URL + if tc.ClientDoFails { + URL = "http://should-definitely-fail.net" + } + + assert.Nil(err) + resp, err := client.sendRequest(tc.Owner, tc.Method, URL, bytes.NewBuffer(tc.Body)) + + if tc.ExpectedErr == nil { + assert.Equal(http.StatusOK, resp.Code) + assert.Equal(tc.ExpectedResponse, resp) + } else { + assert.True(errors.Is(err, tc.ExpectedErr)) + } + }) + } +} + +func TestGetItems(t *testing.T) { + type testCase struct { + Description string + ResponsePayload []byte + ResponseCode int + ShouldEraseBucket bool + ShouldMakeRequestFail bool + ShouldDoRequestFail bool + ExpectedErr error + ExpectedOutput Items + } + + tcs := []testCase{ + { + Description: "Bucket is required", + ShouldEraseBucket: true, + ExpectedErr: ErrBucketEmpty, + }, + { + + Description: "Make request fails", + ShouldMakeRequestFail: true, + ExpectedErr: ErrAuthAcquirerFailure, + }, + { + Description: "Do request fails", + ShouldDoRequestFail: true, + ExpectedErr: errDoRequestFailure, + }, + { + Description: "Unauthorized", + ResponseCode: http.StatusForbidden, + ExpectedErr: ErrFailedAuthentication, + }, + { + Description: "Bad request", + ResponseCode: http.StatusBadRequest, + ExpectedErr: ErrBadRequest, + }, + { + Description: "Other non-success", + ResponseCode: http.StatusInternalServerError, + ExpectedErr: errNonSuccessResponse, + }, + { + Description: "Payload unmarshal error", + ResponseCode: http.StatusOK, + ResponsePayload: []byte("[{}"), + ExpectedErr: errJSONUnmarshal, + }, + { + Description: "Happy path", + ResponseCode: http.StatusOK, + ResponsePayload: getItemsValidPayload(), + ExpectedOutput: getItemsHappyOutput(), + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + var ( + assert = assert.New(t) + require = require.New(t) + bucket = "bucket-name" + owner = "owner-name" + ) + + if tc.ShouldEraseBucket { + bucket = "" + } + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.Equal(http.MethodGet, r.Method) + assert.Equal(owner, r.Header.Get(store.ItemOwnerHeaderKey)) + assert.Equal(fmt.Sprintf("%s/%s", storeAPIPath, bucket), r.URL.Path) + + rw.WriteHeader(tc.ResponseCode) + rw.Write(tc.ResponsePayload) + })) + + client, err := NewClient(ClientConfig{ + HTTPClient: server.Client(), + Address: server.URL, + MetricsProvider: provider.NewDiscardProvider(), + }) + + require.Nil(err) + + if tc.ShouldMakeRequestFail { + client.auth = acquirerFunc(failAcquirer) + } + + if tc.ShouldDoRequestFail { + client.storeBaseURL = failingURL + } + + output, err := client.GetItems(bucket, "owner-name") + + assert.True(errors.Is(err, tc.ExpectedErr)) + if tc.ExpectedErr == nil { + assert.EqualValues(tc.ExpectedOutput, output) + } + }) + } +} + +func TestPushItem(t *testing.T) { + type testCase struct { + Description string + Item model.Item + Owner string + ResponseCode int + ShouldEraseID bool + ShouldEraseBucket bool + ShouldRespNonSuccess bool + ShouldMakeRequestFail bool + ShouldDoRequestFail bool + ExpectedErr error + ExpectedOutput PushResult + } + + validItem := model.Item{ + ID: "252f10c83610ebca1a059c0bae8255eba2f95be4d1d7bcfa89d7248a82d9f111", + Data: map[string]interface{}{ + "field0": float64(0), + "nested": map[string]interface{}{ + "response": "wow", + }, + }} + + tcs := []testCase{ + { + Description: "Bucket is required", + Item: validItem, + ShouldEraseBucket: true, + ExpectedErr: ErrBucketEmpty, + }, + { + Description: "Item ID Missing", + Item: validItem, + ShouldEraseID: true, + ExpectedErr: ErrItemIDEmpty, + }, + { + Description: "Item ID Missing from payload", + Item: model.Item{Data: validItem.Data}, + ExpectedErr: ErrItemIDEmpty, + }, + { + Description: "Item ID Mismatch", + Item: model.Item{ID: "752f10c83610ebca1a059c0bae8255eba2f95be4d1d7bcfa89d7248a82d9f119", Data: validItem.Data}, + ExpectedErr: ErrItemIDMismatch, + }, + { + Description: "Item Data missing", + Item: model.Item{ID: validItem.ID}, + ExpectedErr: ErrItemDataEmpty, + }, + { + Description: "Make request fails", + Item: validItem, + ShouldMakeRequestFail: true, + ExpectedErr: ErrAuthAcquirerFailure, + }, + { + Description: "Do request fails", + Item: validItem, + ShouldDoRequestFail: true, + ExpectedErr: errDoRequestFailure, + }, + { + Description: "Unauthorized", + Item: validItem, + ResponseCode: http.StatusForbidden, + ExpectedErr: ErrFailedAuthentication, + }, + { + Description: "Bad request", + Item: validItem, + ResponseCode: http.StatusBadRequest, + ExpectedErr: ErrBadRequest, + }, + { + Description: "Other non-success", + Item: validItem, + ResponseCode: http.StatusInternalServerError, + ExpectedErr: errNonSuccessResponse, + }, + { + Description: "Create success", + Item: validItem, + ResponseCode: http.StatusCreated, + ExpectedOutput: CreatedPushResult, + }, + { + Description: "Update success", + Item: validItem, + ResponseCode: http.StatusOK, + ExpectedOutput: UpdatedPushResult, + }, + + { + Description: "Update success with owner", + Item: validItem, + ResponseCode: http.StatusOK, + Owner: "owner-name", + ExpectedOutput: UpdatedPushResult, + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + var ( + assert = assert.New(t) + require = require.New(t) + bucket = "bucket-name" + id = "252f10c83610ebca1a059c0bae8255eba2f95be4d1d7bcfa89d7248a82d9f111" + ) + + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.Equal(fmt.Sprintf("%s/%s/%s", storeAPIPath, bucket, id), r.URL.Path) + assert.Equal(tc.Owner, r.Header.Get(store.ItemOwnerHeaderKey)) + rw.WriteHeader(tc.ResponseCode) + + if tc.ResponseCode == http.StatusCreated || tc.ResponseCode == http.StatusOK { + payload, err := ioutil.ReadAll(r.Body) + require.Nil(err) + var item model.Item + err = json.Unmarshal(payload, &item) + require.Nil(err) + assert.EqualValues(tc.Item, item) + } + })) + + client, err := NewClient(ClientConfig{ + HTTPClient: server.Client(), + Address: server.URL, + MetricsProvider: provider.NewDiscardProvider(), + }) + + if tc.ShouldMakeRequestFail { + client.auth = acquirerFunc(failAcquirer) + } + + if tc.ShouldDoRequestFail { + client.storeBaseURL = failingURL + } + + if tc.ShouldEraseBucket { + bucket = "" + } + + if tc.ShouldEraseID { + id = "" + } + + require.Nil(err) + output, err := client.PushItem(id, bucket, tc.Owner, tc.Item) + + if tc.ExpectedErr == nil { + assert.EqualValues(tc.ExpectedOutput, output) + } else { + assert.True(errors.Is(err, tc.ExpectedErr)) + } + }) + } +} + +func TestRemoveItem(t *testing.T) { + type testCase struct { + Description string + ResponsePayload []byte + ResponseCode int + Owner string + ShouldEraseBucket bool + ShouldEraseID bool + ShouldRespNonSuccess bool + ShouldMakeRequestFail bool + ShouldDoRequestFail bool + ExpectedErr error + ExpectedOutput model.Item + } + + tcs := []testCase{ + { + Description: "Bucket is required", + ShouldEraseBucket: true, + ExpectedErr: ErrBucketEmpty, + }, + { + Description: "Item ID Missing", + ShouldEraseID: true, + ExpectedErr: ErrItemIDEmpty, + }, + { + Description: "Make request fails", + ShouldMakeRequestFail: true, + ExpectedErr: ErrAuthAcquirerFailure, + }, + { + Description: "Do request fails", + ShouldDoRequestFail: true, + ExpectedErr: errDoRequestFailure, + }, + { + Description: "Unauthorized", + ResponseCode: http.StatusForbidden, + ExpectedErr: ErrFailedAuthentication, + }, + { + Description: "Bad request", + ResponseCode: http.StatusBadRequest, + ExpectedErr: ErrBadRequest, + }, + { + Description: "Other non-success", + ResponseCode: http.StatusInternalServerError, + ExpectedErr: errNonSuccessResponse, + }, + { + Description: "Unmarshal failure", + ResponseCode: http.StatusOK, + ResponsePayload: []byte("{{}"), + ExpectedErr: errJSONUnmarshal, + }, + { + Description: "Succcess", + ResponseCode: http.StatusOK, + ResponsePayload: getRemoveItemValidPayload(), + ExpectedOutput: getRemoveItemHappyOutput(), + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + var ( + assert = assert.New(t) + require = require.New(t) + bucket = "bucket-name" + id = "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7" + ) + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + assert.Equal(fmt.Sprintf("%s/%s/%s", storeAPIPath, bucket, id), r.URL.Path) + assert.Equal(http.MethodDelete, r.Method) + rw.WriteHeader(tc.ResponseCode) + rw.Write(tc.ResponsePayload) + })) + + client, err := NewClient(ClientConfig{ + HTTPClient: server.Client(), + Address: server.URL, + MetricsProvider: provider.NewDiscardProvider(), + }) + + if tc.ShouldMakeRequestFail { + client.auth = acquirerFunc(failAcquirer) + } + + if tc.ShouldDoRequestFail { + client.storeBaseURL = failingURL + } + + if tc.ShouldEraseID { + id = "" + } + + if tc.ShouldEraseBucket { + bucket = "" + } + + require.Nil(err) + output, err := client.RemoveItem(id, bucket, tc.Owner) + + if tc.ExpectedErr == nil { + assert.EqualValues(tc.ExpectedOutput, output) + } else { + assert.True(errors.Is(err, tc.ExpectedErr)) + } + }) + } +} + +func TestTranslateStatusCode(t *testing.T) { + type testCase struct { + Description string + Code int + ExpectedErr error + } + + tcs := []testCase{ + { + Code: http.StatusForbidden, + ExpectedErr: ErrFailedAuthentication, + }, + { + Code: http.StatusUnauthorized, + ExpectedErr: ErrFailedAuthentication, + }, + { + Code: http.StatusBadRequest, + ExpectedErr: ErrBadRequest, + }, + { + Code: http.StatusInternalServerError, + ExpectedErr: errNonSuccessResponse, + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + assert := assert.New(t) + assert.Equal(tc.ExpectedErr, translateNonSuccessStatusCode(tc.Code)) + }) + } +} +func failAcquirer() (string, error) { + return "", errors.New("always fail") +} + +type acquirerFunc func() (string, error) + +func (a acquirerFunc) Acquire() (string, error) { + return a() +} + +func getItemsValidPayload() []byte { + return []byte(`[{ + "id": "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7", + "data": { + "words": [ + "Hello","World" + ], + "year": 2021 + }, + "ttl": 255 + }]`) +} + +func getItemsHappyOutput() Items { + return []model.Item{ + { + ID: "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7", + Data: map[string]interface{}{ + "words": []interface{}{"Hello", "World"}, + "year": float64(2021), + }, + TTL: aws.Int64(255), + }, + } +} + +func getRemoveItemValidPayload() []byte { + return []byte(` + { + "id": "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7", + "data": { + "words": [ + "Hello","World" + ], + "year": 2021 + }, + "ttl": 100 + }`) +} + +func getRemoveItemHappyOutput() model.Item { + return model.Item{ + ID: "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7", + Data: map[string]interface{}{ + "words": []interface{}{"Hello", "World"}, + "year": float64(2021), + }, + TTL: aws.Int64(100), + } +} diff --git a/chrysom/store.go b/chrysom/store.go index bb12043e..7fcb22ff 100644 --- a/chrysom/store.go +++ b/chrysom/store.go @@ -32,10 +32,10 @@ type PushReader interface { type Pusher interface { // Push applies user configurable for registering an item returning the id // i.e. updated the storage with said item. - Push(item model.Item, owner string) (PushResult, error) + PushItem(id, bucket, owner string, item model.Item) (PushResult, error) // Remove will remove the item from the store - Remove(id string, owner string) (model.Item, error) + RemoveItem(id, bucket string, owner string) (model.Item, error) } type Listener interface { @@ -43,7 +43,7 @@ type Listener interface { // additions, or updates. // // The list of hooks must contain only the current items. - Update(items []model.Item) + Update(items Items) } type ListenerFunc func(items []model.Item) @@ -54,7 +54,7 @@ func (listener ListenerFunc) Update(items []model.Item) { type Reader interface { // GeItems will return all the current items or an error. - GetItems(owner string) ([]model.Item, error) + GetItems(bucket, owner string) (Items, error) Start(ctx context.Context) error