Skip to content

Commit

Permalink
Merge pull request #221 from SkynetLabs/ivo/fix_stripe_subs_cancelling
Browse files Browse the repository at this point in the history
Fix Stripe duplicate sub canceling.
  • Loading branch information
ro-tex committed Jun 23, 2022
2 parents 3b240ee + e5fad5d commit d15e4dc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 23 deletions.
26 changes: 14 additions & 12 deletions api/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ var (
// `https://account.` prepended to it).
DashboardURL = "https://account.siasky.net"

// StripeTestMode tells us whether to use Stripe's test mode or prod mode
// plan and price ids. This depends on what kind of key is stored in the
// STRIPE_API_KEY environment variable.
StripeTestMode = false

// True is a helper for when we need to pass a *bool to Stripe.
True = true

Expand Down Expand Up @@ -96,8 +91,11 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
Customer: s.Customer.ID,
Status: string(stripe.SubscriptionStatusActive),
})
// Pick the latest active plan and set the user's tier based on that.
subs := it.SubscriptionList().Data
if len(subs) > 1 {
api.staticLogger.Tracef("More than one active subscription detected: %+v", subs)
}
// Pick the latest active plan and set the user's tier based on that.
var mostRecentSub *stripe.Subscription
for _, subsc := range subs {
if mostRecentSub == nil || subsc.Created > mostRecentSub.Created {
Expand All @@ -122,9 +120,6 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
}
// Cancel all subs aside from the latest one.
p := stripe.SubscriptionCancelParams{
Params: stripe.Params{
StripeAccount: &s.Customer.ID,
},
InvoiceNow: &True,
Prorate: &True,
}
Expand All @@ -136,9 +131,10 @@ func (api *API) processStripeSub(ctx context.Context, s *stripe.Subscription) er
api.staticLogger.Warnf("Empty subscription ID! User ID '%s', Stripe ID '%s', subscription object '%+v'", u.ID.Hex(), u.StripeID, subs)
continue
}
subsc, err = sub.Cancel(subsc.ID, &p)
cs, err := sub.Cancel(subsc.ID, &p)
if err != nil {
api.staticLogger.Warnf("Failed to cancel sub with id '%s' for user '%s' with Stripe customer id '%s'. Error: '%s'", subsc.ID, u.ID.Hex(), s.Customer.ID, err.Error())
api.staticLogger.Tracef("Sub information returned by Stripe: %+v", cs)
} else {
api.staticLogger.Tracef("Successfully cancelled sub with id '%s' for user '%s' with Stripe customer id '%s'.", subsc.ID, u.ID.Hex(), s.Customer.ID)
}
Expand Down Expand Up @@ -329,7 +325,8 @@ func (api *API) stripeWebhookPOST(_ *database.User, w http.ResponseWriter, req *
return
}
// Check the details about this subscription:
s, err := sub.Get(hasSub.Sub, nil)
var s *stripe.Subscription
s, err = sub.Get(hasSub.Sub, nil)
if err != nil {
api.staticLogger.Debugln("Webhook: Failed to fetch sub:", err)
api.WriteError(w, err, http.StatusInternalServerError)
Expand Down Expand Up @@ -365,8 +362,13 @@ func readStripeEvent(w http.ResponseWriter, req *http.Request) (*stripe.Event, i

// StripePrices returns a mapping of Stripe price ids to Skynet tiers.
func StripePrices() map[string]int {
if StripeTestMode {
if StripeTestMode() {
return stripePricesTest
}
return stripePricesProd
}

// StripeTestMode tells us whether we're using a test key or a live key.
func StripeTestMode() bool {
return strings.HasPrefix(stripe.Key, "sk_test_")
}
40 changes: 40 additions & 0 deletions api/stripe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package api

import (
"reflect"
"testing"

"github.com/stripe/stripe-go/v72"
)

// TestStripePrices ensures that we work with the correct set of prices.
func TestStripePrices(t *testing.T) {
// Set the Stripe key to a live key.
stripe.Key = "sk_live_FAKE_LIVE_KEY"
// Make sure we got the prod prices we expect.
if !reflect.DeepEqual(StripePrices(), stripePricesProd) {
t.Fatal("Expected prod prices, got something else.")
}
// Set the Stripe key to a test key.
stripe.Key = "sk_test_FAKE_TEST_KEY"
// Make sure we got the prod prices we expect.
if !reflect.DeepEqual(StripePrices(), stripePricesTest) {
t.Fatal("Expected test prices, got something else.")
}
}

// TestStripeTestMode ensures that we detect test mode accurately.
func TestStripeTestMode(t *testing.T) {
// Set the Stripe key to a live key.
stripe.Key = "sk_live_FAKE_LIVE_KEY"
// Expect test mode to be off.
if StripeTestMode() {
t.Fatal("Expected live mode, got test mode.")
}
// Set the Stripe key to a test key.
stripe.Key = "sk_test_FAKE_TEST_KEY"
// Expect test mode to be on.
if !StripeTestMode() {
t.Fatal("Expected test mode, got live mode.")
}
}
4 changes: 0 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net/url"
"os"
"strconv"
"strings"

"github.com/SkynetLabs/skynet-accounts/api"
"github.com/SkynetLabs/skynet-accounts/build"
Expand Down Expand Up @@ -70,7 +69,6 @@ type (
PortalAddressAccounts string
ServerLockID string
StripeKey string
StripeTestMode bool
JWKSFile string
JWTTTL int
EmailURI string
Expand Down Expand Up @@ -145,7 +143,6 @@ func parseConfiguration(logger *logrus.Logger) (ServiceConfig, error) {

if sk := os.Getenv(envStripeAPIKey); sk != "" {
config.StripeKey = sk
config.StripeTestMode = !strings.HasPrefix(sk, "sk_live_")
}
if jwks := os.Getenv(envAccountsJWKSFile); jwks != "" {
config.JWKSFile = jwks
Expand Down Expand Up @@ -231,7 +228,6 @@ func main() {
api.DashboardURL = config.PortalAddressAccounts
email.ServerLockID = config.ServerLockID
stripe.Key = config.StripeKey
api.StripeTestMode = config.StripeTestMode
jwt.AccountsJWKSFile = config.JWKSFile
jwt.TTL = config.JWTTTL
email.From = config.EmailFrom
Expand Down
6 changes: 0 additions & 6 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,6 @@ func TestParseConfiguration(t *testing.T) {
if config.StripeKey != sk {
t.Fatalf("Expected %s, got %s", sk, config.StripeKey)
}
if config.StripeTestMode {
t.Fatal("Expected live mode.")
}
if config.ServerLockID != serverDomain {
t.Fatalf("Expected %s, got %s", serverDomain, config.ServerLockID)
}
Expand Down Expand Up @@ -227,9 +224,6 @@ func TestParseConfiguration(t *testing.T) {
if config.StripeKey != sk {
t.Fatalf("Expected %s, got %s", sk, config.StripeKey)
}
if !config.StripeTestMode {
t.Fatal("Expected test mode.")
}
if config.MaxAPIKeys != maxKeys {
t.Fatalf("Expected %d, got %d", maxKeys, config.MaxAPIKeys)
}
Expand Down
1 change: 0 additions & 1 deletion test/api/stripe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ func TestStripe(t *testing.T) {
"Expected STRIPE_API_KEY that starts with '%s', got '%s'", t.Name(), "sk_test_", key[:8])
}
stripe.Key = key
api.StripeTestMode = true

tests := map[string]func(t *testing.T, at *test.AccountsTester){
"get billing": testStripeBillingGET,
Expand Down

0 comments on commit d15e4dc

Please sign in to comment.