Skip to content

Commit

Permalink
Add Whitelist
Browse files Browse the repository at this point in the history
  • Loading branch information
lestrrat committed Mar 22, 2022
1 parent b33fb19 commit 838d3ec
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 8 deletions.
30 changes: 30 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@ import (
"time"
)

// Whitelist is an interface for a set of URL whitelists. When provided
// to fetching operations, urls are checked against this object, and
// the object must return true for urls to be fetched.
type Whitelist interface {
IsAllowed(string) bool
}

// WhitelistFunc is a httprc.Whitelist object based on a function.
// You can perform any sort of check against the given URL to determine
// if it can be fetched or not.
type WhitelistFunc func(string) bool

func (w WhitelistFunc) IsAllowed(u string) bool {
return w(u)
}

// ErrSink is an abstraction that allows users to consume errors
// produced while the cache queue is running.
type HTTPClient interface {
Get(string) (*http.Response, error)
}
Expand All @@ -27,6 +45,7 @@ type HTTPClient interface {
type Cache struct {
mu sync.RWMutex
queue *queue
wl Whitelist
}

const defaultRefreshWindow = 15 * time.Minute
Expand All @@ -52,6 +71,7 @@ func New(ctx context.Context, options ...ConstructorOption) *Cache {
var refreshWindow time.Duration
var errSink ErrSink
var nfetchers int
var wl Whitelist
for _, option := range options {
//nolint:forcetypeassert
switch option.Ident() {
Expand All @@ -61,6 +81,8 @@ func New(ctx context.Context, options ...ConstructorOption) *Cache {
nfetchers = option.Value().(int)
case identErrSink{}:
errSink = option.Value().(ErrSink)
case identWhitelist{}:
wl = option.Value().(Whitelist)
}
}

Expand All @@ -77,6 +99,7 @@ func New(ctx context.Context, options ...ConstructorOption) *Cache {

return &Cache{
queue: queue,
wl: wl,
}
}

Expand All @@ -87,6 +110,13 @@ func New(ctx context.Context, options ...ConstructorOption) *Cache {
func (c *Cache) Register(u string, options ...RegisterOption) error {
c.mu.Lock()
defer c.mu.Unlock()

if wl := c.wl; wl != nil {
if !wl.IsAllowed(u) {
return fmt.Errorf(`httprc.Cache: url %q has been rejected by whitelist`, u)
}
}

c.queue.Register(u, options...)
return nil
}
Expand Down
18 changes: 10 additions & 8 deletions httprc_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package httprc_test
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"sync"
Expand All @@ -23,7 +22,7 @@ func Example() {
msg := helloWorld

srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(`Cache-Control`, fmt.Sprintf(`max-age=%d`, 3))
w.Header().Set(`Cache-Control`, fmt.Sprintf(`max-age=%d`, 2))
w.WriteHeader(http.StatusOK)
mu.RLock()
fmt.Fprint(w, msg)
Expand All @@ -35,24 +34,27 @@ func Example() {
defer cancel()

errSink := httprc.ErrSinkFunc(func(err error) {
log.Printf("%s", err)
fmt.Printf("%s\n", err)
})

c := httprc.New(ctx,
httprc.WithErrSink(errSink),
httprc.WithRefreshWindow(time.Second), // force checks every second
)

c.Register(srv.URL, httprc.WithHTTPClient(srv.Client()))
c.Register(srv.URL,
httprc.WithHTTPClient(srv.Client()), // we need client with TLS settings
httprc.WithMinRefreshInterval(time.Second), // allow max-age=1 (smallest)
)

payload, err := c.Get(ctx, srv.URL)
if err != nil {
log.Printf("%s", err)
fmt.Printf("%s\n", err)
return
}

if string(payload.([]byte)) != helloWorld {
log.Printf("payload mismatch: %s", payload)
fmt.Printf("payload mismatch: %s\n", payload)
return
}

Expand All @@ -64,12 +66,12 @@ func Example() {

payload, err = c.Get(ctx, srv.URL)
if err != nil {
log.Printf("%s", err)
fmt.Printf("%s\n", err)
return
}

if string(payload.([]byte)) != goodbyeWorld {
log.Printf("payload mismatch: %s", payload)
fmt.Printf("payload mismatch: %s\n", payload)
return
}

Expand Down
4 changes: 4 additions & 0 deletions httprc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import (
"github.com/stretchr/testify/assert"
)

func TestWhitelist(t *testing.T) {

}

type dummyErrSink struct {
mu sync.RWMutex
errors []error
Expand Down
6 changes: 6 additions & 0 deletions options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ options:
comment: |
WithFetchWorkerCount specifies the number of HTTP fetch workers that are spawned
in the backend. By default 3 workers are spawned.
- ident: Whitelist
interface: ConstructorOption
argument_type: Whitelist
comment: |
WithWhitelist specifies the Whitelist object that can control which URLs can be
registered to the cache.
- ident: Transformer
interface: RegisterOption
argument_type: Transformer
Expand Down
11 changes: 11 additions & 0 deletions options_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type identMinRefreshInterval struct{}
type identRefreshInterval struct{}
type identRefreshWindow struct{}
type identTransformer struct{}
type identWhitelist struct{}

func (identErrSink) String() string {
return "WithErrSink"
Expand Down Expand Up @@ -70,6 +71,10 @@ func (identTransformer) String() string {
return "WithTransformer"
}

func (identWhitelist) String() string {
return "WithWhitelist"
}

// WithErrSink specifies the `httprc.ErrSink` object that handles errors
// that occurred during the cache's execution. For example, you will be
// able to intercept errors that occurred during the execution of Transformers.
Expand Down Expand Up @@ -147,3 +152,9 @@ func WithRefreshWindow(v time.Duration) ConstructorOption {
func WithTransformer(v Transformer) RegisterOption {
return &registerOption{option.New(identTransformer{}, v)}
}

// WithWhitelist specifies the Whitelist object that can control which URLs can be
// registered to the cache.
func WithWhitelist(v Whitelist) ConstructorOption {
return &constructorOption{option.New(identWhitelist{}, v)}
}
1 change: 1 addition & 0 deletions options_gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ func TestOptionIdent(t *testing.T) {
require.Equal(t, "WithRefreshInterval", identRefreshInterval{}.String())
require.Equal(t, "WithRefreshWindow", identRefreshWindow{}.String())
require.Equal(t, "WithTransformer", identTransformer{}.String())
require.Equal(t, "WithWhitelist", identWhitelist{}.String())
}

0 comments on commit 838d3ec

Please sign in to comment.