Skip to content

Commit

Permalink
refactor(ring): newring method takes a pool as an argument
Browse files Browse the repository at this point in the history
BREAKING CHANGE: the method ring.NewRing has now a mandatory argument of
type structs.BufferPool[*[]uint64]
  • Loading branch information
lehugueni committed Jan 20, 2025
1 parent d8213b6 commit 9074893
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 38 deletions.
1 change: 0 additions & 1 deletion core/rlwe/ciphertext.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,4 @@ func RecycleCiphertextInUintPool(params ParameterProvider, ct *Ciphertext) {
ringQ.RecyclePolyInUintPool(&ct.Value[i])
}
ct = nil
return
}
2 changes: 1 addition & 1 deletion examples/singleparty/bgv_vectorized_ole/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func newvOLErings(params parameters) *vOLErings {
panic(err)
}

if rings.ringQ, err = ring.NewRing(N, primes); err != nil {
if rings.ringQ, err = ring.NewRing(N, primes, nil); err != nil {
panic(err)
}

Expand Down
2 changes: 1 addition & 1 deletion ring/interpolation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Interpolator struct {
func NewInterpolator(degree int, T uint64) (itp *Interpolator, err error) {
itp = new(Interpolator)

if itp.r, err = NewRing(1<<bits.Len64(uint64(degree)), []uint64{T}); err != nil {
if itp.r, err = NewRing(1<<bits.Len64(uint64(degree)), []uint64{T}, nil); err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions ring/ntt_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func BenchmarkNTT(b *testing.B) {

func benchNTT(LogN, Qi int, b *testing.B) {
b.Run(fmt.Sprintf("Forward/N=%d/Qi=%d", 1<<LogN, Qi), func(b *testing.B) {
r, err := NewRing(1<<LogN, Qi60[:Qi])
r, err := NewRing(1<<LogN, Qi60[:Qi], nil)
if err != nil {
b.Fatal(err)
}
Expand All @@ -46,7 +46,7 @@ func benchNTT(LogN, Qi int, b *testing.B) {

func benchINTT(LogN, Qi int, b *testing.B) {
b.Run(fmt.Sprintf("Backward/N=%d/Qi=%d", 1<<LogN, Qi), func(b *testing.B) {
r, err := NewRing(1<<LogN, Qi60[:Qi])
r, err := NewRing(1<<LogN, Qi60[:Qi], nil)
if err != nil {
b.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion ring/ntt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestNTT(t *testing.T) {

for _, tv := range testVector[:] {

ringQ, err := NewRing(tv.N, tv.Qis)
ringQ, err := NewRing(tv.N, tv.Qis, nil)

if err != nil {
t.Fatal(err)
Expand Down
26 changes: 16 additions & 10 deletions ring/ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,10 @@ func (r Ring) BRedConstants() (BRC [][2]uint64) {

// NewRing creates a new RNS Ring with degree N and coefficient moduli Moduli with Standard NTT. N must be a power of two larger than 8. Moduli should be
// a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo 2*N.
// A pool implementing BufferPool[*[]uint64] will be stored in the returned Ring and will be used to efficiently instantiate large objects.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRing(N int, Moduli []uint64, pool ...structs.BufferPool[*[]uint64]) (r *Ring, err error) {
var bp structs.BufferPool[*[]uint64]
switch len(pool) {
case 0:
case 1:
bp = pool[0]
default:
return nil, fmt.Errorf("cannot create new ring: more than 1 buffer pools provided")
}
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N, bp)
func NewRing(N int, Moduli []uint64, pool structs.BufferPool[*[]uint64]) (r *Ring, err error) {
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N, pool)
}

// NewRingConjugateInvariant creates a new RNS Ring with degree N and coefficient moduli Moduli with Conjugate Invariant NTT. N must be a power of two larger than 8. Moduli should be
Expand Down Expand Up @@ -334,6 +327,19 @@ func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) N

r.bufferPool = pool

if r.bufferPool != nil { // Check that provided pool returns slices of length N
arr := r.bufferPool.Get()
if len(*arr) != N {
return nil, fmt.Errorf("invalid pool: pool must return []uint64 of length=%d != %d", N, len(*arr))
}
r.bufferPool.Put(arr)
} else { // If no pool provided: create one
r.bufferPool = structs.NewSyncPool(func() *[]uint64 {
arr := make([]uint64, N)
return &arr
})
}

return r, r.generateNTTConstants(nil, nil)
}

Expand Down
2 changes: 1 addition & 1 deletion ring/ring_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func benchGenRing(tc *testParams, b *testing.B) {

b.Run(testString("GenRing", tc.ringQ), func(b *testing.B) {
for i := 0; i < b.N; i++ {
if _, err := NewRing(tc.ringQ.N(), tc.ringQ.ModuliChain()); err != nil {
if _, err := NewRing(tc.ringQ.N(), tc.ringQ.ModuliChain(), tc.ringQ.bufferPool); err != nil {
b.Error(err)
}
}
Expand Down
47 changes: 29 additions & 18 deletions ring/ring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func testNTTConjugateInvariant(tc *testParams, t *testing.T) {
ringQ := tc.ringQ
Q := ringQ.ModuliChain()
N := ringQ.N()
ringQ2N, _ := NewRing(N<<1, Q)
ringQ2N, _ := NewRing(N<<1, Q, nil)
ringQConjugateInvariant, _ := NewRingFromType(N, Q, ConjugateInvariant, nil)

sampler := NewUniformSampler(tc.prng, ringQ)
Expand Down Expand Up @@ -131,42 +131,53 @@ func testNTTConjugateInvariant(tc *testParams, t *testing.T) {

func testNewRing(t *testing.T) {
t.Run("NewRing", func(t *testing.T) {
r, err := NewRing(0, nil)
r, err := NewRing(0, nil, nil)
require.Nil(t, r)
require.Error(t, err)

r, err = NewRing(0, []uint64{})
r, err = NewRing(0, []uint64{}, nil)
require.Nil(t, r)
require.Error(t, err)

r, err = NewRing(4, []uint64{})
r, err = NewRing(4, []uint64{}, nil)
require.Nil(t, r)
require.Error(t, err)

r, err = NewRing(8, []uint64{})
r, err = NewRing(8, []uint64{}, nil)
require.Nil(t, r)
require.Error(t, err)

r, err = NewRing(16, []uint64{7}) // Passing non NTT-enabling coeff modulus
require.NotNil(t, r) // Should still return a Ring instance
require.Error(t, err) // Should also return an error due to non NTT
r, err = NewRing(16, []uint64{7}, nil) // Passing non NTT-enabling coeff modulus
require.NotNil(t, r) // Should still return a Ring instance
require.Error(t, err) // Should also return an error due to non NTT

r, err = NewRing(16, []uint64{4}) // Passing non prime moduli
require.NotNil(t, r) // Should still return a Ring instance
require.Error(t, err) // Should also return an error due to non NTT
r, err = NewRing(16, []uint64{4}, nil) // Passing non prime moduli
require.NotNil(t, r) // Should still return a Ring instance
require.Error(t, err) // Should also return an error due to non NTT

r, err = NewRing(16, []uint64{97, 7}) // Passing a NTT-enabling and a non NTT-enabling coeff modulus
require.NotNil(t, r) // Should still return a Ring instance
require.Error(t, err) // Should also return an error due to non NTT
r, err = NewRing(16, []uint64{97, 7}, nil) // Passing a NTT-enabling and a non NTT-enabling coeff modulus
require.NotNil(t, r) // Should still return a Ring instance
require.Error(t, err) // Should also return an error due to non NTT

r, err = NewRing(16, []uint64{97, 97}) // Passing non CRT-enabling coeff modulus
require.Nil(t, r) // Should not return a Ring instance
r, err = NewRing(16, []uint64{97, 97}, nil) // Passing non CRT-enabling coeff modulus
require.Nil(t, r) // Should not return a Ring instance
require.Error(t, err)

r, err = NewRing(16, []uint64{97}) // Passing NTT-enabling coeff modulus
r, err = NewRing(16, []uint64{97}, nil) // Passing NTT-enabling coeff modulus
require.NotNil(t, r)
require.NoError(t, err)

pool := structs.NewSyncPool(func() *[]uint64 {
arr := make([]uint64, 16)
return &arr
})

r, err = NewRing(16, []uint64{97}, pool) // Passing NTT-enabling coeff modulus
require.NotNil(t, r)
require.NoError(t, err)
r, err = NewRing(32, []uint64{97}, pool) // Passing NTT-enabling coeff modulus
require.Nil(t, r)
require.Error(t, err)
})
}

Expand Down Expand Up @@ -907,7 +918,7 @@ func testMultByMonomial(tc *testParams, t *testing.T) {

func testShift(t *testing.T) {

r, _ := NewRing(16, []uint64{97})
r, _ := NewRing(16, []uint64{97}, nil)
p1, p2 := r.NewPoly(), r.NewPoly()

for i := range p1.Coeffs[0] {
Expand Down
4 changes: 2 additions & 2 deletions ring/ringqp/ring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import (

func TestRingQP(t *testing.T) {
LogN := 10
ringQ, err := ring.NewRing(1<<LogN, ring.Qi60[:4])
ringQ, err := ring.NewRing(1<<LogN, ring.Qi60[:4], nil)
require.NoError(t, err)

ringP, err := ring.NewRing(1<<LogN, ring.Pi60[:4])
ringP, err := ring.NewRing(1<<LogN, ring.Pi60[:4], nil)
require.NoError(t, err)

ringQP := Ring{ringQ, ringP}
Expand Down
2 changes: 1 addition & 1 deletion schemes/bgv/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func NewParameters(rlweParams rlwe.Parameters, t uint64) (p Parameters, err erro
if err != nil {
return Parameters{}, err
}
// One can reuse the pool from rlweParams as the ring dimension N is the same
// One can reuse the pool from rlweParams.ringQ as the ring dimension N is the same
poolQMul := rlweParams.RingQ().BufferPool()
if ringQMul, err = ring.NewRing(rlweParams.N(), primes, poolQMul); err != nil {
return Parameters{}, err
Expand Down

0 comments on commit 9074893

Please sign in to comment.