Skip to content

Commit

Permalink
feat(vm): add remainder precompile
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Oct 5, 2023
1 parent 328447d commit 9fa664f
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 0 deletions.
84 changes: 84 additions & 0 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{91}): &decrypt{}, // lib
common.BytesToAddress([]byte{92}): &fheDiv{}, // lib
common.BytesToAddress([]byte{93}): &fheLib{},
common.BytesToAddress([]byte{94}): &fheRem{},
common.BytesToAddress([]byte{99}): &faucet{},
}

Expand Down Expand Up @@ -139,6 +140,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{91}): &decrypt{},
common.BytesToAddress([]byte{92}): &fheDiv{},
common.BytesToAddress([]byte{93}): &fheLib{},
common.BytesToAddress([]byte{94}): &fheRem{},
common.BytesToAddress([]byte{99}): &faucet{},
}

Expand Down Expand Up @@ -184,6 +186,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{91}): &decrypt{},
common.BytesToAddress([]byte{92}): &fheDiv{},
common.BytesToAddress([]byte{93}): &fheLib{},
common.BytesToAddress([]byte{94}): &fheRem{},
common.BytesToAddress([]byte{99}): &faucet{},
}

Expand Down Expand Up @@ -229,6 +232,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{91}): &decrypt{},
common.BytesToAddress([]byte{92}): &fheDiv{},
common.BytesToAddress([]byte{93}): &fheLib{},
common.BytesToAddress([]byte{94}): &fheRem{},
common.BytesToAddress([]byte{99}): &faucet{},
}

Expand Down Expand Up @@ -274,6 +278,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{91}): &decrypt{},
common.BytesToAddress([]byte{92}): &fheDiv{},
common.BytesToAddress([]byte{93}): &fheLib{},
common.BytesToAddress([]byte{94}): &fheRem{},
common.BytesToAddress([]byte{99}): &faucet{},
}

Expand Down Expand Up @@ -1418,6 +1423,12 @@ var fheDivGasCosts = map[fheUintType]uint64{
FheUint32: params.FheUint32DivGas,
}

var fheRemGasCosts = map[fheUintType]uint64{
FheUint8: params.FheUint8RemGas,
FheUint16: params.FheUint16RemGas,
FheUint32: params.FheUint32RemGas,
}

var fheShiftGasCosts = map[fheUintType]uint64{
FheUint8: params.FheUint8ShiftGas,
FheUint16: params.FheUint16ShiftGas,
Expand Down Expand Up @@ -1496,6 +1507,7 @@ var signatureFheMax = makeKeccakSignature("fheMax(uint256,uint256,bytes1)")
var signatureFheNeg = makeKeccakSignature("fheNeg(uint256)")
var signatureFheNot = makeKeccakSignature("fheNot(uint256)")
var signatureFheDiv = makeKeccakSignature("fheDiv(uint256,uint256,bytes1)")
var signatureFheRem = makeKeccakSignature("fheRem(uint256,uint256,bytes1)")
var signatureFheBitAnd = makeKeccakSignature("fheBitAnd(uint256,uint256,bytes1)")
var signatureFheBitOr = makeKeccakSignature("fheBitOr(uint256,uint256,bytes1)")
var signatureFheBitXor = makeKeccakSignature("fheBitXor(uint256,uint256,bytes1)")
Expand Down Expand Up @@ -1583,6 +1595,10 @@ func (e *fheLib) RequiredGas(accessibleState PrecompileAccessibleState, input []
case signatureFheDiv:
bwCompatBytes := input[4:minInt(69, len(input))]
return (&fheDiv{}).RequiredGas(accessibleState, bwCompatBytes)
// first 4 bytes of keccak256('fheRem(uint256,uint256,bytes1)')
case signatureFheRem:
bwCompatBytes := input[4:minInt(69, len(input))]
return (&fheRem{}).RequiredGas(accessibleState, bwCompatBytes)
// first 4 bytes of keccak256('fheBitAnd(uint256,uint256,bytes1)')
case signatureFheBitAnd:
bwCompatBytes := input[4:minInt(69, len(input))]
Expand Down Expand Up @@ -1708,6 +1724,10 @@ func (e *fheLib) Run(accessibleState PrecompileAccessibleState, caller common.Ad
case signatureFheDiv:
bwCompatBytes := input[4:minInt(69, len(input))]
return (&fheDiv{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly)
// first 4 bytes of keccak256('fheRem(uint256,uint256,bytes1)')
case signatureFheRem:
bwCompatBytes := input[4:minInt(69, len(input))]
return (&fheRem{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly)
// first 4 bytes of keccak256('fheBitAnd(uint256,uint256,bytes1)')
case signatureFheBitAnd:
bwCompatBytes := input[4:minInt(69, len(input))]
Expand Down Expand Up @@ -2591,6 +2611,70 @@ func (e *fheDiv) Run(accessibleState PrecompileAccessibleState, caller common.Ad
}
}

type fheRem struct{}

func (e *fheRem) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
logger := accessibleState.Interpreter().evm.Logger
isScalar, err := isScalarOp(accessibleState, input)
if err != nil {
logger.Error("fheRem RequiredGas() cannot detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
var lhs *verifiedCiphertext
if !isScalar {
logger.Error("fheRem RequiredGas() only scalar in division is supported, two ciphertexts received", "input", hex.EncodeToString(input))
return 0
} else {
lhs, _, err = getScalarOperands(accessibleState, input)
if err != nil {
logger.Error("fheRem RequiredGas() scalar inputs not verified", "err", err, "input", hex.EncodeToString(input))
return 0
}
}
return fheRemGasCosts[lhs.ciphertext.fheUintType]
}

func (e *fheRem) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := accessibleState.Interpreter().evm.Logger

isScalar, err := isScalarOp(accessibleState, input)
if err != nil {
logger.Error("fheRem cannot detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

if !isScalar {
err = errors.New("fheRem supports only scalar input operation, two ciphertexts received")
logger.Error("fheRem supports only scalar input operation, two ciphertexts received", "input", hex.EncodeToString(input))
return nil, err
} else {
lhs, rhs, err := getScalarOperands(accessibleState, input)
if err != nil {
logger.Error("fheRem scalar inputs not verified", "err", err, "input", hex.EncodeToString(input))
return nil, err
}

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil
}

result, err := lhs.ciphertext.scalarRem(rhs.Uint64())
if err != nil {
logger.Error("fheRem failed", "err", err)
return nil, err
}
importCiphertext(accessibleState, result)

// TODO: for testing
writeResult(result, "rem_scalar_result", logger)

resultHash := result.getHash()
logger.Info("fheRem scalar success", "lhs", lhs.ciphertext.getHash().Hex(), "rhs", rhs.Uint64(), "result", resultHash.Hex())
return resultHash[:], nil
}
}

type fheBitAnd struct{}

func (e *fheBitAnd) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 {
Expand Down
125 changes: 125 additions & 0 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,55 @@ func FheLibDiv(t *testing.T, fheUintType fheUintType, scalar bool) {
}
}

func FheLibRem(t *testing.T, fheUintType fheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 7
rhs = 3
case FheUint16:
lhs = 721
rhs = 1000
case FheUint32:
lhs = 1337
rhs = 73
}
expected := lhs % rhs
c := &fheLib{}
signature := "fheRem(uint256,uint256,bytes1)"
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
var rhsHash common.Hash
if scalar {
rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes())
} else {
rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()
}
input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input, readOnly)
if scalar {
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
} else {
if err == nil {
t.Fatal("Non scalar remainder should fail")
}
}
}

func FheLibBitAnd(t *testing.T, fheUintType fheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
Expand Down Expand Up @@ -1926,6 +1975,54 @@ func FheDiv(t *testing.T, fheUintType fheUintType, scalar bool) {
}
}

func FheRem(t *testing.T, fheUintType fheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
case FheUint8:
lhs = 9
rhs = 5
case FheUint16:
lhs = 1773
rhs = 523
case FheUint32:
lhs = 123765
rhs = 2179
}
expected := lhs % rhs
c := &fheRem{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash()
var rhsHash common.Hash
if scalar {
rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes())
} else {
rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash()
}
input := toPrecompileInput(scalar, lhsHash, rhsHash)
out, err := c.Run(state, addr, addr, input, readOnly)
if scalar {
if err != nil {
t.Fatalf(err.Error())
}
res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out))
if res == nil {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
} else {
if err == nil {
t.Fatal("Non scalar remainder should fail")
}
}
}

func FheBitAnd(t *testing.T, fheUintType fheUintType, scalar bool) {
var lhs, rhs uint64
switch fheUintType {
Expand Down Expand Up @@ -3126,6 +3223,10 @@ func TestFheLibDiv8(t *testing.T) {
FheLibDiv(t, FheUint8, true)
}

func TestFheLibRem8(t *testing.T) {
FheLibRem(t, FheUint8, true)
}

func TestFheLibBitAnd8(t *testing.T) {
FheLibBitAnd(t, FheUint8, false)
}
Expand Down Expand Up @@ -3246,6 +3347,30 @@ func TestFheScalarDiv32(t *testing.T) {
FheDiv(t, FheUint32, true)
}

func TestFheRem8(t *testing.T) {
FheRem(t, FheUint8, false)
}

func TestFheRem16(t *testing.T) {
FheRem(t, FheUint16, false)
}

func TestFheRem32(t *testing.T) {
FheRem(t, FheUint32, false)
}

func TestFheScalarRem8(t *testing.T) {
FheRem(t, FheUint8, true)
}

func TestFheScalarRem16(t *testing.T) {
FheRem(t, FheUint16, true)
}

func TestFheScalarRem32(t *testing.T) {
FheRem(t, FheUint32, true)
}

func TestFheBitAnd8(t *testing.T) {
FheBitAnd(t, FheUint8, false)
}
Expand Down
46 changes: 46 additions & 0 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,39 @@ void* scalar_div_fhe_uint32(void* ct, uint32_t pt, void* sks)
return result;
}
void* scalar_rem_fhe_uint8(void* ct, uint8_t pt, void* sks)
{
FheUint8* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint8_scalar_rem(ct, pt, &result);
if(r != 0) return NULL;
return result;
}
void* scalar_rem_fhe_uint16(void* ct, uint16_t pt, void* sks)
{
FheUint16* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint16_scalar_rem(ct, pt, &result);
if(r != 0) return NULL;
return result;
}
void* scalar_rem_fhe_uint32(void* ct, uint32_t pt, void* sks)
{
FheUint32* result = NULL;
checked_set_server_key(sks);
const int r = fhe_uint32_scalar_rem(ct, pt, &result);
if(r != 0) return NULL;
return result;
}
void* bitand_fhe_uint8(void* ct1, void* ct2, void* sks)
{
FheUint8* result = NULL;
Expand Down Expand Up @@ -2011,6 +2044,19 @@ func (lhs *tfheCiphertext) scalarDiv(rhs uint64) (*tfheCiphertext, error) {
})
}

func (lhs *tfheCiphertext) scalarRem(rhs uint64) (*tfheCiphertext, error) {
return lhs.executeBinaryScalarOperation(rhs,
func(lhs unsafe.Pointer, rhs C.uint8_t) unsafe.Pointer {
return C.scalar_rem_fhe_uint8(lhs, rhs, sks)
},
func(lhs unsafe.Pointer, rhs C.uint16_t) unsafe.Pointer {
return C.scalar_rem_fhe_uint16(lhs, rhs, sks)
},
func(lhs unsafe.Pointer, rhs C.uint32_t) unsafe.Pointer {
return C.scalar_rem_fhe_uint32(lhs, rhs, sks)
})
}

func (lhs *tfheCiphertext) bitand(rhs *tfheCiphertext) (*tfheCiphertext, error) {
return lhs.executeBinaryCiphertextOperation(rhs,
func(lhs unsafe.Pointer, rhs unsafe.Pointer) unsafe.Pointer {
Expand Down
Loading

0 comments on commit 9fa664f

Please sign in to comment.