diff --git a/jose.go b/jose.go index 9bc6482..6942903 100644 --- a/jose.go +++ b/jose.go @@ -458,3 +458,14 @@ func retrieveActualKey(headers map[string]interface{}, payload string, key inter return key, nil } + +func MatchAlg(expected string, key interface{}) func(headers map[string]interface{}, payload string) interface{} { + return func(headers map[string]interface{}, payload string) interface{} { + alg := headers["alg"].(string) + if expected == alg { + return key + } + + return errors.New("Expected alg to be '" + expected + "' but got '" + alg + "'") + } +} diff --git a/jose_test.go b/jose_test.go index eff9e95..dffd4dc 100644 --- a/jose_test.go +++ b/jose_test.go @@ -2600,6 +2600,33 @@ func (s *TestSuite) TestDeregisterJwc(c *C) { c.Assert(test, Equals, "") } +func (s *TestSuite) TestDecode_TwoPhased_MatchAlg(c *C) { + //given + token := "eyJhbGciOiJFUzI1NiIsImN0eSI6InRleHRcL3BsYWluIn0.eyJoZWxsbyI6ICJ3b3JsZCJ9.EVnmDMlz-oi05AQzts-R3aqWvaBlwVZddWkmaaHyMx5Phb2NSLgyI0kccpgjjAyo1S5KCB3LIMPfmxCX_obMKA" + + //when + test, _, err := Decode(token, MatchAlg("ES256", Ecc256Public())) + + //then + c.Assert(err, IsNil) + c.Assert(test, Equals, `{"hello": "world"}`) +} + +func (s *TestSuite) TestDecode_TwoPhased_MatchAlg_Invalid(c *C) { + //given + token := "eyJhbGciOiJFUzI1NiIsImN0eSI6InRleHRcL3BsYWluIn0.eyJoZWxsbyI6ICJ3b3JsZCJ9.EVnmDMlz-oi05AQzts-R3aqWvaBlwVZddWkmaaHyMx5Phb2NSLgyI0kccpgjjAyo1S5KCB3LIMPfmxCX_obMKA" + + //when + test, headers, err := Decode(token, MatchAlg("RS256", Ecc256Public())) + + fmt.Printf("\nalg doesn't match err=%v\n", err) + + //then + c.Assert(headers, IsNil) + c.Assert(err, NotNil) + c.Assert(test, Equals, "") +} + // test utils func PubKey() *rsa.PublicKey { key, _ := Rsa.ReadPublic([]byte(pubKey))