-
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathhttp.go
276 lines (265 loc) · 8.93 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
package jwkset
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"golang.org/x/time/rate"
)
var (
// ErrNewClient fails to create a new JWK Set client.
ErrNewClient = errors.New("failed to create new JWK Set client")
)
// HTTPClientOptions are options for creating a new JWK Set client.
type HTTPClientOptions struct {
// Given contains keys known from outside HTTP URLs.
Given Storage
// HTTPURLs are a mapping of HTTP URLs to JWK Set endpoints to storage implementations for the keys located at the
// URL. If empty, HTTP will not be used.
HTTPURLs map[string]Storage
// PrioritizeHTTP is a flag that indicates whether keys from the HTTP URL should be prioritized over keys from the
// given storage.
PrioritizeHTTP bool
// RateLimitWaitMax is the timeout for waiting for rate limiting to end.
RateLimitWaitMax time.Duration
// RefreshUnknownKID is non-nil to indicate that remote HTTP resources should be refreshed if a key with an unknown
// key ID is trying to be read. This makes reading methods block until the context is over, a key with the matching
// key ID is found in a refreshed remote resource, or all refreshes complete.
RefreshUnknownKID *rate.Limiter
}
// Client is a JWK Set client.
type httpClient struct {
given Storage
httpURLs map[string]Storage
prioritizeHTTP bool
rateLimitWaitMax time.Duration
refreshUnknownKID *rate.Limiter
}
// NewHTTPClient creates a new JWK Set client from remote HTTP resources.
func NewHTTPClient(options HTTPClientOptions) (Storage, error) {
if options.Given == nil && len(options.HTTPURLs) == 0 {
return nil, fmt.Errorf("%w: no given keys or HTTP URLs", ErrNewClient)
}
for u, store := range options.HTTPURLs {
if store == nil {
var err error
options.HTTPURLs[u], err = NewStorageFromHTTP(u, HTTPClientStorageOptions{})
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client storage for %q: %w", u, errors.Join(err, ErrNewClient))
}
}
}
given := options.Given
if given == nil {
given = NewMemoryStorage()
}
c := httpClient{
given: given,
httpURLs: options.HTTPURLs,
prioritizeHTTP: options.PrioritizeHTTP,
rateLimitWaitMax: options.RateLimitWaitMax,
refreshUnknownKID: options.RefreshUnknownKID,
}
return c, nil
}
// NewDefaultHTTPClient creates a new JWK Set client with default options from remote HTTP resources.
//
// The default behavior is to:
// 1. Refresh remote HTTP resources every hour.
// 2. Prioritize keys from remote HTTP resources over keys from the given storage.
// 3. Refresh remote HTTP resources if a key with an unknown key ID is trying to be read, with a rate limit of 5 minutes.
// 4. Log to slog.Default() if a refresh fails.
func NewDefaultHTTPClient(urls []string) (Storage, error) {
return NewDefaultHTTPClientCtx(context.Background(), urls)
}
// NewDefaultHTTPClientCtx is the same as NewDefaultHTTPClient, but with a context that can end the refresh goroutine.
func NewDefaultHTTPClientCtx(ctx context.Context, urls []string) (Storage, error) {
clientOptions := HTTPClientOptions{
HTTPURLs: make(map[string]Storage),
RateLimitWaitMax: time.Minute,
RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1),
}
for _, u := range urls {
refreshErrorHandler := func(ctx context.Context, err error) {
slog.Default().ErrorContext(ctx, "Failed to refresh HTTP JWK Set from remote HTTP resource.",
"error", err,
"url", u,
)
}
options := HTTPClientStorageOptions{
Ctx: ctx,
NoErrorReturnFirstHTTPReq: true,
RefreshErrorHandler: refreshErrorHandler,
RefreshInterval: time.Hour,
}
c, err := NewStorageFromHTTP(u, options)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client storage for %q: %w", u, errors.Join(err, ErrNewClient))
}
clientOptions.HTTPURLs[u] = c
}
return NewHTTPClient(clientOptions)
}
func (c httpClient) KeyDelete(ctx context.Context, keyID string) (ok bool, err error) {
ok, err = c.given.KeyDelete(ctx, keyID)
if err != nil && !errors.Is(err, ErrKeyNotFound) {
return false, fmt.Errorf("failed to delete key with ID %q from given storage due to error: %w", keyID, err)
}
if ok {
return true, nil
}
for _, store := range c.httpURLs {
ok, err = store.KeyDelete(ctx, keyID)
if err != nil && !errors.Is(err, ErrKeyNotFound) {
return false, fmt.Errorf("failed to delete key with ID %q from HTTP storage due to error: %w", keyID, err)
}
if ok {
return true, nil
}
}
return false, nil
}
func (c httpClient) KeyRead(ctx context.Context, keyID string) (jwk JWK, err error) {
if !c.prioritizeHTTP {
jwk, err = c.given.KeyRead(ctx, keyID)
switch {
case errors.Is(err, ErrKeyNotFound):
// Do nothing.
case err != nil:
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in given storage due to error: %w", keyID, err)
default:
return jwk, nil
}
}
for _, store := range c.httpURLs {
jwk, err = store.KeyRead(ctx, keyID)
switch {
case errors.Is(err, ErrKeyNotFound):
continue
case err != nil:
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in HTTP storage due to error: %w", keyID, err)
default:
return jwk, nil
}
}
if c.prioritizeHTTP {
jwk, err = c.given.KeyRead(ctx, keyID)
switch {
case errors.Is(err, ErrKeyNotFound):
// Do nothing.
case err != nil:
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in given storage due to error: %w", keyID, err)
default:
return jwk, nil
}
}
if c.refreshUnknownKID != nil {
var cancel context.CancelFunc = func() {}
if c.rateLimitWaitMax > 0 {
ctx, cancel = context.WithTimeout(ctx, c.rateLimitWaitMax)
}
defer cancel()
err = c.refreshUnknownKID.Wait(ctx)
if err != nil {
return JWK{}, fmt.Errorf("failed to wait for JWK Set refresh rate limiter due to error: %w", err)
}
for _, store := range c.httpURLs {
s, ok := store.(httpStorage)
if !ok {
continue
}
err = s.refresh(ctx)
if err != nil {
if s.options.RefreshErrorHandler != nil {
s.options.RefreshErrorHandler(ctx, err)
}
continue
}
jwk, err = store.KeyRead(ctx, keyID)
switch {
case errors.Is(err, ErrKeyNotFound):
// Do nothing.
case err != nil:
return JWK{}, fmt.Errorf("failed to find JWT key with ID %q in HTTP storage due to error: %w", keyID, err)
default:
return jwk, nil
}
}
}
return JWK{}, fmt.Errorf("%w %q", ErrKeyNotFound, keyID)
}
func (c httpClient) KeyReadAll(ctx context.Context) ([]JWK, error) {
jwks, err := c.given.KeyReadAll(ctx)
if err != nil {
return nil, fmt.Errorf("failed to snapshot given keys due to error: %w", err)
}
for u, store := range c.httpURLs {
j, err := store.KeyReadAll(ctx)
if err != nil {
return nil, fmt.Errorf("failed to snapshot HTTP keys from %q due to error: %w", u, err)
}
jwks = append(jwks, j...)
}
return jwks, nil
}
func (c httpClient) KeyWrite(ctx context.Context, jwk JWK) error {
return c.given.KeyWrite(ctx, jwk)
}
func (c httpClient) JSON(ctx context.Context) (json.RawMessage, error) {
m, err := c.combineStorage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to combine storage due to error: %w", err)
}
return m.JSON(ctx)
}
func (c httpClient) JSONPublic(ctx context.Context) (json.RawMessage, error) {
m, err := c.combineStorage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to combine storage due to error: %w", err)
}
return m.JSONPublic(ctx)
}
func (c httpClient) JSONPrivate(ctx context.Context) (json.RawMessage, error) {
m, err := c.combineStorage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to combine storage due to error: %w", err)
}
return m.JSONPrivate(ctx)
}
func (c httpClient) JSONWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (json.RawMessage, error) {
m, err := c.combineStorage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to combine storage due to error: %w", err)
}
return m.JSONWithOptions(ctx, marshalOptions, validationOptions)
}
func (c httpClient) Marshal(ctx context.Context) (JWKSMarshal, error) {
m, err := c.combineStorage(ctx)
if err != nil {
return JWKSMarshal{}, fmt.Errorf("failed to combine storage due to error: %w", err)
}
return m.Marshal(ctx)
}
func (c httpClient) MarshalWithOptions(ctx context.Context, marshalOptions JWKMarshalOptions, validationOptions JWKValidateOptions) (JWKSMarshal, error) {
m, err := c.combineStorage(ctx)
if err != nil {
return JWKSMarshal{}, fmt.Errorf("failed to combine storage due to error: %w", err)
}
return m.MarshalWithOptions(ctx, marshalOptions, validationOptions)
}
func (c httpClient) combineStorage(ctx context.Context) (Storage, error) {
jwks, err := c.KeyReadAll(ctx)
if err != nil {
return nil, fmt.Errorf("failed to snapshot keys due to error: %w", err)
}
m := NewMemoryStorage()
for _, jwk := range jwks {
err = m.KeyWrite(ctx, jwk)
if err != nil {
return nil, fmt.Errorf("failed to write key to memory storage due to error: %w", err)
}
}
return m, nil
}