-
Notifications
You must be signed in to change notification settings - Fork 1
/
sd-jwt.go
369 lines (304 loc) · 11 KB
/
sd-jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
// Package go_sd_jwt provides a library for creating and validating SD-JWTs.
// The resulting SdJwt object exposes methods for retrieving the claims and
// disclosures as well as retrieving all disclosed claims in line with the specification.
package go_sd_jwt
import (
"crypto"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"slices"
"strings"
"time"
"github.com/MichaelFraser99/go-sd-jwt/disclosure"
e "github.com/MichaelFraser99/go-sd-jwt/internal/error"
"github.com/MichaelFraser99/go-sd-jwt/internal/utils"
"github.com/MichaelFraser99/go-sd-jwt/kbjwt"
)
// SdJwt this object represents a valid SD-JWT. Created using the FromToken function which performs the required validation.
// Helper methods are provided for retrieving the contents
type SdJwt struct {
Head map[string]any
Body map[string]any
Signature string
KbJwt *kbjwt.KbJwt
Disclosures []disclosure.Disclosure
}
// New
// Creates a new SD-JWT from a JWT format token.
// The token is validated inline with the SD-JWT specification.
// If the token is valid, a new SdJwt object is returned.
// If a kb-jwt is included, the contents of this too will be validated.
func New(token string) (*SdJwt, error) {
return validateJwt(token)
}
// NewFromComponents
// Creates a new SD-JWT from the individual components optionally taking in a kbJwt.
// The token is validated inline with the SD-JWT specification.
// If the token is valid, a new SdJwt object is returned.
// If a kb-jwt is included, the contents of this too will be validated.
// This function is designed to cater for the much more free-form JSON serialization options on offer
func NewFromComponents(protected, payload, signature string, disclosures []string, kbJwt *string) (*SdJwt, error) {
token := fmt.Sprintf("%s.%s.%s", protected, payload, signature)
if len(disclosures) > 0 {
token = fmt.Sprintf("%s~%s~", token, strings.Join(disclosures, "~"))
}
if kbJwt != nil {
token = fmt.Sprintf("%s%s", token, *kbJwt)
}
return validateJwt(token)
}
// Token This method returns the SD Jwt in its current state, in a token format, as a string
func (s *SdJwt) Token() (*string, error) {
headBytes, err := json.Marshal(s.Head)
if err != nil {
return nil, err
}
b64Head := base64.RawURLEncoding.EncodeToString(headBytes)
bodyBytes, err := json.Marshal(s.Body)
if err != nil {
return nil, err
}
b64Body := base64.RawURLEncoding.EncodeToString(bodyBytes)
disclosureString := ""
for _, d := range s.Disclosures {
disclosureString += fmt.Sprintf("%s~", d.EncodedValue)
}
tokenString := fmt.Sprintf("%s.%s.%s~%s", b64Head, b64Body, s.Signature, disclosureString)
if s.KbJwt != nil {
tokenString += s.KbJwt.Token
}
return utils.Pointer(tokenString), nil
}
// AddKeyBindingJwt This method adds a keybinding jwt signed with the provided signer interface and hash
// If the provided hash does not match the hash algorithm specified in the SD Jwt (or isn't sha256 if no _sd_alg claim present), an error will be thrown
// The sd_hash value will be set based off of all disclosures present in the current sd jwt object
func (s *SdJwt) AddKeyBindingJwt(signer crypto.Signer, h crypto.Hash, alg, aud, nonce string) error {
if s.KbJwt != nil {
return errors.New("key binding jwt already exists")
}
sdAlg, ok := s.Body["_sd_alg"].(string)
if (ok && !strings.EqualFold(sdAlg, h.String())) || (!ok && strings.ToLower(h.String()) != "sha-256") {
return errors.New("key binding jwt hashing algorithm does not match the hashing algorithm specified in the sd-jwt - if sd-jwt does not specify a hashing algorithm, sha-256 is selected by default")
}
kbHead := map[string]string{
"typ": "kb+jwt",
"alg": strings.ToUpper(alg),
}
// calculate sd hash
bSdHead, err := json.Marshal(s.Head)
if err != nil {
return fmt.Errorf("error marshalling sd-jwt header: %w", err)
}
b64SdHead := make([]byte, base64.RawURLEncoding.EncodedLen(len(bSdHead)))
base64.RawURLEncoding.Encode(b64SdHead, bSdHead)
bSdBody, err := json.Marshal(s.Body)
if err != nil {
return fmt.Errorf("error marshalling sd-jwt body: %w", err)
}
b64SdBody := make([]byte, base64.RawURLEncoding.EncodedLen(len(bSdBody)))
base64.RawURLEncoding.Encode(b64SdBody, bSdBody)
disclosureString := ""
for _, d := range s.Disclosures {
disclosureString += d.EncodedValue + "~"
}
fullToken := fmt.Sprintf("%s.%s.%s~%s", string(b64SdHead), string(b64SdBody), s.Signature, disclosureString)
hasher := h.New()
hasher.Write([]byte(fullToken))
hashedToken := hasher.Sum(nil)
b64SdHash := make([]byte, base64.RawURLEncoding.EncodedLen(len(hashedToken)))
base64.RawURLEncoding.Encode(b64SdHash, hashedToken)
kbJwt := kbjwt.KbJwt{
Iat: utils.Pointer(time.Now().Unix()),
Aud: utils.Pointer(aud),
Nonce: utils.Pointer(nonce),
SdHash: utils.Pointer(string(b64SdHash)),
}
bKbHead, err := json.Marshal(kbHead)
if err != nil {
return fmt.Errorf("error marshalling kb-jwt header: %w", err)
}
b64KbHead := make([]byte, base64.RawURLEncoding.EncodedLen(len(bKbHead)))
base64.RawURLEncoding.Encode(b64KbHead, bKbHead)
bKbBody, err := json.Marshal(kbJwt)
if err != nil {
return fmt.Errorf("error marshalling kb-jwt body: %w", err)
}
b64KbBody := make([]byte, base64.RawURLEncoding.EncodedLen(len(bKbBody)))
base64.RawURLEncoding.Encode(b64KbBody, bKbBody)
signInput := string(b64KbHead) + "." + string(b64KbBody)
sig, err := signer.Sign(rand.Reader, []byte(signInput), nil)
if err != nil {
return fmt.Errorf("error signing kb-jwt: %w", err)
}
b64Sig := make([]byte, base64.RawURLEncoding.EncodedLen(len(sig)))
base64.RawURLEncoding.Encode(b64Sig, sig)
kbJwt.Token = signInput + "." + string(b64Sig)
s.KbJwt = &kbJwt
return nil
}
func GetHash(hashString string) (hash.Hash, error) {
var h hash.Hash
switch strings.ToLower(hashString) {
case "sha-256", "":
// default to sha-256
h = sha256.New()
case "sha-224":
h = sha256.New224()
case "sha-512":
h = sha512.New()
case "sha-384":
h = sha512.New384()
case "sha-512/224":
h = sha512.New512_224()
case "sha-512/256":
h = sha512.New512_256()
case "sha3-224":
h = crypto.SHA3_224.New()
case "sha3-256":
h = crypto.SHA3_256.New()
case "sha3-384":
h = crypto.SHA3_384.New()
case "sha3-512":
h = crypto.SHA3_512.New()
default:
return nil, errors.New("unsupported _sd_alg: " + hashString)
}
return h, nil
}
// GetDisclosedClaims returns the claims that were disclosed in the token or included as plaintext values.
// This function will error one of the following scenarios is encountered:
// 1. The SD-JWT contains a disclosure that does not match an included digest
// 2. The SD-JWT contains a malformed _sd claim
// 3. The SD-JWT contains an unsupported value for the _sd_alg claim
// 4. The SD-JWT has a disclosure that is malformed for the use (e.g. doesn't contain a claim name for a non-array digest)
func (s *SdJwt) GetDisclosedClaims() (map[string]any, error) {
disclosuresToCheck := make([]disclosure.Disclosure, len(s.Disclosures))
copy(disclosuresToCheck, s.Disclosures)
var h hash.Hash
var err error
strAlg, ok := s.Body["_sd_alg"].(string)
if ok {
h, err = GetHash(strAlg)
if err != nil {
return nil, err
}
} else {
h = sha256.New()
}
bodyMap := utils.CopyMap(s.Body)
for {
var indexesFound []int
for i := 0; i < len(disclosuresToCheck); i++ {
d := disclosuresToCheck[i]
h.Write([]byte(d.EncodedValue))
hashedDisclosures := h.Sum(nil)
base64HashedDisclosureBytes := make([]byte, base64.RawURLEncoding.EncodedLen(len(hashedDisclosures)))
base64.RawURLEncoding.Encode(base64HashedDisclosureBytes, hashedDisclosures)
found, err := utils.ValidateSDClaims(utils.PointerMap(bodyMap), &d, string(base64HashedDisclosureBytes))
if err != nil {
return nil, err
}
if found {
indexesFound = append(indexesFound, i)
}
h.Reset()
}
if len(indexesFound) == 0 && len(disclosuresToCheck) > 0 {
return nil, fmt.Errorf("no matching digest found for: %v", utils.StringifyDisclosures(disclosuresToCheck))
}
slices.Sort(indexesFound)
slices.Reverse(indexesFound)
for _, i := range indexesFound {
disclosuresToCheck = append(disclosuresToCheck[:i], disclosuresToCheck[i+1:]...)
}
if len(disclosuresToCheck) == 0 {
break
}
}
bodyMap = utils.StripSDClaims(bodyMap)
return bodyMap, nil
}
func validateJwt(token string) (*SdJwt, error) {
sdJwt := &SdJwt{}
sections := strings.Split(token, "~")
if len(sections) < 2 {
return nil, fmt.Errorf("%wtoken has no specified disclosures", e.InvalidToken)
}
tokenSections := strings.Split(sections[0], ".")
if len(tokenSections) != 3 {
return nil, fmt.Errorf("%wtoken is not a valid JWT", e.InvalidToken)
}
jwtHead := map[string]any{}
hb, err := base64.RawURLEncoding.DecodeString(tokenSections[0])
if err != nil {
return nil, fmt.Errorf("%wfailed to decode header: %s", e.InvalidToken, err.Error())
}
err = json.Unmarshal(hb, &jwtHead)
if err != nil {
return nil, fmt.Errorf("%wfailed to json parse decoded header: %s", e.InvalidToken, err.Error())
}
sdJwt.Head = jwtHead
sdJwt.Signature = tokenSections[2]
if sections[len(sections)-1] != "" && sections[len(sections)-1][len(sections[len(sections)-1])-1:] != "~" {
kbJwt := utils.CheckForKbJwt(sections[len(sections)-1])
if kbJwt == nil {
return nil, fmt.Errorf("%wif no kb-jwt is provided, the last disclosure must be followed by a ~", e.InvalidToken)
}
sections = sections[:len(sections)-1]
if kbJwt != nil {
sdJwt.KbJwt, err = kbjwt.NewFromToken(*kbJwt)
if err != nil {
return nil, fmt.Errorf("failed to extract kb-jwt: %w", err)
}
}
}
disclosures, err := utils.ValidateDisclosures(sections[1:])
if err != nil {
return nil, fmt.Errorf("%wfailed to validate disclosures: %s", e.InvalidToken, err.Error())
}
sdJwt.Disclosures = disclosures
b, err := base64.RawURLEncoding.DecodeString(tokenSections[1])
if err != nil {
return nil, fmt.Errorf("%wfailed to decode payload: %s", e.InvalidToken, err.Error())
}
var m map[string]any
err = json.Unmarshal(b, &m)
if err != nil {
return nil, fmt.Errorf("%wfailed to json parse decoded payload: %s", e.InvalidToken, err.Error())
}
err = utils.ValidateDigests(m)
if err != nil {
return nil, fmt.Errorf("%wfailed to validate digests: %s", e.InvalidToken, err.Error())
}
sdJwt.Body = m
if sdJwt.KbJwt != nil {
tokenBytes := []byte(fmt.Sprintf("%s~", strings.Join(sections, "~")))
var h hash.Hash
strAlg, ok := sdJwt.Body["_sd_alg"].(string)
if ok {
h, err = GetHash(strAlg)
if err != nil {
return nil, err
}
} else {
h = sha256.New()
}
_, err = h.Write(tokenBytes)
if err != nil {
return nil, fmt.Errorf("failed to hash provided token for kbjwt validation: %s", err.Error())
}
hashedToken := h.Sum(nil)
b64Ht := make([]byte, base64.RawURLEncoding.EncodedLen(len(hashedToken)))
base64.RawURLEncoding.Encode(b64Ht, hashedToken)
if string(b64Ht) != *sdJwt.KbJwt.SdHash {
return nil, fmt.Errorf("%wsd hash validation failed: calculated hash %s does not equal provided hash %s", e.InvalidToken, string(b64Ht), *sdJwt.KbJwt.SdHash)
}
}
return sdJwt, nil
}