From 8096e3cb95aefe0d50e10843559266cdb77b8ee5 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Fri, 10 Jun 2022 17:50:55 +0200 Subject: [PATCH 1/2] Fix Stripe duplicate sub canceling. --- api/stripe.go | 18 ++++++++---------- api/stripe_test.go | 40 +++++++++++++++++++++++++++++++++++++++ build/debug_off.go | 1 + build/debug_on.go | 1 + build/release_dev.go | 1 + build/release_standard.go | 1 + build/release_testing.go | 1 + main.go | 4 ---- main_test.go | 6 ------ test/api/stripe_test.go | 1 - 10 files changed, 53 insertions(+), 21 deletions(-) create mode 100644 api/stripe_test.go diff --git a/api/stripe.go b/api/stripe.go index fc5cce11..6e5548ad 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -34,11 +34,6 @@ var ( // `https://account.` prepended to it). DashboardURL = "https://account.siasky.net" - // StripeTestMode tells us whether to use Stripe's test mode or prod mode - // plan and price ids. This depends on what kind of key is stored in the - // STRIPE_API_KEY environment variable. - StripeTestMode = false - // True is a helper for when we need to pass a *bool to Stripe. True = true @@ -122,9 +117,6 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er } // Cancel all subs aside from the latest one. p := stripe.SubscriptionCancelParams{ - Params: stripe.Params{ - StripeAccount: &s.Customer.ID, - }, InvoiceNow: &True, Prorate: &True, } @@ -329,7 +321,8 @@ func (api *API) stripeWebhookPOST(_ *database.User, w http.ResponseWriter, req * return } // Check the details about this subscription: - s, err := sub.Get(hasSub.Sub, nil) + var s *stripe.Subscription + s, err = sub.Get(hasSub.Sub, nil) if err != nil { api.staticLogger.Debugln("Webhook: Failed to fetch sub:", err) api.WriteError(w, err, http.StatusInternalServerError) @@ -365,8 +358,13 @@ func readStripeEvent(w http.ResponseWriter, req *http.Request) (*stripe.Event, i // StripePrices returns a mapping of Stripe price ids to Skynet tiers. func StripePrices() map[string]int { - if StripeTestMode { + if StripeTestMode() { return stripePricesTest } return stripePricesProd } + +// StripeTestMode tells us whether we're using a test key or a live key. +func StripeTestMode() bool { + return strings.HasPrefix(stripe.Key, "sk_test_") +} diff --git a/api/stripe_test.go b/api/stripe_test.go new file mode 100644 index 00000000..4db405ba --- /dev/null +++ b/api/stripe_test.go @@ -0,0 +1,40 @@ +package api + +import ( + "reflect" + "testing" + + "github.com/stripe/stripe-go/v72" +) + +// TestStripePrices ensures that we work with the correct set of prices. +func TestStripePrices(t *testing.T) { + // Set the Stripe key to a live key. + stripe.Key = "sk_live_FAKE_LIVE_KEY" + // Make sure we got the prod prices we expect. + if !reflect.DeepEqual(StripePrices(), stripePricesProd) { + t.Fatal("Expected prod prices, got something else.") + } + // Set the Stripe key to a test key. + stripe.Key = "sk_test_FAKE_TEST_KEY" + // Make sure we got the prod prices we expect. + if !reflect.DeepEqual(StripePrices(), stripePricesTest) { + t.Fatal("Expected test prices, got something else.") + } +} + +// TestStripeTestMode ensures that we detect test mode accurately. +func TestStripeTestMode(t *testing.T) { + // Set the Stripe key to a live key. + stripe.Key = "sk_live_FAKE_LIVE_KEY" + // Expect test mode to be off. + if StripeTestMode() { + t.Fatal("Expected live mode, got test mode.") + } + // Set the Stripe key to a test key. + stripe.Key = "sk_test_FAKE_TEST_KEY" + // Expect test mode to be on. + if !StripeTestMode() { + t.Fatal("Expected test mode, got live mode.") + } +} diff --git a/build/debug_off.go b/build/debug_off.go index edd7789e..330edcf9 100644 --- a/build/debug_off.go +++ b/build/debug_off.go @@ -1,3 +1,4 @@ +//go:build !debug // +build !debug package build diff --git a/build/debug_on.go b/build/debug_on.go index eab60313..1e6f5012 100644 --- a/build/debug_on.go +++ b/build/debug_on.go @@ -1,3 +1,4 @@ +//go:build debug // +build debug package build diff --git a/build/release_dev.go b/build/release_dev.go index 69cc259e..fda05068 100644 --- a/build/release_dev.go +++ b/build/release_dev.go @@ -1,3 +1,4 @@ +//go:build dev // +build dev package build diff --git a/build/release_standard.go b/build/release_standard.go index 39971c21..7c88a6ed 100644 --- a/build/release_standard.go +++ b/build/release_standard.go @@ -1,3 +1,4 @@ +//go:build !testing && !dev // +build !testing,!dev package build diff --git a/build/release_testing.go b/build/release_testing.go index e6aeedfb..38943e71 100644 --- a/build/release_testing.go +++ b/build/release_testing.go @@ -1,3 +1,4 @@ +//go:build testing // +build testing package build diff --git a/main.go b/main.go index 161afd42..4b713717 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "strconv" - "strings" "github.com/SkynetLabs/skynet-accounts/api" "github.com/SkynetLabs/skynet-accounts/build" @@ -70,7 +69,6 @@ type ( PortalAddressAccounts string ServerLockID string StripeKey string - StripeTestMode bool JWKSFile string JWTTTL int EmailURI string @@ -145,7 +143,6 @@ func parseConfiguration(logger *logrus.Logger) (ServiceConfig, error) { if sk := os.Getenv(envStripeAPIKey); sk != "" { config.StripeKey = sk - config.StripeTestMode = !strings.HasPrefix(sk, "sk_live_") } if jwks := os.Getenv(envAccountsJWKSFile); jwks != "" { config.JWKSFile = jwks @@ -231,7 +228,6 @@ func main() { api.DashboardURL = config.PortalAddressAccounts email.ServerLockID = config.ServerLockID stripe.Key = config.StripeKey - api.StripeTestMode = config.StripeTestMode jwt.AccountsJWKSFile = config.JWKSFile jwt.TTL = config.JWTTTL email.From = config.EmailFrom diff --git a/main_test.go b/main_test.go index 109d8916..8c0bec05 100644 --- a/main_test.go +++ b/main_test.go @@ -190,9 +190,6 @@ func TestParseConfiguration(t *testing.T) { if config.StripeKey != sk { t.Fatalf("Expected %s, got %s", sk, config.StripeKey) } - if config.StripeTestMode { - t.Fatal("Expected live mode.") - } if config.ServerLockID != serverDomain { t.Fatalf("Expected %s, got %s", serverDomain, config.ServerLockID) } @@ -227,9 +224,6 @@ func TestParseConfiguration(t *testing.T) { if config.StripeKey != sk { t.Fatalf("Expected %s, got %s", sk, config.StripeKey) } - if !config.StripeTestMode { - t.Fatal("Expected test mode.") - } if config.MaxAPIKeys != maxKeys { t.Fatalf("Expected %d, got %d", maxKeys, config.MaxAPIKeys) } diff --git a/test/api/stripe_test.go b/test/api/stripe_test.go index 9ec52465..60287eb0 100644 --- a/test/api/stripe_test.go +++ b/test/api/stripe_test.go @@ -32,7 +32,6 @@ func TestStripe(t *testing.T) { "Expected STRIPE_API_KEY that starts with '%s', got '%s'", t.Name(), "sk_test_", key[:8]) } stripe.Key = key - api.StripeTestMode = true tests := map[string]func(t *testing.T, at *test.AccountsTester){ "get billing": testStripeBillingGET, From e5fad5da3a25cb3cc61436ee735025b48c9f5683 Mon Sep 17 00:00:00 2001 From: Ivaylo Novakov Date: Mon, 20 Jun 2022 18:28:40 +0300 Subject: [PATCH 2/2] Fix variable overwriting in order to improve logged information. Add additional logging. --- api/stripe.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/api/stripe.go b/api/stripe.go index 6e5548ad..afe4ba6b 100644 --- a/api/stripe.go +++ b/api/stripe.go @@ -91,8 +91,11 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er Customer: s.Customer.ID, Status: string(stripe.SubscriptionStatusActive), }) - // Pick the latest active plan and set the user's tier based on that. subs := it.SubscriptionList().Data + if len(subs) > 1 { + api.staticLogger.Tracef("More than one active subscription detected: %+v", subs) + } + // Pick the latest active plan and set the user's tier based on that. var mostRecentSub *stripe.Subscription for _, subsc := range subs { if mostRecentSub == nil || subsc.Created > mostRecentSub.Created { @@ -128,9 +131,10 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er api.staticLogger.Warnf("Empty subscription ID! User ID '%s', Stripe ID '%s', subscription object '%+v'", u.ID.Hex(), u.StripeID, subs) continue } - subsc, err = sub.Cancel(subsc.ID, &p) + cs, err := sub.Cancel(subsc.ID, &p) if err != nil { api.staticLogger.Warnf("Failed to cancel sub with id '%s' for user '%s' with Stripe customer id '%s'. Error: '%s'", subsc.ID, u.ID.Hex(), s.Customer.ID, err.Error()) + api.staticLogger.Tracef("Sub information returned by Stripe: %+v", cs) } else { api.staticLogger.Tracef("Successfully cancelled sub with id '%s' for user '%s' with Stripe customer id '%s'.", subsc.ID, u.ID.Hex(), s.Customer.ID) }