From d7af027c2032a5da69c3962cc0b60fa7b8d5c92d Mon Sep 17 00:00:00 2001 From: Leon Hibnik <107353745+LeonHibnik@users.noreply.github.com> Date: Wed, 1 Jan 2025 11:42:13 +0200 Subject: [PATCH] [FEAT] pow field Arithmetic API (#713) --- .../include/icicle/fields/complex_extension.h | 11 ++++++++++ .../include/icicle/fields/quartic_extension.h | 11 ++++++++++ icicle/src/fields/ffi_extern.cpp | 10 ++++++++++ .../curves/bls12377/g2/include/scalar_field.h | 1 + .../curves/bls12377/include/scalar_field.h | 1 + .../golang/curves/bls12377/scalar_field.go | 12 +++++++++++ .../bls12377/tests/scalar_field_test.go | 5 +++++ .../curves/bls12381/g2/include/scalar_field.h | 1 + .../curves/bls12381/include/scalar_field.h | 1 + .../golang/curves/bls12381/scalar_field.go | 12 +++++++++++ .../bls12381/tests/scalar_field_test.go | 5 +++++ .../curves/bn254/g2/include/scalar_field.h | 1 + .../curves/bn254/include/scalar_field.h | 1 + wrappers/golang/curves/bn254/scalar_field.go | 12 +++++++++++ .../curves/bn254/tests/scalar_field_test.go | 5 +++++ .../curves/bw6761/g2/include/scalar_field.h | 1 + .../curves/bw6761/include/scalar_field.h | 1 + wrappers/golang/curves/bw6761/scalar_field.go | 12 +++++++++++ .../curves/bw6761/tests/scalar_field_test.go | 5 +++++ .../curves/grumpkin/include/scalar_field.h | 1 + .../golang/curves/grumpkin/scalar_field.go | 12 +++++++++++ .../grumpkin/tests/scalar_field_test.go | 5 +++++ .../babybear/extension/extension_field.go | 12 +++++++++++ .../babybear/extension/include/scalar_field.h | 1 + .../fields/babybear/include/scalar_field.h | 1 + .../golang/fields/babybear/scalar_field.go | 12 +++++++++++ .../babybear/tests/extension_field_test.go | 5 +++++ .../babybear/tests/scalar_field_test.go | 5 +++++ .../generator/fields/templates/field.go.tmpl | 12 +++++++++++ .../fields/templates/field_test.go.tmpl | 5 +++++ .../fields/templates/scalar_field.h.tmpl | 1 + wrappers/rust/icicle-core/src/field.rs | 20 +++++++++++++++++++ wrappers/rust/icicle-core/src/tests.rs | 13 ++++++++---- wrappers/rust/icicle-core/src/traits.rs | 1 + 34 files changed, 210 insertions(+), 4 deletions(-) diff --git a/icicle/include/icicle/fields/complex_extension.h b/icicle/include/icicle/fields/complex_extension.h index a6d116a3a..42740859f 100644 --- a/icicle/include/icicle/fields/complex_extension.h +++ b/icicle/include/icicle/fields/complex_extension.h @@ -210,6 +210,17 @@ class ComplexExtensionField FF xs_norm_squared = FF::sqr(xs.real) - nonresidue_times_im; return xs_conjugate * ComplexExtensionField{FF::inverse(xs_norm_squared), FF::zero()}; } + + static constexpr HOST_DEVICE ComplexExtensionField pow(ComplexExtensionField base, int exp) + { + ComplexExtensionField res = one(); + while (exp > 0) { + if (exp & 1) res = res * base; + base = base * base; + exp >>= 1; + } + return res; + } }; #ifdef __CUDACC__ diff --git a/icicle/include/icicle/fields/quartic_extension.h b/icicle/include/icicle/fields/quartic_extension.h index 43038d588..784298cac 100644 --- a/icicle/include/icicle/fields/quartic_extension.h +++ b/icicle/include/icicle/fields/quartic_extension.h @@ -251,6 +251,17 @@ class QuarticExtensionField FF::reduce(FF::mul_wide(xs.im1, x2) - FF::mul_wide(xs.im3, x0)), }; } + + static constexpr HOST_DEVICE QuarticExtensionField pow(QuarticExtensionField base, int exp) + { + QuarticExtensionField res = one(); + while (exp > 0) { + if (exp & 1) res = res * base; + base = base * base; + exp >>= 1; + } + return res; + } }; #if __CUDACC__ template diff --git a/icicle/src/fields/ffi_extern.cpp b/icicle/src/fields/ffi_extern.cpp index 12a0ce146..605bb7ea0 100644 --- a/icicle/src/fields/ffi_extern.cpp +++ b/icicle/src/fields/ffi_extern.cpp @@ -28,6 +28,11 @@ extern "C" void CONCAT_EXPAND(FIELD, inv)(scalar_t* scalar1, scalar_t* result) *result = scalar_t::inverse(*scalar1); } +extern "C" void CONCAT_EXPAND(FIELD, pow)(scalar_t* base, int exp, scalar_t* result) +{ + *result = scalar_t::pow(*base, exp); +} + #ifdef EXT_FIELD extern "C" void CONCAT_EXPAND(FIELD, extension_generate_scalars)(extension_t* scalars, int size) { @@ -53,4 +58,9 @@ extern "C" void CONCAT_EXPAND(FIELD, extension_inv)(extension_t* scalar1, extens { *result = extension_t::inverse(*scalar1); } + +extern "C" void CONCAT_EXPAND(FIELD, extension_pow)(extension_t* base, int exp, extension_t* result) +{ + *result = extension_t::pow(*base, exp); +} #endif // EXT_FIELD diff --git a/wrappers/golang/curves/bls12377/g2/include/scalar_field.h b/wrappers/golang/curves/bls12377/g2/include/scalar_field.h index 0fe8729a0..a775b53ff 100644 --- a/wrappers/golang/curves/bls12377/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bls12377/g2/include/scalar_field.h @@ -16,6 +16,7 @@ void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_377_inv(const scalar_t* a, scalar_t* result); +void bls12_377_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12377/include/scalar_field.h b/wrappers/golang/curves/bls12377/include/scalar_field.h index 0fe8729a0..a775b53ff 100644 --- a/wrappers/golang/curves/bls12377/include/scalar_field.h +++ b/wrappers/golang/curves/bls12377/include/scalar_field.h @@ -16,6 +16,7 @@ void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_377_inv(const scalar_t* a, scalar_t* result); +void bls12_377_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12377/scalar_field.go b/wrappers/golang/curves/bls12377/scalar_field.go index c771f6ae3..6eab46b76 100644 --- a/wrappers/golang/curves/bls12377/scalar_field.go +++ b/wrappers/golang/curves/bls12377/scalar_field.go @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField { return res } +func (f ScalarField) Pow(exp int) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_377_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bls12377/tests/scalar_field_test.go b/wrappers/golang/curves/bls12377/tests/scalar_field_test.go index e7928eb19..766924485 100644 --- a/wrappers/golang/curves/bls12377/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bls12377/tests/scalar_field_test.go @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/curves/bls12381/g2/include/scalar_field.h b/wrappers/golang/curves/bls12381/g2/include/scalar_field.h index 54c08d6f4..265aa6dba 100644 --- a/wrappers/golang/curves/bls12381/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bls12381/g2/include/scalar_field.h @@ -16,6 +16,7 @@ void bls12_381_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_381_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_381_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_381_inv(const scalar_t* a, scalar_t* result); +void bls12_381_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12381/include/scalar_field.h b/wrappers/golang/curves/bls12381/include/scalar_field.h index 54c08d6f4..265aa6dba 100644 --- a/wrappers/golang/curves/bls12381/include/scalar_field.h +++ b/wrappers/golang/curves/bls12381/include/scalar_field.h @@ -16,6 +16,7 @@ void bls12_381_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_381_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_381_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bls12_381_inv(const scalar_t* a, scalar_t* result); +void bls12_381_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12381/scalar_field.go b/wrappers/golang/curves/bls12381/scalar_field.go index ed6e7002e..d59feb942 100644 --- a/wrappers/golang/curves/bls12381/scalar_field.go +++ b/wrappers/golang/curves/bls12381/scalar_field.go @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField { return res } +func (f ScalarField) Pow(exp int) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_381_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bls12381/tests/scalar_field_test.go b/wrappers/golang/curves/bls12381/tests/scalar_field_test.go index c4b68d9cc..11dcdca39 100644 --- a/wrappers/golang/curves/bls12381/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bls12381/tests/scalar_field_test.go @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/curves/bn254/g2/include/scalar_field.h b/wrappers/golang/curves/bn254/g2/include/scalar_field.h index 96cda48f9..c91e99236 100644 --- a/wrappers/golang/curves/bn254/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bn254/g2/include/scalar_field.h @@ -16,6 +16,7 @@ void bn254_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bn254_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bn254_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bn254_inv(const scalar_t* a, scalar_t* result); +void bn254_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bn254/include/scalar_field.h b/wrappers/golang/curves/bn254/include/scalar_field.h index 96cda48f9..c91e99236 100644 --- a/wrappers/golang/curves/bn254/include/scalar_field.h +++ b/wrappers/golang/curves/bn254/include/scalar_field.h @@ -16,6 +16,7 @@ void bn254_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bn254_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bn254_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bn254_inv(const scalar_t* a, scalar_t* result); +void bn254_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bn254/scalar_field.go b/wrappers/golang/curves/bn254/scalar_field.go index 0c07ad3ed..84f7122a3 100644 --- a/wrappers/golang/curves/bn254/scalar_field.go +++ b/wrappers/golang/curves/bn254/scalar_field.go @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField { return res } +func (f ScalarField) Pow(exp int) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bn254_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bn254/tests/scalar_field_test.go b/wrappers/golang/curves/bn254/tests/scalar_field_test.go index 01f5c6609..be1ea1573 100644 --- a/wrappers/golang/curves/bn254/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bn254/tests/scalar_field_test.go @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/curves/bw6761/g2/include/scalar_field.h b/wrappers/golang/curves/bw6761/g2/include/scalar_field.h index a40217357..117745476 100644 --- a/wrappers/golang/curves/bw6761/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bw6761/g2/include/scalar_field.h @@ -16,6 +16,7 @@ void bw6_761_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bw6_761_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bw6_761_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bw6_761_inv(const scalar_t* a, scalar_t* result); +void bw6_761_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bw6761/include/scalar_field.h b/wrappers/golang/curves/bw6761/include/scalar_field.h index a40217357..117745476 100644 --- a/wrappers/golang/curves/bw6761/include/scalar_field.h +++ b/wrappers/golang/curves/bw6761/include/scalar_field.h @@ -16,6 +16,7 @@ void bw6_761_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void bw6_761_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void bw6_761_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void bw6_761_inv(const scalar_t* a, scalar_t* result); +void bw6_761_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bw6761/scalar_field.go b/wrappers/golang/curves/bw6761/scalar_field.go index 8ab39796e..cfe1d4199 100644 --- a/wrappers/golang/curves/bw6761/scalar_field.go +++ b/wrappers/golang/curves/bw6761/scalar_field.go @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField { return res } +func (f ScalarField) Pow(exp int) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bw6_761_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bw6761/tests/scalar_field_test.go b/wrappers/golang/curves/bw6761/tests/scalar_field_test.go index 93bd0aa3d..5d42279c8 100644 --- a/wrappers/golang/curves/bw6761/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bw6761/tests/scalar_field_test.go @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/curves/grumpkin/include/scalar_field.h b/wrappers/golang/curves/grumpkin/include/scalar_field.h index c53e61bbd..138407651 100644 --- a/wrappers/golang/curves/grumpkin/include/scalar_field.h +++ b/wrappers/golang/curves/grumpkin/include/scalar_field.h @@ -16,6 +16,7 @@ void grumpkin_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void grumpkin_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void grumpkin_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void grumpkin_inv(const scalar_t* a, scalar_t* result); +void grumpkin_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/grumpkin/scalar_field.go b/wrappers/golang/curves/grumpkin/scalar_field.go index 379975aa5..f8cdb2b45 100644 --- a/wrappers/golang/curves/grumpkin/scalar_field.go +++ b/wrappers/golang/curves/grumpkin/scalar_field.go @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField { return res } +func (f ScalarField) Pow(exp int) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.grumpkin_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go b/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go index 61e898cd6..f1a3094c0 100644 --- a/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go +++ b/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/fields/babybear/extension/extension_field.go b/wrappers/golang/fields/babybear/extension/extension_field.go index e0d064e7f..ea7348f38 100644 --- a/wrappers/golang/fields/babybear/extension/extension_field.go +++ b/wrappers/golang/fields/babybear/extension/extension_field.go @@ -168,6 +168,18 @@ func (f ExtensionField) Sqr() ExtensionField { return res } +func (f ExtensionField) Pow(exp int) ExtensionField { + var res ExtensionField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_extension_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/fields/babybear/extension/include/scalar_field.h b/wrappers/golang/fields/babybear/extension/include/scalar_field.h index 79e4a4603..fb05f39a3 100644 --- a/wrappers/golang/fields/babybear/extension/include/scalar_field.h +++ b/wrappers/golang/fields/babybear/extension/include/scalar_field.h @@ -16,6 +16,7 @@ void babybear_extension_add(const scalar_t* a, const scalar_t* b, scalar_t* resu void babybear_extension_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void babybear_extension_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void babybear_extension_inv(const scalar_t* a, scalar_t* result); +void babybear_extension_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/fields/babybear/include/scalar_field.h b/wrappers/golang/fields/babybear/include/scalar_field.h index 9f8f0e5a7..9d9e0a5f4 100644 --- a/wrappers/golang/fields/babybear/include/scalar_field.h +++ b/wrappers/golang/fields/babybear/include/scalar_field.h @@ -16,6 +16,7 @@ void babybear_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void babybear_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void babybear_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void babybear_inv(const scalar_t* a, scalar_t* result); +void babybear_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/fields/babybear/scalar_field.go b/wrappers/golang/fields/babybear/scalar_field.go index ef6185b94..d1b58c737 100644 --- a/wrappers/golang/fields/babybear/scalar_field.go +++ b/wrappers/golang/fields/babybear/scalar_field.go @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField { return res } +func (f ScalarField) Pow(exp int) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/fields/babybear/tests/extension_field_test.go b/wrappers/golang/fields/babybear/tests/extension_field_test.go index e468fc309..88f609007 100644 --- a/wrappers/golang/fields/babybear/tests/extension_field_test.go +++ b/wrappers/golang/fields/babybear/tests/extension_field_test.go @@ -119,6 +119,11 @@ func testExtensionFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/fields/babybear/tests/scalar_field_test.go b/wrappers/golang/fields/babybear/tests/scalar_field_test.go index abf8f0c3a..79a30df8a 100644 --- a/wrappers/golang/fields/babybear/tests/scalar_field_test.go +++ b/wrappers/golang/fields/babybear/tests/scalar_field_test.go @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/internal/generator/fields/templates/field.go.tmpl b/wrappers/golang/internal/generator/fields/templates/field.go.tmpl index c03943389..09864379c 100644 --- a/wrappers/golang/internal/generator/fields/templates/field.go.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/field.go.tmpl @@ -171,6 +171,18 @@ func (f {{.FieldPrefix}}Field) Sqr() {{.FieldPrefix}}Field { return res } +func (f {{.FieldPrefix}}Field) Pow(exp int) {{.FieldPrefix}}Field { + var res {{.FieldPrefix}}Field + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cExp := (C.int)(exp) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.{{.Field}}_pow(cF, cExp, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl b/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl index b60362a61..d739433db 100644 --- a/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl @@ -120,6 +120,11 @@ func test{{.FieldPrefix}}FieldArithmetic(suite *suite.Suite) { suite.Equal(square, mul, "Square and multiplication do not yield the same value") + pow4 := scalarA.Pow(4) + mulBySelf := mul.Mul(&mul) + + suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value") + inv := scalarA.Inv() one := scalarA.Mul(&inv) diff --git a/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl b/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl index 59b7573b7..92daf6440 100644 --- a/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl @@ -16,6 +16,7 @@ void {{.Field}}_add(const scalar_t* a, const scalar_t* b, scalar_t* result); void {{.Field}}_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); void {{.Field}}_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); void {{.Field}}_inv(const scalar_t* a, scalar_t* result); +void {{.Field}}_pow(const scalar_t* a, int exp, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/rust/icicle-core/src/field.rs b/wrappers/rust/icicle-core/src/field.rs index 9e3a16c2e..a63f22fd2 100644 --- a/wrappers/rust/icicle-core/src/field.rs +++ b/wrappers/rust/icicle-core/src/field.rs @@ -114,6 +114,7 @@ pub trait FieldArithmetic { fn mul(first: F, second: F) -> F; fn sqr(first: F) -> F; fn inv(first: F) -> F; + fn pow(first: F, exp: usize) -> F; } impl Arithmetic for Field @@ -127,6 +128,10 @@ where fn inv(self) -> Self { F::inv(self) } + + fn pow(self, exp: usize) -> Self { + F::pow(self, exp) + } } impl Add for Field @@ -232,6 +237,9 @@ macro_rules! impl_scalar_field { #[link_name = concat!($field_prefix, "_inv")] pub(crate) fn inv(a: *const $field_name, result: *mut $field_name); + + #[link_name = concat!($field_prefix, "_pow")] + pub(crate) fn pow(a: *const $field_name, exp: usize, result: *mut $field_name); } pub(crate) fn convert_scalars_montgomery( @@ -303,6 +311,18 @@ macro_rules! impl_scalar_field { $field_prefix_ident::inv(&first as *const $field_name, &mut result as *mut $field_name); } + result + } + fn pow(first: $field_name, exp: usize) -> $field_name { + let mut result = $field_name::zero(); + unsafe { + $field_prefix_ident::pow( + &first as *const $field_name, + exp as usize, + &mut result as *mut $field_name, + ); + } + result } } diff --git a/wrappers/rust/icicle-core/src/tests.rs b/wrappers/rust/icicle-core/src/tests.rs index 370fc01b7..7a18d9297 100644 --- a/wrappers/rust/icicle-core/src/tests.rs +++ b/wrappers/rust/icicle-core/src/tests.rs @@ -30,12 +30,17 @@ where let result2 = result1 - scalars_b[i]; assert_eq!(result2, scalars_a[i]); } - + + // Test field multiplication API let scalar_a = scalars_a[0]; let square = scalar_a.sqr(); - let mul = scalar_a.mul(scalar_a); - - assert_eq!(square, mul); + let mul_by_self = scalar_a.mul(scalar_a); + assert_eq!(square, mul_by_self); + + // Test field pow API + let pow_4 = scalar_a.pow(4); + let mul_mul = mul_by_self.mul(mul_by_self); + assert_eq!(pow_4, mul_mul); let inv = scalar_a.inv(); let one = scalar_a.mul(inv); diff --git a/wrappers/rust/icicle-core/src/traits.rs b/wrappers/rust/icicle-core/src/traits.rs index 835a0dca0..0c1ef6222 100644 --- a/wrappers/rust/icicle-core/src/traits.rs +++ b/wrappers/rust/icicle-core/src/traits.rs @@ -35,4 +35,5 @@ pub trait MontgomeryConvertible: Sized { pub trait Arithmetic: Sized + Add + Sub + Mul { fn sqr(self) -> Self; fn inv(self) -> Self; + fn pow(self, exp: usize) -> Self; }