Skip to content

Commit

Permalink
Feat/vlad/poseidon go binding (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladfdp authored and yshekel committed May 19, 2024
1 parent df7333e commit a359b06
Show file tree
Hide file tree
Showing 77 changed files with 1,025 additions and 142 deletions.
1 change: 1 addition & 0 deletions icicle/src/fields/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(POLYNOMIAL_SOURCE_FILES

# TODO: impl poseidon for small fields. note that it needs to be defined over the extension field!
if (DEFINED CURVE)
list(APPEND FIELD_SOURCE ${SRC}/poseidon/extern.cu)
list(APPEND FIELD_SOURCE ${SRC}/poseidon/poseidon.cu)
list(APPEND FIELD_SOURCE ${SRC}/poseidon/tree/merkle.cu)
endif()
Expand Down
5 changes: 2 additions & 3 deletions icicle/src/poseidon/Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
test_poseidon: test.cu poseidon.cu kernels.cu constants.cu
nvcc -o test_poseidon -I../../include -DFIELD_ID=2 -DCURVE_ID=2 test.cu
./test_poseidon
test_poseidon : test.cu poseidon.cu kernels.cu constants.cu nvcc - o test_poseidon - I../../ include - DFIELD_ID =
2 - DCURVE_ID = 2 test.cu./ test_poseidon
18 changes: 0 additions & 18 deletions icicle/src/poseidon/constants.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,4 @@ namespace poseidon {

return CHK_LAST();
}

extern "C" cudaError_t CONCAT_EXPAND(FIELD, create_optimized_poseidon_constants_cuda)(
int arity,
int full_rounds_half,
int partial_rounds,
const scalar_t* constants,
device_context::DeviceContext& ctx,
PoseidonConstants<scalar_t>* poseidon_constants)
{
return create_optimized_poseidon_constants<scalar_t>(
arity, full_rounds_half, partial_rounds, constants, ctx, poseidon_constants);
}

extern "C" cudaError_t CONCAT_EXPAND(FIELD, init_optimized_poseidon_constants_cuda)(
int arity, device_context::DeviceContext& ctx, PoseidonConstants<scalar_t>* constants)
{
return init_optimized_poseidon_constants<scalar_t>(arity, ctx, constants);
}
} // namespace poseidon
59 changes: 59 additions & 0 deletions icicle/src/poseidon/extern.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "fields/field_config.cuh"

using namespace field_config;

#include "poseidon.cu"
#include "constants.cu"

#include "gpu-utils/device_context.cuh"
#include "utils/utils.h"

namespace poseidon {
/**
* Extern "C" version of [poseidon_hash_cuda] function with the following
* value of template parameter (where the field is given by `-DFIELD` env variable during build):
* - `S` is the [field](@ref scalar_t) - either a scalar field of the elliptic curve or a
* stand-alone "STARK field";
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(FIELD, poseidon_hash_cuda)(
scalar_t* input,
scalar_t* output,
int number_of_states,
int arity,
const PoseidonConstants<scalar_t>& constants,
PoseidonConfig& config)
{
switch (arity) {
case 2:
return poseidon_hash<scalar_t, 3>(input, output, number_of_states, constants, config);
case 4:
return poseidon_hash<scalar_t, 5>(input, output, number_of_states, constants, config);
case 8:
return poseidon_hash<scalar_t, 9>(input, output, number_of_states, constants, config);
case 11:
return poseidon_hash<scalar_t, 12>(input, output, number_of_states, constants, config);
default:
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "PoseidonHash: #arity must be one of [2, 4, 8, 11]");
}
return CHK_LAST();
}

extern "C" cudaError_t CONCAT_EXPAND(FIELD, create_optimized_poseidon_constants_cuda)(
int arity,
int full_rounds_half,
int partial_rounds,
const scalar_t* constants,
device_context::DeviceContext& ctx,
PoseidonConstants<scalar_t>* poseidon_constants)
{
return create_optimized_poseidon_constants<scalar_t>(
arity, full_rounds_half, partial_rounds, constants, ctx, poseidon_constants);
}

extern "C" cudaError_t CONCAT_EXPAND(FIELD, init_optimized_poseidon_constants_cuda)(
int arity, device_context::DeviceContext& ctx, PoseidonConstants<scalar_t>* constants)
{
return init_optimized_poseidon_constants<scalar_t>(arity, ctx, constants);
}
} // namespace poseidon
24 changes: 0 additions & 24 deletions icicle/src/poseidon/poseidon.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using namespace field_config;

#include "poseidon/poseidon.cuh"
#include "constants.cu"
#include "kernels.cu"

namespace poseidon {
Expand Down Expand Up @@ -88,27 +87,4 @@ namespace poseidon {
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
return CHK_LAST();
}

extern "C" cudaError_t CONCAT_EXPAND(FIELD, poseidon_hash_cuda)(
scalar_t* input,
scalar_t* output,
int number_of_states,
int arity,
const PoseidonConstants<scalar_t>& constants,
PoseidonConfig& config)
{
switch (arity) {
case 2:
return poseidon_hash<scalar_t, 3>(input, output, number_of_states, constants, config);
case 4:
return poseidon_hash<scalar_t, 5>(input, output, number_of_states, constants, config);
case 8:
return poseidon_hash<scalar_t, 9>(input, output, number_of_states, constants, config);
case 11:
return poseidon_hash<scalar_t, 12>(input, output, number_of_states, constants, config);
default:
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "PoseidonHash: #arity must be one of [2, 4, 8, 11]");
}
return CHK_LAST();
}
} // namespace poseidon
94 changes: 94 additions & 0 deletions wrappers/golang/core/poseidon.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package core

import (
"fmt"
"unsafe"

cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
)

type PoseidonConfig struct {
/// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext).
Ctx cr.DeviceContext
areInputsOnDevice bool
areOutputsOnDevice bool
///If true, input is considered to be a states vector, holding the preimages in aligned or not aligned format.
///Memory under the input pointer will be used for states. If false, fresh states memory will be allocated and input will be copied into it */
InputIsAState bool
/// If true - input should be already aligned for poseidon permutation.
///* Aligned format: [0, A, B, 0, C, D, ...] (as you might get by using loop_state)
///* not aligned format: [A, B, 0, C, D, 0, ...] (as you might get from cudaMemcpy2D) */
Aligned bool
///If true, hash results will also be copied in the input pointer in aligned format
LoopState bool
///Whether to run the Poseidon asynchronously. If set to `true`, the poseidon_hash function will be
///non-blocking and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or `cudaDeviceSynchronize`.
///If set to false, the poseidon_hash function will block the current CPU thread. */
IsAsync bool
}

type PoseidonConstants[T any] struct {
Arity int32
PartialRounds int32
FullRoundsHalf int32
RoundConstants unsafe.Pointer
MdsMatrix unsafe.Pointer
NonSparseMatrix unsafe.Pointer
SparseMatrices unsafe.Pointer
DomainTag T
}

func GetDefaultPoseidonConfig() PoseidonConfig {
ctx, _ := cr.GetDefaultDeviceContext()
return PoseidonConfig{
ctx, // Ctx
false, // areInputsOnDevice
false, // areOutputsOnDevice
false, // inputIsAState
false, // aligned
false, // loopState
false, // IsAsync
}
}

func PoseidonCheck[T any](input, output HostOrDeviceSlice, cfg *PoseidonConfig, constants *PoseidonConstants[T], numberOfStates int) (unsafe.Pointer, unsafe.Pointer, unsafe.Pointer) {
inputLen, outputLen := input.Len(), output.Len()
arity := int(constants.Arity)
expectedInputLen := arity * numberOfStates
if cfg.InputIsAState {
expectedInputLen += numberOfStates
}

if inputLen != expectedInputLen {
errorString := fmt.Sprintf(
"input is not the right length for the given parameters: %d, should be: %d",
inputLen,
arity*numberOfStates,
)
panic(errorString)
}

if outputLen != numberOfStates {
errorString := fmt.Sprintf(
"output is not the right length for the given parameters: %d, should be: %d",
outputLen,
numberOfStates,
)
panic(errorString)
}
cfg.areInputsOnDevice = input.IsOnDevice()
cfg.areOutputsOnDevice = output.IsOnDevice()

if input.IsOnDevice() {
input.(DeviceSlice).CheckDevice()

}

if output.IsOnDevice() {
output.(DeviceSlice).CheckDevice()
}

cfgPointer := unsafe.Pointer(cfg)

return input.AsUnsafePointer(), output.AsUnsafePointer(), cfgPointer
}
3 changes: 1 addition & 2 deletions wrappers/golang/curves/bls12377/g2/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ package g2
import "C"

import (
"unsafe"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
"unsafe"
)

func G2GetDefaultMSMConfig() core.MSMConfig {
Expand Down
3 changes: 1 addition & 2 deletions wrappers/golang/curves/bls12377/msm/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ package msm
import "C"

import (
"unsafe"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
"unsafe"
)

func GetDefaultMSMConfig() core.MSMConfig {
Expand Down
6 changes: 4 additions & 2 deletions wrappers/golang/curves/bls12377/ntt/ntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ package ntt
import "C"

import (
"unsafe"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
bls12_377 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377"
)

import (
"unsafe"
)

func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
scalarsPointer, resultsPointer, size, cfgPointer := core.NttCheck[T](scalars, cfg, results)

Expand Down
25 changes: 25 additions & 0 deletions wrappers/golang/curves/bls12377/poseidon/include/poseidon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <cuda_runtime.h>
#include <stdbool.h>

#ifndef _BLS12_377_POSEIDON_H
#define _BLS12_377_POSEIDON_H

#ifdef __cplusplus
extern "C" {
#endif

typedef struct scalar_t scalar_t;
typedef struct PoseidonConfig PoseidonConfig;
typedef struct DeviceContext DeviceContext;
typedef struct PoseidonConstants PoseidonConstants;


cudaError_t bls12_377_poseidon_hash_cuda(const scalar_t* input, scalar_t* output, int number_of_states, int arity, PoseidonConstants* constants, PoseidonConfig* config);
cudaError_t bls12_377_create_optimized_poseidon_constants_cuda(int arity, int full_rounds_halfs, int partial_rounds, const scalar_t* constants, DeviceContext* ctx, PoseidonConstants* poseidon_constants);
cudaError_t bls12_377_init_optimized_poseidon_constants_cuda(int arity, DeviceContext* ctx, PoseidonConstants* constants);

#ifdef __cplusplus
}
#endif

#endif
57 changes: 57 additions & 0 deletions wrappers/golang/curves/bls12377/poseidon/poseidon.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package poseidon

// #cgo CFLAGS: -I./include/
// #include "poseidon.h"
import "C"

import (
"unsafe"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
)

func GetDefaultPoseidonConfig() core.PoseidonConfig {
return core.GetDefaultPoseidonConfig()
}

func PoseidonHash[T any](scalars, results core.HostOrDeviceSlice, numberOfStates int, cfg *core.PoseidonConfig, constants *core.PoseidonConstants[T]) core.IcicleError {
scalarsPointer, resultsPointer, cfgPointer := core.PoseidonCheck(scalars, results, cfg, constants, numberOfStates)

cScalars := (*C.scalar_t)(scalarsPointer)
cResults := (*C.scalar_t)(resultsPointer)
cNumberOfStates := (C.int)(numberOfStates)
cArity := (C.int)(constants.Arity)
cConstants := (*C.PoseidonConstants)(unsafe.Pointer(constants))
cCfg := (*C.PoseidonConfig)(cfgPointer)

__ret := C.bls12_377_poseidon_hash_cuda(cScalars, cResults, cNumberOfStates, cArity, cConstants, cCfg)

err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func CreateOptimizedPoseidonConstants[T any](arity, fullRoundsHalfs, partialRounds int, constants core.HostOrDeviceSlice, ctx cr.DeviceContext, poseidonConstants *core.PoseidonConstants[T]) core.IcicleError {

cArity := (C.int)(arity)
cFullRoundsHalfs := (C.int)(fullRoundsHalfs)
cPartialRounds := (C.int)(partialRounds)
cConstants := (*C.scalar_t)(constants.AsUnsafePointer())
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
cPoseidonConstants := (*C.PoseidonConstants)(unsafe.Pointer(poseidonConstants))

__ret := C.bls12_377_create_optimized_poseidon_constants_cuda(cArity, cFullRoundsHalfs, cPartialRounds, cConstants, cCtx, cPoseidonConstants)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

func InitOptimizedPoseidonConstantsCuda[T any](arity int, ctx cr.DeviceContext, constants *core.PoseidonConstants[T]) core.IcicleError {

cArity := (C.int)(arity)
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
cConstants := (*C.PoseidonConstants)(unsafe.Pointer(constants))

__ret := C.bls12_377_init_optimized_poseidon_constants_cuda(cArity, cCtx, cConstants)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
3 changes: 1 addition & 2 deletions wrappers/golang/curves/bls12377/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import "C"
import (
"encoding/binary"
"fmt"
"unsafe"

"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
"unsafe"
)

const (
Expand Down
3 changes: 1 addition & 2 deletions wrappers/golang/curves/bls12377/tests/base_field_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package tests

import (
"testing"

bls12_377 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/test_helpers"
"github.com/stretchr/testify/assert"
"testing"
)

const (
Expand Down
3 changes: 1 addition & 2 deletions wrappers/golang/curves/bls12377/tests/curve_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package tests

import (
"testing"

bls12_377 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/test_helpers"
"github.com/stretchr/testify/assert"
"testing"
)

func TestAffineZero(t *testing.T) {
Expand Down
Loading

0 comments on commit a359b06

Please sign in to comment.