diff --git a/api/routes.go b/api/routes.go index 0692dfb2..dfcda593 100644 --- a/api/routes.go +++ b/api/routes.go @@ -77,7 +77,9 @@ func (api *API) buildHTTPRoutes() { api.staticRouter.POST("/user/recover/request", api.WithDBSession(api.noAuth(api.userRecoverRequestPOST))) api.staticRouter.POST("/user/recover", api.WithDBSession(api.noAuth(api.userRecoverPOST))) - api.staticRouter.POST("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingPOST, false))) + api.staticRouter.GET("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingHANDLER, false))) + // `POST /stripe/billing` is deprecated. Please use `GET /stripe/billing`. + api.staticRouter.POST("/stripe/billing", api.WithDBSession(api.withAuth(api.stripeBillingHANDLER, false))) api.staticRouter.POST("/stripe/checkout", api.WithDBSession(api.withAuth(api.stripeCheckoutPOST, false))) api.staticRouter.GET("/stripe/prices", api.noAuth(api.stripePricesGET)) api.staticRouter.POST("/stripe/webhook", api.WithDBSession(api.noAuth(api.stripeWebhookPOST))) diff --git a/api/stripe.go b/api/stripe.go index 965c8482..fc5cce11 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -152,10 +152,10 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er return err } -// stripeBillingPOST creates a new billing session for the user and redirects +// stripeBillingHANDLER creates a new billing session for the user and redirects // them to it. If the user does not yet have a Stripe customer, one is // registered for them. -func (api *API) stripeBillingPOST(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { +func (api *API) stripeBillingHANDLER(u *database.User, w http.ResponseWriter, req *http.Request, _ httprouter.Params) { if u.StripeID == "" { id, err := api.stripeCreateCustomer(req.Context(), u) if err != nil { diff --git a/test/api/stripe_test.go b/test/api/stripe_test.go index 25ede319..9ec52465 100644 --- a/test/api/stripe_test.go +++ b/test/api/stripe_test.go @@ -35,6 +35,7 @@ func TestStripe(t *testing.T) { api.StripeTestMode = true tests := map[string]func(t *testing.T, at *test.AccountsTester){ + "get billing": testStripeBillingGET, "post billing": testStripeBillingPOST, "get prices": testStripePricesGET, "post checkout": testStripeCheckoutPOST, @@ -52,6 +53,37 @@ func TestStripe(t *testing.T) { } } +// testStripeBillingGET ensures that we can create a new billing session. +func testStripeBillingGET(t *testing.T, at *test.AccountsTester) { + name := test.DBNameForTest(t.Name()) + r, _, err := at.UserPOST(name+"@siasky.net", name+"pass") + if err != nil { + t.Fatal(err) + } + c := test.ExtractCookie(r) + + at.SetFollowRedirects(false) + + // Try to start a billing session without valid user auth. + at.ClearCredentials() + _, s, err := at.StripeBillingGET() + if err == nil || s != http.StatusUnauthorized { + t.Fatalf("Expected 401 Unauthorized, got %d %s", s, err) + } + // Try with a valid user. Expect a temporary redirect error. This is not a + // fail case, we expect that to happen. In production we'll follow that + // redirect. + at.SetCookie(c) + h, s, err := at.StripeBillingGET() + if err != nil || s != http.StatusTemporaryRedirect { + t.Fatalf("Expected %d and no error, got %d '%s'", http.StatusTemporaryRedirect, s, err) + } + expectedRedirectPrefix := "https://billing.stripe.com/session/" + if !strings.HasPrefix(h.Get("Location"), expectedRedirectPrefix) { + t.Fatalf("Expected a redirect link with prefix '%s', got '%s'", expectedRedirectPrefix, h.Get("Location")) + } +} + // testStripeBillingPOST ensures that we can create a new billing session. func testStripeBillingPOST(t *testing.T, at *test.AccountsTester) { name := test.DBNameForTest(t.Name()) diff --git a/test/database/user_test.go b/test/database/user_test.go index 1926f1f6..7918426f 100644 --- a/test/database/user_test.go +++ b/test/database/user_test.go @@ -265,7 +265,7 @@ func TestUserConfirmEmail(t *testing.T) { t.Fatal("Failed to generate a token.") } // Set the expiration of the token in the past. - u.EmailConfirmationTokenExpiration = time.Now().UTC().Add(-time.Minute) + u.EmailConfirmationTokenExpiration = time.Now().UTC().Add(-time.Minute).Truncate(time.Millisecond) err = db.UserSave(ctx, u) if err != nil { t.Fatal("Failed to save the user:", err) diff --git a/test/tester.go b/test/tester.go index c6bae60a..999cc12d 100644 --- a/test/tester.go +++ b/test/tester.go @@ -668,6 +668,17 @@ func (at *AccountsTester) UploadInfo(sl string) ([]api.UploadInfo, int, error) { /*** Stripe helpers ***/ +// StripeBillingGET performs a `GET /stripe/billing` +func (at *AccountsTester) StripeBillingGET() (http.Header, int, error) { + r, err := at.Request(http.MethodGet, "/stripe/billing", nil, nil, nil, nil) + // We ignore the temporary redirect error because it's the expected + // behaviour of this endpoint. + if err != nil && !strings.Contains(err.Error(), "307 Temporary Redirect") { + return nil, r.StatusCode, err + } + return r.Header, r.StatusCode, nil +} + // StripeBillingPOST performs a `POST /stripe/billing` func (at *AccountsTester) StripeBillingPOST() (http.Header, int, error) { r, err := at.Request(http.MethodPost, "/stripe/billing", nil, nil, nil, nil)