Skip to content

Commit

Permalink
Add context for google login and cancel if HTTP request is cancelled
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Feb 26, 2024
1 parent 7db7fdf commit 59b3b7d
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 33 deletions.
2 changes: 1 addition & 1 deletion commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func fnLoginGoogleCookies(ce *WrappedCommandEvent) {
return
}
ce.Redact()
err = ce.User.LoginGoogle(cookies, func(emoji string) {
err = ce.User.LoginGoogle(ce.Ctx, cookies, func(emoji string) {
ce.Reply(emoji)
})
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion libgm/gmtest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -56,7 +57,7 @@ func main() {
cli = libgm.NewClient(&sess, log)
cli.SetEventHandler(evtHandler)
if doLogin {
err = cli.DoGaiaPairing(func(emoji string) {
err = cli.DoGaiaPairing(context.TODO(), func(emoji string) {
fmt.Println(emoji)
})
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion libgm/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ const ContentTypeProtobuf = "application/x-protobuf"
const ContentTypePBLite = "application/json+protobuf"

func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, contentType string) (*http.Response, error) {
ctx := c.Logger.WithContext(context.TODO())
return c.makeProtobufHTTPRequestContext(ctx, url, data, contentType)
}

func (c *Client) makeProtobufHTTPRequestContext(ctx context.Context, url string, data proto.Message, contentType string) (*http.Response, error) {
var body []byte
var err error
switch contentType {
Expand All @@ -37,7 +42,6 @@ func (c *Client) makeProtobufHTTPRequest(url string, data proto.Message, content
if err != nil {
return nil, err
}
ctx := c.Logger.WithContext(context.TODO())
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
Expand Down
39 changes: 21 additions & 18 deletions libgm/pair_google.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package libgm

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
Expand Down Expand Up @@ -61,15 +62,15 @@ func (c *Client) baseSignInGaiaPayload() *gmproto.SignInGaiaRequest {
}
}

func (c *Client) signInGaiaInitial() (*gmproto.SignInGaiaResponse, error) {
func (c *Client) signInGaiaInitial(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
payload := c.baseSignInGaiaPayload()
payload.UnknownInt3 = 1
return typedHTTPResponse[*gmproto.SignInGaiaResponse](
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite),
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
)
}

func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
func (c *Client) signInGaiaGetToken(ctx context.Context) (*gmproto.SignInGaiaResponse, error) {
key, err := x509.MarshalPKIXPublicKey(c.AuthData.RefreshKey.GetPublicKey())
if err != nil {
return nil, err
Expand All @@ -80,7 +81,7 @@ func (c *Client) signInGaiaGetToken() (*gmproto.SignInGaiaResponse, error) {
SomeData: key,
}
resp, err := typedHTTPResponse[*gmproto.SignInGaiaResponse](
c.makeProtobufHTTPRequest(util.SignInGaiaURL, payload, ContentTypePBLite),
c.makeProtobufHTTPRequestContext(ctx, util.SignInGaiaURL, payload, ContentTypePBLite),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -242,11 +243,11 @@ var (
ErrPairingTimeout = errors.New("pairing timed out")
)

func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
func (c *Client) DoGaiaPairing(ctx context.Context, emojiCallback func(string)) error {
if len(c.AuthData.Cookies) == 0 {
return ErrNoCookies
}
sigResp, err := c.signInGaiaGetToken()
sigResp, err := c.signInGaiaGetToken(ctx)
if err != nil {
return fmt.Errorf("failed to prepare gaia pairing: %w", err)
}
Expand All @@ -272,7 +273,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
if err != nil {
return fmt.Errorf("failed to prepare pairing payloads: %w", err)
}
serverInit, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
serverInit, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_INIT, clientInit)
if err != nil {
return fmt.Errorf("failed to send client init: %w", err)
}
Expand All @@ -281,7 +282,7 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
return fmt.Errorf("error processing server init: %w", err)
}
emojiCallback(pairingEmoji)
finishResp, err := c.sendGaiaPairingMessage(ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
finishResp, err := c.sendGaiaPairingMessage(ctx, ps, gmproto.ActionType_CREATE_GAIA_PAIRING_CLIENT_FINISHED, clientFinish)
if finishResp.GetFinishErrorType() != 0 {
switch finishResp.GetFinishErrorCode() {
case 5:
Expand Down Expand Up @@ -312,8 +313,8 @@ func (c *Client) DoGaiaPairing(emojiCallback func(string)) error {
return nil
}

func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
resp, err := c.sessionHandler.sendMessageWithParams(SendMessageParams{
func (c *Client) sendGaiaPairingMessage(ctx context.Context, sess PairingSession, action gmproto.ActionType, msg []byte) (*gmproto.GaiaPairingResponseContainer, error) {
respCh, err := c.sessionHandler.sendAsyncMessage(SendMessageParams{
Action: action,
Data: &gmproto.GaiaPairingRequestContainer{
PairingAttemptID: sess.UUID.String(),
Expand All @@ -324,18 +325,21 @@ func (c *Client) sendGaiaPairingMessage(sess PairingSession, action gmproto.Acti
DontEncrypt: true,
CustomTTL: (300 * time.Second).Microseconds(),
MessageType: gmproto.MessageType_GAIA_2,

NoPingOnTimeout: true,
})
if err != nil {
return nil, err
}
var respDat gmproto.GaiaPairingResponseContainer
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
if err != nil {
return nil, err
select {
case resp := <-respCh:
var respDat gmproto.GaiaPairingResponseContainer
err = proto.Unmarshal(resp.Message.UnencryptedData, &respDat)
if err != nil {
return nil, err
}
return &respDat, nil
case <-ctx.Done():
return nil, ctx.Err()
}
return &respDat, nil
}

func (c *Client) UnpairGaia() error {
Expand All @@ -344,6 +348,5 @@ func (c *Client) UnpairGaia() error {
Data: &gmproto.RevokeGaiaPairingRequest{
PairingAttemptID: c.AuthData.PairingID.String(),
},
NoPingOnTimeout: true,
})
}
6 changes: 0 additions & 6 deletions libgm/session_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,6 @@ func (s *SessionHandler) sendMessageWithParams(params SendMessageParams) (*Incom
return nil, err
}

if params.NoPingOnTimeout {
return <-ch, nil
}

select {
case resp := <-ch:
return resp, nil
Expand Down Expand Up @@ -175,8 +171,6 @@ type SendMessageParams struct {
CustomTTL int64
DontEncrypt bool
MessageType gmproto.MessageType

NoPingOnTimeout bool
}

func (s *SessionHandler) buildMessage(params SendMessageParams) (string, proto.Message, error) {
Expand Down
8 changes: 7 additions & 1 deletion provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ

log := prov.zlog.With().Str("user_id", user.MXID.String()).Str("endpoint", "login").Logger()

err := user.AsyncLoginGoogleWait()
err := user.AsyncLoginGoogleWait(r.Context())
if err != nil {
log.Err(err).Msg("Failed to wait for google login")
switch {
Expand All @@ -388,6 +388,12 @@ func (prov *ProvisioningAPI) GoogleLoginWait(w http.ResponseWriter, r *http.Requ
Error: err.Error(),
ErrCode: "timeout",
})
case errors.Is(err, context.Canceled):
// This should only happen if the client already disconnected, so clients will probably never see this error code.
jsonResponse(w, http.StatusBadRequest, Error{
Error: err.Error(),
ErrCode: "context-cancelled",
})
default:
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Failed to finish login",
Expand Down
24 changes: 19 additions & 5 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,13 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
outEmoji = emoji
initialWait.Done()
}
var ctx context.Context
ctx, user.cancelLogin = context.WithCancel(context.Background())
go func() {
err := user.LoginGoogle(cookies, callback)
defer func() {
user.cancelLogin = nil
}()
err := user.LoginGoogle(ctx, cookies, callback)
if !callbackDone {
user.zlog.Err(err).Msg("Async google login failed before callback")
initialWait.Done()
Expand All @@ -505,15 +510,24 @@ func (user *User) AsyncLoginGoogleStart(cookies map[string]string) (outEmoji str
return
}

func (user *User) AsyncLoginGoogleWait() error {
func (user *User) AsyncLoginGoogleWait(ctx context.Context) error {
ch := user.googleAsyncPairErrChan.Swap(nil)
if ch == nil {
return ErrNoLoginInProgress
}
return <-*ch
select {
case ret := <-*ch:
return ret
case <-ctx.Done():
user.zlog.Err(ctx.Err()).Msg("Login wait context canceled, canceling login")
if cancelLogin := user.cancelLogin; cancelLogin != nil {
cancelLogin()
}
return ctx.Err()
}
}

func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(string)) error {
func (user *User) LoginGoogle(ctx context.Context, cookies map[string]string, emojiCallback func(string)) error {
user.connLock.Lock()
defer user.connLock.Unlock()
if user.Session != nil {
Expand All @@ -533,7 +547,7 @@ func (user *User) LoginGoogle(cookies map[string]string, emojiCallback func(stri
authData.Cookies = cookies
user.createClient(authData)
Analytics.Track(user.MXID, "$login_start")
err := user.Client.DoGaiaPairing(emojiCallback)
err := user.Client.DoGaiaPairing(ctx, emojiCallback)
if err != nil {
user.unlockedDeleteConnection()
return err
Expand Down

0 comments on commit 59b3b7d

Please sign in to comment.