diff --git a/api/api.go b/api/api.go index 62977ff..3f984c1 100644 --- a/api/api.go +++ b/api/api.go @@ -12,6 +12,7 @@ import ( "golang.org/x/exp/slices" "tier.run/api/apitypes" "tier.run/api/materialize" + "tier.run/client/tier" "tier.run/control" "tier.run/refs" "tier.run/stripe" @@ -163,6 +164,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error { + clockID := r.Header.Get(tier.ClockHeader) + r = r.Clone(control.WithClock(r.Context(), clockID)) + switch r.URL.Path { case "/v1/whoami": return h.serveWhoAmI(w, r) @@ -184,6 +188,8 @@ func (h *Handler) serve(w http.ResponseWriter, r *http.Request) error { return h.servePush(w, r) case "/v1/payment_methods": return h.servePaymentMethods(w, r) + case "/v1/clock": + return h.serveClock(w, r) default: return trweb.NotFound } @@ -287,6 +293,7 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) error { Name: info.Name, Description: info.Description, Phone: info.Phone, + Created: info.CreatedAt(), Metadata: info.Metadata, PaymentMethod: info.PaymentMethod, InvoiceSettings: apitypes.InvoiceSettings(info.InvoiceSettings), @@ -412,6 +419,48 @@ func (h *Handler) servePaymentMethods(w http.ResponseWriter, r *http.Request) er }) } +func (h *Handler) serveClock(w http.ResponseWriter, r *http.Request) error { + writeResp := func(c *control.Clock) error { + return httpJSON(w, apitypes.ClockResponse{ + ID: c.ID(), + Link: c.Link(), + Present: c.Present(), + Status: c.Status(), + }) + } + + switch r.Method { + case "GET": + clockID := r.FormValue("id") + c := h.c.ClockFromID(clockID) + if err := c.Sync(r.Context()); err != nil { + return err + } + return writeResp(c) + case "POST": + var v apitypes.ClockRequest + if err := trweb.DecodeStrict(r, &v); err != nil { + return err + } + + if v.ID == "" { + c, err := h.c.NewClock(r.Context(), v.Name, v.Present) + if err != nil { + return err + } + return writeResp(c) + } else { + c := h.c.ClockFromID(v.ID) + if err := c.Advance(r.Context(), v.Present); err != nil { + return err + } + return writeResp(c) + } + default: + return trweb.MethodNotAllowed + } +} + func httpJSON(w http.ResponseWriter, v any) error { w.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(w) diff --git a/api/api_test.go b/api/api_test.go index 1292ec3..d00d492 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -36,6 +36,7 @@ func newTestClient(t *testing.T) (*tier.Client, *control.Client) { s := httptest.NewTLSServer(h) t.Cleanup(s.Close) tc := &tier.Client{ + Logf: t.Logf, BaseURL: s.URL, HTTPClient: s.Client(), } @@ -135,7 +136,6 @@ func TestAPISubscribe(t *testing.T) { defer maybeFailNow(t) fn := mpn(feature) err := tc.ReportUsage(ctx, org, fn.String(), n, &tier.ReportParams{ - At: time.Now().Add(1 * time.Minute), Clobber: false, }) diff.Test(t, t.Errorf, err, wantErr) @@ -465,15 +465,46 @@ func TestWhoAmI(t *testing.T) { } } +func TestClock(t *testing.T) { + t.Parallel() + + tc, _ := newTestClient(t) + + now := time.Now().Truncate(time.Second) + ctx, err := tc.WithClock(context.Background(), t.Name(), now) + if err != nil { + t.Fatal(err) + } + + now = now.Add(100 * time.Hour) + if err := tc.Advance(ctx, now); err != nil { + t.Fatal(err) + } + + if err := tc.Subscribe(ctx, "org:test"); err != nil { + t.Fatal(err) + } + + c, err := tc.LookupOrg(ctx, "org:test") + if err != nil { + t.Fatal(err) + } + + if !c.Created.Equal(now) { + t.Errorf("c.Created = %v; want %v", c.Created, now) + } +} + func TestTierReport(t *testing.T) { t.Parallel() - ctx := context.Background() - tc, cc := newTestClient(t) + tc, _ := newTestClient(t) - farIntoTheFuture := time.Now().Add(24 * time.Hour) - clock := stroke.NewClock(t, cc.Stripe, t.Name(), farIntoTheFuture) - cc.Clock = clock.ID() + now := time.Now() + ctx, err := tc.WithClock(context.Background(), t.Name(), now) + if err != nil { + t.Fatal(err) + } pr, err := tc.PushJSON(ctx, []byte(` { @@ -503,9 +534,10 @@ func TestTierReport(t *testing.T) { } report := func(n int, at time.Time, wantErr error) { + t.Helper() if err := tc.ReportUsage(ctx, "org:test", "feature:t", n, &tier.ReportParams{ // Force 'now' at Stripe. - At: time.Time{}, + At: at, Clobber: false, }); !errors.Is(err, wantErr) { @@ -513,8 +545,8 @@ func TestTierReport(t *testing.T) { } } - report(10, time.Time{}, nil) - report(10, clock.Now().Add(1*time.Minute), nil) + report(10, now, nil) + report(10, now.Add(1*time.Minute), nil) limit, used, err := tc.LookupLimit(ctx, "org:test", "feature:t") if err != nil { diff --git a/api/apitypes/apitypes.go b/api/apitypes/apitypes.go index 75cc3b8..f5f650d 100644 --- a/api/apitypes/apitypes.go +++ b/api/apitypes/apitypes.go @@ -46,6 +46,7 @@ type OrgInfo struct { Name string `json:"name"` Description string `json:"description"` Phone string `json:"phone"` + Created time.Time `json:"created"` Metadata map[string]string `json:"metadata"` PaymentMethod string `json:"payment_method"` @@ -123,3 +124,16 @@ type WhoAmIResponse struct { Isolated bool `json:"isolated"` URL string `json:"url"` } + +type ClockRequest struct { + ID string + Name string + Present time.Time +} + +type ClockResponse struct { + ID string `json:"id"` + Link string `json:"link"` + Present time.Time `json:"present"` + Status string `json:"status"` +} diff --git a/client/tier/client.go b/client/tier/client.go index 30be6e6..418189c 100644 --- a/client/tier/client.go +++ b/client/tier/client.go @@ -7,16 +7,22 @@ package tier import ( "context" "encoding/json" + "errors" "net/http" "net/url" "os" "time" + "tailscale.com/logtail/backoff" "tier.run/api/apitypes" "tier.run/fetch" "tier.run/refs" ) +// ClockHeader is the header used to pass the clock ID to the tier sidecar. +// It is exported for use by the sidecar API. Most users want to use WithClock. +const ClockHeader = "Tier-Clock" + const Inf = 1<<63 - 1 type Client struct { @@ -25,6 +31,27 @@ type Client struct { BaseURL string // the base URL of the tier sidecar; default is http://127.0.0.1:8080 HTTPClient *http.Client + + Logf func(fmt string, args ...any) +} + +func (c *Client) logf(fmt string, args ...any) { + if c.Logf != nil { + c.Logf(fmt, args...) + } +} + +type clockKey struct{} + +// WithClock returns a context with the provided clock ID set. The clock ID is +// pass via the Tier-Clock header to be used by the tier sidecar. +func WithClock(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, clockKey{}, id) +} + +func clockFromContext(ctx context.Context) string { + id, _ := ctx.Value(clockKey{}).(string) + return id } // FromEnv returns a Client configured from the environment. The BaseURL is set @@ -80,7 +107,7 @@ func (c *Client) PullJSON(ctx context.Context) ([]byte, error) { return fetchOK[[]byte, *apitypes.Error](ctx, c, "GET", "/v1/pull", nil) } -// WhoIS reports the Stripe ID for the given organization. +// WhoIs reports the Stripe customer ID for the provided org. OrgInfo is not set. func (c *Client) WhoIs(ctx context.Context, org string) (apitypes.WhoIsResponse, error) { return fetchOK[apitypes.WhoIsResponse, *apitypes.Error](ctx, c, "GET", "/v1/whois?org="+org, nil) } @@ -299,7 +326,69 @@ func (c *Client) WhoAmI(ctx context.Context) (apitypes.WhoAmIResponse, error) { return fetchOK[apitypes.WhoAmIResponse, *apitypes.Error](ctx, c, "GET", "/v1/whoami", nil) } +// WithClock creates a new test clock with the provided name and start time, +// and returns a new context with the clock ID set. +// +// It is an error to call WithClock if a clock is already set in the context. +func (c *Client) WithClock(ctx context.Context, name string, start time.Time) (context.Context, error) { + if clockFromContext(ctx) != "" { + return nil, errors.New("tier: clock already set in context") + } + clock, err := fetchOK[apitypes.ClockResponse, *apitypes.Error](ctx, c, "POST", "/v1/clock", apitypes.ClockRequest{ + Name: name, + Present: start, + }) + if err != nil { + return nil, err + } + return WithClock(ctx, clock.ID), nil +} + +// Advance advances the test clock set in the context to t. +// +// It is an error to call Advance if no clock is set in the context. +func (c *Client) Advance(ctx context.Context, t time.Time) error { + clockID := clockFromContext(ctx) + if clockID == "" { + return errors.New("tier: no clock set in context") + } + cr, err := fetchOK[apitypes.ClockResponse, *apitypes.Error](ctx, c, "POST", "/v1/clock", apitypes.ClockRequest{ + ID: clockID, + Present: t, + }) + if err != nil { + return err + } + return c.awaitClockReady(ctx, cr.ID) +} + +var errForBackoff = errors.New("force backoff") + +func (c *Client) awaitClockReady(ctx context.Context, id string) error { + bo := backoff.NewBackoff("tier", c.logf, 5*time.Second) + for { + cr, err := c.syncClock(ctx, id) + if err != nil { + return err + } + if cr.Status == "ready" { + return nil + } + c.logf("clock %s status = %q; waiting", id, cr.Status) + bo.BackOff(ctx, errForBackoff) + } +} + +func (c *Client) syncClock(ctx context.Context, id string) (apitypes.ClockResponse, error) { + return fetchOK[apitypes.ClockResponse, *apitypes.Error](ctx, c, "GET", "/v1/clock?id="+id, nil) +} + func fetchOK[T any, E error](ctx context.Context, c *Client, method, path string, body any) (T, error) { up := url.UserPassword(c.APIKey, "") - return fetch.OK[T, E](ctx, c.client(), method, c.baseURL(path), body, up) + clockID := clockFromContext(ctx) + var h http.Header + if clockID != "" { + h = http.Header{ClockHeader: []string{clockID}} + } + return fetch.OK[T, E](ctx, c.client(), method, c.baseURL(path), body, up, h) } diff --git a/client/tier/example_test.go b/client/tier/example_test.go index 68d9461..37c83b0 100644 --- a/client/tier/example_test.go +++ b/client/tier/example_test.go @@ -6,6 +6,7 @@ import ( "io" "log" "net/http" + "time" "tier.run/client/tier" ) @@ -60,6 +61,30 @@ func ExampleClient_Can_report() { fmt.Println(convert(readInput())) } +func ExampleClient_WithClock_testClocks() { + c, err := tier.FromEnv() + if err != nil { + panic(err) + } + + now := time.Now() + + ctx, err := c.WithClock(context.Background(), "testName", now) + if err != nil { + panic(err) + } + + // Use ctx with other Client methods + + // This creates the customer and subscription using the clock. + _ = c.Subscribe(ctx, "org:example", "plan:free@0") + + // Advance the clock by 24 hours, and then report usage. + _ = c.Advance(ctx, now.Add(24*time.Hour)) + + _ = c.Report(ctx, "org:example", "feature:bandwidth", 1000) +} + func orgFromSession(r *http.Request) string { return "org:example" } diff --git a/control/client.go b/control/client.go index 54e9caf..878c0ca 100644 --- a/control/client.go +++ b/control/client.go @@ -138,7 +138,6 @@ type Tier struct { type Client struct { Logf func(format string, args ...any) Stripe *stripe.Client - Clock string // a test clock name if any should be used KeySource string // the source of the API key cache memo diff --git a/control/client_test.go b/control/client_test.go index df18078..5c253c5 100644 --- a/control/client_test.go +++ b/control/client_test.go @@ -5,7 +5,6 @@ import ( "errors" "sync" "testing" - "time" "golang.org/x/exp/slices" "kr.dev/diff" @@ -31,12 +30,6 @@ func newTestClient(t *testing.T) *Client { } } -func (c *Client) setClock(t *testing.T, now time.Time) *stroke.Clock { - clock := stroke.NewClock(t, c.Stripe, t.Name(), now) - c.Clock = clock.ID() - return clock -} - func TestRoundTrip(t *testing.T) { tc := newTestClient(t) ctx := context.Background() diff --git a/control/clock.go b/control/clock.go new file mode 100644 index 0000000..47c4068 --- /dev/null +++ b/control/clock.go @@ -0,0 +1,112 @@ +package control + +import ( + "context" + "errors" + "net/url" + "time" + + "tailscale.com/logtail/backoff" + "tier.run/stripe" +) + +type stripeClock struct { + ID string + Status string + Present int64 `json:"frozen_time"` +} + +type Clock struct { + id string + present time.Time + status string + + sc *stripe.Client + logf func(format string, args ...any) +} + +// ClockFromID returns a Clock for the given clock ID. It does not check that +// the clock exists or what its status is. If the status or present time are +// needed, clients should call Sync. +func (c *Client) ClockFromID(id string) *Clock { + return &Clock{ + id: id, + sc: c.Stripe, + } +} + +// NewClock creates a new clock in the Stripe account associated with the +// client and returns a Clock ready to use. +func (c *Client) NewClock(ctx context.Context, name string, start time.Time) (*Clock, error) { + var f stripe.Form + f.Set("name", name) + f.Set("frozen_time", start.Truncate(time.Second)) + var v stripeClock + if err := c.Stripe.Do(ctx, "POST", "/v1/test_helpers/test_clocks", f, &v); err != nil { + return nil, err + } + return &Clock{ + id: v.ID, + present: time.Unix(v.Present, 0), + sc: c.Stripe, + logf: c.Logf, + }, nil +} + +func (c *Clock) ID() string { return c.id } +func (c *Clock) Present() time.Time { return c.present } +func (c *Clock) Status() string { return c.status } + +// Link returns a link to the clock in the Stripe dashboard. +func (c *Clock) Link() string { + dashURL, err := url.JoinPath("https://dashboard.stripe.com", c.sc.AccountID, "/test/test-clocks", c.ID()) + if err != nil { + panic(err) // should never happen + } + return dashURL +} + +// Advance advances the clock to the given time. +func (c *Clock) Advance(ctx context.Context, to time.Time) error { + var f stripe.Form + f.Set("frozen_time", to.Truncate(time.Second)) + var v stripeClock + if err := c.sc.Do(ctx, "POST", "/v1/test_helpers/test_clocks/"+c.ID()+"/advance", f, &v); err != nil { + return err + } + c.present = time.Unix(v.Present, 0) + return nil +} + +var errForceBackoff = errors.New("force backoff") + +// Wait waits for the clock to be ready, or until the context is canceled. It +// returns an error if any. +func (c *Clock) Wait(ctx context.Context) error { + bo := backoff.NewBackoff("stroke: clock: advance backoff", c.vlogf, 5*time.Second) + for { + c.Sync(ctx) + c.vlogf("clock %q is %q", c.ID(), c.Status()) + if c.Status() == "ready" { + return nil + } + bo.BackOff(context.Background(), errForceBackoff) + } +} + +func (c *Clock) Sync(ctx context.Context) error { + var f stripe.Form + var v stripeClock + if err := c.sc.Do(ctx, "GET", "/v1/test_helpers/test_clocks/"+c.ID(), f, &v); err != nil { + return err + } + c.present = time.Unix(v.Present, 0) + c.status = v.Status + return nil +} + +func (c *Clock) vlogf(format string, args ...any) { + if c.logf != nil { + c.logf(format, args...) + } +} diff --git a/control/schedule.go b/control/schedule.go index e218342..491012c 100644 --- a/control/schedule.go +++ b/control/schedule.go @@ -16,6 +16,19 @@ import ( "tier.run/values" ) +type clockKey struct{} + +func WithClock(ctx context.Context, clockID string) context.Context { + return context.WithValue(ctx, clockKey{}, clockID) +} + +func clockFromContext(ctx context.Context) string { + if v := ctx.Value(clockKey{}); v != nil { + return v.(string) + } + return "" +} + const defaultScheduleName = "default" // Errors @@ -50,11 +63,16 @@ type OrgInfo struct { Description string Phone string Metadata map[string]string + Created int PaymentMethod string InvoiceSettings InvoiceSettings } +func (oi *OrgInfo) CreatedAt() time.Time { + return time.Unix(int64(oi.Created), 0) +} + type Phase struct { Org string // set on read Effective time.Time @@ -923,8 +941,8 @@ func (c *Client) createCustomer(ctx context.Context, org string, info *OrgInfo) if err := setOrgInfo(&f, info); err != nil { return "", err } - if c.Clock != "" { - f.Set("test_clock", c.Clock) + if clockID := clockFromContext(ctx); clockID != "" { + f.Set("test_clock", clockID) } var created struct { stripe.ID diff --git a/control/schedule_test.go b/control/schedule_test.go index 5c70622..9be6a96 100644 --- a/control/schedule_test.go +++ b/control/schedule_test.go @@ -23,7 +23,6 @@ import ( "kr.dev/errorfmt" "tier.run/refs" "tier.run/stripe" - "tier.run/stripe/stroke" "tier.run/values" ) @@ -48,8 +47,7 @@ var ignoreProviderIDs = diff.OptionList( func TestSchedule(t *testing.T) { ciOnly(t) - c := newTestClient(t) - ctx := context.Background() + s := newScheduleTester(t) var model []Feature plan := func(ff []Feature) []refs.FeaturePlan { @@ -74,20 +72,11 @@ func TestSchedule(t *testing.T) { Currency: "usd", }}) - c.Push(ctx, model, pushLogger(t)) - - sub := func(org string, fs []refs.FeaturePlan) { - t.Helper() - t.Logf("subscribing %s to %# v", org, pretty.Formatter(fs)) - - if err := c.SubscribeTo(ctx, org, fs); err != nil { - t.Fatalf("%# v", pretty.Formatter(err)) - } - } + s.push(model) check := func(org string, want []Phase) { t.Helper() - got, err := c.LookupPhases(ctx, org) + got, err := s.cc.LookupPhases(s.ctx, org) if err != nil { t.Fatal(err) } @@ -95,8 +84,7 @@ func TestSchedule(t *testing.T) { diff.Test(t, t.Errorf, got, want, ignoreProviderIDs) } - clock := c.setClock(t, t0) - sub("org:example", planFree) + s.schedule("org:example", 0, "", planFree...) check("org:example", []Phase{{ Org: "org:example", Current: true, @@ -105,8 +93,8 @@ func TestSchedule(t *testing.T) { Plans: plans("plan:free@0"), }}) - clock.Advance(t1) - sub("org:example", planPro) + s.advanceTo(t1) + s.schedule("org:example", 0, "", planPro...) check("org:example", []Phase{ { Org: "org:example", @@ -118,7 +106,7 @@ func TestSchedule(t *testing.T) { }) // downgrade and check no new phases - sub("org:example", planFree) + s.schedule("org:example", 0, "", planFree...) check("org:example", []Phase{ { Org: "org:example", @@ -131,38 +119,52 @@ func TestSchedule(t *testing.T) { } type scheduleTester struct { + ctx context.Context t *testing.T cc *Client - clock *stroke.Clock + clock *Clock } func newScheduleTester(t *testing.T) *scheduleTester { t.Helper() c := newTestClient(t) - clock := c.setClock(t, t0) - return &scheduleTester{t: t, cc: c, clock: clock} + clock, err := c.NewClock(context.Background(), t.Name(), t0) + if err != nil { + t.Fatal(err) + } + ctx := WithClock(context.Background(), clock.ID()) + return &scheduleTester{ctx: ctx, t: t, cc: c, clock: clock} } func (s *scheduleTester) push(model []Feature) { s.t.Helper() - s.cc.Push(context.Background(), model, pushLogger(s.t)) + s.cc.Push(s.ctx, model, pushLogger(s.t)) if s.t.Failed() { s.t.FailNow() } } func (s *scheduleTester) advance(days int) { - s.clock.Advance(s.clock.Now().AddDate(0, 0, days)) + s.advanceTo(s.clock.Present().AddDate(0, 0, days)) +} + +func (s *scheduleTester) advanceTo(t time.Time) { + if err := s.clock.Advance(s.ctx, t); err != nil { + s.t.Fatal(err) + } + if err := s.clock.Wait(s.ctx); err != nil { + s.t.Fatal(err) + } } func (s *scheduleTester) advanceToNextPeriod() { // TODO(bmizerany): make Phase aware so that it jumps based on the // start of the next phase if the current phase ends sooner than than 1 // interval of the current phase. - now := s.clock.Now() + now := s.clock.Present() eop := time.Date(now.Year(), now.Month()+1, 1, 0, 0, 0, 0, time.UTC) s.t.Logf("advancing to next period %s", eop) - s.clock.Advance(eop) + s.advanceTo(eop) } func (s *scheduleTester) cancel(org string) { @@ -174,7 +176,7 @@ func (s *scheduleTester) cancel(org string) { func (s *scheduleTester) setPaymentMethod(org string, pm string) { s.t.Helper() s.t.Logf("setting payment method for %s to %s", org, pm) - if err := s.cc.PutCustomer(context.Background(), org, &OrgInfo{ + if err := s.cc.PutCustomer(s.ctx, org, &OrgInfo{ PaymentMethod: pm, InvoiceSettings: InvoiceSettings{ DefaultPaymentMethod: pm, @@ -209,14 +211,14 @@ func (s *scheduleTester) schedule(org string, trialDays int, payment string, fs PaymentMethod: payment, Phases: ps, } - if err := s.cc.Schedule(context.Background(), org, p); err != nil { + if err := s.cc.Schedule(s.ctx, org, p); err != nil { s.t.Fatalf("error subscribing: %v", err) } } func (s *scheduleTester) report(org, name string, n int) { s.t.Helper() - if err := s.cc.ReportUsage(context.Background(), org, mpn(name), Report{ + if err := s.cc.ReportUsage(s.ctx, org, mpn(name), Report{ N: n, }); err != nil { s.t.Fatal(err) @@ -228,10 +230,13 @@ func (s *scheduleTester) report(org, name string, n int) { //lint:ignore U1000 saving for a rainy day func (s *scheduleTester) checkLimits(org string, want []Usage) { s.t.Helper() - got, err := s.cc.LookupLimits(context.Background(), org) + got, err := s.cc.LookupLimits(s.ctx, org) if err != nil { s.t.Fatal(err) } + slices.SortFunc(got, func(a, b Usage) bool { + return refs.ByName(a.Feature, b.Feature) + }) s.diff(got, want, diff.ZeroFields[Usage]("Start", "End", "Feature")) if s.t.Failed() { s.t.FailNow() @@ -241,7 +246,7 @@ func (s *scheduleTester) checkLimits(org string, want []Usage) { // ignores period dates func (s *scheduleTester) checkInvoices(org string, want []Invoice) { s.t.Helper() - got, err := s.cc.LookupInvoices(context.Background(), org) + got, err := s.cc.LookupInvoices(s.ctx, org) if err != nil { s.t.Fatal(err) } @@ -360,7 +365,7 @@ func TestSchedule_TrialSwapWithPaid(t *testing.T) { for i, step := range steps { t.Logf("step: %+v", step) s.schedule("org:test", step.trialDays, "", featureX) - status, err := s.cc.LookupStatus(context.Background(), "org:test") + status, err := s.cc.LookupStatus(s.ctx, "org:test") if err != nil { t.Fatalf("[%d]: unexpected error: %v", i, err) } @@ -623,8 +628,7 @@ func TestSchedulePaymentMethod(t *testing.T) { } func TestLookupPhasesWithTiersRoundTrip(t *testing.T) { - c := newTestClient(t) - ctx := context.Background() + s := newScheduleTester(t) fs := []Feature{ { @@ -656,13 +660,10 @@ func TestLookupPhasesWithTiersRoundTrip(t *testing.T) { fps[i] = f.FeaturePlan } - c.setClock(t, t0) - c.Push(ctx, fs, pushLogger(t)) - if err := c.SubscribeTo(ctx, "org:example", fps); err != nil { - t.Fatal(err) - } + s.push(fs) + s.schedule("org:example", 0, "", fps...) - got, err := c.LookupPhases(ctx, "org:example") + got, err := s.cc.LookupPhases(s.ctx, "org:example") if err != nil { t.Fatal(err) } @@ -692,20 +693,17 @@ func TestSubscribeToPlan(t *testing.T) { Currency: "usd", }} - ctx := context.Background() - tc := newTestClient(t) - tc.Push(ctx, fs, pushLogger(t)) - tc.setClock(t, t0) + s := newScheduleTester(t) + s.push(fs) efs, err := Expand(fs, "plan:pro@0") if err != nil { t.Fatal(err) } - if err := tc.SubscribeTo(ctx, "org:example", efs); err != nil { - t.Fatal(err) - } - got, err := tc.LookupPhases(ctx, "org:example") + s.schedule("org:example", 0, "", efs...) + + got, err := s.cc.LookupPhases(s.ctx, "org:example") if err != nil { t.Fatal(err) } @@ -764,17 +762,11 @@ func TestLookupPhases(t *testing.T) { }, } - tc := newTestClient(t) - ctx := context.Background() - tc.Push(ctx, fs0, pushLogger(t)) - - tc.setClock(t, t0) - - if err := tc.SubscribeTo(ctx, "org:example", FeaturePlans(fs0)); err != nil { - t.Fatal(err) - } + s := newScheduleTester(t) + s.push(fs0) + s.schedule("org:example", 0, "", FeaturePlans(fs0)...) - got, err := tc.LookupPhases(ctx, "org:example") + got, err := s.cc.LookupPhases(s.ctx, "org:example") if err != nil { t.Fatal(err) } @@ -800,14 +792,12 @@ func TestLookupPhases(t *testing.T) { Currency: "usd", }, } - tc.Push(ctx, fs1, pushLogger(t)) + s.push(fs1) fpsFrag := FeaturePlans(append(fs0, fs1[1:]...)) - if err := tc.SubscribeTo(ctx, "org:example", fpsFrag); err != nil { - t.Fatal(err) - } + s.schedule("org:example", 0, "", fpsFrag...) - got, err = tc.LookupPhases(ctx, "org:example") + got, err = s.cc.LookupPhases(s.ctx, "org:example") if err != nil { t.Fatal(err) } @@ -818,8 +808,6 @@ func TestLookupPhases(t *testing.T) { got[i] = p } - t.Logf("got: %# v", pretty.Formatter(got)) - want = []Phase{{ Org: "org:example", Current: true, @@ -1131,16 +1119,12 @@ func TestReportUsage(t *testing.T) { }, } - tc := newTestClient(t) - ctx := context.Background() - tc.Push(ctx, fs, pushLogger(t)) - tc.setClock(t, t0) + s := newScheduleTester(t) + s.push(fs) - if err := tc.SubscribeTo(ctx, "org:example", FeaturePlans(fs)); err != nil { - t.Fatal(err) - } + s.schedule("org:example", 0, "", FeaturePlans(fs)...) - g, groupCtx := errgroup.WithContext(ctx) + g, groupCtx := errgroup.WithContext(s.ctx) report := func(feature string, n int) { fn, err := refs.ParseName(feature) if err != nil { @@ -1148,9 +1132,8 @@ func TestReportUsage(t *testing.T) { } g.Go(func() (err error) { defer errorfmt.Handlef("%s: %w", feature, &err) - return tc.ReportUsage(groupCtx, "org:example", fn, Report{ - N: n, - At: t0, + return s.cc.ReportUsage(groupCtx, "org:example", fn, Report{ + N: n, }) }) } @@ -1161,22 +1144,11 @@ func TestReportUsage(t *testing.T) { t.Fatal(err) } - got, err := tc.LookupLimits(ctx, "org:example") - if err != nil { - t.Fatal(err) - } - - slices.SortFunc(got, func(a, b Usage) bool { - return refs.ByName(a.Feature, b.Feature) - }) - - want := []Usage{ + s.checkLimits("org:example", []Usage{ {Feature: mpf("feature:10@plan:test@0"), Start: t0, End: endOfStripeMonth(t0), Used: 3, Limit: 10}, {Feature: mpf("feature:inf@plan:test@0"), Start: t0, End: endOfStripeMonth(t0), Used: 9, Limit: Inf}, {Feature: mpf("feature:lic@plan:test@0"), Start: t1, End: t2, Used: 1, Limit: Inf}, - } - - diff.Test(t, t.Errorf, got, want) + }) } func TestReportUsageFeatureNotFound(t *testing.T) { @@ -1257,7 +1229,7 @@ func TestSchedulePutCustomer(t *testing.T) { want.Metadata = map[string]string{} } } - diff.Test(t, t.Errorf, got, want) + diff.Test(t, t.Errorf, got, want, diff.ZeroFields[OrgInfo]("Created")) } check("org:invalid", &o{Email: "invalid"}, nil, ErrInvalidEmail, ErrOrgNotFound) diff --git a/control/usage.go b/control/usage.go index 89a6917..ee4139f 100644 --- a/control/usage.go +++ b/control/usage.go @@ -10,7 +10,6 @@ import ( "golang.org/x/exp/maps" "kr.dev/errorfmt" - "tailscale.com/logtail/backoff" "tier.run/refs" "tier.run/stripe" ) @@ -49,24 +48,7 @@ func (c *Client) ReportUsage(ctx context.Context, org string, feature refs.Name, f.SetIdempotencyKey(randomString()) - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - - // TODO(bmizerany): use Dedup here - bo := backoff.NewBackoff("ReportUsage", c.Logf, 3*time.Second) - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - err := c.Stripe.Do(ctx, "POST", "/v1/subscription_items/"+itemID+"/usage_records", f, nil) - c.Logf("ReportUsage: %v", err) - bo.BackOff(ctx, err) - if err == nil { - return nil - } - } + return c.Stripe.Do(ctx, "POST", "/v1/subscription_items/"+itemID+"/usage_records", f, nil) } func (c *Client) LookupLimits(ctx context.Context, org string) ([]Usage, error) { diff --git a/stripe/stroke/clock_test.go b/stripe/stroke/clock_test.go deleted file mode 100644 index 9258480..0000000 --- a/stripe/stroke/clock_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package stroke - -import ( - "testing" - "time" -) - -func TestClock(t *testing.T) { - c := Client(t) - now := time.Now() - clock := NewClock(t, c, "test", now) - - want := now.Truncate(time.Second) - if got := clock.Now(); !got.Equal(want) { - t.Fatalf("got %v, want %v", got, want) - } - - clock.Advance(now.Add(time.Hour)) - want = want.Add(time.Hour) - if got := clock.Now(); !got.Equal(want) { - t.Fatalf("got %v, want %v", got, want) - } -} diff --git a/stripe/stroke/stroke.go b/stripe/stroke/stroke.go index 1fd6c8a..d8ccb41 100644 --- a/stripe/stroke/stroke.go +++ b/stripe/stroke/stroke.go @@ -6,8 +6,6 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" - "net/url" "strings" "testing" "time" @@ -75,95 +73,3 @@ func createAccount(c *stripe.Client, t testing.TB) (string, error) { } } - -type Clock struct { - id string - helper func() - advance func(time.Time) - - sync func() - now time.Time - status string - - logf func(string, ...any) - - dashURL string -} - -func NewClock(t *testing.T, c *stripe.Client, name string, start time.Time) *Clock { - type T struct { - ID string - Status string - Time int64 `json:"frozen_time"` - } - - ctx := context.Background() - - do := func(method, path string, f stripe.Form) (v T) { - t.Helper() - if err := c.Do(ctx, method, path, f, &v); err != nil { - t.Fatalf("error calling %s %s: %v", method, path, err) - } - return - } - - var f stripe.Form - f.Set("name", name) - f.Set("frozen_time", start) - v := do("POST", "/v1/test_helpers/test_clocks", f) - path := "/v1/test_helpers/test_clocks/" + v.ID - - // NOTE: There is no point in deleting clocks. Clients should use - // isolated accounts, which when deleted, delete all associated clocks - // and other objects. The API call to delete each clock would just be a - // waste of time. - - dashURL, err := url.JoinPath("https://dashboard.stripe.com", c.AccountID, "/test/test-clocks", v.ID) - if err != nil { - panic(err) // should never happen - } - - var cl *Clock - cl = &Clock{ - id: v.ID, - helper: t.Helper, - advance: func(now time.Time) { - var f stripe.Form - f.Set("frozen_time", now) - do("POST", path+"/advance", f) - }, - sync: func() { - v := do("GET", path, stripe.Form{}) - cl.logf("clock: sync: status=%s, time=%v", v.Status, time.Unix(v.Time, 0)) - cl.now = time.Unix(v.Time, 0).UTC() - cl.status = v.Status - }, - now: start.Truncate(time.Second), - logf: t.Logf, - dashURL: dashURL, - } - - return cl -} - -var errForceBackoff = errors.New("force backoff") - -// ID returns the ID of the clock. -func (c *Clock) ID() string { return c.id } -func (c *Clock) DashboardURL() string { return c.dashURL } - -func (c *Clock) Advance(t time.Time) { - c.helper() - c.advance(t) - bo := backoff.NewBackoff("stroke: clock: advance backoff", c.logf, 5*time.Second) - for { - c.sync() - if c.status == "ready" { - return - } - bo.BackOff(context.Background(), errForceBackoff) - } -} - -// Now retrieves the current time for the clock from Stripe and returns it. -func (c *Clock) Now() time.Time { return c.now }