Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
api: accepted require_billing_address in /v1/checkout
Browse files Browse the repository at this point in the history
This also adds test coverage for Checkout.

Fixes #252
  • Loading branch information
bmizerany committed Feb 19, 2023
1 parent ef7129f commit 55725f1
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 22 deletions.
7 changes: 4 additions & 3 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ func (h *Handler) serveCheckout(w http.ResponseWriter, r *http.Request) error {
return err
}
link, err := h.c.Checkout(r.Context(), cr.Org, cr.SuccessURL, &control.CheckoutParams{
TrialDays: cr.TrialDays,
Features: fs,
CancelURL: cr.CancelURL,
TrialDays: cr.TrialDays,
Features: fs,
CancelURL: cr.CancelURL,
RequireBillingAddress: cr.RequireBillingAddress,
})
if err != nil {
return err
Expand Down
11 changes: 6 additions & 5 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ type OrgInfo struct {
}

type CheckoutRequest struct {
Org string `json:"org"`
TrialDays int `json:"trial_days"`
Features []string `json:"features"`
SuccessURL string `json:"success_url"`
CancelURL string `json:"cancel_url"`
Org string `json:"org"`
TrialDays int `json:"trial_days"`
Features []string `json:"features"`
SuccessURL string `json:"success_url"`
CancelURL string `json:"cancel_url"`
RequireBillingAddress bool `json:"require_billing_address"`
}

type ScheduleRequest struct {
Expand Down
18 changes: 10 additions & 8 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,12 @@ func (c *Client) Checkout(ctx context.Context, org string, successURL string, p
p = &CheckoutParams{}
}
r := &apitypes.CheckoutRequest{
Org: org,
SuccessURL: successURL,
CancelURL: p.CancelURL,
TrialDays: p.TrialDays,
Features: p.Features,
Org: org,
SuccessURL: successURL,
CancelURL: p.CancelURL,
TrialDays: p.TrialDays,
Features: p.Features,
RequireBillingAddress: p.RequireBillingAddress,
}
return fetchOK[*apitypes.CheckoutResponse, *apitypes.Error](ctx, c, "POST", "/v1/checkout", r)
}
Expand All @@ -263,9 +264,10 @@ type Phase = apitypes.Phase
type OrgInfo = apitypes.OrgInfo

type CheckoutParams struct {
TrialDays int
Features []string
CancelURL string
TrialDays int
Features []string
CancelURL string
RequireBillingAddress bool
}

type ScheduleParams struct {
Expand Down
8 changes: 5 additions & 3 deletions cmd/tier/tier.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ func runTier(cmd string, args []string) (err error) {
cancel := fs.Bool("cancel", false, "cancels the subscription")
successURL := fs.String("checkout", "", "subscribe via Stripe checkout")
cancelURL := fs.String("cancel_url", "", "sets the cancel URL for use with -checkout")
requireBillingAddress := fs.Bool("require_billing_address", false, "require billing address for use with --checkout")
if err := fs.Parse(args); err != nil {
return err
}
Expand All @@ -301,9 +302,10 @@ func runTier(cmd string, args []string) (err error) {
useCheckout := *successURL != ""
if useCheckout {
cr, err := tc().Checkout(ctx, org, *successURL, &tier.CheckoutParams{
TrialDays: *trial,
Features: refs,
CancelURL: *cancelURL,
TrialDays: *trial,
Features: refs,
CancelURL: *cancelURL,
RequireBillingAddress: *requireBillingAddress,
})
if err != nil {
return err
Expand Down
10 changes: 7 additions & 3 deletions control/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,10 @@ func addPhases(ctx context.Context, c *Client, f *stripe.Form, update bool, name
}

type CheckoutParams struct {
TrialDays int
Features []Feature
CancelURL string
TrialDays int
Features []Feature
CancelURL string
RequireBillingAddress bool
}

func (c *Client) Checkout(ctx context.Context, org string, successURL string, p *CheckoutParams) (link string, err error) {
Expand All @@ -446,6 +447,9 @@ func (c *Client) Checkout(ctx context.Context, org string, successURL string, p
if p.CancelURL != "" {
f.Set("cancel_url", p.CancelURL)
}
if p.RequireBillingAddress {
f.Set("billing_address_collection", "required")
}
if len(p.Features) == 0 {
f.Set("mode", "setup")
// TODO: support other payment methods:
Expand Down
102 changes: 102 additions & 0 deletions control/schedule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package control

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strconv"
"sync"
"testing"
"time"

Expand All @@ -21,6 +24,7 @@ import (
"tier.run/refs"
"tier.run/stripe"
"tier.run/stripe/stroke"
"tier.run/values"
)

var (
Expand Down Expand Up @@ -695,6 +699,95 @@ func TestLookupPhases(t *testing.T) {
diff.Test(t, t.Errorf, got, want, ignoreProviderIDs)
}

func TestCheckoutRequiredAddress(t *testing.T) {
type G struct {
successURL string
cancelURL string
bac string // billing_address_collection
trialDays string
}

var mu sync.Mutex
var got []G
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case wants(r, "GET", "/v1/customers"):
jsonEncode(t, w, msa{
"data": []msa{
{
"metadata": msa{
"tier.org": "org:demo",
},
},
},
})
case wants(r, "POST", "/v1/checkout/sessions"):
mu.Lock()
got = append(got, G{
successURL: r.FormValue("success_url"),
cancelURL: r.FormValue("cancel_url"),
bac: r.FormValue("billing_address_collection"),
trialDays: r.FormValue("subscription_data[trial_period_days]"),
})
mu.Unlock()
jsonEncode(t, w, msa{
"URL": "http://co.com/123",
})
default:
t.Errorf("UNEXPECTED: %s %s", r.Method, r.URL.Path)
}
})

s := httptest.NewServer(h)
t.Cleanup(s.Close)

cc := &Client{
Logf: t.Logf,
Stripe: &stripe.Client{
BaseURL: s.URL,
},
}

TF := []bool{true, false}
for _, withAddress := range TF {
for _, withFeatures := range TF {
for _, withCancel := range TF {
for _, withTrial := range TF {
got = nil

var (
bac = values.ReturnIf(withAddress, "required")
cancelURL = values.ReturnIf(withCancel, "https://c.com")
features = values.ReturnIf(withFeatures, []Feature{{}})
trialDays = values.ReturnIf(withTrial, 14)
)

link, err := cc.Checkout(context.Background(), "org:demo", "http://s.com", &CheckoutParams{
Features: features,
RequireBillingAddress: withAddress,
CancelURL: cancelURL,
TrialDays: trialDays,
})
if err != nil {
t.Fatal(err)
}

if want := "http://co.com/123"; link != want {
t.Errorf("link = %q; want %q", link, want)
}

diff.Test(t, t.Errorf, got, []G{{
successURL: "http://s.com",
cancelURL: cancelURL,
bac: bac,
trialDays: values.ReturnIf(withTrial && withFeatures, strconv.Itoa(trialDays)),
}})
}
}
}
}
}

func TestLookupPhasesNoSchedule(t *testing.T) {
// TODO(bmizerany): This tests assumptions, but we need an integration
// test provin fields "like" trial actually fall off / go to zero when
Expand Down Expand Up @@ -1124,3 +1217,12 @@ func writeHuJSON(w io.Writer, s string, args ...any) {
panic(err)
}
}

type msa map[string]any

func jsonEncode(t *testing.T, w io.Writer, v any) {
t.Helper()
if err := json.NewEncoder(w).Encode(v); err != nil {
t.Error(err)
}
}
17 changes: 17 additions & 0 deletions values/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ func MaybeSet[T comparable](v *T, a T) {
*v = Coalesce(*v, a)
}

// ReturnIf retuns a if cond is true; otherwise the zero value of T is
// returned.
func ReturnIf[T any](cond bool, a T) T {
if cond {
return a
}
var zero T
return zero
}

// SetIf sets v to a if cond is true
func SetIf[T any](v *T, cond bool, a T) {
if cond {
*v = a
}
}

type Collection[K comparable, V any] map[K][]V

func (c *Collection[K, V]) Add(key K, v V) {
Expand Down

0 comments on commit 55725f1

Please sign in to comment.