Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

api: accept payment on schedule #257

Merged
merged 1 commit into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"strings"

"github.com/kr/pretty"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -120,6 +121,17 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
})
return
}

var ipe *stripe.Error
if errors.As(err, &ipe) && strings.Contains(ipe.Message, "No such PaymentMethod") {
trweb.WriteError(w, &trweb.HTTPError{
Status: 400,
Code: "invalid_payment_method",
Message: ipe.Message,
})
return
}

if trweb.WriteError(w, lookupErr(err)) || trweb.WriteError(w, err) {
return
}
Expand Down Expand Up @@ -234,7 +246,11 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
})
}
}
return h.c.Schedule(r.Context(), sr.Org, phases)

return h.c.Schedule(r.Context(), sr.Org, control.ScheduleParams{
PaymentMethod: sr.PaymentMethodID,
Phases: phases,
})
}

func (h *Handler) serveReport(w http.ResponseWriter, r *http.Request) error {
Expand Down
21 changes: 21 additions & 0 deletions api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,27 @@ func TestAPISubscribe(t *testing.T) {
Code: "TERR1020",
Message: "feature or plan not found",
})

_, err := tc.Schedule(ctx, "org:test", &tier.ScheduleParams{
Phases: []apitypes.Phase{
{Trial: true, Features: []string{"plan:test@0"}},
},
PaymentMethodID: "pm_card_us",
})

// Quick lint check to make sure the PaymentMethod made it to Stripe.
// In production, payment methods can be set on a sub by sub basis;
// however in test mode, we may only use test payment methods, and in
// test mode, stripe does not accept test payment methods on a sub by
// sub basis, so there is no real way to test our support for this
// feature. Instead, here, we just check stripe complains about the
// payment method to show it saw what we wanted it to see in
// production.
diff.Test(t, t.Errorf, err, &apitypes.Error{
Status: 400,
Code: "invalid_payment_method",
Message: "No such PaymentMethod: 'pm_card_us'",
})
}

func TestPhaseBadOrg(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ type CheckoutRequest struct {
}

type ScheduleRequest struct {
Org string `json:"org"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
Org string `json:"org"`
PaymentMethodID string `json:"payment_method_id"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
}

// ScheduleResponse is the expected response from a schedule request. It is
Expand Down
12 changes: 7 additions & 5 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,17 @@ type CheckoutParams struct {
}

type ScheduleParams struct {
Info *OrgInfo
Phases []Phase
Info *OrgInfo
Phases []Phase
PaymentMethodID string
}

func (c *Client) Schedule(ctx context.Context, org string, p *ScheduleParams) (*apitypes.ScheduleResponse, error) {
return fetchOK[*apitypes.ScheduleResponse, *apitypes.Error](ctx, c, "POST", "/v1/subscribe", &apitypes.ScheduleRequest{
Org: org,
Info: (*apitypes.OrgInfo)(p.Info),
Phases: p.Phases,
Org: org,
Info: (*apitypes.OrgInfo)(p.Info),
Phases: p.Phases,
PaymentMethodID: p.PaymentMethodID,
})

}
Expand Down
3 changes: 3 additions & 0 deletions cmd/tier/help.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ Checkout only flags:
--cancel_url=<cancel_url>
specify a cancel_url for Stripe Checkout. This flag is ignored
if --checkout is not set.
--paymentmethod=<paymentmethod_id>
specify a payment method to use for the subscription. This flag
is ignored with --checkout.

Global Flags:

Expand Down
2 changes: 2 additions & 0 deletions cmd/tier/tier.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ func runTier(cmd string, args []string) (err error) {
successURL := fs.String("checkout", "", "subscribe via Stripe checkout")
cancelURL := fs.String("cancel_url", "", "sets the cancel URL for use with -checkout")
requireBillingAddress := fs.Bool("require_billing_address", false, "require billing address for use with --checkout")
paymentMethod := fs.String("paymentmethod", "", "sets the Stripe payment method for the subscription (e.g. pm_123). It is ignored with --checkout")
if err := fs.Parse(args); err != nil {
return err
}
Expand Down Expand Up @@ -317,6 +318,7 @@ func runTier(cmd string, args []string) (err error) {
Info: &tier.OrgInfo{
Email: *email,
},
PaymentMethodID: *paymentMethod,
}
switch {
case *trial > 0:
Expand Down
111 changes: 62 additions & 49 deletions control/schedule.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (c *Client) lookupSubscription(ctx context.Context, org, name string) (sub
return s, nil
}

func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub string, phases []Phase) (err error) {
func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub string, p ScheduleParams) (err error) {
defer errorfmt.Handlef("stripe: createSchedule: %q: %w", org, &err)

create := func(f stripe.Form) (string, error) {
Expand All @@ -200,23 +200,43 @@ func (c *Client) createSchedule(ctx context.Context, org, name string, fromSub s
}
// We can only update phases after the schedule is created from
// the subscription.
return c.updateSchedule(ctx, sid, name, phases)
return c.updateSchedule(ctx, sid, name, p)
} else {
defer errorfmt.Handlef("newSub: %w", &err)
cid, err := c.WhoIs(ctx, org)
if err != nil {
return err
}
var f stripe.Form
if p.PaymentMethod != "" {
f.Set("default_settings", "default_payment_method", p.PaymentMethod)
}
f.Set("customer", cid)
if err := addPhases(ctx, c, &f, false, name, phases); err != nil {
if err := addPhases(ctx, c, &f, false, name, p.Phases); err != nil {
return err
}
_, err = create(f)
return err
}
}

type stripeSubSchedule struct {
stripe.ID
Current struct {
Start int64 `json:"start_date"`
End int64 `json:"end_date"`
} `json:"current_phase"`
Phases []struct {
Metadata struct {
Name string `json:"tier.subscription"`
}
Start int64 `json:"start_date"`
Items []struct {
Price stripePrice
}
}
}

func (c *Client) lookupPhases(ctx context.Context, org string, s subscription, name string) (current Phase, all []Phase, err error) {
defer errorfmt.Handlef("lookupPhases: %w", &err)

Expand All @@ -225,26 +245,9 @@ func (c *Client) lookupPhases(ctx context.Context, org string, s subscription, n
return ps[0], ps, nil
}

type T struct {
stripe.ID
Current struct {
Start int64 `json:"start_date"`
End int64 `json:"end_date"`
} `json:"current_phase"`
Phases []struct {
Metadata struct {
Name string `json:"tier.subscription"`
}
Start int64 `json:"start_date"`
Items []struct {
Price stripePrice
}
}
}

g, ctx := errgroup.WithContext(ctx)

var ss T
var ss stripeSubSchedule
g.Go(func() error {
var f stripe.Form
f.Add("expand[]", "phases.items.price")
Expand Down Expand Up @@ -348,13 +351,16 @@ func subscriptionToPhases(org string, s subscription) []Phase {
return ps
}

func (c *Client) updateSchedule(ctx context.Context, schedID, name string, phases []Phase) (err error) {
func (c *Client) updateSchedule(ctx context.Context, schedID, name string, p ScheduleParams) (err error) {
defer errorfmt.Handlef("stripe: updateSchedule: %q: %w", schedID, &err)
if schedID == "" {
return errors.New("subscription id required")
}
var f stripe.Form
if err := addPhases(ctx, c, &f, true, name, phases); err != nil {
if p.PaymentMethod != "" {
f.Set("default_settings", "default_payment_method", p.PaymentMethod)
}
if err := addPhases(ctx, c, &f, true, name, p.Phases); err != nil {
return err
}
return c.Stripe.Do(ctx, "POST", "/v1/subscription_schedules/"+schedID, f, nil)
Expand Down Expand Up @@ -478,8 +484,13 @@ func (c *Client) Checkout(ctx context.Context, org string, successURL string, p
}
}

func (c *Client) Schedule(ctx context.Context, org string, phases []Phase) error {
err := c.schedule(ctx, org, phases)
type ScheduleParams struct {
PaymentMethod string
Phases []Phase
}

func (c *Client) Schedule(ctx context.Context, org string, p ScheduleParams) error {
err := c.schedule(ctx, org, p)
c.Logf("stripe: schedule: %v", err)
var e *stripe.Error
if errors.As(err, &e) {
Expand All @@ -493,21 +504,21 @@ func (c *Client) Schedule(ctx context.Context, org string, phases []Phase) error
return err
}

func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err error) {
func (c *Client) schedule(ctx context.Context, org string, p ScheduleParams) (err error) {
defer errorfmt.Handlef("tier: schedule: %q: %w", org, &err)

if err := c.PutCustomer(ctx, org, nil); err != nil {
return err
}

if len(phases) == 0 {
if len(p.Phases) == 0 {
return errors.New("tier: schedule: at least one phase required")
}

scheduleNow := phases[0].Effective.IsZero()
cancelNow := scheduleNow && len(phases[0].Features) == 0
scheduleNow := p.Phases[0].Effective.IsZero()
cancelNow := scheduleNow && len(p.Phases[0].Features) == 0

if cancelNow && len(phases) > 1 {
if cancelNow && len(p.Phases) > 1 {
return errors.New("tier: a cancel phase must be the final phase")
}

Expand All @@ -522,7 +533,7 @@ func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err
//
// If this is a "cancel immediately" request, it returns
// ErrInvalidCancel because there is no subscription to cancel.
return c.createSchedule(ctx, org, defaultScheduleName, "", phases)
return c.createSchedule(ctx, org, defaultScheduleName, "", p)
}
if err != nil {
return err
Expand All @@ -536,7 +547,7 @@ func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err

if s.ScheduleID == "" {
// We have a subscription, but it is has no active schedule, so start a new one.
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, phases)
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, p)
} else {
cp, _, err := c.lookupPhases(ctx, org, s, defaultScheduleName)
if err != nil {
Expand All @@ -546,18 +557,18 @@ func (c *Client) schedule(ctx context.Context, org string, phases []Phase) (err
if cp.Valid() {
if scheduleNow {
// attach phase to current
phases[0].Effective = cp.Effective
p.Phases[0].Effective = cp.Effective
} else {
phases = append([]Phase{cp}, phases...)
p.Phases = append([]Phase{cp}, p.Phases...)
}
}

err = c.updateSchedule(ctx, s.ScheduleID, defaultScheduleName, phases)
err = c.updateSchedule(ctx, s.ScheduleID, defaultScheduleName, p)
if isReleased(err) {
// Lost a race with the clock and the schedule was
// released just after seeing it, but before our
// update.
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, phases)
return c.createSchedule(ctx, org, defaultScheduleName, s.ID, p)
}
if err != nil {
return err
Expand All @@ -582,9 +593,9 @@ func isReleased(err error) bool {
// taking over any in-progress schedule. The customer is billed immediately
// with prorations if any.
func (c *Client) SubscribeTo(ctx context.Context, org string, fs []refs.FeaturePlan) error {
return c.Schedule(ctx, org, []Phase{{
Features: fs,
}})
return c.Schedule(ctx, org, ScheduleParams{
Phases: []Phase{{Features: fs}},
})

}

Expand Down Expand Up @@ -840,24 +851,26 @@ func (c *Client) Isolated() bool {
return c.Stripe.AccountID != ""
}

type stripeCustomer struct {
stripe.ID
Email string
Metadata struct {
Org string `json:"tier.org"`
}
}

func (c *Client) WhoIs(ctx context.Context, org string) (id string, err error) {
defer errorfmt.Handlef("whois: %q: %w", org, &err)
if !strings.HasPrefix(org, "org:") {
return "", &ValidationError{Message: "org must be prefixed with \"org:\""}
}

cid, err := c.cache.load(org, func() (string, error) {
type T struct {
stripe.ID
Email string
Metadata struct {
Org string `json:"tier.org"`
}
}
var f stripe.Form
cus, err := stripe.List[T](ctx, c.Stripe, "GET", "/v1/customers", f).Find(func(v T) bool {
return v.Metadata.Org == org
})
cus, err := stripe.List[stripeCustomer](ctx, c.Stripe, "GET", "/v1/customers", f).
Find(func(v stripeCustomer) bool {
return v.Metadata.Org == org
})
if err != nil {
return "", err
}
Expand Down
Loading