Skip to content

Commit

Permalink
[FEAT] pow field Arithmetic API (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonHibnik authored Jan 1, 2025
1 parent 696cfe3 commit d7af027
Show file tree
Hide file tree
Showing 34 changed files with 210 additions and 4 deletions.
11 changes: 11 additions & 0 deletions icicle/include/icicle/fields/complex_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ class ComplexExtensionField
FF xs_norm_squared = FF::sqr(xs.real) - nonresidue_times_im;
return xs_conjugate * ComplexExtensionField{FF::inverse(xs_norm_squared), FF::zero()};
}

static constexpr HOST_DEVICE ComplexExtensionField pow(ComplexExtensionField base, int exp)
{
ComplexExtensionField res = one();
while (exp > 0) {
if (exp & 1) res = res * base;
base = base * base;
exp >>= 1;
}
return res;
}
};

#ifdef __CUDACC__
Expand Down
11 changes: 11 additions & 0 deletions icicle/include/icicle/fields/quartic_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ class QuarticExtensionField
FF::reduce(FF::mul_wide(xs.im1, x2) - FF::mul_wide(xs.im3, x0)),
};
}

static constexpr HOST_DEVICE QuarticExtensionField pow(QuarticExtensionField base, int exp)
{
QuarticExtensionField res = one();
while (exp > 0) {
if (exp & 1) res = res * base;
base = base * base;
exp >>= 1;
}
return res;
}
};
#if __CUDACC__
template <class CONFIG, class T>
Expand Down
10 changes: 10 additions & 0 deletions icicle/src/fields/ffi_extern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ extern "C" void CONCAT_EXPAND(FIELD, inv)(scalar_t* scalar1, scalar_t* result)
*result = scalar_t::inverse(*scalar1);
}

extern "C" void CONCAT_EXPAND(FIELD, pow)(scalar_t* base, int exp, scalar_t* result)
{
*result = scalar_t::pow(*base, exp);
}

#ifdef EXT_FIELD
extern "C" void CONCAT_EXPAND(FIELD, extension_generate_scalars)(extension_t* scalars, int size)
{
Expand All @@ -53,4 +58,9 @@ extern "C" void CONCAT_EXPAND(FIELD, extension_inv)(extension_t* scalar1, extens
{
*result = extension_t::inverse(*scalar1);
}

extern "C" void CONCAT_EXPAND(FIELD, extension_pow)(extension_t* base, int exp, extension_t* result)
{
*result = extension_t::pow(*base, exp);
}
#endif // EXT_FIELD
1 change: 1 addition & 0 deletions wrappers/golang/curves/bls12377/g2/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_inv(const scalar_t* a, scalar_t* result);
void bls12_377_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bls12377/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_inv(const scalar_t* a, scalar_t* result);
void bls12_377_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/curves/bls12377/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField {
return res
}

func (f ScalarField) Pow(exp int) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_377_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/curves/bls12377/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bls12381/g2/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bls12_381_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_381_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_381_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_381_inv(const scalar_t* a, scalar_t* result);
void bls12_381_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bls12381/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bls12_381_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_381_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_381_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_381_inv(const scalar_t* a, scalar_t* result);
void bls12_381_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/curves/bls12381/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField {
return res
}

func (f ScalarField) Pow(exp int) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_381_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/curves/bls12381/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bn254/g2/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bn254_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bn254_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bn254_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bn254_inv(const scalar_t* a, scalar_t* result);
void bn254_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bn254/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bn254_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bn254_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bn254_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bn254_inv(const scalar_t* a, scalar_t* result);
void bn254_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/curves/bn254/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField {
return res
}

func (f ScalarField) Pow(exp int) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bn254_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/curves/bn254/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bw6761/g2/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bw6_761_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bw6_761_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bw6_761_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bw6_761_inv(const scalar_t* a, scalar_t* result);
void bw6_761_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/bw6761/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void bw6_761_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bw6_761_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bw6_761_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bw6_761_inv(const scalar_t* a, scalar_t* result);
void bw6_761_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/curves/bw6761/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField {
return res
}

func (f ScalarField) Pow(exp int) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bw6_761_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/curves/bw6761/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/curves/grumpkin/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void grumpkin_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void grumpkin_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void grumpkin_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void grumpkin_inv(const scalar_t* a, scalar_t* result);
void grumpkin_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/curves/grumpkin/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField {
return res
}

func (f ScalarField) Pow(exp int) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.grumpkin_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/curves/grumpkin/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/fields/babybear/extension/extension_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ExtensionField) Sqr() ExtensionField {
return res
}

func (f ExtensionField) Pow(exp int) ExtensionField {
var res ExtensionField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.babybear_extension_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void babybear_extension_add(const scalar_t* a, const scalar_t* b, scalar_t* resu
void babybear_extension_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void babybear_extension_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void babybear_extension_inv(const scalar_t* a, scalar_t* result);
void babybear_extension_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions wrappers/golang/fields/babybear/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void babybear_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void babybear_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void babybear_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void babybear_inv(const scalar_t* a, scalar_t* result);
void babybear_pow(const scalar_t* a, int exp, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/fields/babybear/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ func (f ScalarField) Sqr() ScalarField {
return res
}

func (f ScalarField) Pow(exp int) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.babybear_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/fields/babybear/tests/extension_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testExtensionFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
5 changes: 5 additions & 0 deletions wrappers/golang/fields/babybear/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ func testScalarFieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
12 changes: 12 additions & 0 deletions wrappers/golang/internal/generator/fields/templates/field.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ func (f {{.FieldPrefix}}Field) Sqr() {{.FieldPrefix}}Field {
return res
}

func (f {{.FieldPrefix}}Field) Pow(exp int) {{.FieldPrefix}}Field {
var res {{.FieldPrefix}}Field

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cExp := (C.int)(exp)
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.{{.Field}}_pow(cF, cExp, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ func test{{.FieldPrefix}}FieldArithmetic(suite *suite.Suite) {

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

pow4 := scalarA.Pow(4)
mulBySelf := mul.Mul(&mul)

suite.Equal(pow4, mulBySelf, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
Expand Down
Loading

0 comments on commit d7af027

Please sign in to comment.