-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat/vlad/poseidon go binding (#513)
- Loading branch information
Showing
77 changed files
with
1,025 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
test_poseidon: test.cu poseidon.cu kernels.cu constants.cu | ||
nvcc -o test_poseidon -I../../include -DFIELD_ID=2 -DCURVE_ID=2 test.cu | ||
./test_poseidon | ||
test_poseidon : test.cu poseidon.cu kernels.cu constants.cu nvcc - o test_poseidon - I../../ include - DFIELD_ID = | ||
2 - DCURVE_ID = 2 test.cu./ test_poseidon |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#include "fields/field_config.cuh" | ||
|
||
using namespace field_config; | ||
|
||
#include "poseidon.cu" | ||
#include "constants.cu" | ||
|
||
#include "gpu-utils/device_context.cuh" | ||
#include "utils/utils.h" | ||
|
||
namespace poseidon { | ||
/** | ||
* Extern "C" version of [poseidon_hash_cuda] function with the following | ||
* value of template parameter (where the field is given by `-DFIELD` env variable during build): | ||
* - `S` is the [field](@ref scalar_t) - either a scalar field of the elliptic curve or a | ||
* stand-alone "STARK field"; | ||
* @return `cudaSuccess` if the execution was successful and an error code otherwise. | ||
*/ | ||
extern "C" cudaError_t CONCAT_EXPAND(FIELD, poseidon_hash_cuda)( | ||
scalar_t* input, | ||
scalar_t* output, | ||
int number_of_states, | ||
int arity, | ||
const PoseidonConstants<scalar_t>& constants, | ||
PoseidonConfig& config) | ||
{ | ||
switch (arity) { | ||
case 2: | ||
return poseidon_hash<scalar_t, 3>(input, output, number_of_states, constants, config); | ||
case 4: | ||
return poseidon_hash<scalar_t, 5>(input, output, number_of_states, constants, config); | ||
case 8: | ||
return poseidon_hash<scalar_t, 9>(input, output, number_of_states, constants, config); | ||
case 11: | ||
return poseidon_hash<scalar_t, 12>(input, output, number_of_states, constants, config); | ||
default: | ||
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "PoseidonHash: #arity must be one of [2, 4, 8, 11]"); | ||
} | ||
return CHK_LAST(); | ||
} | ||
|
||
extern "C" cudaError_t CONCAT_EXPAND(FIELD, create_optimized_poseidon_constants_cuda)( | ||
int arity, | ||
int full_rounds_half, | ||
int partial_rounds, | ||
const scalar_t* constants, | ||
device_context::DeviceContext& ctx, | ||
PoseidonConstants<scalar_t>* poseidon_constants) | ||
{ | ||
return create_optimized_poseidon_constants<scalar_t>( | ||
arity, full_rounds_half, partial_rounds, constants, ctx, poseidon_constants); | ||
} | ||
|
||
extern "C" cudaError_t CONCAT_EXPAND(FIELD, init_optimized_poseidon_constants_cuda)( | ||
int arity, device_context::DeviceContext& ctx, PoseidonConstants<scalar_t>* constants) | ||
{ | ||
return init_optimized_poseidon_constants<scalar_t>(arity, ctx, constants); | ||
} | ||
} // namespace poseidon |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package core | ||
|
||
import ( | ||
"fmt" | ||
"unsafe" | ||
|
||
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime" | ||
) | ||
|
||
type PoseidonConfig struct { | ||
/// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext). | ||
Ctx cr.DeviceContext | ||
areInputsOnDevice bool | ||
areOutputsOnDevice bool | ||
///If true, input is considered to be a states vector, holding the preimages in aligned or not aligned format. | ||
///Memory under the input pointer will be used for states. If false, fresh states memory will be allocated and input will be copied into it */ | ||
InputIsAState bool | ||
/// If true - input should be already aligned for poseidon permutation. | ||
///* Aligned format: [0, A, B, 0, C, D, ...] (as you might get by using loop_state) | ||
///* not aligned format: [A, B, 0, C, D, 0, ...] (as you might get from cudaMemcpy2D) */ | ||
Aligned bool | ||
///If true, hash results will also be copied in the input pointer in aligned format | ||
LoopState bool | ||
///Whether to run the Poseidon asynchronously. If set to `true`, the poseidon_hash function will be | ||
///non-blocking and you'd need to synchronize it explicitly by running `cudaStreamSynchronize` or `cudaDeviceSynchronize`. | ||
///If set to false, the poseidon_hash function will block the current CPU thread. */ | ||
IsAsync bool | ||
} | ||
|
||
type PoseidonConstants[T any] struct { | ||
Arity int32 | ||
PartialRounds int32 | ||
FullRoundsHalf int32 | ||
RoundConstants unsafe.Pointer | ||
MdsMatrix unsafe.Pointer | ||
NonSparseMatrix unsafe.Pointer | ||
SparseMatrices unsafe.Pointer | ||
DomainTag T | ||
} | ||
|
||
func GetDefaultPoseidonConfig() PoseidonConfig { | ||
ctx, _ := cr.GetDefaultDeviceContext() | ||
return PoseidonConfig{ | ||
ctx, // Ctx | ||
false, // areInputsOnDevice | ||
false, // areOutputsOnDevice | ||
false, // inputIsAState | ||
false, // aligned | ||
false, // loopState | ||
false, // IsAsync | ||
} | ||
} | ||
|
||
func PoseidonCheck[T any](input, output HostOrDeviceSlice, cfg *PoseidonConfig, constants *PoseidonConstants[T], numberOfStates int) (unsafe.Pointer, unsafe.Pointer, unsafe.Pointer) { | ||
inputLen, outputLen := input.Len(), output.Len() | ||
arity := int(constants.Arity) | ||
expectedInputLen := arity * numberOfStates | ||
if cfg.InputIsAState { | ||
expectedInputLen += numberOfStates | ||
} | ||
|
||
if inputLen != expectedInputLen { | ||
errorString := fmt.Sprintf( | ||
"input is not the right length for the given parameters: %d, should be: %d", | ||
inputLen, | ||
arity*numberOfStates, | ||
) | ||
panic(errorString) | ||
} | ||
|
||
if outputLen != numberOfStates { | ||
errorString := fmt.Sprintf( | ||
"output is not the right length for the given parameters: %d, should be: %d", | ||
outputLen, | ||
numberOfStates, | ||
) | ||
panic(errorString) | ||
} | ||
cfg.areInputsOnDevice = input.IsOnDevice() | ||
cfg.areOutputsOnDevice = output.IsOnDevice() | ||
|
||
if input.IsOnDevice() { | ||
input.(DeviceSlice).CheckDevice() | ||
|
||
} | ||
|
||
if output.IsOnDevice() { | ||
output.(DeviceSlice).CheckDevice() | ||
} | ||
|
||
cfgPointer := unsafe.Pointer(cfg) | ||
|
||
return input.AsUnsafePointer(), output.AsUnsafePointer(), cfgPointer | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
wrappers/golang/curves/bls12377/poseidon/include/poseidon.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include <cuda_runtime.h> | ||
#include <stdbool.h> | ||
|
||
#ifndef _BLS12_377_POSEIDON_H | ||
#define _BLS12_377_POSEIDON_H | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
typedef struct scalar_t scalar_t; | ||
typedef struct PoseidonConfig PoseidonConfig; | ||
typedef struct DeviceContext DeviceContext; | ||
typedef struct PoseidonConstants PoseidonConstants; | ||
|
||
|
||
cudaError_t bls12_377_poseidon_hash_cuda(const scalar_t* input, scalar_t* output, int number_of_states, int arity, PoseidonConstants* constants, PoseidonConfig* config); | ||
cudaError_t bls12_377_create_optimized_poseidon_constants_cuda(int arity, int full_rounds_halfs, int partial_rounds, const scalar_t* constants, DeviceContext* ctx, PoseidonConstants* poseidon_constants); | ||
cudaError_t bls12_377_init_optimized_poseidon_constants_cuda(int arity, DeviceContext* ctx, PoseidonConstants* constants); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package poseidon | ||
|
||
// #cgo CFLAGS: -I./include/ | ||
// #include "poseidon.h" | ||
import "C" | ||
|
||
import ( | ||
"unsafe" | ||
|
||
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core" | ||
cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime" | ||
) | ||
|
||
func GetDefaultPoseidonConfig() core.PoseidonConfig { | ||
return core.GetDefaultPoseidonConfig() | ||
} | ||
|
||
func PoseidonHash[T any](scalars, results core.HostOrDeviceSlice, numberOfStates int, cfg *core.PoseidonConfig, constants *core.PoseidonConstants[T]) core.IcicleError { | ||
scalarsPointer, resultsPointer, cfgPointer := core.PoseidonCheck(scalars, results, cfg, constants, numberOfStates) | ||
|
||
cScalars := (*C.scalar_t)(scalarsPointer) | ||
cResults := (*C.scalar_t)(resultsPointer) | ||
cNumberOfStates := (C.int)(numberOfStates) | ||
cArity := (C.int)(constants.Arity) | ||
cConstants := (*C.PoseidonConstants)(unsafe.Pointer(constants)) | ||
cCfg := (*C.PoseidonConfig)(cfgPointer) | ||
|
||
__ret := C.bls12_377_poseidon_hash_cuda(cScalars, cResults, cNumberOfStates, cArity, cConstants, cCfg) | ||
|
||
err := (cr.CudaError)(__ret) | ||
return core.FromCudaError(err) | ||
} | ||
|
||
func CreateOptimizedPoseidonConstants[T any](arity, fullRoundsHalfs, partialRounds int, constants core.HostOrDeviceSlice, ctx cr.DeviceContext, poseidonConstants *core.PoseidonConstants[T]) core.IcicleError { | ||
|
||
cArity := (C.int)(arity) | ||
cFullRoundsHalfs := (C.int)(fullRoundsHalfs) | ||
cPartialRounds := (C.int)(partialRounds) | ||
cConstants := (*C.scalar_t)(constants.AsUnsafePointer()) | ||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) | ||
cPoseidonConstants := (*C.PoseidonConstants)(unsafe.Pointer(poseidonConstants)) | ||
|
||
__ret := C.bls12_377_create_optimized_poseidon_constants_cuda(cArity, cFullRoundsHalfs, cPartialRounds, cConstants, cCtx, cPoseidonConstants) | ||
err := (cr.CudaError)(__ret) | ||
return core.FromCudaError(err) | ||
} | ||
|
||
func InitOptimizedPoseidonConstantsCuda[T any](arity int, ctx cr.DeviceContext, constants *core.PoseidonConstants[T]) core.IcicleError { | ||
|
||
cArity := (C.int)(arity) | ||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) | ||
cConstants := (*C.PoseidonConstants)(unsafe.Pointer(constants)) | ||
|
||
__ret := C.bls12_377_init_optimized_poseidon_constants_cuda(cArity, cCtx, cConstants) | ||
err := (cr.CudaError)(__ret) | ||
return core.FromCudaError(err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.