From e6120766ec338bbe0a8dab5dd48653067670a3ba Mon Sep 17 00:00:00 2001 From: armfazh Date: Thu, 28 Jul 2022 17:53:41 -0700 Subject: [PATCH] Adds conditional move and select to group. --- go.mod | 6 ++-- go.sum | 12 ++++---- group/group.go | 4 +++ group/group_test.go | 65 +++++++++++++++++++++++++++++++++++++++++++ group/ristretto255.go | 24 ++++++++++++++++ group/short.go | 60 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 162 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 37829514d..f9de51b5e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/cloudflare/circl go 1.16 require ( - github.com/bwesterb/go-ristretto v1.2.1 - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d - golang.org/x/sys v0.0.0-20220624220833-87e55d714810 + github.com/bwesterb/go-ristretto v1.2.2 + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa + golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 ) diff --git a/go.sum b/go.sum index cb1c6f8bd..c04d92360 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,13 @@ -github.com/bwesterb/go-ristretto v1.2.1 h1:Xd9ZXmjKE2aY8Ub7+4bX7tXsIPsV1pIZaUlJUjI1toE= -github.com/bwesterb/go-ristretto v1.2.1/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +github.com/bwesterb/go-ristretto v1.2.2 h1:S2C0mmSjCLS3H9+zfXoIoKzl+cOncvBvt6pE+zTm5Ms= +github.com/bwesterb/go-ristretto v1.2.2/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220624220833-87e55d714810 h1:rHZQSjJdAI4Xf5Qzeh2bBc5YJIkPFVM6oDtMFYmgws0= -golang.org/x/sys v0.0.0-20220624220833-87e55d714810/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/group/group.go b/group/group.go index 8af94cbef..9de399c34 100644 --- a/group/group.go +++ b/group/group.go @@ -36,6 +36,8 @@ type Element interface { Copy() Element IsIdentity() bool IsEqual(Element) bool + CMov(int, Element) Element + CSelect(int, Element, Element) Element Add(Element, Element) Element Dbl(Element) Element Neg(Element) Element @@ -53,6 +55,8 @@ type Scalar interface { Copy() Scalar IsEqual(Scalar) bool SetUint64(uint64) + CMov(int, Scalar) Scalar + CSelect(int, Scalar, Scalar) Scalar Add(Scalar, Scalar) Scalar Sub(Scalar, Scalar) Scalar Mul(Scalar, Scalar) Scalar diff --git a/group/group_test.go b/group/group_test.go index 76a0b1f85..de91cfebe 100644 --- a/group/group_test.go +++ b/group/group_test.go @@ -26,6 +26,8 @@ func TestGroup(t *testing.T) { t.Run(n+"/Neg", func(tt *testing.T) { testNeg(tt, testTimes, g) }) t.Run(n+"/Mul", func(tt *testing.T) { testMul(tt, testTimes, g) }) t.Run(n+"/MulGen", func(tt *testing.T) { testMulGen(tt, testTimes, g) }) + t.Run(n+"/CMov", func(tt *testing.T) { testCMov(tt, testTimes, g) }) + t.Run(n+"/CSelect", func(tt *testing.T) { testCSelect(tt, testTimes, g) }) t.Run(n+"/Order", func(tt *testing.T) { testOrder(tt, testTimes, g) }) t.Run(n+"/Marshal", func(tt *testing.T) { testMarshal(tt, testTimes, g) }) t.Run(n+"/Scalar", func(tt *testing.T) { testScalar(tt, testTimes, g) }) @@ -101,6 +103,45 @@ func testMulGen(t *testing.T, testTimes int, g group.Group) { } } +func testCMov(t *testing.T, testTimes int, g group.Group) { + for i := 0; i < testTimes; i++ { + P := g.RandomElement(rand.Reader) + Q := g.RandomElement(rand.Reader) + + want := P.Copy() + got := P.CMov(0, Q) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + + want = Q.Copy() + got = P.CMov(1, Q) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + } +} + +func testCSelect(t *testing.T, testTimes int, g group.Group) { + for i := 0; i < testTimes; i++ { + P := g.RandomElement(rand.Reader) + Q := g.RandomElement(rand.Reader) + R := g.RandomElement(rand.Reader) + + want := R.Copy() + got := P.CSelect(0, Q, R) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + + want = Q.Copy() + got = P.CSelect(1, Q, R) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + } +} + func testOrder(t *testing.T, testTimes int, g group.Group) { Q := g.NewElement() order := g.Order() @@ -207,6 +248,30 @@ func testScalar(t *testing.T, testTimes int, g group.Group) { if l := uint(len(enc1)); l != params.ScalarLength { test.ReportError(t, l, params.ScalarLength) } + + want := c.Copy() + got := c.CMov(0, a) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + + want = b.Copy() + got = d.CMov(1, b) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + + want = b.Copy() + got = e.CSelect(0, a, b) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } + + want = a.Copy() + got = f.CSelect(1, a, b) + if !got.IsEqual(want) { + test.ReportError(t, got, want) + } } a := g.RandomScalar(rand.Reader) diff --git a/group/ristretto255.go b/group/ristretto255.go index c72f4e9da..5b640bd22 100644 --- a/group/ristretto255.go +++ b/group/ristretto255.go @@ -128,6 +128,8 @@ func (g ristrettoGroup) HashToScalar(msg, dst []byte) Scalar { func (e *ristrettoElement) Group() Group { return Ristretto255 } +func (e *ristrettoElement) String() string { return fmt.Sprintf("%x", e.p.Bytes()) } + func (e *ristrettoElement) IsIdentity() bool { var zero r255.Point zero.SetZero() @@ -147,6 +149,17 @@ func (e *ristrettoElement) Copy() Element { return &ristrettoElement{*new(r255.Point).Set(&e.p)} } +func (e *ristrettoElement) CMov(v int, x Element) Element { + e.p.ConditionalSet(&x.(*ristrettoElement).p, int32(v)) + return e +} + +func (e *ristrettoElement) CSelect(v int, x Element, y Element) Element { + e.p.ConditionalSet(&x.(*ristrettoElement).p, int32(v)) + e.p.ConditionalSet(&y.(*ristrettoElement).p, int32(1-v)) + return e +} + func (e *ristrettoElement) Add(x Element, y Element) Element { e.p.Add(&x.(*ristrettoElement).p, &y.(*ristrettoElement).p) return e @@ -200,6 +213,17 @@ func (s *ristrettoScalar) Copy() Scalar { return &ristrettoScalar{*new(r255.Scalar).Set(&s.s)} } +func (s *ristrettoScalar) CMov(v int, x Scalar) Scalar { + s.s.ConditionalSet(&x.(*ristrettoScalar).s, int32(v)) + return s +} + +func (s *ristrettoScalar) CSelect(v int, x Scalar, y Scalar) Scalar { + s.s.ConditionalSet(&x.(*ristrettoScalar).s, int32(v)) + s.s.ConditionalSet(&y.(*ristrettoScalar).s, int32(1-v)) + return s +} + func (s *ristrettoScalar) Add(x Scalar, y Scalar) Scalar { s.s.Add(&x.(*ristrettoScalar).s, &y.(*ristrettoScalar).s) return s diff --git a/group/short.go b/group/short.go index 4fe886ed9..edb72c964 100644 --- a/group/short.go +++ b/group/short.go @@ -141,6 +141,51 @@ func (e *wElt) Set(a Element) Element { } func (e *wElt) Copy() Element { return e.wG.zeroElement().Set(e) } + +func (e *wElt) CMov(v int, a Element) Element { + aa := e.cvtElt(a) + l := (e.wG.c.Params().BitSize + 7) / 8 + bufE := make([]byte, l) + bufA := make([]byte, l) + e.x.FillBytes(bufE) + aa.x.FillBytes(bufA) + subtle.ConstantTimeCopy(v, bufE, bufA) + e.x.SetBytes(bufE) + + e.y.FillBytes(bufE) + aa.y.FillBytes(bufA) + subtle.ConstantTimeCopy(v, bufE, bufA) + e.y.SetBytes(bufE) + + return e +} + +func (e *wElt) CSelect(v int, a Element, b Element) Element { + aa, bb := e.cvtElt(a), e.cvtElt(b) + l := (e.wG.c.Params().BitSize + 7) / 8 + bufE := make([]byte, l) + bufA := make([]byte, l) + bufB := make([]byte, l) + + e.x.FillBytes(bufE) + aa.x.FillBytes(bufA) + bb.x.FillBytes(bufB) + for i := range bufE { + bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i]))) + } + e.x.SetBytes(bufE) + + e.y.FillBytes(bufE) + aa.y.FillBytes(bufA) + bb.y.FillBytes(bufB) + for i := range bufE { + bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i]))) + } + e.y.SetBytes(bufE) + + return e +} + func (e *wElt) Add(a, b Element) Element { aa, bb := e.cvtElt(a), e.cvtElt(b) e.x, e.y = e.c.Add(aa.x, aa.y, bb.x, bb.y) @@ -244,6 +289,21 @@ func (s *wScl) Set(a Scalar) Scalar { } func (s *wScl) Copy() Scalar { return s.wG.zeroScalar().Set(s) } + +func (s *wScl) CMov(v int, a Scalar) Scalar { + aa := s.cvtScl(a) + subtle.ConstantTimeCopy(v, s.k, aa.k) + return s +} + +func (s *wScl) CSelect(v int, a Scalar, b Scalar) Scalar { + aa, bb := s.cvtScl(a), s.cvtScl(b) + for i := range s.k { + s.k[i] = byte(subtle.ConstantTimeSelect(v, int(aa.k[i]), int(bb.k[i]))) + } + return s +} + func (s *wScl) Add(a, b Scalar) Scalar { aa, bb := s.cvtScl(a), s.cvtScl(b) r := new(big.Int)