Skip to content

Commit

Permalink
feat: add support for casting (#118)
Browse files Browse the repository at this point in the history
* feat(tfhe): add support for casting

* feat: add `cast` precompile

* feat(cast): add type validity check
  • Loading branch information
tremblaythibaultl authored Jun 22, 2023
1 parent 4818f9c commit 1038691
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 18 deletions.
71 changes: 55 additions & 16 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
// common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
common.BytesToAddress([]byte{99}): &faucet{},
}
Expand Down Expand Up @@ -106,7 +106,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
// common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
common.BytesToAddress([]byte{99}): &faucet{},
}
Expand Down Expand Up @@ -136,7 +136,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
// common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
common.BytesToAddress([]byte{99}): &faucet{},
}
Expand Down Expand Up @@ -166,7 +166,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
// common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
common.BytesToAddress([]byte{99}): &faucet{},
}
Expand Down Expand Up @@ -196,7 +196,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{73}): &fheLt{},
// common.BytesToAddress([]byte{74}): &fheRand{},
common.BytesToAddress([]byte{75}): &optimisticRequire{},
// common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{76}): &cast{},
common.BytesToAddress([]byte{77}): &trivialEncrypt{},
common.BytesToAddress([]byte{99}): &faucet{},
}
Expand Down Expand Up @@ -1438,7 +1438,7 @@ func (e *verifyCiphertext) RequiredGas(accessibleState PrecompileAccessibleState
func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := accessibleState.Interpreter().evm.Logger
if len(input) <= 1 {
msg := "verifyCiphertext RequiredGas() input needs to contain a ciphertext and one byte for its type"
msg := "verifyCiphertext Run() input needs to contain a ciphertext and one byte for its type"
logger.Error(msg, "len", len(input))
return nil, errors.New(msg)
}
Expand Down Expand Up @@ -2026,18 +2026,57 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add
// return ctHash[:], nil
// }

// type cast struct{}
type cast struct{}

// func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
// return 0
// }
func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
if len(input) != 33 {
accessibleState.Interpreter().evm.Logger.Error(
"cast RequiredGas() input needs to contain a ciphertext and one byte for its type",
"len", len(input))
return 0
}
return params.FheCastGas
}

// // Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API.
// func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
// // var ctHandle = common.BytesToHash(input[0:31])
// // var toType = input[32]
// return nil, nil
// }
// Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API.
func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := accessibleState.Interpreter().evm.Logger
if len(input) != 33 {
msg := "cast Run() input needs to contain a ciphertext and one byte for its type"
logger.Error(msg, "len", len(input))
return nil, errors.New(msg)
}

ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("cast input not verified")
return nil, errors.New("unverified ciphertext handle")
}

castToType := fheUintType(input[32])
if !castToType.isValid() {
logger.Error("invalid type to cast to")
return nil, errors.New("invalid type provided")
}

res, err := ct.ciphertext.castTo(castToType)
if err != nil {
msg := "cast Run() error casting ciphertext to"
logger.Error(msg, "type", castToType)
return nil, errors.New(msg)
}

resHash := res.getHash()

importCiphertext(accessibleState, res)
if accessibleState.Interpreter().evm.Commit {
logger.Info("cast success",
"ctHash", resHash.Hex(),
)
}

return resHash.Bytes(), nil
}

type faucet struct{}

Expand Down
107 changes: 107 additions & 0 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,66 @@ void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value,
assert(r == 0);
}
void* cast_8_16(void* ct, void* sks) {
FheUint16* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint8_cast_into_fhe_uint16(ct, &result);
assert(r == 0);
return result;
}
void* cast_8_32(void* ct, void* sks) {
FheUint32* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint8_cast_into_fhe_uint32(ct, &result);
assert(r == 0);
return result;
}
void* cast_16_8(void* ct, void* sks) {
FheUint8* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint16_cast_into_fhe_uint8(ct, &result);
assert(r == 0);
return result;
}
void* cast_16_32(void* ct, void* sks) {
FheUint32* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint16_cast_into_fhe_uint32(ct, &result);
assert(r == 0);
return result;
}
void* cast_32_8(void* ct, void* sks) {
FheUint8* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint32_cast_into_fhe_uint8(ct, &result);
assert(r == 0);
return result;
}
void* cast_32_16(void* ct, void* sks) {
FheUint16* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint32_cast_into_fhe_uint16(ct, &result);
assert(r == 0);
return result;
}
*/
import "C"

Expand Down Expand Up @@ -810,6 +870,49 @@ func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) {
return res, nil
}

func (ct *tfheCiphertext) castTo(castToType fheUintType) (*tfheCiphertext, error) {
if !ct.availableForOps() {
panic("cannot cast a non-initialized ciphertext")
}

if ct.fheUintType == castToType {
return nil, errors.New("casting to same type is not supported")
}

if !castToType.isValid() {
return nil, errors.New("invalid type to cast to")
}

res := new(tfheCiphertext)
res.fheUintType = castToType

switch ct.fheUintType {
case FheUint8:
switch castToType {
case FheUint16:
res.setPtr(C.cast_8_16(ct.ptr, sks))
case FheUint32:
res.setPtr(C.cast_8_32(ct.ptr, sks))
}
case FheUint16:
switch castToType {
case FheUint8:
res.setPtr(C.cast_16_8(ct.ptr, sks))
case FheUint32:
res.setPtr(C.cast_16_32(ct.ptr, sks))
}
case FheUint32:
switch castToType {
case FheUint8:
res.setPtr(C.cast_32_8(ct.ptr, sks))
case FheUint16:
res.setPtr(C.cast_32_16(ct.ptr, sks))
}
}

return res, nil
}

func (ct *tfheCiphertext) decrypt() big.Int {
if !ct.availableForOps() {
panic("cannot decrypt a null ciphertext")
Expand Down Expand Up @@ -869,6 +972,10 @@ func (ct *tfheCiphertext) initialized() bool {
return (ct.ptr != nil)
}

func (t *fheUintType) isValid() bool {
return (*t <= 2)
}

// Used for testing.
func encryptAndSerializeCompact(value uint32, fheUintType fheUintType) []byte {
out := &C.Buffer{}
Expand Down
67 changes: 65 additions & 2 deletions core/vm/tfhe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package vm

import (
"bytes"
"math"
"math/big"
"testing"
)
Expand Down Expand Up @@ -301,9 +302,9 @@ func TfheLt(t *testing.T, fheUintType fheUintType) {
b.SetUint64(133337)
}
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint8)
ctA.encrypt(a, fheUintType)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint8)
ctB.encrypt(b, fheUintType)
ctRes1, _ := ctA.lte(ctB)
ctRes2, _ := ctB.lte(ctA)
res1 := ctRes1.decrypt()
Expand All @@ -316,6 +317,44 @@ func TfheLt(t *testing.T, fheUintType fheUintType) {
}
}

func TfheCast(t *testing.T, fheUintTypeFrom fheUintType, fheUintTypeTo fheUintType) {
var a big.Int
switch fheUintTypeFrom {
case FheUint8:
a.SetUint64(2)
case FheUint16:
a.SetUint64(4283)
case FheUint32:
a.SetUint64(1333337)
}

var modulus uint64
switch fheUintTypeTo {
case FheUint8:
modulus = uint64(math.Pow(2, 8))
case FheUint16:
modulus = uint64(math.Pow(2, 16))
case FheUint32:
modulus = uint64(math.Pow(2, 32))
}

ctA := new(tfheCiphertext)
ctA.encrypt(a, fheUintTypeFrom)
ctRes, err := ctA.castTo(fheUintTypeTo)
if err != nil {
t.Fatal(err)
}

if ctRes.fheUintType != fheUintTypeTo {
t.Fatalf("type %d != type %d", ctA.fheUintType, fheUintTypeTo)
}
res := ctRes.decrypt()
expected := a.Uint64() % modulus
if res.Uint64() != expected {
t.Fatalf("%d != %d", res.Uint64(), expected)
}
}

func TestTfheEncryptDecrypt8(t *testing.T) {
TfheEncryptDecrypt(t, FheUint8)
}
Expand Down Expand Up @@ -470,3 +509,27 @@ func TestTfheLte32(t *testing.T) {
func TestTfheLt32(t *testing.T) {
TfheLte(t, FheUint32)
}

func TestTfhe8Cast16(t *testing.T) {
TfheCast(t, FheUint8, FheUint16)
}

func TestTfhe8Cast32(t *testing.T) {
TfheCast(t, FheUint8, FheUint32)
}

func TestTfhe16Cast8(t *testing.T) {
TfheCast(t, FheUint16, FheUint8)
}

func TestTfhe16Cast32(t *testing.T) {
TfheCast(t, FheUint16, FheUint32)
}

func TestTfhe32Cast8(t *testing.T) {
TfheCast(t, FheUint16, FheUint8)
}

func TestTfhe32Cast16(t *testing.T) {
TfheCast(t, FheUint16, FheUint8)
}
2 changes: 2 additions & 0 deletions params/protocol_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ const (
FheUint16ProtectedStorageSloadGas uint64 = FheUint8ProtectedStorageSloadGas * 2
FheUint32ProtectedStorageSloadGas uint64 = FheUint16ProtectedStorageSloadGas * 4

FheCastGas uint64 = 100

FhePubKeyGas uint64 = 2

FheUint8TrivialEncryptGas uint64 = 100
Expand Down

0 comments on commit 1038691

Please sign in to comment.