From 9a46d9d483cfdb8fd5ca77d1f448d79d78b0e791 Mon Sep 17 00:00:00 2001 From: Rodrigo Villarreal Date: Sat, 31 Jul 2021 09:46:00 -0700 Subject: [PATCH] Feature cont.: authorize CLI as admin with private (#4338) --- common/authorization/oauthAuthorizer.go | 24 +++--- common/authorization/oauthAutorizer_test.go | 7 +- common/rsa.go | 82 +++++++++++++++++++++ common/util.go | 23 ------ tools/cli/app.go | 7 +- tools/cli/factory.go | 4 + tools/cli/flags.go | 2 + tools/cli/util.go | 49 +++++++++++- 8 files changed, 156 insertions(+), 42 deletions(-) create mode 100644 common/rsa.go diff --git a/common/authorization/oauthAuthorizer.go b/common/authorization/oauthAuthorizer.go index be1df4e345f..145f37141aa 100644 --- a/common/authorization/oauthAuthorizer.go +++ b/common/authorization/oauthAuthorizer.go @@ -43,7 +43,7 @@ type oauthAuthority struct { log log.Logger } -type jwtClaims struct { +type JWTClaims struct { Sub string Name string Groups string // separated by space @@ -117,17 +117,17 @@ func (a *oauthAuthority) getVerifier() (jwt.Verifier, error) { return verifier, nil } -func (a *oauthAuthority) parseToken(tokenStr string, verifier jwt.Verifier) (*jwtClaims, error) { +func (a *oauthAuthority) parseToken(tokenStr string, verifier jwt.Verifier) (*JWTClaims, error) { token, verifyErr := jwt.ParseAndVerifyString(tokenStr, verifier) if verifyErr != nil { return nil, verifyErr } - var claims jwtClaims + var claims JWTClaims _ = json.Unmarshal(token.RawClaims(), &claims) return &claims, nil } -func (a *oauthAuthority) validateTTL(claims *jwtClaims) error { +func (a *oauthAuthority) validateTTL(claims *JWTClaims) error { if claims.TTL > a.authorizationCfg.MaxJwtTTL { return fmt.Errorf("TTL in token is larger than MaxTTL allowed") } @@ -137,7 +137,7 @@ func (a *oauthAuthority) validateTTL(claims *jwtClaims) error { return nil } -func (a *oauthAuthority) validatePermission(claims *jwtClaims, attributes *Attributes, data map[string]string) error { +func (a *oauthAuthority) validatePermission(claims *JWTClaims, attributes *Attributes, data map[string]string) error { groups := "" switch attributes.Permission { case PermissionRead: @@ -145,18 +145,14 @@ func (a *oauthAuthority) validatePermission(claims *jwtClaims, attributes *Attri case PermissionWrite: groups = data[common.DomainDataKeyForWriteGroups] default: - if claims.Admin { - return nil - } else { - return fmt.Errorf("token doesn't have permission for admin API") - } + return fmt.Errorf("token doesn't have permission for %v API", attributes.Permission) } // groups are separated by space - jwtGroups := strings.Split(groups, groupSeparator) - allowedGroups := strings.Split(claims.Groups, groupSeparator) + allowedGroups := strings.Split(groups, groupSeparator) // groups that allowed by domain configuration(in domainData) + jwtGroups := strings.Split(claims.Groups, groupSeparator) // groups that the request has associated with - for _, group1 := range jwtGroups { - for _, group2 := range allowedGroups { + for _, group1 := range allowedGroups { + for _, group2 := range jwtGroups { if group1 == group2 { return nil } diff --git a/common/authorization/oauthAutorizer_test.go b/common/authorization/oauthAutorizer_test.go index 205d40bc1a6..7203aef5070 100644 --- a/common/authorization/oauthAutorizer_test.go +++ b/common/authorization/oauthAutorizer_test.go @@ -211,7 +211,7 @@ func (s *oauthSuite) TestDifferentGroup() { s.att.Permission = PermissionWrite authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache) s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool { - return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have the right permission, jwt groups: [], allowed groups: [a b c]" + return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have the right permission, jwt groups: [a b c], allowed groups: []" })) result, _ := authorizer.Authorize(s.ctx, &s.att) s.Equal(result.Decision, DecisionDeny) @@ -222,8 +222,9 @@ func (s *oauthSuite) TestIncorrectPermission() { s.att.Permission = Permission(15) authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache) s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool { - return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have permission for admin API" + return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have permission for 15 API" })) - result, _ := authorizer.Authorize(s.ctx, &s.att) + result, err := authorizer.Authorize(s.ctx, &s.att) + s.NoError(err) s.Equal(result.Decision, DecisionDeny) } diff --git a/common/rsa.go b/common/rsa.go new file mode 100644 index 00000000000..b2a3c2d7a22 --- /dev/null +++ b/common/rsa.go @@ -0,0 +1,82 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package common + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "strings" +) + +type KeyType string + +const ( + KeyTypePrivate KeyType = "private key" + + KeyTypePublic KeyType = "public key" +) + +func loadRSAKey(path string, keyType KeyType) (interface{}, error) { + keyString, err := ioutil.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("invalid %s path %s", keyType, path) + } + block, _ := pem.Decode(keyString) + if block == nil || strings.ToLower(block.Type) != strings.ToLower(string(keyType)) { + return nil, fmt.Errorf("failed to parse PEM block containing the %s", keyType) + } + + switch keyType { + case KeyTypePrivate: + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse DER encoded %s: %s", keyType, err.Error()) + } + return key, nil + case KeyTypePublic: + key, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse DER encoded %s: %s", keyType, err.Error()) + } + return key, nil + default: + return nil, fmt.Errorf("invalid Key Type") + } +} + +func LoadRSAPublicKey(path string) (*rsa.PublicKey, error) { + key, err := loadRSAKey(path, KeyTypePublic) + if err != nil { + return nil, err + } + return key.(*rsa.PublicKey), err +} + +func LoadRSAPrivateKey(path string) (*rsa.PrivateKey, error) { + key, err := loadRSAKey(path, KeyTypePrivate) + if err != nil { + return nil, err + } + return key.(*rsa.PrivateKey), err +} diff --git a/common/util.go b/common/util.go index 4f250598280..4b5381cb3f2 100644 --- a/common/util.go +++ b/common/util.go @@ -22,12 +22,8 @@ package common import ( "context" - "crypto/rsa" - "crypto/x509" "encoding/json" - "encoding/pem" "fmt" - "io/ioutil" "math" "math/rand" "sort" @@ -958,25 +954,6 @@ func SleepWithMinDuration(desired time.Duration, available time.Duration) time.D return available - d } -// LoadRSAPublicKey loads a rsa.PublicKey from the given filepath -func LoadRSAPublicKey(path string) (*rsa.PublicKey, error) { - key, err := ioutil.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("invalid public key path %s", path) - } - block, _ := pem.Decode(key) - if block == nil || block.Type != "PUBLIC KEY" { - return nil, fmt.Errorf("failed to parse PEM block containing the public key") - } - - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse DER encoded public key: " + err.Error()) - } - publicKey := pub.(*rsa.PublicKey) - return publicKey, nil -} - // ConvertErrToGetTaskFailedCause converts error to GetTaskFailedCause func ConvertErrToGetTaskFailedCause(err error) types.GetTaskFailedCause { if IsContextTimeoutError(err) { diff --git a/tools/cli/app.go b/tools/cli/app.go index 81883defb1c..6333f7860bb 100644 --- a/tools/cli/app.go +++ b/tools/cli/app.go @@ -66,9 +66,14 @@ func NewCliApp() *cli.App { }, cli.StringFlag{ Name: FlagJWT, - Usage: "optional JWT for authorization", + Usage: "optional JWT for authorization. Either this or --jwt-private-key is needed for jwt authorization", EnvVar: "CADENCE_CLI_JWT", }, + cli.StringFlag{ + Name: FlagJWTPrivateKeyWithAlias, + Usage: "optional private key path to create JWT. Either this or --jwt is needed for jwt authorization. --jwt flag has priority over this one if both provided", + EnvVar: "CADENCE_CLI_JWT_PRIVATE_KEY", + }, } app.Commands = []cli.Command{ { diff --git a/tools/cli/factory.go b/tools/cli/factory.go index b75f27d4876..af00e6ebc2f 100644 --- a/tools/cli/factory.go +++ b/tools/cli/factory.go @@ -140,3 +140,7 @@ func (vm *versionMiddleware) Call(ctx context.Context, request *transport.Reques func getJWT(c *cli.Context) string { return c.GlobalString(FlagJWT) } + +func getJWTPrivateKey(c *cli.Context) string { + return c.GlobalString(FlagJWTPrivateKey) +} diff --git a/tools/cli/flags.go b/tools/cli/flags.go index b431fff5203..72af09ab8e6 100644 --- a/tools/cli/flags.go +++ b/tools/cli/flags.go @@ -277,6 +277,8 @@ const ( DelayStartSeconds = "delay_start_seconds" FlagConnectionAttributes = "conn_attrs" FlagJWT = "jwt" + FlagJWTPrivateKey = "jwt-private-key" + FlagJWTPrivateKeyWithAlias = FlagJWTPrivateKey + ", jwt-pk" ) var flagsForExecution = []cli.Flag{ diff --git a/tools/cli/util.go b/tools/cli/util.go index 0d94f892bb3..8cb42937c16 100644 --- a/tools/cli/util.go +++ b/tools/cli/util.go @@ -38,6 +38,7 @@ import ( "strings" "time" + "github.com/cristalhq/jwt/v3" "github.com/fatih/color" "github.com/urfave/cli" "github.com/valyala/fastjson" @@ -45,6 +46,7 @@ import ( "go.uber.org/cadence/client" "github.com/uber/cadence/common" + "github.com/uber/cadence/common/authorization" cc "github.com/uber/cadence/common/client" ) @@ -763,8 +765,27 @@ func getCliIdentity() string { return fmt.Sprintf("cadence-cli@%s", hostName) } +func processJWTFlags(ctx context.Context, cliCtx *cli.Context) context.Context { + path := getJWTPrivateKey(cliCtx) + t := getJWT(cliCtx) + var token string + + if t != "" { + token = t + } else if path != "" { + createdToken, err := createJWT(path) + if err != nil { + ErrorAndExit("Error creating JWT token", err) + } + token = *createdToken + } + + ctx = context.WithValue(ctx, CtxKeyJWT, token) + return ctx +} + func populateContextFromCLIContext(ctx context.Context, cliCtx *cli.Context) context.Context { - ctx = context.WithValue(ctx, CtxKeyJWT, getJWT(cliCtx)) + ctx = processJWTFlags(ctx, cliCtx) return ctx } @@ -1007,3 +1028,29 @@ func getInputFile(inputFile string) *os.File { } return f } + +// createJWT defines the logic to create a JWT +func createJWT(keyPath string) (*string, error) { + claims := authorization.JWTClaims{ + Admin: true, + Iat: time.Now().Unix(), + TTL: 60 * 10, + } + + privateKey, err := common.LoadRSAPrivateKey(keyPath) + if err != nil { + return nil, err + } + + signer, err := jwt.NewSignerRS(jwt.RS256, privateKey) + if err != nil { + return nil, err + } + builder := jwt.NewBuilder(signer) + token, err := builder.Build(claims) + if token == nil { + return nil, err + } + tokenString := token.String() + return &tokenString, nil +}