Skip to content

Commit

Permalink
bufferpool uint64 in ring
Browse files Browse the repository at this point in the history
  • Loading branch information
lehugueni committed Dec 12, 2024
1 parent 5e6a37a commit 77c62e6
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 66 deletions.
12 changes: 7 additions & 5 deletions core/rlwe/ciphertext.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/tuneinsight/lattigo/v6/ring"
"github.com/tuneinsight/lattigo/v6/utils/sampling"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)

// Ciphertext is a generic type for RLWE ciphertexts.
Expand Down Expand Up @@ -71,12 +70,14 @@ func (ct Ciphertext) Equal(other *Ciphertext) bool {
return ct.Element.Equal(&other.Element)
}

func NewCiphertextFromUintPool(pool structs.BufferPool[*[]uint64], params ParameterProvider, degree int, level int) *Ciphertext {
func NewCiphertextFromUintPool(params ParameterProvider, degree int, levelQ int) *Ciphertext {
p := params.GetRLWEParameters()

ringQ := p.RingQ().AtLevel(levelQ)

Value := make([]ring.Poly, degree+1)
for i := range Value {
Value[i] = *ring.NewPolyFromUintPool(pool, p.N(), level)
Value[i] = *ringQ.NewPolyFromUintPool()
}

el := Element[ring.Poly]{
Expand All @@ -90,9 +91,10 @@ func NewCiphertextFromUintPool(pool structs.BufferPool[*[]uint64], params Parame
return &Ciphertext{el}
}

func RecycleCiphertextInUintPool(pool structs.BufferPool[*[]uint64], ct *Ciphertext) {
func RecycleCiphertextInUintPool(params ParameterProvider, ct *Ciphertext) {
ringQ := params.GetRLWEParameters().ringQ
for i := range ct.Value {
ring.RecyclePolyInUintPool(pool, &ct.Value[i])
ringQ.RecyclePolyInUintPool(&ct.Value[i])
}
ct = nil
return

Check failure on line 100 in core/rlwe/ciphertext.go

View workflow job for this annotation

GitHub Actions / Run static checks

redundant return statement (S1023)
Expand Down
38 changes: 17 additions & 21 deletions core/rlwe/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,33 @@ func newBuffer[T any](f func() T) structs.BufferPool[T] {
func NewEvaluatorBuffersWithUintPool(params Parameters) *EvaluatorBuffers {
buff := new(EvaluatorBuffers)
ringQP := params.RingQP()
ringQ := params.ringQ

buffUint := newBuffer(func() *[]uint64 {
buff := make([]uint64, params.RingQ().N())
return &buff
})

buff.BuffQPPool = structs.NewBuffFromUintPool(buffUint,
func(bp structs.BufferPool[*[]uint64]) *ringqp.Poly {
return ringQP.NewPolyQPFromUintPool(bp)
buff.BuffQPPool = structs.NewBuffFromUintPool(
func() *ringqp.Poly {
return ringQP.NewPolyQPFromUintPool()
},
func(bp structs.BufferPool[*[]uint64], poly *ringqp.Poly) {
ringqp.RecyclePolyQPFromUintPool(bp, poly)
func(poly *ringqp.Poly) {
ringQP.RecyclePolyQPFromUintPool(poly)
},
)
buff.BuffQPool = structs.NewBuffFromUintPool(buffUint,
func(bp structs.BufferPool[*[]uint64]) *ring.Poly {
return ring.NewPolyFromUintPool(bp, params.ringQ.N(), params.ringQ.Level())
buff.BuffQPool = structs.NewBuffFromUintPool(
func() *ring.Poly {
return ringQ.NewPolyFromUintPool()
},
func(bp structs.BufferPool[*[]uint64], poly *ring.Poly) {
ring.RecyclePolyInUintPool(bp, poly)
func(poly *ring.Poly) {
ringQ.RecyclePolyInUintPool(poly)
},
)
buff.BuffCtPool = structs.NewBuffFromUintPool(buffUint,
func(bp structs.BufferPool[*[]uint64]) *Ciphertext {
return NewCiphertextFromUintPool(bp, params, 2, params.MaxLevel())
buff.BuffCtPool = structs.NewBuffFromUintPool(
func() *Ciphertext {
return NewCiphertextFromUintPool(params, 2, params.MaxLevel())
},
func(bp structs.BufferPool[*[]uint64], ct *Ciphertext) {
RecycleCiphertextInUintPool(bp, ct)
func(ct *Ciphertext) {
RecycleCiphertextInUintPool(params, ct)
},
)
buff.BuffBitPool = buffUint
buff.BuffBitPool = ringQ.BufferPool()
return buff
}

Expand Down
10 changes: 8 additions & 2 deletions core/rlwe/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/tuneinsight/lattigo/v6/ring/ringqp"
"github.com/tuneinsight/lattigo/v6/utils"
"github.com/tuneinsight/lattigo/v6/utils/buffer"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)

// MaxLogN is the log2 of the largest supported polynomial modulus degree.
Expand Down Expand Up @@ -855,11 +856,16 @@ func GenModuli(LogNthRoot int, logQ, logP []int) (q, p []uint64, err error) {
}

func (p *Parameters) initRings() (err error) {
if p.ringQ, err = ring.NewRingFromType(1<<p.logN, p.qi, p.ringType); err != nil {
N := 1 << p.logN
buffPool := structs.NewSyncPool(func() *[]uint64 {
buff := make([]uint64, N)
return &buff
})
if p.ringQ, err = ring.NewRingFromType(1<<p.logN, p.qi, p.ringType, buffPool); err != nil {
return fmt.Errorf("initRings/ringQ: %w", err)
}
if len(p.pi) != 0 {
if p.ringP, err = ring.NewRingFromType(1<<p.logN, p.pi, p.ringType); err != nil {
if p.ringP, err = ring.NewRingFromType(1<<p.logN, p.pi, p.ringType, buffPool); err != nil {
return fmt.Errorf("initRings/ringP: %w", err)
}
}
Expand Down
24 changes: 16 additions & 8 deletions ring/basis_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,22 @@ func NewBasisExtender(ringQ, ringP *Ring) (be *BasisExtender) {
be.modDownConstantsPtoQ = genmodDownConstants(ringQ, ringP)
be.modDownConstantsQtoP = genmodDownConstants(ringP, ringQ)

be.buffQPool = structs.NewSyncPool(func() *Poly {
polyQ := ringQ.NewPoly()
return &polyQ
})
be.buffPPool = structs.NewSyncPool(func() *Poly {
polyP := ringP.NewPoly()
return &polyP
})
be.buffQPool = structs.NewBuffFromUintPool(
func() *Poly {
return ringQ.NewPolyFromUintPool()
},
func(poly *Poly) {
ringQ.RecyclePolyInUintPool(poly)
},
)
be.buffPPool = structs.NewBuffFromUintPool(
func() *Poly {
return ringP.NewPolyFromUintPool()
},
func(poly *Poly) {
ringP.RecyclePolyInUintPool(poly)
},
)

return
}
Expand Down
2 changes: 1 addition & 1 deletion ring/poly.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type Poly struct {
Coeffs structs.Matrix[uint64]
}

func NewPolyFromUintPool(pool structs.BufferPool[*[]uint64], N, level int) (pol *Poly) {
func NewPolyFromUintPool(pool structs.BufferPool[*[]uint64], level int) (pol *Poly) {
coeffs := make([][]uint64, level+1)
for i := range coeffs {
coeffs[i] = *pool.Get()
Expand Down
51 changes: 43 additions & 8 deletions ring/ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/tuneinsight/lattigo/v6/utils"
"github.com/tuneinsight/lattigo/v6/utils/bignum"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)

const (
Expand Down Expand Up @@ -78,6 +79,8 @@ type Ring struct {
RescaleConstants [][]uint64

level int

bufferPool structs.BufferPool[*[]uint64]
}

// ConjugateInvariantRing returns the conjugate invariant ring of the receiver ring.
Expand Down Expand Up @@ -173,6 +176,11 @@ func (r Ring) Level() int {
return r.level
}

// BufferPool returns the pool of *[]uint64
func (r Ring) BufferPool() structs.BufferPool[*[]uint64] {
return r.bufferPool
}

// AtLevel returns an instance of the target ring that operates at the target level.
// This instance is thread safe and can be use concurrently with the base ring.
func (r Ring) AtLevel(level int) *Ring {
Expand All @@ -192,6 +200,7 @@ func (r Ring) AtLevel(level int) *Ring {
ModulusAtLevel: r.ModulusAtLevel,
RescaleConstants: r.RescaleConstants,
level: level,
bufferPool: r.bufferPool,
}
}

Expand Down Expand Up @@ -241,28 +250,44 @@ func (r Ring) BRedConstants() (BRC [][2]uint64) {
// NewRing creates a new RNS Ring with degree N and coefficient moduli Moduli with Standard NTT. N must be a power of two larger than 8. Moduli should be
// a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo 2*N.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRing(N int, Moduli []uint64) (r *Ring, err error) {
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N)
func NewRing(N int, Moduli []uint64, pool ...structs.BufferPool[*[]uint64]) (r *Ring, err error) {
var bp structs.BufferPool[*[]uint64]
switch len(pool) {
case 0:
case 1:
bp = pool[0]
default:
return nil, fmt.Errorf("cannot create new ring: more than 1 buffer pools provided")
}
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N, bp)
}

// NewRingConjugateInvariant creates a new RNS Ring with degree N and coefficient moduli Moduli with Conjugate Invariant NTT. N must be a power of two larger than 8. Moduli should be
// a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo 4*N.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRingConjugateInvariant(N int, Moduli []uint64) (r *Ring, err error) {
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N)
func NewRingConjugateInvariant(N int, Moduli []uint64, pool ...structs.BufferPool[*[]uint64]) (r *Ring, err error) {
var bp structs.BufferPool[*[]uint64]
switch len(pool) {
case 0:
case 1:
bp = pool[0]
default:
return nil, fmt.Errorf("cannot create new ring: more than 1 buffer pools provided")
}
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N, bp)
}

// NewRingFromType creates a new RNS Ring with degree N and coefficient moduli Moduli for which the type of NTT is determined by the ringType argument.
// If ringType==Standard, the ring is instantiated with standard NTT with the Nth root of unity 2*N. If ringType==ConjugateInvariant, the ring
// is instantiated with a ConjugateInvariant NTT with Nth root of unity 4*N. N must be a power of two larger than 8.
// Moduli should be a non-empty []uint64 with distinct prime elements. All moduli must also be equal to 1 modulo the root of unity.
// An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRingFromType(N int, Moduli []uint64, ringType Type) (r *Ring, err error) {
func NewRingFromType(N int, Moduli []uint64, ringType Type, pool structs.BufferPool[*[]uint64]) (r *Ring, err error) {
switch ringType {
case Standard:
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N)
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerStandard, 2*N, pool)
case ConjugateInvariant:
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N)
return NewRingWithCustomNTT(N, Moduli, NewNumberTheoreticTransformerConjugateInvariant, 4*N, pool)
default:
return nil, fmt.Errorf("invalid ring type")
}
Expand All @@ -272,7 +297,7 @@ func NewRingFromType(N int, Moduli []uint64, ringType Type) (r *Ring, err error)
// ModuliChain should be a non-empty []uint64 with distinct prime elements.
// All moduli must also be equal to 1 modulo the root of unity.
// N must be a power of two larger than 8. An error is returned with a nil *Ring in the case of non NTT-enabling parameters.
func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) NumberTheoreticTransformer, NthRoot int) (r *Ring, err error) {
func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) NumberTheoreticTransformer, NthRoot int, pool structs.BufferPool[*[]uint64]) (r *Ring, err error) {
r = new(Ring)

// Checks if N is a power of 2
Expand Down Expand Up @@ -307,6 +332,8 @@ func NewRingWithCustomNTT(N int, ModuliChain []uint64, ntt func(*SubRing, int) N

r.level = len(ModuliChain) - 1

r.bufferPool = pool

return r, r.generateNTTConstants(nil, nil)
}

Expand Down Expand Up @@ -359,6 +386,14 @@ func (r Ring) NewPoly() Poly {
return NewPoly(r.N(), r.level)
}

func (r Ring) NewPolyFromUintPool() (p *Poly) {
return NewPolyFromUintPool(r.bufferPool, r.level)
}

func (r Ring) RecyclePolyInUintPool(pol *Poly) {
RecyclePolyInUintPool(r.bufferPool, pol)
}

// NewMonomialXi returns a polynomial X^{i}.
func (r Ring) NewMonomialXi(i int) (p Poly) {

Expand Down
10 changes: 7 additions & 3 deletions ring/ring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ func genTestParams(defaultParams Parameters) (tc *testParams, err error) {

tc = new(testParams)

if tc.ringQ, err = NewRing(1<<defaultParams.logN, defaultParams.qi); err != nil {
pool := structs.NewSyncPool(func() *[]uint64 {
buff := make([]uint64, 1<<defaultParams.logN)
return &buff
})
if tc.ringQ, err = NewRing(1<<defaultParams.logN, defaultParams.qi, pool); err != nil {
return nil, err
}
if tc.ringP, err = NewRing(1<<defaultParams.logN, defaultParams.pi); err != nil {
if tc.ringP, err = NewRing(1<<defaultParams.logN, defaultParams.pi, pool); err != nil {
return nil, err
}
if tc.prng, err = sampling.NewPRNG(); err != nil {
Expand Down Expand Up @@ -90,7 +94,7 @@ func testNTTConjugateInvariant(tc *testParams, t *testing.T) {
Q := ringQ.ModuliChain()
N := ringQ.N()
ringQ2N, _ := NewRing(N<<1, Q)
ringQConjugateInvariant, _ := NewRingFromType(N, Q, ConjugateInvariant)
ringQConjugateInvariant, _ := NewRingFromType(N, Q, ConjugateInvariant, nil)

sampler := NewUniformSampler(tc.prng, ringQ)
p1 := sampler.ReadNew()
Expand Down
14 changes: 7 additions & 7 deletions ring/ringqp/ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/tuneinsight/lattigo/v6/ring"
"github.com/tuneinsight/lattigo/v6/utils/bignum"
"github.com/tuneinsight/lattigo/v6/utils/structs"
)

// Ring is a structure that implements the operation in the ring R_QP.
Expand Down Expand Up @@ -221,18 +220,19 @@ func (r Ring) NewPoly() Poly {
return Poly{Q, P}
}

func (r Ring) NewPolyQPFromUintPool(pool structs.BufferPool[*[]uint64]) *Poly {
// NewPolyQPFromUintPool creates a new polynomial using the *[]uint64 BufferPool for backing arrays.
func (r Ring) NewPolyQPFromUintPool() *Poly {
var Q, P *ring.Poly
if r.RingQ != nil {
Q = ring.NewPolyFromUintPool(pool, r.RingQ.N(), r.RingQ.Level())
Q = r.RingQ.NewPolyFromUintPool()
}
if r.RingP != nil {
P = ring.NewPolyFromUintPool(pool, r.RingP.N(), r.RingP.Level())
P = r.RingP.NewPolyFromUintPool()
}
return &Poly{*Q, *P}
}

func RecyclePolyQPFromUintPool(pool structs.BufferPool[*[]uint64], poly *Poly) {
ring.RecyclePolyInUintPool(pool, &poly.Q)
ring.RecyclePolyInUintPool(pool, &poly.P)
func (r Ring) RecyclePolyQPFromUintPool(poly *Poly) {
r.RingQ.RecyclePolyInUintPool(&poly.Q)
r.RingP.RecyclePolyInUintPool(&poly.P)
}
15 changes: 11 additions & 4 deletions schemes/ckks/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,17 @@ func NewEncoder(parameters Parameters, precision ...uint) (ecd *Encoder) {
buff := make([]*big.Int, m>>1)
return &buff
})
ecd.BuffPolyPool = structs.NewSyncPool(func() *ring.Poly {
poly := parameters.RingQ().NewPoly()
return &poly
})

ringQ := parameters.RingQ()

ecd.BuffPolyPool = structs.NewBuffFromUintPool(
func() *ring.Poly {
return ringQ.NewPolyFromUintPool()
},
func(poly *ring.Poly) {
ringQ.RecyclePolyInUintPool(poly)
},
)

if prec <= 53 {

Expand Down
12 changes: 5 additions & 7 deletions utils/structs/concurrent_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,23 @@ func (spool *SyncPool[T]) Put(buff T) {
}

type BuffFromUintPool[T any] struct {
uintPool BufferPool[*[]uint64] // Pool that must contain *[]uint64 objects
createObject func(BufferPool[*[]uint64]) T
recycleObject func(BufferPool[*[]uint64], T)
createObject func() T
recycleObject func(T)
}

func NewBuffFromUintPool[T any](pool BufferPool[*[]uint64], create func(BufferPool[*[]uint64]) T, recycle func(BufferPool[*[]uint64], T)) *BuffFromUintPool[T] {
func NewBuffFromUintPool[T any](create func() T, recycle func(T)) *BuffFromUintPool[T] {
return &BuffFromUintPool[T]{
uintPool: pool,
createObject: create,
recycleObject: recycle,
}
}

func (bu *BuffFromUintPool[T]) Get() T {
return bu.createObject(bu.uintPool)
return bu.createObject()
}

func (bu *BuffFromUintPool[T]) Put(obj T) {
bu.recycleObject(bu.uintPool, obj)
bu.recycleObject(obj)
}

type FreeList[T any] struct {
Expand Down

0 comments on commit 77c62e6

Please sign in to comment.