Skip to content

Commit

Permalink
feat: add endpoint to get a user's own organization (#3402)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdevine authored Oct 6, 2022
1 parent 75b9884 commit bac01df
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 8 deletions.
4 changes: 4 additions & 0 deletions api/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ type Organization struct {
Domain string `json:"domain"`
}

type GetOrganizationRequest struct {
ID IDOrSelf `uri:"id"`
}

type ListOrganizationsRequest struct {
Name string `form:"name"`
PaginationRequest
Expand Down
12 changes: 11 additions & 1 deletion internal/access/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import (
"github.com/infrahq/infra/uid"
)

// isOrganizationSelf is used by authorization checks to see if the calling identity is requesting their own organization
func isOrganizationSelf(c *gin.Context, orgID uid.ID) (bool, error) {
org := GetRequestContext(c).Authenticated.Organization
return org != nil && org.ID == orgID, nil
}

func ListOrganizations(c *gin.Context, name string, pg *data.Pagination) ([]models.Organization, error) {
selectors := []data.SelectorFunc{}
if name != "" {
Expand All @@ -29,7 +35,11 @@ func ListOrganizations(c *gin.Context, name string, pg *data.Pagination) ([]mode
}

func GetOrganization(c *gin.Context, id uid.ID) (*models.Organization, error) {
db, err := RequireInfraRole(c, models.InfraSupportAdminRole)
roles := []string{models.InfraSupportAdminRole}

// If the user is in the org, allow them to call this endpoint, otherwise they must be
// an InfraSupportAdmin.
db, err := hasAuthorization(c, id, isOrganizationSelf, roles...)
if err != nil {
return nil, HandleAuthErr(err, "organizations", "get", models.InfraSupportAdminRole)
}
Expand Down
4 changes: 4 additions & 0 deletions internal/server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ func jsonUnmarshal(t *testing.T, raw string) interface{} {
return out
}

var cmpAPIOrganizationJSON = gocmp.Options{
gocmp.FilterPath(pathMapKey(`created`, `updated`), cmpApproximateTime),
}

var cmpAPIUserJSON = gocmp.Options{
gocmp.FilterPath(pathMapKey(`created`, `updated`, `lastSeenAt`), cmpApproximateTime),
gocmp.FilterPath(pathMapKey(`id`), cmpAnyValidUID),
Expand Down
8 changes: 5 additions & 3 deletions internal/server/models/organization.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ type Organization struct {

func (o *Organization) ToAPI() *api.Organization {
return &api.Organization{
ID: o.ID,
Name: o.Name,
Domain: o.Domain,
ID: o.ID,
Name: o.Name,
Created: api.Time(o.CreatedAt),
Updated: api.Time(o.UpdatedAt),
Domain: o.Domain,
}
}

Expand Down
14 changes: 12 additions & 2 deletions internal/server/organizations.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package server

import (
"fmt"

"github.com/gin-gonic/gin"

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/access"
"github.com/infrahq/infra/internal/server/models"
)
Expand All @@ -22,8 +25,15 @@ func (a *API) ListOrganizations(c *gin.Context, r *api.ListOrganizationsRequest)
return result, nil
}

func (a *API) GetOrganization(c *gin.Context, r *api.Resource) (*api.Organization, error) {
org, err := access.GetOrganization(c, r.ID)
func (a *API) GetOrganization(c *gin.Context, r *api.GetOrganizationRequest) (*api.Organization, error) {
if r.ID.IsSelf {
iden := access.GetRequestContext(c).Authenticated.Organization
if iden == nil {
return nil, fmt.Errorf("%w: no user is logged in", internal.ErrUnauthorized)
}
r.ID.ID = iden.ID
}
org, err := access.GetOrganization(c, r.ID.ID)
if err != nil {
return nil, err
}
Expand Down
156 changes: 156 additions & 0 deletions internal/server/organizations_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package server

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"

"gotest.tools/v3/assert"

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/uid"
)

func createOrgs(t *testing.T, db data.GormTxn, orgs ...*models.Organization) {
Expand All @@ -21,6 +25,158 @@ func createOrgs(t *testing.T, db data.GormTxn, orgs ...*models.Organization) {
}
}

func TestAPI_GetOrganization(t *testing.T) {
srv := setupServer(t, withAdminUser, withSupportAdminGrant, withMultiOrgEnabled)
routes := srv.GenerateRoutes()

var (
first = models.Organization{Name: "first", Domain: "first.com"}
)

createOrgs(t, srv.DB(), &first)
createID := func(t *testing.T, name string) uid.ID {
t.Helper()
var buf bytes.Buffer
body := api.CreateUserRequest{Name: name}
err := json.NewEncoder(&buf).Encode(body)
assert.NilError(t, err)

// nolint:noctx
req, err := http.NewRequest(http.MethodPost, "/api/users", &buf)
assert.NilError(t, err)
req.Header.Set("Authorization", "Bearer "+adminAccessKey(srv))
req.Header.Set("Infra-Version", apiVersionLatest)

resp := httptest.NewRecorder()
routes.ServeHTTP(resp, req)
assert.Equal(t, resp.Code, http.StatusCreated, resp.Body.String())
respObj := &api.CreateUserResponse{}
err = json.Unmarshal(resp.Body.Bytes(), respObj)
assert.NilError(t, err)
return respObj.ID
}
idMe := createID(t, "me@example.com")

token := &models.AccessKey{
IssuedFor: idMe,
ProviderID: data.InfraProvider(srv.DB()).ID,
ExpiresAt: time.Now().Add(10 * time.Second),
}

accessKeyMe, err := data.CreateAccessKey(srv.DB(), token)
assert.NilError(t, err)

type testCase struct {
urlPath string
setup func(t *testing.T, req *http.Request)
expected func(t *testing.T, resp *httptest.ResponseRecorder)
}

run := func(t *testing.T, tc testCase) {
req, err := http.NewRequest(http.MethodGet, tc.urlPath, nil)
assert.NilError(t, err)
req.Header.Add("Infra-Version", "0.15.2")

if tc.setup != nil {
tc.setup(t, req)
}

resp := httptest.NewRecorder()
routes.ServeHTTP(resp, req)

tc.expected(t, resp)
}

testCases := map[string]testCase{
"not authenticated": {
urlPath: "/api/organizations/" + first.ID.String(),
setup: func(t *testing.T, req *http.Request) {
req.Header.Del("Authorization")
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusUnauthorized)
},
},
"not authorized": {
urlPath: "/api/organizations/" + first.ID.String(),
setup: func(t *testing.T, req *http.Request) {
key, _ := createAccessKey(t, srv.DB(), "someonenew@example.com")
req.Header.Set("Authorization", "Bearer "+key)
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusForbidden)
},
},
"organization by ID for default org": {
urlPath: "/api/organizations/" + srv.db.DefaultOrg.ID.String(),
setup: func(t *testing.T, req *http.Request) {
req.Header.Set("Authorization", "Bearer "+accessKeyMe)
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusOK)
},
},
"organization by ID for a different org": {
urlPath: "/api/organizations/" + first.ID.String(),
setup: func(t *testing.T, req *http.Request) {
req.Header.Set("Authorization", "Bearer "+accessKeyMe)
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusForbidden)
},
},
"organization by ID for a different org by support admin": {
urlPath: "/api/organizations/" + first.ID.String(),
setup: func(t *testing.T, req *http.Request) {
req.Header.Set("Authorization", "Bearer "+adminAccessKey(srv))
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusOK)
},
},
"organization by self": {
urlPath: "/api/organizations/self",
setup: func(t *testing.T, req *http.Request) {
req.Header.Set("Authorization", "Bearer "+accessKeyMe)
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusOK)
},
},
"JSON response": {
urlPath: "/api/organizations/" + srv.db.DefaultOrg.ID.String(),
setup: func(t *testing.T, req *http.Request) {
req.Header.Set("Authorization", "Bearer "+accessKeyMe)
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusOK)

expected := jsonUnmarshal(t, fmt.Sprintf(`
{
"id": "%[1]v",
"name": "%[2]v",
"created": "%[3]v",
"updated": "%[3]v",
"domain": "%[4]v"
}`,
srv.db.DefaultOrg.ID.String(),
srv.db.DefaultOrg.Name,
time.Now().UTC().Format(time.RFC3339),
srv.db.DefaultOrg.Domain,
))
actual := jsonUnmarshal(t, resp.Body.String())
assert.DeepEqual(t, actual, expected, cmpAPIOrganizationJSON)
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
run(t, tc)
})
}

}

func TestAPI_ListOrganizations(t *testing.T) {
srv := setupServer(t, withAdminUser, withSupportAdminGrant)
routes := srv.GenerateRoutes()
Expand Down
5 changes: 3 additions & 2 deletions internal/server/testdata/openapi3.json
Original file line number Diff line number Diff line change
Expand Up @@ -3809,9 +3809,10 @@
"name": "id",
"required": true,
"schema": {
"description": "a uid or the literal self",
"example": "4yJ3n3D8E2",
"format": "uid",
"pattern": "[\\da-zA-HJ-NP-Z]{1,11}",
"format": "uid|self",
"pattern": "[\\da-zA-HJ-NP-Z]{1,11}|self",
"type": "string"
}
}
Expand Down

0 comments on commit bac01df

Please sign in to comment.