diff --git a/auth/api/http/keys/endpoint.go b/auth/api/http/keys/endpoint.go index 4c3d1b7ecca..6aa1788b0fa 100644 --- a/auth/api/http/keys/endpoint.go +++ b/auth/api/http/keys/endpoint.go @@ -85,3 +85,18 @@ func revokeEndpoint(svc auth.Service) endpoint.Endpoint { return revokeKeyRes{}, nil } } + +func revokeTokenEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokeTokenReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.RevokeToken(ctx, req.token); err != nil { + return nil, err + } + + return revokeKeyRes{}, nil + } +} diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index a1f4fe448d6..d248bc1474a 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -4,7 +4,6 @@ package keys_test import ( - "context" "encoding/json" "fmt" "io" @@ -16,12 +15,11 @@ import ( "github.com/absmach/magistrala/auth" httpapi "github.com/absmach/magistrala/auth/api/http" - "github.com/absmach/magistrala/auth/jwt" "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/apiutil" + "github.com/absmach/magistrala/internal/testsutil" mglog "github.com/absmach/magistrala/logger" svcerr "github.com/absmach/magistrala/pkg/errors/service" - "github.com/absmach/magistrala/pkg/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -66,17 +64,6 @@ func (tr testRequest) make() (*http.Response, error) { return tr.client.Do(req) } -func newService() (auth.Service, *mocks.KeyRepository) { - krepo := new(mocks.KeyRepository) - prepo := new(mocks.PolicyAgent) - drepo := new(mocks.DomainsRepository) - idProvider := uuid.NewMock() - - t := jwt.New([]byte(secret)) - - return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), krepo -} - func newServer(svc auth.Service) *httptest.Server { mux := httpapi.MakeHandler(svc, mglog.NewMock(), "") return httptest.NewServer(mux) @@ -91,9 +78,7 @@ func toJSON(data interface{}) string { } func TestIssue(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -108,11 +93,14 @@ func TestIssue(t *testing.T) { req string ct string token string + resp auth.Token + err error status int }{ { desc: "issue login key with empty token", req: toJSON(lk), + resp: auth.Token{AccessToken: "token"}, ct: contentType, token: "", status: http.StatusUnauthorized, @@ -120,29 +108,30 @@ func TestIssue(t *testing.T) { { desc: "issue API key", req: toJSON(ak), + resp: auth.Token{AccessToken: "token"}, ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusCreated, }, { desc: "issue recovery key", req: toJSON(rk), ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusCreated, }, { desc: "issue login key wrong content type", req: toJSON(lk), ct: "", - token: token.AccessToken, + token: "token", status: http.StatusUnsupportedMediaType, }, { desc: "issue recovery key wrong content type", req: toJSON(rk), ct: "", - token: token.AccessToken, + token: "token", status: http.StatusUnsupportedMediaType, }, { @@ -150,6 +139,7 @@ func TestIssue(t *testing.T) { req: toJSON(ak), ct: contentType, token: "wrong", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { @@ -157,27 +147,28 @@ func TestIssue(t *testing.T) { req: toJSON(rk), ct: contentType, token: "", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { desc: "issue key with invalid request", req: "{", ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, { desc: "issue key with invalid JSON", req: "{invalid}", ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, { desc: "issue key with invalid JSON content", req: `{"Type":{"key":"AccessToken"}}`, ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, } @@ -191,24 +182,16 @@ func TestIssue(t *testing.T) { token: tc.token, body: strings.NewReader(tc.req), } - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil) + svcCall := svc.On("Issue", mock.Anything, tc.token, mock.Anything).Return(tc.resp, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } func TestRetrieve(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id} - - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) - k, err := svc.Issue(context.Background(), token.AccessToken, key) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - repocall.Unset() + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -224,8 +207,8 @@ func TestRetrieve(t *testing.T) { }{ { desc: "retrieve an existing key", - id: k.AccessToken, - token: token.AccessToken, + id: testsutil.GenerateUUID(t), + token: "token", key: auth.Key{ Subject: id, Type: auth.AccessKey, @@ -238,13 +221,13 @@ func TestRetrieve(t *testing.T) { { desc: "retrieve a non-existing key", id: "non-existing", - token: token.AccessToken, - status: http.StatusBadRequest, + token: "token", + status: http.StatusNotFound, err: svcerr.ErrNotFound, }, { desc: "retrieve a key with an invalid token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "wrong", status: http.StatusUnauthorized, err: svcerr.ErrAuthentication, @@ -252,7 +235,7 @@ func TestRetrieve(t *testing.T) { { desc: "retrieve a key with an empty token", token: "", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), status: http.StatusUnauthorized, err: svcerr.ErrAuthentication, }, @@ -265,24 +248,16 @@ func TestRetrieve(t *testing.T) { url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id), token: tc.token, } - repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err) + svcCall := svc.On("RetrieveKey", mock.Anything, tc.token, tc.id).Return(tc.key, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } func TestRevoke(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id} - - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) - k, err := svc.Issue(context.Background(), token.AccessToken, key) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - repocall.Unset() + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -292,29 +267,31 @@ func TestRevoke(t *testing.T) { desc string id string token string + err error status int }{ { desc: "revoke an existing key", - id: k.AccessToken, - token: token.AccessToken, + id: testsutil.GenerateUUID(t), + token: "token", status: http.StatusNoContent, }, { desc: "revoke a non-existing key", id: "non-existing", - token: token.AccessToken, + token: "token", status: http.StatusNoContent, }, { desc: "revoke key with invalid token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "wrong", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { desc: "revoke key with empty token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "", status: http.StatusUnauthorized, }, @@ -327,10 +304,63 @@ func TestRevoke(t *testing.T) { url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id), token: tc.token, } - repocall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil) + svcCall := svc.On("Revoke", mock.Anything, tc.token, tc.id).Return(tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + } +} + +func TestRevokeToken(t *testing.T) { + svc := new(mocks.Service) + + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + cases := []struct { + desc string + id string + token string + err error + status int + }{ + { + desc: "revoke an existing token", + token: "token", + status: http.StatusNoContent, + }, + { + desc: "revoke a non-existing token", + token: "token", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke invalid token", + token: "wrong", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke empty token", + token: "", + status: http.StatusUnauthorized, + }, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodDelete, + url: fmt.Sprintf("%s/keys/", ts.URL), + token: tc.token, + } + svcCall := svc.On("RevokeToken", mock.Anything, tc.token).Return(tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } diff --git a/auth/api/http/keys/requests.go b/auth/api/http/keys/requests.go index a9954a7351a..ceb3cc41c64 100644 --- a/auth/api/http/keys/requests.go +++ b/auth/api/http/keys/requests.go @@ -46,3 +46,15 @@ func (req keyReq) validate() error { } return nil } + +type revokeTokenReq struct { + token string +} + +func (req revokeTokenReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + return nil +} diff --git a/auth/api/http/keys/requests_test.go b/auth/api/http/keys/requests_test.go index 7ab8ae70416..bcc7fe82a9e 100644 --- a/auth/api/http/keys/requests_test.go +++ b/auth/api/http/keys/requests_test.go @@ -86,3 +86,30 @@ func TestKeyReqValidate(t *testing.T) { assert.Equal(t, tc.err, err) } } + +func TestRevokeTokenReqValidate(t *testing.T) { + cases := []struct { + desc string + req revokeTokenReq + err error + }{ + { + desc: "valid request", + req: revokeTokenReq{ + token: valid, + }, + err: nil, + }, + { + desc: "empty token", + req: revokeTokenReq{ + token: "", + }, + err: apiutil.ErrBearerToken, + }, + } + for _, tc := range cases { + err := tc.req.validate() + assert.Equal(t, tc.err, err) + } +} diff --git a/auth/api/http/keys/transport.go b/auth/api/http/keys/transport.go index c66c15c0e4e..31c4c6d2a27 100644 --- a/auth/api/http/keys/transport.go +++ b/auth/api/http/keys/transport.go @@ -33,6 +33,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { opts..., ).ServeHTTP) + r.Delete("/", kithttp.NewServer( + revokeTokenEndpoint(svc), + decodeRevokeTokenReq, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Get("/{id}", kithttp.NewServer( (retrieveEndpoint(svc)), decodeKeyReq, @@ -70,3 +77,11 @@ func decodeKeyReq(_ context.Context, r *http.Request) (interface{}, error) { } return req, nil } + +func decodeRevokeTokenReq(_ context.Context, r *http.Request) (interface{}, error) { + req := revokeTokenReq{ + token: apiutil.ExtractBearerToken(r), + } + + return req, nil +} diff --git a/auth/api/logging.go b/auth/api/logging.go index 3f2c7537c10..466e545c4ac 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -193,6 +193,22 @@ func (lm *loggingMiddleware) Revoke(ctx context.Context, token, id string) (err return lm.svc.Revoke(ctx, token, id) } +func (lm *loggingMiddleware) RevokeToken(ctx context.Context, token string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke token failed to complete successfully", args...) + return + } + lm.logger.Info("Revoke token completed successfully", args...) + }(time.Now()) + + return lm.svc.RevokeToken(ctx, token) +} + func (lm *loggingMiddleware) RetrieveKey(ctx context.Context, token, id string) (key auth.Key, err error) { defer func(begin time.Time) { args := []any{ diff --git a/auth/api/metrics.go b/auth/api/metrics.go index c7e63e2a603..de823b53ab5 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -109,6 +109,15 @@ func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error return ms.svc.Revoke(ctx, token, id) } +func (ms *metricsMiddleware) RevokeToken(ctx context.Context, token string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_token").Add(1) + ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeToken(ctx, token) +} + func (ms *metricsMiddleware) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { defer func(begin time.Time) { ms.counter.With("method", "retrieve_key").Add(1) diff --git a/auth/cache/policies_test.go b/auth/cache/policies_test.go index 54a65957a4a..b82020cbcaa 100644 --- a/auth/cache/policies_test.go +++ b/auth/cache/policies_test.go @@ -27,7 +27,7 @@ var policy = auth.PolicyReq{ Permission: auth.ViewPermission, } -func setupRedisClient(t *testing.T) auth.Cache { +func setupRedisCacheClient(t *testing.T) auth.Cache { opts, err := redis.ParseURL(redisURL) assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) redisClient := redis.NewClient(opts) @@ -35,7 +35,7 @@ func setupRedisClient(t *testing.T) auth.Cache { } func TestSave(t *testing.T) { - authCache := setupRedisClient(t) + authCache := setupRedisCacheClient(t) cases := []struct { desc string @@ -153,7 +153,7 @@ func TestSave(t *testing.T) { } func TestContains(t *testing.T) { - authCache := setupRedisClient(t) + authCache := setupRedisCacheClient(t) key, val := policy.KV() err := authCache.Save(context.Background(), key, val) @@ -237,7 +237,7 @@ func TestContains(t *testing.T) { } func TestRemove(t *testing.T) { - authCache := setupRedisClient(t) + authCache := setupRedisCacheClient(t) subject := policy.Subject object := policy.Object diff --git a/auth/cache/tokens.go b/auth/cache/tokens.go new file mode 100644 index 00000000000..fa54e1184e0 --- /dev/null +++ b/auth/cache/tokens.go @@ -0,0 +1,56 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" +) + +const defKey = "revoked_tokens" + +var _ auth.Cache = (*tokensCache)(nil) + +type tokensCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewTokensCache returns redis auth cache implementation. +func NewTokensCache(client *redis.Client, duration time.Duration) auth.Cache { + return &tokensCache{ + client: client, + keyDuration: duration, + } +} + +func (tc *tokensCache) Save(ctx context.Context, _, value string) error { + if err := tc.client.SAdd(ctx, defKey, value, tc.keyDuration).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (tc *tokensCache) Contains(ctx context.Context, _, value string) bool { + ok, err := tc.client.SIsMember(ctx, defKey, value).Result() + if err != nil { + return false + } + + return ok +} + +func (tc *tokensCache) Remove(ctx context.Context, value string) error { + if err := tc.client.SRem(ctx, defKey, value).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} diff --git a/auth/cache/tokens_test.go b/auth/cache/tokens_test.go new file mode 100644 index 00000000000..8f9902073f5 --- /dev/null +++ b/auth/cache/tokens_test.go @@ -0,0 +1,184 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/auth/cache" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var key = auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), +} + +func setupRedisTokensClient(t *testing.T) auth.Cache { + opts, err := redis.ParseURL(redisURL) + assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) + redisClient := redis.NewClient(opts) + return cache.NewPoliciesCache(redisClient, 10*time.Minute) +} + +func TestTokenSave(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + cases := []struct { + desc string + key auth.Key + err error + }{ + { + desc: "Save token", + key: key, + err: nil, + }, + { + desc: "Save already cached policy", + key: key, + err: nil, + }, + { + desc: "Save another policy", + key: auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), + }, + err: nil, + }, + { + desc: "Save policy with long key", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with empty key", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Save(context.Background(), "", tc.key.ID) + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestTokenContains(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + err := tokensCache.Save(context.Background(), "", key.ID) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + key auth.Key + ok bool + }{ + { + desc: "Contains existing key", + key: key, + ok: true, + }, + { + desc: "Contains non existing key", + key: auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), + }, + }, + { + desc: "Contains key with long id", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + }, + { + desc: "Contains key with empty id", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestTokenRemove(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + num := 1000 + var ids []string + for i := 0; i < num; i++ { + id := testsutil.GenerateUUID(&testing.T{}) + err := tokensCache.Save(context.Background(), "", id) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + ids = append(ids, id) + } + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "Remove an existing id from cache", + id: ids[0], + err: nil, + }, + { + desc: "Remove multiple existing id from cache", + id: "*", + err: nil, + }, + { + desc: "Remove non existing id from cache", + id: testsutil.GenerateUUID(&testing.T{}), + err: nil, + }, + { + desc: "Remove policy with empty id from cache", + err: nil, + }, + { + desc: "Remove policy with long id from cache", + id: strings.Repeat("a", 513*1024*1024), + err: repoerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Remove(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err)) + if tc.id == "*" { + for _, id := range ids { + ok := tokensCache.Contains(context.Background(), "", id) + assert.False(t, ok) + } + return + } + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.id) + assert.False(t, ok) + } + }) + } +} diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index 461adb95be9..dafc9904894 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -4,14 +4,17 @@ package jwt_test import ( + "context" "fmt" "testing" "time" "github.com/absmach/magistrala/auth" authjwt "github.com/absmach/magistrala/auth/jwt" + "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" svcerr "github.com/absmach/magistrala/pkg/errors/service" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -55,7 +58,9 @@ func newToken(issuerName string, key auth.Key) string { } func TestIssue(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) cases := []struct { desc string @@ -128,7 +133,9 @@ func TestIssue(t *testing.T) { } func TestParse(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) token, err := tokenizer.Issue(key()) require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) @@ -162,11 +169,19 @@ func TestParse(t *testing.T) { inValidToken := newToken("invalid", key()) + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + cases := []struct { - desc string - key auth.Key - token string - err error + desc string + key auth.Key + token string + cacheContains bool + repoContains bool + cacheSave error + err error }{ { desc: "parse valid key", @@ -222,14 +237,191 @@ func TestParse(t *testing.T) { token: emptyToken, err: nil, }, + { + desc: "parse refresh token", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: false, + err: nil, + }, + { + desc: "parse revoked refresh token in cache", + key: refreshKey, + token: refreshToken, + cacheContains: true, + repoContains: false, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token not in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, } for _, tc := range cases { - key, err := tokenizer.Parse(tc.token) + cacheCall := cache.On("Contains", context.Background(), "", tc.key.ID).Return(tc.cacheContains) + repoCall := repo.On("Contains", context.Background(), tc.key.ID).Return(tc.repoContains) + cacheCall1 := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheSave) + key, err := tokenizer.Parse(context.Background(), tc.token) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) } + cacheCall.Unset() + repoCall.Unset() + cacheCall1.Unset() + } +} + +func TestRevoke(t *testing.T) { + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) + + token, err := tokenizer.Issue(key()) + require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) + + apiKey := key() + apiKey.Type = auth.APIKey + apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + apiToken, err := tokenizer.Issue(apiKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + expKey := key() + expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + expToken, err := tokenizer.Issue(expKey) + require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err)) + + emptyDomainKey := key() + emptyDomainKey.Domain = "" + emptyDomainToken, err := tokenizer.Issue(emptyDomainKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptySubjectKey := key() + emptySubjectKey.Subject = "" + emptySubjectToken, err := tokenizer.Issue(emptySubjectKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptyKey := key() + emptyKey.Domain = "" + emptyKey.Subject = "" + emptyToken, err := tokenizer.Issue(emptyKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + inValidToken := newToken("invalid", key()) + + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + cases := []struct { + desc string + key auth.Key + token string + repoErr error + cacheErr error + err error + }{ + { + desc: "revoke valid key", + key: key(), + token: token, + err: nil, + }, + { + desc: "revoke invalid key", + key: auth.Key{}, + token: "invalid", + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke expired key", + key: auth.Key{}, + token: expToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke expired API key", + key: apiKey, + token: apiToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke token with invalid issuer", + key: auth.Key{}, + token: inValidToken, + err: errInvalidIssuer, + }, + { + desc: "revoke token with invalid content", + key: auth.Key{}, + token: newToken(issuerName, key()), + err: authjwt.ErrJSONHandle, + }, + { + desc: "revoke token with empty domain", + key: emptyDomainKey, + token: emptyDomainToken, + err: nil, + }, + { + desc: "revoke token with empty subject", + key: emptySubjectKey, + token: emptySubjectToken, + err: nil, + }, + { + desc: "revoke token with empty domain and subject", + key: emptyKey, + token: emptyToken, + err: nil, + }, + { + desc: "revoke refresh token", + key: refreshKey, + token: refreshToken, + err: nil, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: nil, + cacheErr: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: repoerr.ErrCreateEntity, + cacheErr: nil, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := repo.On("Save", context.Background(), tc.key.ID).Return(tc.repoErr) + cacheCall := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheErr) + err := tokenizer.Revoke(context.Background(), tc.token) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + cacheCall.Unset() + repoCall.Unset() } } diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index ad79549016d..0c48ac45776 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -26,6 +26,8 @@ var ( ErrValidateJWTToken = errors.New("failed to validate jwt token") // ErrJSONHandle indicates an error in handling JSON. ErrJSONHandle = errors.New("failed to perform operation JSON") + // errRevokedToken indicates that the token is revoked. + errRevokedToken = errors.New("token is revoked") ) const ( @@ -40,14 +42,18 @@ const ( type tokenizer struct { secret []byte + cache auth.Cache + repo auth.TokenRepository } var _ auth.Tokenizer = (*tokenizer)(nil) // NewRepository instantiates an implementation of Token repository. -func New(secret []byte) auth.Tokenizer { +func New(secret []byte, repo auth.TokenRepository, cache auth.Cache) auth.Tokenizer { return &tokenizer{ secret: secret, + repo: repo, + cache: cache, } } @@ -79,7 +85,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return string(signedTkn), nil } -func (tok *tokenizer) Parse(token string) (auth.Key, error) { +func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) { tkn, err := tok.validateToken(token) if err != nil { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) @@ -90,9 +96,48 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) } + if key.Type == auth.RefreshKey { + switch tok.cache.Contains(ctx, "", key.ID) { + case true: + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + default: + if ok := tok.repo.Contains(ctx, key.ID); ok { + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + } + } + } + return key, nil } +func (tok *tokenizer) Revoke(ctx context.Context, token string) error { + tkn, err := tok.validateToken(token) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + key, err := toKey(tkn) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if key.Type == auth.RefreshKey { + if err := tok.repo.Save(ctx, key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + } + + return nil +} + func (tok *tokenizer) validateToken(token string) (jwt.Token, error) { tkn, err := jwt.Parse( []byte(token), diff --git a/auth/mocks/cache.go b/auth/mocks/cache.go new file mode 100644 index 00000000000..92a4a33be68 --- /dev/null +++ b/auth/mocks/cache.go @@ -0,0 +1,84 @@ +// Code generated by mockery v2.42.3. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Cache is an autogenerated mock type for the Cache type +type Cache struct { + mock.Mock +} + +// Contains provides a mock function with given fields: ctx, key, value +func (_m *Cache) Contains(ctx context.Context, key string, value string) bool { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Remove provides a mock function with given fields: ctx, key +func (_m *Cache) Remove(ctx context.Context, key string) error { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, key, value +func (_m *Cache) Save(ctx context.Context, key string, value string) error { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCache(t interface { + mock.TestingT + Cleanup(func()) +}) *Cache { + mock := &Cache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 49f36f208db..1216c4dc754 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -595,6 +595,24 @@ func (_m *Service) Revoke(ctx context.Context, token string, id string) error { return r0 } +// RevokeToken provides a mock function with given fields: ctx, token +func (_m *Service) RevokeToken(ctx context.Context, token string) error { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeToken") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, token) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UnassignUsers provides a mock function with given fields: ctx, token, id, userIds, relation func (_m *Service) UnassignUsers(ctx context.Context, token string, id string, userIds []string, relation string) error { ret := _m.Called(ctx, token, id, userIds, relation) diff --git a/auth/mocks/token.go b/auth/mocks/token.go new file mode 100644 index 00000000000..1ccd2ba9886 --- /dev/null +++ b/auth/mocks/token.go @@ -0,0 +1,66 @@ +// Code generated by mockery v2.42.3. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// TokenRepository is an autogenerated mock type for the TokenRepository type +type TokenRepository struct { + mock.Mock +} + +// Contains provides a mock function with given fields: ctx, id +func (_m *TokenRepository) Contains(ctx context.Context, id string) bool { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, id +func (_m *TokenRepository) Save(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewTokenRepository creates a new instance of TokenRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenRepository { + mock := &TokenRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/policies.go b/auth/policies.go index d4877a227ea..3162aa20987 100644 --- a/auth/policies.go +++ b/auth/policies.go @@ -247,6 +247,8 @@ type PolicyAgent interface { // Cache represents a cache repository. It exposes functionalities // through `auth` to perform caching. +// +//go:generate mockery --name Cache --output=./mocks --filename cache.go --quiet --note "Copyright (c) Abstract Machines" type Cache interface { // Save saves the key-value pair in the cache. Save(ctx context.Context, key, value string) error diff --git a/auth/postgres/init.go b/auth/postgres/init.go index bae8674b366..54020f696da 100644 --- a/auth/postgres/init.go +++ b/auth/postgres/init.go @@ -51,6 +51,17 @@ func Migration() *migrate.MemoryMigrationSource { `DROP TABLE IF EXISTS keys`, }, }, + { + Id: "auth_2", + Up: []string{ + `CREATE TABLE IF NOT EXISTS tokens ( + id VARCHAR(36) PRIMARY KEY + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS tokens`, + }, + }, }, } } diff --git a/auth/postgres/token.go b/auth/postgres/token.go new file mode 100644 index 00000000000..2399b80f8e4 --- /dev/null +++ b/auth/postgres/token.go @@ -0,0 +1,61 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/internal/postgres" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" +) + +var _ auth.TokenRepository = (*tokenRepo)(nil) + +type tokenRepo struct { + db postgres.Database +} + +// NewTokensRepository instantiates a PostgreSQL implementation of tokens repository. +func NewTokensRepository(db postgres.Database) auth.TokenRepository { + return &tokenRepo{ + db: db, + } +} + +func (repo *tokenRepo) Save(ctx context.Context, id string) error { + q := `INSERT INTO tokens (id) VALUES ($1);` + + result, err := repo.db.ExecContext(ctx, q, id) + if err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + if rows, err := result.RowsAffected(); rows == 0 { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (repo *tokenRepo) Contains(ctx context.Context, id string) bool { + q := `SELECT * FROM tokens WHERE id = $1;` + + rows, err := repo.db.QueryContext(ctx, q, id) + if err != nil { + return false + } + defer rows.Close() + + if rows.Next() { + id := "" + if err = rows.Scan(&id); err != nil { + return false + } + + return true + } + + return false +} diff --git a/auth/service.go b/auth/service.go index cbb5e426a25..c6e5c627ae2 100644 --- a/auth/service.go +++ b/auth/service.go @@ -80,6 +80,9 @@ type Authn interface { // issued by the user identified by the provided key. Revoke(ctx context.Context, token, id string) error + // RevokeToken revokes the token. + RevokeToken(ctx context.Context, token string) error + // RetrieveKey retrieves data for the Key identified by the provided // ID, that is issued by the user identified by the provided key. RetrieveKey(ctx context.Context, token, id string) (Key, error) @@ -131,6 +134,12 @@ func New(keys KeyRepository, domains DomainsRepository, idp magistrala.IDProvide func (svc service) Issue(ctx context.Context, token string, key Key) (Token, error) { key.IssuedAt = time.Now().UTC() + id, err := svc.idProvider.ID() + if err != nil { + return Token{}, errors.Wrap(errIssueUser, err) + } + key.ID = id + switch key.Type { case APIKey: return svc.userKey(ctx, token, key) @@ -146,7 +155,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err } func (svc service) Revoke(ctx context.Context, token, id string) error { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return errors.Wrap(errRevoke, err) } @@ -156,8 +165,12 @@ func (svc service) Revoke(ctx context.Context, token, id string) error { return nil } +func (svc service) RevokeToken(ctx context.Context, token string) error { + return svc.tokenizer.Revoke(ctx, token) +} + func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, error) { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return Key{}, errors.Wrap(errRetrieve, err) } @@ -170,7 +183,7 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro } func (svc service) Identify(ctx context.Context, token string) (Key, error) { - key, err := svc.tokenizer.Parse(token) + key, err := svc.tokenizer.Parse(ctx, token) if errors.Contains(err, ErrExpiry) { err = svc.keys.Remove(ctx, key.Issuer, key.ID) return Key{}, errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(ErrKeyExpired, err)) @@ -459,7 +472,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) { } func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) { - k, err := svc.tokenizer.Parse(token) + k, err := svc.tokenizer.Parse(ctx, token) if err != nil { return Token{}, errors.Wrap(errRetrieve, err) } @@ -523,7 +536,7 @@ func (svc service) checkUserDomain(ctx context.Context, key Key) (subject string } func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) { - id, sub, err := svc.authenticate(token) + id, sub, err := svc.authenticate(ctx, token) if err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -533,12 +546,6 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e key.Subject = sub } - keyID, err := svc.idProvider.ID() - if err != nil { - return Token{}, errors.Wrap(errIssueUser, err) - } - key.ID = keyID - if _, err := svc.keys.Save(ctx, key); err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -551,8 +558,8 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e return Token{AccessToken: tkn}, nil } -func (svc service) authenticate(token string) (string, string, error) { - key, err := svc.tokenizer.Parse(token) +func (svc service) authenticate(ctx context.Context, token string) (string, string, error) { + key, err := svc.tokenizer.Parse(ctx, token) if err != nil { return "", "", errors.Wrap(svcerr.ErrAuthentication, err) } diff --git a/auth/service_test.go b/auth/service_test.go index 25685ff7e63..d9f2cb366ff 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -63,13 +63,15 @@ var ( drepo *mocks.DomainsRepository ) -func newService() (auth.Service, string) { +func newService() (auth.Service, *mocks.TokenRepository, *mocks.Cache, string) { krepo = new(mocks.KeyRepository) + trepo := new(mocks.TokenRepository) + cache := new(mocks.Cache) prepo = new(mocks.PolicyAgent) drepo = new(mocks.DomainsRepository) idProvider := uuid.NewMock() - t := jwt.New([]byte(secret)) + t := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -80,13 +82,25 @@ func newService() (auth.Service, string) { } token, _ := t.Issue(key) - return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), trepo, cache, token } -func TestIssue(t *testing.T) { - svc, accessToken := newService() +func newMinimalService() auth.Service { + krepo = new(mocks.KeyRepository) + trepo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + prepo = new(mocks.PolicyAgent) + drepo = new(mocks.DomainsRepository) + idProvider := uuid.NewMock() + + t := jwt.New([]byte(secret), trepo, cache) - n := jwt.New([]byte(secret)) + return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration) +} + +func TestIssue(t *testing.T) { + svc, trepo, cache, accessToken := newService() + n := jwt.New([]byte(secret), trepo, cache) apikey := auth.Key{ IssuedAt: time.Now(), @@ -379,6 +393,9 @@ func TestIssue(t *testing.T) { checkDOmainPolicyReq auth.PolicyReq checkPolicyErr error retrieveByIDErr error + cacheContains bool + repoContains bool + cacheSave error err error }{ { @@ -492,21 +509,82 @@ func TestIssue(t *testing.T) { retrieveByIDErr: repoerr.ErrNotFound, err: svcerr.ErrDomainAuthorization, }, + { + desc: "issue revoked refresh key in cache", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + }, + checkPolicyRequest: auth.PolicyReq{ + Subject: email, + SubjectType: auth.UserType, + Object: auth.MagistralaObject, + ObjectType: auth.PlatformType, + Permission: auth.AdminPermission, + }, + cacheContains: true, + repoContains: false, + token: refreshToken, + err: svcerr.ErrAuthentication, + }, + { + desc: "issue revoked refresh key not in cache", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + }, + checkPolicyRequest: auth.PolicyReq{ + Subject: email, + SubjectType: auth.UserType, + Object: auth.MagistralaObject, + ObjectType: auth.PlatformType, + Permission: auth.AdminPermission, + }, + cacheContains: false, + repoContains: true, + token: refreshToken, + err: svcerr.ErrAuthentication, + }, + { + desc: "issue revoked refresh key failed to save in cache", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + }, + checkPolicyRequest: auth.PolicyReq{ + Subject: email, + SubjectType: auth.UserType, + Object: auth.MagistralaObject, + ObjectType: auth.PlatformType, + Permission: auth.AdminPermission, + }, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + token: refreshToken, + err: svcerr.ErrAuthentication, + }, } for _, tc := range cases4 { - repoCall := prepo.On("CheckPolicy", mock.Anything, tc.checkPolicyRequest).Return(tc.checkPolicyErr) - repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr) - repoCall2 := prepo.On("CheckPolicy", mock.Anything, tc.checkDOmainPolicyReq).Return(tc.checkPolicyErr) + cacheCall := cache.On("Contains", context.Background(), "", refreshkey.ID).Return(tc.cacheContains) + repoCall := trepo.On("Contains", context.Background(), refreshkey.ID).Return(tc.repoContains) + cacheCall1 := cache.On("Save", context.Background(), "", refreshkey.ID).Return(tc.cacheSave) + repoCall1 := prepo.On("CheckPolicy", mock.Anything, tc.checkPolicyRequest).Return(tc.checkPolicyErr) + repoCall2 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr) + repoCall3 := prepo.On("CheckPolicy", mock.Anything, tc.checkDOmainPolicyReq).Return(tc.checkPolicyErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + cacheCall.Unset() + cacheCall1.Unset() repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } } func TestRevoke(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, errIssueUser) secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) repocall.Unset() @@ -559,7 +637,7 @@ func TestRevoke(t *testing.T) { } func TestRetrieve(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) @@ -629,7 +707,7 @@ func TestRetrieve(t *testing.T) { } func TestIdentify(t *testing.T) { - svc, _ := newService() + svc, trepo, cache, _ := newService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) repocall1 := prepo.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil) @@ -655,7 +733,7 @@ func TestIdentify(t *testing.T) { assert.Nil(t, err, fmt.Sprintf("Issuing expired login key expected to succeed: %s", err)) repocall4.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -734,7 +812,7 @@ func TestIdentify(t *testing.T) { } func TestAuthorize(t *testing.T) { - svc, accessToken := newService() + svc, trepo, cache, accessToken := newService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) repocall1 := prepo.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil) @@ -755,7 +833,7 @@ func TestAuthorize(t *testing.T) { repocall2.Unset() repocall3.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -1198,7 +1276,7 @@ func TestAuthorize(t *testing.T) { } func TestAddPolicy(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1240,7 +1318,7 @@ func TestAddPolicy(t *testing.T) { } func TestAddPolicies(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1302,7 +1380,7 @@ func TestAddPolicies(t *testing.T) { } func TestDeletePolicy(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1344,7 +1422,7 @@ func TestDeletePolicy(t *testing.T) { } func TestDeletePolicies(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1406,7 +1484,7 @@ func TestDeletePolicies(t *testing.T) { } func TestListObjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1459,7 +1537,7 @@ func TestListObjects(t *testing.T) { } func TestListAllObjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1512,7 +1590,7 @@ func TestListAllObjects(t *testing.T) { } func TestCountObjects(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() pageLen := uint64(15) @@ -1524,7 +1602,7 @@ func TestCountObjects(t *testing.T) { } func TestListSubjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1577,7 +1655,7 @@ func TestListSubjects(t *testing.T) { } func TestListAllSubjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1630,7 +1708,7 @@ func TestListAllSubjects(t *testing.T) { } func TestCountSubjects(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() pageLen := uint64(15) repocall2 := prepo.On("RetrieveAllSubjectsCount", mock.Anything, mock.Anything, mock.Anything).Return(pageLen, nil) @@ -1641,7 +1719,7 @@ func TestCountSubjects(t *testing.T) { } func TestListPermissions(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() pr := auth.PolicyReq{ Subject: id, @@ -1698,7 +1776,7 @@ func TestSwitchToPermission(t *testing.T) { } func TestCreateDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1823,7 +1901,7 @@ func TestCreateDomain(t *testing.T) { } func TestRetrieveDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1883,7 +1961,7 @@ func TestRetrieveDomain(t *testing.T) { } func TestRetrieveDomainPermissions(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1942,7 +2020,7 @@ func TestRetrieveDomainPermissions(t *testing.T) { } func TestUpdateDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2022,7 +2100,7 @@ func TestUpdateDomain(t *testing.T) { } func TestChangeDomainStatus(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() disabledStatus := auth.DisabledStatus @@ -2099,7 +2177,7 @@ func TestChangeDomainStatus(t *testing.T) { } func TestListDomains(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2165,7 +2243,7 @@ func TestListDomains(t *testing.T) { } func TestAssignUsers(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2482,7 +2560,7 @@ func TestAssignUsers(t *testing.T) { } func TestUnassignUsers(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2669,7 +2747,7 @@ func TestUnassignUsers(t *testing.T) { } func TestListUsersDomains(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string diff --git a/auth/tokenizer.go b/auth/tokenizer.go index 1aaed7df4f0..991bbdf891d 100644 --- a/auth/tokenizer.go +++ b/auth/tokenizer.go @@ -3,11 +3,27 @@ package auth +import "context" + // Tokenizer specifies API for encoding and decoding between string and Key. type Tokenizer interface { // Issue converts API Key to its string representation. Issue(key Key) (token string, err error) // Parse extracts API Key data from string token. - Parse(token string) (key Key, err error) + Parse(ctx context.Context, token string) (key Key, err error) + + // Revoke revokes the token. + Revoke(ctx context.Context, token string) error +} + +// TokenRepository specifies token persistence API. +// +//go:generate mockery --name TokenRepository --output=./mocks --filename token.go --quiet --note "Copyright (c) Abstract Machines" +type TokenRepository interface { + // Save persists the token. + Save(ctx context.Context, id string) (err error) + + // Contains checks if token with provided ID exists. + Contains(ctx context.Context, id string) (ok bool) } diff --git a/auth/tracing/tracing.go b/auth/tracing/tracing.go index e95c33a7105..127d4937e18 100644 --- a/auth/tracing/tracing.go +++ b/auth/tracing/tracing.go @@ -43,6 +43,13 @@ func (tm *tracingMiddleware) Revoke(ctx context.Context, token, id string) error return tm.svc.Revoke(ctx, token, id) } +func (tm *tracingMiddleware) RevokeToken(ctx context.Context, token string) error { + ctx, span := tm.tracer.Start(ctx, "revoke") + defer span.End() + + return tm.svc.RevokeToken(ctx, token) +} + func (tm *tracingMiddleware) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { ctx, span := tm.tracer.Start(ctx, "retrieve_key", trace.WithAttributes( attribute.String("id", id), diff --git a/cmd/auth/main.go b/cmd/auth/main.go index e592fb4944d..0a89ea6d8ad 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -217,13 +217,15 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, cacheClient *redis.Client, keyDuration time.Duration, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { database := postgres.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) + tokensRepo := apostgres.NewTokensRepository(database) domainsRepo := apostgres.NewDomainRepository(database) policiesCache := cache.NewPoliciesCache(cacheClient, keyDuration) + tokensCache := cache.NewTokensCache(cacheClient, keyDuration) pa := spicedb.NewPolicyAgent(spicedbClient, logger, policiesCache) idProvider := uuid.New() - t := jwt.New([]byte(cfg.SecretKey)) + t := jwt.New([]byte(cfg.SecretKey), tokensRepo, tokensCache) svc := auth.New(keysRepo, domainsRepo, idProvider, t, pa, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc = api.LoggingMiddleware(svc, logger)