From d9a0b16cdd96d0c2156c9e62349cb0e37bb1b924 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 16 Nov 2022 10:11:20 +0100 Subject: [PATCH] curve25519: use crypto/ecdh on Go 1.20 For golang/go#52221 Change-Id: I27e867d4cc89cd52c8d510f0dbab4e89b7cd4763 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/451115 Auto-Submit: Filippo Valsorda Reviewed-by: Cherry Mui TryBot-Result: Gopher Robot Run-TryBot: Filippo Valsorda Reviewed-by: Roland Shoemaker --- curve25519/curve25519.go | 99 ++---------------------------- curve25519/curve25519_compat.go | 105 ++++++++++++++++++++++++++++++++ curve25519/curve25519_go120.go | 46 ++++++++++++++ curve25519/curve25519_test.go | 28 +++++---- curve25519/vectors_test.go | 2 +- 5 files changed, 173 insertions(+), 107 deletions(-) create mode 100644 curve25519/curve25519_compat.go create mode 100644 curve25519/curve25519_go120.go diff --git a/curve25519/curve25519.go b/curve25519/curve25519.go index bc62161d6e..00f963ea20 100644 --- a/curve25519/curve25519.go +++ b/curve25519/curve25519.go @@ -5,71 +5,18 @@ // Package curve25519 provides an implementation of the X25519 function, which // performs scalar multiplication on the elliptic curve known as Curve25519. // See RFC 7748. +// +// Starting in Go 1.20, this package is a wrapper for the X25519 implementation +// in the crypto/ecdh package. package curve25519 // import "golang.org/x/crypto/curve25519" -import ( - "crypto/subtle" - "errors" - "strconv" - - "golang.org/x/crypto/curve25519/internal/field" -) - // ScalarMult sets dst to the product scalar * point. // // Deprecated: when provided a low-order point, ScalarMult will set dst to all // zeroes, irrespective of the scalar. Instead, use the X25519 function, which // will return an error. func ScalarMult(dst, scalar, point *[32]byte) { - var e [32]byte - - copy(e[:], scalar[:]) - e[0] &= 248 - e[31] &= 127 - e[31] |= 64 - - var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element - x1.SetBytes(point[:]) - x2.One() - x3.Set(&x1) - z3.One() - - swap := 0 - for pos := 254; pos >= 0; pos-- { - b := e[pos/8] >> uint(pos&7) - b &= 1 - swap ^= int(b) - x2.Swap(&x3, swap) - z2.Swap(&z3, swap) - swap = int(b) - - tmp0.Subtract(&x3, &z3) - tmp1.Subtract(&x2, &z2) - x2.Add(&x2, &z2) - z2.Add(&x3, &z3) - z3.Multiply(&tmp0, &x2) - z2.Multiply(&z2, &tmp1) - tmp0.Square(&tmp1) - tmp1.Square(&x2) - x3.Add(&z3, &z2) - z2.Subtract(&z3, &z2) - x2.Multiply(&tmp1, &tmp0) - tmp1.Subtract(&tmp1, &tmp0) - z2.Square(&z2) - - z3.Mult32(&tmp1, 121666) - x3.Square(&x3) - tmp0.Add(&tmp0, &z3) - z3.Multiply(&x1, &z2) - z2.Multiply(&tmp1, &tmp0) - } - - x2.Swap(&x3, swap) - z2.Swap(&z3, swap) - - z2.Invert(&z2) - x2.Multiply(&x2, &z2) - copy(dst[:], x2.Bytes()) + scalarMult(dst, scalar, point) } // ScalarBaseMult sets dst to the product scalar * base where base is the @@ -78,7 +25,7 @@ func ScalarMult(dst, scalar, point *[32]byte) { // It is recommended to use the X25519 function with Basepoint instead, as // copying into fixed size arrays can lead to unexpected bugs. func ScalarBaseMult(dst, scalar *[32]byte) { - ScalarMult(dst, scalar, &basePoint) + scalarBaseMult(dst, scalar) } const ( @@ -91,21 +38,10 @@ const ( // Basepoint is the canonical Curve25519 generator. var Basepoint []byte -var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} +var basePoint = [32]byte{9} func init() { Basepoint = basePoint[:] } -func checkBasepoint() { - if subtle.ConstantTimeCompare(Basepoint, []byte{ - 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }) != 1 { - panic("curve25519: global Basepoint value was modified") - } -} - // X25519 returns the result of the scalar multiplication (scalar * point), // according to RFC 7748, Section 5. scalar, point and the return value are // slices of 32 bytes. @@ -121,26 +57,3 @@ func X25519(scalar, point []byte) ([]byte, error) { var dst [32]byte return x25519(&dst, scalar, point) } - -func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) { - var in [32]byte - if l := len(scalar); l != 32 { - return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32") - } - if l := len(point); l != 32 { - return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32") - } - copy(in[:], scalar) - if &point[0] == &Basepoint[0] { - checkBasepoint() - ScalarBaseMult(dst, &in) - } else { - var base, zero [32]byte - copy(base[:], point) - ScalarMult(dst, &in, &base) - if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 { - return nil, errors.New("bad input point: low order point") - } - } - return dst[:], nil -} diff --git a/curve25519/curve25519_compat.go b/curve25519/curve25519_compat.go new file mode 100644 index 0000000000..ba647e8d77 --- /dev/null +++ b/curve25519/curve25519_compat.go @@ -0,0 +1,105 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.20 + +package curve25519 + +import ( + "crypto/subtle" + "errors" + "strconv" + + "golang.org/x/crypto/curve25519/internal/field" +) + +func scalarMult(dst, scalar, point *[32]byte) { + var e [32]byte + + copy(e[:], scalar[:]) + e[0] &= 248 + e[31] &= 127 + e[31] |= 64 + + var x1, x2, z2, x3, z3, tmp0, tmp1 field.Element + x1.SetBytes(point[:]) + x2.One() + x3.Set(&x1) + z3.One() + + swap := 0 + for pos := 254; pos >= 0; pos-- { + b := e[pos/8] >> uint(pos&7) + b &= 1 + swap ^= int(b) + x2.Swap(&x3, swap) + z2.Swap(&z3, swap) + swap = int(b) + + tmp0.Subtract(&x3, &z3) + tmp1.Subtract(&x2, &z2) + x2.Add(&x2, &z2) + z2.Add(&x3, &z3) + z3.Multiply(&tmp0, &x2) + z2.Multiply(&z2, &tmp1) + tmp0.Square(&tmp1) + tmp1.Square(&x2) + x3.Add(&z3, &z2) + z2.Subtract(&z3, &z2) + x2.Multiply(&tmp1, &tmp0) + tmp1.Subtract(&tmp1, &tmp0) + z2.Square(&z2) + + z3.Mult32(&tmp1, 121666) + x3.Square(&x3) + tmp0.Add(&tmp0, &z3) + z3.Multiply(&x1, &z2) + z2.Multiply(&tmp1, &tmp0) + } + + x2.Swap(&x3, swap) + z2.Swap(&z3, swap) + + z2.Invert(&z2) + x2.Multiply(&x2, &z2) + copy(dst[:], x2.Bytes()) +} + +func scalarBaseMult(dst, scalar *[32]byte) { + checkBasepoint() + scalarMult(dst, scalar, &basePoint) +} + +func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) { + var in [32]byte + if l := len(scalar); l != 32 { + return nil, errors.New("bad scalar length: " + strconv.Itoa(l) + ", expected 32") + } + if l := len(point); l != 32 { + return nil, errors.New("bad point length: " + strconv.Itoa(l) + ", expected 32") + } + copy(in[:], scalar) + if &point[0] == &Basepoint[0] { + scalarBaseMult(dst, &in) + } else { + var base, zero [32]byte + copy(base[:], point) + scalarMult(dst, &in, &base) + if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 { + return nil, errors.New("bad input point: low order point") + } + } + return dst[:], nil +} + +func checkBasepoint() { + if subtle.ConstantTimeCompare(Basepoint, []byte{ + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }) != 1 { + panic("curve25519: global Basepoint value was modified") + } +} diff --git a/curve25519/curve25519_go120.go b/curve25519/curve25519_go120.go new file mode 100644 index 0000000000..627df49727 --- /dev/null +++ b/curve25519/curve25519_go120.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.20 + +package curve25519 + +import "crypto/ecdh" + +func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) { + curve := ecdh.X25519() + pub, err := curve.NewPublicKey(point) + if err != nil { + return nil, err + } + priv, err := curve.NewPrivateKey(scalar) + if err != nil { + return nil, err + } + out, err := priv.ECDH(pub) + if err != nil { + return nil, err + } + copy(dst[:], out) + return dst[:], nil +} + +func scalarMult(dst, scalar, point *[32]byte) { + if _, err := x25519(dst, scalar[:], point[:]); err != nil { + // The only error condition for x25519 when the inputs are 32 bytes long + // is if the output would have been the all-zero value. + for i := range dst { + dst[i] = 0 + } + } +} + +func scalarBaseMult(dst, scalar *[32]byte) { + curve := ecdh.X25519() + priv, err := curve.NewPrivateKey(scalar[:]) + if err != nil { + panic("curve25519: internal error: scalarBaseMult was not 32 bytes") + } + copy(dst[:], priv.PublicKey().Bytes()) +} diff --git a/curve25519/curve25519_test.go b/curve25519/curve25519_test.go index 5a315416f2..e2b338b5ec 100644 --- a/curve25519/curve25519_test.go +++ b/curve25519/curve25519_test.go @@ -2,13 +2,15 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package curve25519 +package curve25519_test import ( "bytes" "crypto/rand" "encoding/hex" "testing" + + "golang.org/x/crypto/curve25519" ) const expectedHex = "89161fde887b2b53de549af483940106ecc114d6982daa98256de23bdf77661a" @@ -19,7 +21,7 @@ func TestX25519Basepoint(t *testing.T) { for i := 0; i < 200; i++ { var err error - x, err = X25519(x, Basepoint) + x, err = curve25519.X25519(x, curve25519.Basepoint) if err != nil { t.Fatal(err) } @@ -32,12 +34,12 @@ func TestX25519Basepoint(t *testing.T) { } func TestLowOrderPoints(t *testing.T) { - scalar := make([]byte, ScalarSize) + scalar := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(scalar); err != nil { t.Fatal(err) } for i, p := range lowOrderPoints { - out, err := X25519(scalar, p) + out, err := curve25519.X25519(scalar, p) if err == nil { t.Errorf("%d: expected error, got nil", i) } @@ -48,10 +50,10 @@ func TestLowOrderPoints(t *testing.T) { } func TestTestVectors(t *testing.T) { - t.Run("Legacy", func(t *testing.T) { testTestVectors(t, ScalarMult) }) + t.Run("Legacy", func(t *testing.T) { testTestVectors(t, curve25519.ScalarMult) }) t.Run("X25519", func(t *testing.T) { testTestVectors(t, func(dst, scalar, point *[32]byte) { - out, err := X25519(scalar[:], point[:]) + out, err := curve25519.X25519(scalar[:], point[:]) if err != nil { t.Fatal(err) } @@ -88,10 +90,10 @@ func TestHighBitIgnored(t *testing.T) { var hi0, hi1 [32]byte u[31] &= 0x7f - ScalarMult(&hi0, &s, &u) + curve25519.ScalarMult(&hi0, &s, &u) u[31] |= 0x80 - ScalarMult(&hi1, &s, &u) + curve25519.ScalarMult(&hi1, &s, &u) if !bytes.Equal(hi0[:], hi1[:]) { t.Errorf("high bit of group point should not affect result") @@ -101,14 +103,14 @@ func TestHighBitIgnored(t *testing.T) { var benchmarkSink byte func BenchmarkX25519Basepoint(b *testing.B) { - scalar := make([]byte, ScalarSize) + scalar := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(scalar); err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - out, err := X25519(scalar, Basepoint) + out, err := curve25519.X25519(scalar, curve25519.Basepoint) if err != nil { b.Fatal(err) } @@ -117,11 +119,11 @@ func BenchmarkX25519Basepoint(b *testing.B) { } func BenchmarkX25519(b *testing.B) { - scalar := make([]byte, ScalarSize) + scalar := make([]byte, curve25519.ScalarSize) if _, err := rand.Read(scalar); err != nil { b.Fatal(err) } - point, err := X25519(scalar, Basepoint) + point, err := curve25519.X25519(scalar, curve25519.Basepoint) if err != nil { b.Fatal(err) } @@ -131,7 +133,7 @@ func BenchmarkX25519(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - out, err := X25519(scalar, point) + out, err := curve25519.X25519(scalar, point) if err != nil { b.Fatal(err) } diff --git a/curve25519/vectors_test.go b/curve25519/vectors_test.go index 946e9a8a3d..f4c0a1414f 100644 --- a/curve25519/vectors_test.go +++ b/curve25519/vectors_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package curve25519 +package curve25519_test // lowOrderPoints from libsodium. // https://github.com/jedisct1/libsodium/blob/65621a1059a37d/src/libsodium/crypto_scalarmult/curve25519/ref10/x25519_ref10.c#L11-L70