Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
api: add test clock support (#264)
Browse files Browse the repository at this point in the history
This makes the API aware of test clocks, and will pass them along in
requests to Stripe that have support for test clocks.

This also adds WithClock to the tier package for passing a clock ID
through to the API using contexts.

Also, expose Created date on OrgInfo. This was useful for testing test
clocks, but is a useful field to have in general.
  • Loading branch information
bmizerany authored Feb 25, 2023
1 parent e99e82d commit 7322395
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 247 deletions.
49 changes: 49 additions & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/exp/slices"
"tier.run/api/apitypes"
"tier.run/api/materialize"
"tier.run/client/tier"
"tier.run/control"
"tier.run/refs"
"tier.run/stripe"
Expand Down Expand Up @@ -163,6 +164,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error {
clockID := r.Header.Get(tier.ClockHeader)
r = r.Clone(control.WithClock(r.Context(), clockID))

switch r.URL.Path {
case "/v1/whoami":
return h.serveWhoAmI(w, r)
Expand All @@ -184,6 +188,8 @@ func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error {
return h.servePush(w, r)
case "/v1/payment_methods":
return h.servePaymentMethods(w, r)
case "/v1/clock":
return h.serveClock(w, r)
default:
return trweb.NotFound
}
Expand Down Expand Up @@ -287,6 +293,7 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) error {
Name: info.Name,
Description: info.Description,
Phone: info.Phone,
Created: info.CreatedAt(),
Metadata: info.Metadata,
PaymentMethod: info.PaymentMethod,
InvoiceSettings: apitypes.InvoiceSettings(info.InvoiceSettings),
Expand Down Expand Up @@ -412,6 +419,48 @@ func (h *Handler) servePaymentMethods(w http.ResponseWriter, r *http.Request) er
})
}

func (h *Handler) serveClock(w http.ResponseWriter, r *http.Request) error {
writeResp := func(c *control.Clock) error {
return httpJSON(w, apitypes.ClockResponse{
ID: c.ID(),
Link: c.Link(),
Present: c.Present(),
Status: c.Status(),
})
}

switch r.Method {
case "GET":
clockID := r.FormValue("id")
c := h.c.ClockFromID(clockID)
if err := c.Sync(r.Context()); err != nil {
return err
}
return writeResp(c)
case "POST":
var v apitypes.ClockRequest
if err := trweb.DecodeStrict(r, &v); err != nil {
return err
}

if v.ID == "" {
c, err := h.c.NewClock(r.Context(), v.Name, v.Present)
if err != nil {
return err
}
return writeResp(c)
} else {
c := h.c.ClockFromID(v.ID)
if err := c.Advance(r.Context(), v.Present); err != nil {
return err
}
return writeResp(c)
}
default:
return trweb.MethodNotAllowed
}
}

func httpJSON(w http.ResponseWriter, v any) error {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
Expand Down
50 changes: 41 additions & 9 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func newTestClient(t *testing.T) (*tier.Client, *control.Client) {
s := httptest.NewTLSServer(h)
t.Cleanup(s.Close)
tc := &tier.Client{
Logf: t.Logf,
BaseURL: s.URL,
HTTPClient: s.Client(),
}
Expand Down Expand Up @@ -135,7 +136,6 @@ func TestAPISubscribe(t *testing.T) {
defer maybeFailNow(t)
fn := mpn(feature)
err := tc.ReportUsage(ctx, org, fn.String(), n, &tier.ReportParams{
At: time.Now().Add(1 * time.Minute),
Clobber: false,
})
diff.Test(t, t.Errorf, err, wantErr)
Expand Down Expand Up @@ -465,15 +465,46 @@ func TestWhoAmI(t *testing.T) {
}
}

func TestClock(t *testing.T) {
t.Parallel()

tc, _ := newTestClient(t)

now := time.Now().Truncate(time.Second)
ctx, err := tc.WithClock(context.Background(), t.Name(), now)
if err != nil {
t.Fatal(err)
}

now = now.Add(100 * time.Hour)
if err := tc.Advance(ctx, now); err != nil {
t.Fatal(err)
}

if err := tc.Subscribe(ctx, "org:test"); err != nil {
t.Fatal(err)
}

c, err := tc.LookupOrg(ctx, "org:test")
if err != nil {
t.Fatal(err)
}

if !c.Created.Equal(now) {
t.Errorf("c.Created = %v; want %v", c.Created, now)
}
}

func TestTierReport(t *testing.T) {
t.Parallel()

ctx := context.Background()
tc, cc := newTestClient(t)
tc, _ := newTestClient(t)

farIntoTheFuture := time.Now().Add(24 * time.Hour)
clock := stroke.NewClock(t, cc.Stripe, t.Name(), farIntoTheFuture)
cc.Clock = clock.ID()
now := time.Now()
ctx, err := tc.WithClock(context.Background(), t.Name(), now)
if err != nil {
t.Fatal(err)
}

pr, err := tc.PushJSON(ctx, []byte(`
{
Expand Down Expand Up @@ -503,18 +534,19 @@ func TestTierReport(t *testing.T) {
}

report := func(n int, at time.Time, wantErr error) {
t.Helper()
if err := tc.ReportUsage(ctx, "org:test", "feature:t", n, &tier.ReportParams{
// Force 'now' at Stripe.
At: time.Time{},
At: at,

Clobber: false,
}); !errors.Is(err, wantErr) {
t.Errorf("err = %v; want %v", err, wantErr)
}
}

report(10, time.Time{}, nil)
report(10, clock.Now().Add(1*time.Minute), nil)
report(10, now, nil)
report(10, now.Add(1*time.Minute), nil)

limit, used, err := tc.LookupLimit(ctx, "org:test", "feature:t")
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type OrgInfo struct {
Name string `json:"name"`
Description string `json:"description"`
Phone string `json:"phone"`
Created time.Time `json:"created"`
Metadata map[string]string `json:"metadata"`

PaymentMethod string `json:"payment_method"`
Expand Down Expand Up @@ -123,3 +124,16 @@ type WhoAmIResponse struct {
Isolated bool `json:"isolated"`
URL string `json:"url"`
}

type ClockRequest struct {
ID string
Name string
Present time.Time
}

type ClockResponse struct {
ID string `json:"id"`
Link string `json:"link"`
Present time.Time `json:"present"`
Status string `json:"status"`
}
93 changes: 91 additions & 2 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ package tier
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/url"
"os"
"time"

"tailscale.com/logtail/backoff"
"tier.run/api/apitypes"
"tier.run/fetch"
"tier.run/refs"
)

// ClockHeader is the header used to pass the clock ID to the tier sidecar.
// It is exported for use by the sidecar API. Most users want to use WithClock.
const ClockHeader = "Tier-Clock"

const Inf = 1<<63 - 1

type Client struct {
Expand All @@ -25,6 +31,27 @@ type Client struct {

BaseURL string // the base URL of the tier sidecar; default is http://127.0.0.1:8080
HTTPClient *http.Client

Logf func(fmt string, args ...any)
}

func (c *Client) logf(fmt string, args ...any) {
if c.Logf != nil {
c.Logf(fmt, args...)
}
}

type clockKey struct{}

// WithClock returns a context with the provided clock ID set. The clock ID is
// pass via the Tier-Clock header to be used by the tier sidecar.
func WithClock(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, clockKey{}, id)
}

func clockFromContext(ctx context.Context) string {
id, _ := ctx.Value(clockKey{}).(string)
return id
}

// FromEnv returns a Client configured from the environment. The BaseURL is set
Expand Down Expand Up @@ -80,7 +107,7 @@ func (c *Client) PullJSON(ctx context.Context) ([]byte, error) {
return fetchOK[[]byte, *apitypes.Error](ctx, c, "GET", "/v1/pull", nil)
}

// WhoIS reports the Stripe ID for the given organization.
// WhoIs reports the Stripe customer ID for the provided org. OrgInfo is not set.
func (c *Client) WhoIs(ctx context.Context, org string) (apitypes.WhoIsResponse, error) {
return fetchOK[apitypes.WhoIsResponse, *apitypes.Error](ctx, c, "GET", "/v1/whois?org="+org, nil)
}
Expand Down Expand Up @@ -299,7 +326,69 @@ func (c *Client) WhoAmI(ctx context.Context) (apitypes.WhoAmIResponse, error) {
return fetchOK[apitypes.WhoAmIResponse, *apitypes.Error](ctx, c, "GET", "/v1/whoami", nil)
}

// WithClock creates a new test clock with the provided name and start time,
// and returns a new context with the clock ID set.
//
// It is an error to call WithClock if a clock is already set in the context.
func (c *Client) WithClock(ctx context.Context, name string, start time.Time) (context.Context, error) {
if clockFromContext(ctx) != "" {
return nil, errors.New("tier: clock already set in context")
}
clock, err := fetchOK[apitypes.ClockResponse, *apitypes.Error](ctx, c, "POST", "/v1/clock", apitypes.ClockRequest{
Name: name,
Present: start,
})
if err != nil {
return nil, err
}
return WithClock(ctx, clock.ID), nil
}

// Advance advances the test clock set in the context to t.
//
// It is an error to call Advance if no clock is set in the context.
func (c *Client) Advance(ctx context.Context, t time.Time) error {
clockID := clockFromContext(ctx)
if clockID == "" {
return errors.New("tier: no clock set in context")
}
cr, err := fetchOK[apitypes.ClockResponse, *apitypes.Error](ctx, c, "POST", "/v1/clock", apitypes.ClockRequest{
ID: clockID,
Present: t,
})
if err != nil {
return err
}
return c.awaitClockReady(ctx, cr.ID)
}

var errForBackoff = errors.New("force backoff")

func (c *Client) awaitClockReady(ctx context.Context, id string) error {
bo := backoff.NewBackoff("tier", c.logf, 5*time.Second)
for {
cr, err := c.syncClock(ctx, id)
if err != nil {
return err
}
if cr.Status == "ready" {
return nil
}
c.logf("clock %s status = %q; waiting", id, cr.Status)
bo.BackOff(ctx, errForBackoff)
}
}

func (c *Client) syncClock(ctx context.Context, id string) (apitypes.ClockResponse, error) {
return fetchOK[apitypes.ClockResponse, *apitypes.Error](ctx, c, "GET", "/v1/clock?id="+id, nil)
}

func fetchOK[T any, E error](ctx context.Context, c *Client, method, path string, body any) (T, error) {
up := url.UserPassword(c.APIKey, "")
return fetch.OK[T, E](ctx, c.client(), method, c.baseURL(path), body, up)
clockID := clockFromContext(ctx)
var h http.Header
if clockID != "" {
h = http.Header{ClockHeader: []string{clockID}}
}
return fetch.OK[T, E](ctx, c.client(), method, c.baseURL(path), body, up, h)
}
25 changes: 25 additions & 0 deletions client/tier/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"log"
"net/http"
"time"

"tier.run/client/tier"
)
Expand Down Expand Up @@ -60,6 +61,30 @@ func ExampleClient_Can_report() {
fmt.Println(convert(readInput()))
}

func ExampleClient_WithClock_testClocks() {
c, err := tier.FromEnv()
if err != nil {
panic(err)
}

now := time.Now()

ctx, err := c.WithClock(context.Background(), "testName", now)
if err != nil {
panic(err)
}

// Use ctx with other Client methods

// This creates the customer and subscription using the clock.
_ = c.Subscribe(ctx, "org:example", "plan:free@0")

// Advance the clock by 24 hours, and then report usage.
_ = c.Advance(ctx, now.Add(24*time.Hour))

_ = c.Report(ctx, "org:example", "feature:bandwidth", 1000)
}

func orgFromSession(r *http.Request) string {
return "org:example"
}
Expand Down
1 change: 0 additions & 1 deletion control/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ type Tier struct {
type Client struct {
Logf func(format string, args ...any)
Stripe *stripe.Client
Clock string // a test clock name if any should be used
KeySource string // the source of the API key

cache memo
Expand Down
7 changes: 0 additions & 7 deletions control/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"sync"
"testing"
"time"

"golang.org/x/exp/slices"
"kr.dev/diff"
Expand All @@ -31,12 +30,6 @@ func newTestClient(t *testing.T) *Client {
}
}

func (c *Client) setClock(t *testing.T, now time.Time) *stroke.Clock {
clock := stroke.NewClock(t, c.Stripe, t.Name(), now)
c.Clock = clock.ID()
return clock
}

func TestRoundTrip(t *testing.T) {
tc := newTestClient(t)
ctx := context.Background()
Expand Down
Loading

0 comments on commit 7322395

Please sign in to comment.