Skip to content

Commit

Permalink
Feature cont.: authorize CLI as admin with private (#4338)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamrodrigo authored Jul 31, 2021
1 parent 28e0489 commit 9a46d9d
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 42 deletions.
24 changes: 10 additions & 14 deletions common/authorization/oauthAuthorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type oauthAuthority struct {
log log.Logger
}

type jwtClaims struct {
type JWTClaims struct {
Sub string
Name string
Groups string // separated by space
Expand Down Expand Up @@ -117,17 +117,17 @@ func (a *oauthAuthority) getVerifier() (jwt.Verifier, error) {
return verifier, nil
}

func (a *oauthAuthority) parseToken(tokenStr string, verifier jwt.Verifier) (*jwtClaims, error) {
func (a *oauthAuthority) parseToken(tokenStr string, verifier jwt.Verifier) (*JWTClaims, error) {
token, verifyErr := jwt.ParseAndVerifyString(tokenStr, verifier)
if verifyErr != nil {
return nil, verifyErr
}
var claims jwtClaims
var claims JWTClaims
_ = json.Unmarshal(token.RawClaims(), &claims)
return &claims, nil
}

func (a *oauthAuthority) validateTTL(claims *jwtClaims) error {
func (a *oauthAuthority) validateTTL(claims *JWTClaims) error {
if claims.TTL > a.authorizationCfg.MaxJwtTTL {
return fmt.Errorf("TTL in token is larger than MaxTTL allowed")
}
Expand All @@ -137,26 +137,22 @@ func (a *oauthAuthority) validateTTL(claims *jwtClaims) error {
return nil
}

func (a *oauthAuthority) validatePermission(claims *jwtClaims, attributes *Attributes, data map[string]string) error {
func (a *oauthAuthority) validatePermission(claims *JWTClaims, attributes *Attributes, data map[string]string) error {
groups := ""
switch attributes.Permission {
case PermissionRead:
groups = data[common.DomainDataKeyForReadGroups] + groupSeparator + data[common.DomainDataKeyForWriteGroups]
case PermissionWrite:
groups = data[common.DomainDataKeyForWriteGroups]
default:
if claims.Admin {
return nil
} else {
return fmt.Errorf("token doesn't have permission for admin API")
}
return fmt.Errorf("token doesn't have permission for %v API", attributes.Permission)
}
// groups are separated by space
jwtGroups := strings.Split(groups, groupSeparator)
allowedGroups := strings.Split(claims.Groups, groupSeparator)
allowedGroups := strings.Split(groups, groupSeparator) // groups that allowed by domain configuration(in domainData)
jwtGroups := strings.Split(claims.Groups, groupSeparator) // groups that the request has associated with

for _, group1 := range jwtGroups {
for _, group2 := range allowedGroups {
for _, group1 := range allowedGroups {
for _, group2 := range jwtGroups {
if group1 == group2 {
return nil
}
Expand Down
7 changes: 4 additions & 3 deletions common/authorization/oauthAutorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (s *oauthSuite) TestDifferentGroup() {
s.att.Permission = PermissionWrite
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have the right permission, jwt groups: [], allowed groups: [a b c]"
return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have the right permission, jwt groups: [a b c], allowed groups: []"
}))
result, _ := authorizer.Authorize(s.ctx, &s.att)
s.Equal(result.Decision, DecisionDeny)
Expand All @@ -222,8 +222,9 @@ func (s *oauthSuite) TestIncorrectPermission() {
s.att.Permission = Permission(15)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have permission for admin API"
return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have permission for 15 API"
}))
result, _ := authorizer.Authorize(s.ctx, &s.att)
result, err := authorizer.Authorize(s.ctx, &s.att)
s.NoError(err)
s.Equal(result.Decision, DecisionDeny)
}
82 changes: 82 additions & 0 deletions common/rsa.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2021 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package common

import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"strings"
)

type KeyType string

const (
KeyTypePrivate KeyType = "private key"

KeyTypePublic KeyType = "public key"
)

func loadRSAKey(path string, keyType KeyType) (interface{}, error) {
keyString, err := ioutil.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("invalid %s path %s", keyType, path)
}
block, _ := pem.Decode(keyString)
if block == nil || strings.ToLower(block.Type) != strings.ToLower(string(keyType)) {
return nil, fmt.Errorf("failed to parse PEM block containing the %s", keyType)
}

switch keyType {
case KeyTypePrivate:
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse DER encoded %s: %s", keyType, err.Error())
}
return key, nil
case KeyTypePublic:
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse DER encoded %s: %s", keyType, err.Error())
}
return key, nil
default:
return nil, fmt.Errorf("invalid Key Type")
}
}

func LoadRSAPublicKey(path string) (*rsa.PublicKey, error) {
key, err := loadRSAKey(path, KeyTypePublic)
if err != nil {
return nil, err
}
return key.(*rsa.PublicKey), err
}

func LoadRSAPrivateKey(path string) (*rsa.PrivateKey, error) {
key, err := loadRSAKey(path, KeyTypePrivate)
if err != nil {
return nil, err
}
return key.(*rsa.PrivateKey), err
}
23 changes: 0 additions & 23 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@ package common

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"io/ioutil"
"math"
"math/rand"
"sort"
Expand Down Expand Up @@ -958,25 +954,6 @@ func SleepWithMinDuration(desired time.Duration, available time.Duration) time.D
return available - d
}

// LoadRSAPublicKey loads a rsa.PublicKey from the given filepath
func LoadRSAPublicKey(path string) (*rsa.PublicKey, error) {
key, err := ioutil.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("invalid public key path %s", path)
}
block, _ := pem.Decode(key)
if block == nil || block.Type != "PUBLIC KEY" {
return nil, fmt.Errorf("failed to parse PEM block containing the public key")
}

pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse DER encoded public key: " + err.Error())
}
publicKey := pub.(*rsa.PublicKey)
return publicKey, nil
}

// ConvertErrToGetTaskFailedCause converts error to GetTaskFailedCause
func ConvertErrToGetTaskFailedCause(err error) types.GetTaskFailedCause {
if IsContextTimeoutError(err) {
Expand Down
7 changes: 6 additions & 1 deletion tools/cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,14 @@ func NewCliApp() *cli.App {
},
cli.StringFlag{
Name: FlagJWT,
Usage: "optional JWT for authorization",
Usage: "optional JWT for authorization. Either this or --jwt-private-key is needed for jwt authorization",
EnvVar: "CADENCE_CLI_JWT",
},
cli.StringFlag{
Name: FlagJWTPrivateKeyWithAlias,
Usage: "optional private key path to create JWT. Either this or --jwt is needed for jwt authorization. --jwt flag has priority over this one if both provided",
EnvVar: "CADENCE_CLI_JWT_PRIVATE_KEY",
},
}
app.Commands = []cli.Command{
{
Expand Down
4 changes: 4 additions & 0 deletions tools/cli/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,7 @@ func (vm *versionMiddleware) Call(ctx context.Context, request *transport.Reques
func getJWT(c *cli.Context) string {
return c.GlobalString(FlagJWT)
}

func getJWTPrivateKey(c *cli.Context) string {
return c.GlobalString(FlagJWTPrivateKey)
}
2 changes: 2 additions & 0 deletions tools/cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ const (
DelayStartSeconds = "delay_start_seconds"
FlagConnectionAttributes = "conn_attrs"
FlagJWT = "jwt"
FlagJWTPrivateKey = "jwt-private-key"
FlagJWTPrivateKeyWithAlias = FlagJWTPrivateKey + ", jwt-pk"
)

var flagsForExecution = []cli.Flag{
Expand Down
49 changes: 48 additions & 1 deletion tools/cli/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ import (
"strings"
"time"

"github.com/cristalhq/jwt/v3"
"github.com/fatih/color"
"github.com/urfave/cli"
"github.com/valyala/fastjson"
s "go.uber.org/cadence/.gen/go/shared"
"go.uber.org/cadence/client"

"github.com/uber/cadence/common"
"github.com/uber/cadence/common/authorization"
cc "github.com/uber/cadence/common/client"
)

Expand Down Expand Up @@ -763,8 +765,27 @@ func getCliIdentity() string {
return fmt.Sprintf("cadence-cli@%s", hostName)
}

func processJWTFlags(ctx context.Context, cliCtx *cli.Context) context.Context {
path := getJWTPrivateKey(cliCtx)
t := getJWT(cliCtx)
var token string

if t != "" {
token = t
} else if path != "" {
createdToken, err := createJWT(path)
if err != nil {
ErrorAndExit("Error creating JWT token", err)
}
token = *createdToken
}

ctx = context.WithValue(ctx, CtxKeyJWT, token)
return ctx
}

func populateContextFromCLIContext(ctx context.Context, cliCtx *cli.Context) context.Context {
ctx = context.WithValue(ctx, CtxKeyJWT, getJWT(cliCtx))
ctx = processJWTFlags(ctx, cliCtx)
return ctx
}

Expand Down Expand Up @@ -1007,3 +1028,29 @@ func getInputFile(inputFile string) *os.File {
}
return f
}

// createJWT defines the logic to create a JWT
func createJWT(keyPath string) (*string, error) {
claims := authorization.JWTClaims{
Admin: true,
Iat: time.Now().Unix(),
TTL: 60 * 10,
}

privateKey, err := common.LoadRSAPrivateKey(keyPath)
if err != nil {
return nil, err
}

signer, err := jwt.NewSignerRS(jwt.RS256, privateKey)
if err != nil {
return nil, err
}
builder := jwt.NewBuilder(signer)
token, err := builder.Build(claims)
if token == nil {
return nil, err
}
tokenString := token.String()
return &tokenString, nil
}

0 comments on commit 9a46d9d

Please sign in to comment.