Skip to content

Commit

Permalink
fix(exchange): add some nonce checks
Browse files Browse the repository at this point in the history
  • Loading branch information
tdakkota committed Feb 17, 2021
1 parent 9bc6d57 commit 39fc824
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 38 deletions.
27 changes: 21 additions & 6 deletions internal/exchange/client_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ import (
"math/big"

"go.uber.org/zap"

"github.com/gotd/td/internal/proto"

"golang.org/x/xerrors"

"github.com/gotd/td/bin"
"github.com/gotd/td/internal/crypto"
"github.com/gotd/td/internal/mt"
"github.com/gotd/td/internal/proto"
)

// Run runs client-side flow.
Expand All @@ -41,6 +39,7 @@ func (c ClientExchange) Run(ctx context.Context) (ClientExchangeResult, error) {
if res.Nonce != nonce {
return ClientExchangeResult{}, xerrors.New("ResPQ nonce mismatch")
}
serverNonce := res.ServerNonce

// Selecting first public key that match fingerprint.
var selectedPubKey *rsa.PublicKey
Expand Down Expand Up @@ -89,7 +88,7 @@ Loop:
Pq: res.Pq,
Nonce: nonce,
NewNonce: newNonce,
ServerNonce: res.ServerNonce,
ServerNonce: serverNonce,
P: pBytes,
Q: qBytes,
}
Expand All @@ -105,7 +104,7 @@ Loop:
}
reqDHParams := &mt.ReqDHParamsRequest{
Nonce: nonce,
ServerNonce: res.ServerNonce,
ServerNonce: serverNonce,
P: pBytes,
Q: qBytes,
PublicKeyFingerprint: crypto.RSAFingerprint(selectedPubKey),
Expand Down Expand Up @@ -138,8 +137,11 @@ Loop:
if p.Nonce != nonce {
return ClientExchangeResult{}, xerrors.New("ServerDHParamsOk nonce mismatch")
}
if p.ServerNonce != serverNonce {
return ClientExchangeResult{}, xerrors.New("ServerDHParamsOk server nonce mismatch")
}

key, iv := crypto.TempAESKeys(newNonce.BigInt(), res.ServerNonce.BigInt())
key, iv := crypto.TempAESKeys(newNonce.BigInt(), serverNonce.BigInt())
// Decrypting inner data.
data, err := crypto.DecryptExchangeAnswer(p.EncryptedAnswer, key, iv)
if err != nil {
Expand All @@ -151,6 +153,12 @@ Loop:
if err := innerData.Decode(b); err != nil {
return ClientExchangeResult{}, err
}
if innerData.Nonce != nonce {
return ClientExchangeResult{}, xerrors.New("ServerDHInnerData nonce mismatch")
}
if innerData.ServerNonce != serverNonce {
return ClientExchangeResult{}, xerrors.New("ServerDHInnerData server nonce mismatch")
}

dhPrime := big.NewInt(0).SetBytes(innerData.DhPrime)
g := big.NewInt(int64(innerData.G))
Expand Down Expand Up @@ -215,6 +223,13 @@ Loop:
}
switch v := dhSetRes.(type) {
case *mt.DhGenOk: // dh_gen_ok#3bcbf734
if v.Nonce != nonce {
return ClientExchangeResult{}, xerrors.New("DhGenOk nonce mismatch")
}
if v.ServerNonce != serverNonce {
return ClientExchangeResult{}, xerrors.New("DhGenOk server nonce mismatch")
}

var key crypto.Key
authKey.FillBytes(key[:])
authKeyID := key.ID()
Expand Down
47 changes: 47 additions & 0 deletions internal/exchange/client_flow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package exchange

import (
"context"
"crypto/rsa"
"math/rand"
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"

"github.com/gotd/td/internal/tdsync"
"github.com/gotd/td/transport"
)

func TestExchangeTimeout(t *testing.T) {
a := require.New(t)

reader := rand.New(rand.NewSource(1))
key, err := rsa.GenerateKey(reader, 2048)
a.NoError(err)
log := zaptest.NewLogger(t)

i := transport.Intermediate(nil)
client, _ := i.Pipe()

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

grp := tdsync.NewCancellableGroup(ctx)
grp.Go(func(groupCtx context.Context) error {
_, err := NewExchanger(client).
WithLogger(log.Named("client")).
WithRand(reader).
WithTimeout(1 * time.Second).
Client([]*rsa.PublicKey{&key.PublicKey}).
Run(groupCtx)
return err
})

err = grp.Wait()
if err, ok := err.(net.Error); !ok || !err.Timeout() {
require.NoError(t, err)
}
}
32 changes: 0 additions & 32 deletions internal/exchange/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/rsa"
"fmt"
"math/rand"
"net"
"testing"
"time"

Expand Down Expand Up @@ -53,37 +52,6 @@ func TestExchange(t *testing.T) {
require.NoError(t, grp.Wait())
}

func TestExchangeTimeout(t *testing.T) {
a := require.New(t)

reader := rand.New(rand.NewSource(1))
key, err := rsa.GenerateKey(reader, 2048)
a.NoError(err)
log := zaptest.NewLogger(t)

i := transport.Intermediate(nil)
client, _ := i.Pipe()

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

grp := tdsync.NewCancellableGroup(ctx)
grp.Go(func(groupCtx context.Context) error {
_, err := NewExchanger(client).
WithLogger(log.Named("client")).
WithRand(reader).
WithTimeout(1 * time.Second).
Client([]*rsa.PublicKey{&key.PublicKey}).
Run(groupCtx)
return err
})

err = grp.Wait()
if err, ok := err.(net.Error); !ok || !err.Timeout() {
require.NoError(t, err)
}
}

func TestExchangeCorpus(t *testing.T) {
k := testutil.RSAPrivateKey()

Expand Down

0 comments on commit 39fc824

Please sign in to comment.