diff --git a/api/api.go b/api/api.go index eefa826..af4a6a2 100644 --- a/api/api.go +++ b/api/api.go @@ -189,9 +189,10 @@ func (h *Handler) serveCheckout(w http.ResponseWriter, r *http.Request) error { return err } link, err := h.c.Checkout(r.Context(), cr.Org, cr.SuccessURL, &control.CheckoutParams{ - TrialDays: cr.TrialDays, - Features: fs, - CancelURL: cr.CancelURL, + TrialDays: cr.TrialDays, + Features: fs, + CancelURL: cr.CancelURL, + RequireBillingAddress: cr.RequireBillingAddress, }) if err != nil { return err diff --git a/api/apitypes/apitypes.go b/api/apitypes/apitypes.go index e675d95..ca9f6ea 100644 --- a/api/apitypes/apitypes.go +++ b/api/apitypes/apitypes.go @@ -47,11 +47,12 @@ type OrgInfo struct { } type CheckoutRequest struct { - Org string `json:"org"` - TrialDays int `json:"trial_days"` - Features []string `json:"features"` - SuccessURL string `json:"success_url"` - CancelURL string `json:"cancel_url"` + Org string `json:"org"` + TrialDays int `json:"trial_days"` + Features []string `json:"features"` + SuccessURL string `json:"success_url"` + CancelURL string `json:"cancel_url"` + RequireBillingAddress bool `json:"require_billing_address"` } type ScheduleRequest struct { diff --git a/client/tier/client.go b/client/tier/client.go index d5aed74..cf3f497 100644 --- a/client/tier/client.go +++ b/client/tier/client.go @@ -250,11 +250,12 @@ func (c *Client) Checkout(ctx context.Context, org string, successURL string, p p = &CheckoutParams{} } r := &apitypes.CheckoutRequest{ - Org: org, - SuccessURL: successURL, - CancelURL: p.CancelURL, - TrialDays: p.TrialDays, - Features: p.Features, + Org: org, + SuccessURL: successURL, + CancelURL: p.CancelURL, + TrialDays: p.TrialDays, + Features: p.Features, + RequireBillingAddress: p.RequireBillingAddress, } return fetchOK[*apitypes.CheckoutResponse, *apitypes.Error](ctx, c, "POST", "/v1/checkout", r) } @@ -263,9 +264,10 @@ type Phase = apitypes.Phase type OrgInfo = apitypes.OrgInfo type CheckoutParams struct { - TrialDays int - Features []string - CancelURL string + TrialDays int + Features []string + CancelURL string + RequireBillingAddress bool } type ScheduleParams struct { diff --git a/cmd/tier/tier.go b/cmd/tier/tier.go index 2d28d41..8181ece 100644 --- a/cmd/tier/tier.go +++ b/cmd/tier/tier.go @@ -277,6 +277,7 @@ func runTier(cmd string, args []string) (err error) { cancel := fs.Bool("cancel", false, "cancels the subscription") successURL := fs.String("checkout", "", "subscribe via Stripe checkout") cancelURL := fs.String("cancel_url", "", "sets the cancel URL for use with -checkout") + requireBillingAddress := fs.Bool("require_billing_address", false, "require billing address for use with --checkout") if err := fs.Parse(args); err != nil { return err } @@ -301,9 +302,10 @@ func runTier(cmd string, args []string) (err error) { useCheckout := *successURL != "" if useCheckout { cr, err := tc().Checkout(ctx, org, *successURL, &tier.CheckoutParams{ - TrialDays: *trial, - Features: refs, - CancelURL: *cancelURL, + TrialDays: *trial, + Features: refs, + CancelURL: *cancelURL, + RequireBillingAddress: *requireBillingAddress, }) if err != nil { return err diff --git a/control/schedule.go b/control/schedule.go index 5642bda..07a0aa0 100644 --- a/control/schedule.go +++ b/control/schedule.go @@ -419,9 +419,10 @@ func addPhases(ctx context.Context, c *Client, f *stripe.Form, update bool, name } type CheckoutParams struct { - TrialDays int - Features []Feature - CancelURL string + TrialDays int + Features []Feature + CancelURL string + RequireBillingAddress bool } func (c *Client) Checkout(ctx context.Context, org string, successURL string, p *CheckoutParams) (link string, err error) { @@ -446,6 +447,9 @@ func (c *Client) Checkout(ctx context.Context, org string, successURL string, p if p.CancelURL != "" { f.Set("cancel_url", p.CancelURL) } + if p.RequireBillingAddress { + f.Set("billing_address_collection", "required") + } if len(p.Features) == 0 { f.Set("mode", "setup") // TODO: support other payment methods: diff --git a/control/schedule_test.go b/control/schedule_test.go index 5e41cea..c8f6946 100644 --- a/control/schedule_test.go +++ b/control/schedule_test.go @@ -2,6 +2,7 @@ package control import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -9,6 +10,9 @@ import ( "net/http/httptest" "os" "regexp" + "strconv" + "strings" + "sync" "testing" "time" @@ -695,6 +699,95 @@ func TestLookupPhases(t *testing.T) { diff.Test(t, t.Errorf, got, want, ignoreProviderIDs) } +func TestCheckoutRequiredAddress(t *testing.T) { + type G struct { + successURL string + cancelURL string + bac string // billing_address_collection + trialDays string + } + + var mu sync.Mutex + var got []G + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case wants(r, "GET", "/v1/customers"): + jsonEncode(t, w, msa{ + "data": []msa{ + { + "metadata": msa{ + "tier.org": "org:demo", + }, + }, + }, + }) + case wants(r, "POST", "/v1/checkout/sessions"): + mu.Lock() + got = append(got, G{ + successURL: r.FormValue("success_url"), + cancelURL: r.FormValue("cancel_url"), + bac: r.FormValue("billing_address_collection"), + trialDays: r.FormValue("subscription_data[trial_period_days]"), + }) + mu.Unlock() + jsonEncode(t, w, msa{ + "URL": "http://co.com/123", + }) + default: + t.Errorf("UNEXPECTED: %s %s", r.Method, r.URL.Path) + } + }) + + s := httptest.NewServer(h) + t.Cleanup(s.Close) + + cc := &Client{ + Logf: t.Logf, + Stripe: &stripe.Client{ + BaseURL: s.URL, + }, + } + + TF := []bool{true, false} + for _, withAddress := range TF { + for _, withFeatures := range TF { + for _, withCancel := range TF { + for _, withTrial := range TF { + got = nil + + var ( + bac = setif(withAddress, "required") + cancelURL = setif(withCancel, "https://c.com") + features = setif(withFeatures, []Feature{{}}) + trialDays = setif(withTrial, 14) + ) + + link, err := cc.Checkout(context.Background(), "org:demo", "http://s.com", &CheckoutParams{ + Features: features, + RequireBillingAddress: withAddress, + CancelURL: cancelURL, + TrialDays: trialDays, + }) + if err != nil { + t.Fatal(err) + } + + if want := "http://co.com/123"; link != want { + t.Errorf("link = %q; want %q", link, want) + } + + diff.Test(t, t.Errorf, got, []G{{ + successURL: "http://s.com", + cancelURL: cancelURL, + bac: bac, + trialDays: setif(withTrial && withFeatures, strconv.Itoa(trialDays)), + }}) + } + } + } + } +} + func TestLookupPhasesNoSchedule(t *testing.T) { // TODO(bmizerany): This tests assumptions, but we need an integration // test provin fields "like" trial actually fall off / go to zero when @@ -1124,3 +1217,38 @@ func writeHuJSON(w io.Writer, s string, args ...any) { panic(err) } } + +type msa map[string]any + +func jsonEncode(t *testing.T, w io.Writer, v any) { + t.Helper() + if err := json.NewEncoder(w).Encode(v); err != nil { + t.Error(err) + } +} + +func cloneBody(t *testing.T, r io.Reader) io.ReadCloser { + t.Helper() + var b strings.Builder + _, err := io.Copy(&b, r) + if err != nil { + t.Fatal(err) + } + return io.NopCloser(strings.NewReader(b.String())) +} + +func newRequest(method, target string, body io.Reader, f func(*http.Request)) *http.Request { + r := httptest.NewRequest(method, target, body) + if f != nil { + f(r) + } + return r +} + +func setif[T any](cond bool, v T) T { + if cond { + return v + } + var zero T + return zero +}