Skip to content

Commit

Permalink
fix: add missing gas costs for loading ciphertexts
Browse files Browse the repository at this point in the history
  • Loading branch information
dartdart26 committed Jun 11, 2024
1 parent 09e1ffe commit ba51fe3
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 187 deletions.
4 changes: 2 additions & 2 deletions fhevm/ciphertext_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func loadCiphertext(env EVMEnvironment, handle common.Hash) (ct *tfhe.TfheCipher

metadataInt := newInt(env.GetState(ciphertextStorage, handle).Bytes())
if metadataInt.IsZero() {
return nil, 0
return nil, ColdSloadCostEIP2929
}
metadata := newCiphertextMetadata(metadataInt.Bytes32())
ctBytes := make([]byte, 0)
Expand All @@ -93,7 +93,7 @@ func loadCiphertext(env EVMEnvironment, handle common.Hash) (ct *tfhe.TfheCipher
err := ct.Deserialize(ctBytes, metadata.fheUintType)
if err != nil {
logger.Error("failed to deserialize ciphertext from storage", "err", err)
return nil, 0
return nil, ColdSloadCostEIP2929 + DeserializeCiphertextGas
}
env.FhevmData().loadedCiphertexts[handle] = ct
return ct, env.FhevmParams().GasCosts.FheStorageSloadGas[ct.Type()]
Expand Down
81 changes: 0 additions & 81 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3474,23 +3474,6 @@ func FheArrayEqNoRhs(t *testing.T, fheUintType tfhe.FheUintType) {
}
}

func FheArrayEqNoRhsGas(t *testing.T, fheUintType tfhe.FheUintType) {
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth

lhs := make([]*big.Int, 3)
lhs[0] = loadCiphertextInTestMemory(environment, 1, depth, fheUintType).GetHash().Big()
lhs[1] = loadCiphertextInTestMemory(environment, 2, depth, fheUintType).GetHash().Big()
lhs[2] = loadCiphertextInTestMemory(environment, 3, depth, fheUintType).GetHash().Big()
input, _ := arrayEqMethod.Inputs.Pack(lhs)

gas := fheArrayEqRequiredGas(environment, input)
if gas != 0 {
t.Fatalf("fheArrayEq expected 0 gas value")
}
}

func TestFheArrayEqUnverifiedCtInLhs(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
Expand Down Expand Up @@ -3519,28 +3502,6 @@ func TestFheArrayEqUnverifiedCtInLhs(t *testing.T) {
}
}

func TestFheArrayEqUnverifiedCtInLhsGas(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth

lhs := make([]*big.Int, 3)
lhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
lhs[0].Add(lhs[0], big.NewInt(1))
lhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
lhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
rhs := make([]*big.Int, 3)
rhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
rhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
rhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs)

gas := fheArrayEqRequiredGas(environment, input)
if gas != 0 {
t.Fatalf("fheArrayEq expected 0 gas value")
}
}

func TestFheArrayEqUnverifiedCtInRhs(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
Expand Down Expand Up @@ -3570,28 +3531,6 @@ func TestFheArrayEqUnverifiedCtInRhs(t *testing.T) {
}
}

func TestFheArrayEqUnverifiedCtInRhsGas(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth

lhs := make([]*big.Int, 3)
lhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
lhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
lhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
rhs := make([]*big.Int, 3)
rhs[0] = loadCiphertextInTestMemory(environment, 1, depth, tfhe.FheUint32).GetHash().Big()
rhs[1] = loadCiphertextInTestMemory(environment, 2, depth, tfhe.FheUint32).GetHash().Big()
rhs[1].Add(lhs[0], big.NewInt(1))
rhs[2] = loadCiphertextInTestMemory(environment, 3, depth, tfhe.FheUint32).GetHash().Big()
input, _ := arrayEqMethod.Inputs.Pack(lhs, rhs)

gas := fheArrayEqRequiredGas(environment, input)
if gas != 0 {
t.Fatalf("fheArrayEq expected 0 gas value")
}
}

func TestVerifyCiphertextInvalidType(t *testing.T) {
depth := 1
environment := newTestEVMEnvironment()
Expand Down Expand Up @@ -5164,23 +5103,3 @@ func TesFheArrayEqNoRhs32(t *testing.T) {
func TestFheArrayEqNoRhs64(t *testing.T) {
FheArrayEqNoRhs(t, tfhe.FheUint64)
}

func TestFheArrayEqNoRhsGas4(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint4)
}

func TestFheArrayEqNoRhsGas8(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint8)
}

func TestFheArrayEqNoRhsGas16(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint16)
}

func TesFheArrayEqNoRhsGas32(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint32)
}

func TestFheArrayEqNoRhsGas64(t *testing.T) {
FheArrayEqNoRhsGas(t, tfhe.FheUint64)
}
12 changes: 6 additions & 6 deletions fhevm/fhelib.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ func load2Ciphertexts(environment EVMEnvironment, input []byte) (lhs *tfhe.TfheC
loadGasRhs := uint64(0)
lhs, loadGasLhs = loadCiphertext(environment, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, loadGasLhs, errors.New("unverified ciphertext handle")
}
rhs, loadGasRhs = loadCiphertext(environment, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, loadGasLhs + loadGasRhs, errors.New("unverified ciphertext handle")
}
err = nil
loadGas = loadGasLhs + loadGasRhs
Expand All @@ -337,15 +337,15 @@ func load3Ciphertexts(environment EVMEnvironment, input []byte) (first *tfhe.Tfh
loadGasThird := uint64(0)
first, loadGasFirst = loadCiphertext(environment, common.BytesToHash(input[0:32]))
if first == nil {
return nil, nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, nil, loadGasFirst, errors.New("unverified ciphertext handle")
}
second, loadGasSecond = loadCiphertext(environment, common.BytesToHash(input[32:64]))
if second == nil {
return nil, nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, nil, loadGasFirst + loadGasSecond, errors.New("unverified ciphertext handle")
}
third, loadGasThird = loadCiphertext(environment, common.BytesToHash(input[64:96]))
if third == nil {
return nil, nil, nil, 0, errors.New("unverified ciphertext handle")
return nil, nil, nil, loadGasFirst + loadGasSecond + loadGasThird, errors.New("unverified ciphertext handle")
}
err = nil
loadGas = loadGasFirst + loadGasSecond + loadGasThird
Expand All @@ -358,7 +358,7 @@ func getScalarOperands(environment EVMEnvironment, input []byte) (lhs *tfhe.Tfhe
}
lhs, loadGas = loadCiphertext(environment, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, nil, 0, errors.New("failed to load ciphertext")
return nil, nil, loadGas, errors.New("failed to load ciphertext")
}
rhs = &big.Int{}
rhs.SetBytes(input[32:64])
Expand Down
33 changes: 14 additions & 19 deletions fhevm/operators_arithmetic_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ func fheAddSubRequiredGas(environment EVMEnvironment, input []byte) uint64 {
lhs, rhs, loadGas, err = load2Ciphertexts(environment, input)
if err != nil {
logger.Error("fheAdd/Sub RequiredGas() ciphertext failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("fheAdd/Sub RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
if err != nil {
logger.Error("fheAdd/Sub RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
}

Expand All @@ -47,24 +47,22 @@ func fheMulRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheMul RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
loadGas := uint64(0)
var lhs, rhs *tfhe.TfheCiphertext
if !isScalar {
lhs, rhs, loadGas, err = load2Ciphertexts(environment, input)
lhs, rhs, loadGas, err := load2Ciphertexts(environment, input)
if err != nil {
logger.Error("fheMul RequiredGas() ciphertext failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("fheMul RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheMul[lhs.Type()]
return environment.FhevmParams().GasCosts.FheMul[lhs.Type()] + loadGas
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheMul RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarMul[lhs.Type()] + loadGas
}
Expand All @@ -79,16 +77,15 @@ func fheDivRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheDiv RequiredGas() cannot detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
loadGas := uint64(0)
var lhs *tfhe.TfheCiphertext

if !isScalar {
logger.Error("fheDiv RequiredGas() only scalar in division is supported, two ciphertexts received", "input", hex.EncodeToString(input))
return 0
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheDiv RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarDiv[lhs.Type()] + loadGas
}
Expand All @@ -103,16 +100,14 @@ func fheRemRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheRem RequiredGas() cannot detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
var lhs *tfhe.TfheCiphertext
loadGas := uint64(0)
if !isScalar {
logger.Error("fheRem RequiredGas() only scalar in division is supported, two ciphertexts received", "input", hex.EncodeToString(input))
return 0
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheRem RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarRem[lhs.Type()] + loadGas
}
Expand Down
23 changes: 10 additions & 13 deletions fhevm/operators_bit_gas.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/hex"

"github.com/ethereum/go-ethereum/common"
"github.com/zama-ai/fhevm-go/fhevm/tfhe"
)

func fheShlRequiredGas(environment EVMEnvironment, input []byte) uint64 {
Expand All @@ -16,24 +15,22 @@ func fheShlRequiredGas(environment EVMEnvironment, input []byte) uint64 {
logger.Error("fheShift RequiredGas() can not detect if operator is meant to be scalar", "err", err, "input", hex.EncodeToString(input))
return 0
}
var lhs, rhs *tfhe.TfheCiphertext
loadGas := uint64(0)
if !isScalar {
lhs, rhs, loadGas, err = load2Ciphertexts(environment, input)
lhs, rhs, loadGas, err := load2Ciphertexts(environment, input)
if err != nil {
logger.Error("fheShift RequiredGas() ciphertext failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("fheShift RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheShift[lhs.Type()]
return environment.FhevmParams().GasCosts.FheShift[lhs.Type()] + loadGas
} else {
lhs, _, loadGas, err = getScalarOperands(environment, input)
lhs, _, loadGas, err := getScalarOperands(environment, input)
if err != nil {
logger.Error("fheShift RequiredGas() scalar failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheScalarShift[lhs.Type()] + loadGas
}
Expand Down Expand Up @@ -65,7 +62,7 @@ func fheNegRequiredGas(environment EVMEnvironment, input []byte) uint64 {
ct, loadGas := loadCiphertext(environment, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("fheNeg failed to load input", "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheNeg[ct.Type()] + loadGas
}
Expand All @@ -81,7 +78,7 @@ func fheNotRequiredGas(environment EVMEnvironment, input []byte) uint64 {
ct, loadGas := loadCiphertext(environment, common.BytesToHash(input[0:32]))
if ct == nil {
logger.Error("fheNot failed to load input", "input", hex.EncodeToString(input))
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheNot[ct.Type()] + loadGas
}
Expand All @@ -106,11 +103,11 @@ func fheBitAndRequiredGas(environment EVMEnvironment, input []byte) uint64 {
lhs, rhs, loadGas, err := load2Ciphertexts(environment, input)
if err != nil {
logger.Error("Bitwise op RequiredGas() failed to load inputs", "err", err, "input", hex.EncodeToString(input))
return 0
return loadGas
}
if lhs.Type() != rhs.Type() {
logger.Error("Bitwise op RequiredGas() operand type mismatch", "lhs", lhs.Type(), "rhs", rhs.Type())
return 0
return loadGas
}
return environment.FhevmParams().GasCosts.FheBitwiseOp[lhs.Type()] + loadGas
}
Expand Down
16 changes: 9 additions & 7 deletions fhevm/operators_comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,20 +617,22 @@ func init() {
}
}

func getVerifiedCiphertexts(environment EVMEnvironment, unpacked interface{}) ([]*tfhe.TfheCiphertext, error) {
func getVerifiedCiphertexts(environment EVMEnvironment, unpacked interface{}) ([]*tfhe.TfheCiphertext, uint64, error) {
totalLoadGas := uint64(0)
big, ok := unpacked.([]*big.Int)
if !ok {
return nil, fmt.Errorf("fheArrayEq failed to cast to []*big.Int")
return nil, 0, fmt.Errorf("fheArrayEq failed to cast to []*big.Int")
}
ret := make([]*tfhe.TfheCiphertext, 0, len(big))
for _, b := range big {
ct, _ := loadCiphertext(environment, common.BigToHash(b))
ct, loadGas := loadCiphertext(environment, common.BigToHash(b))
if ct == nil {
return nil, fmt.Errorf("fheArrayEq unverified ciphertext")
return nil, totalLoadGas + loadGas, fmt.Errorf("fheArrayEq unverified ciphertext")
}
totalLoadGas += loadGas
ret = append(ret, ct)
}
return ret, nil
return ret, totalLoadGas, nil
}

func fheArrayEqRun(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool, runSpan trace.Span) ([]byte, error) {
Expand All @@ -649,14 +651,14 @@ func fheArrayEqRun(environment EVMEnvironment, caller common.Address, addr commo
return nil, err
}

lhs, err := getVerifiedCiphertexts(environment, unpacked[0])
lhs, _, err := getVerifiedCiphertexts(environment, unpacked[0])
if err != nil {
msg := "fheArrayEqRun failed to get lhs to verified ciphertexts"
logger.Error(msg, "err", err)
return nil, err
}

rhs, err := getVerifiedCiphertexts(environment, unpacked[1])
rhs, _, err := getVerifiedCiphertexts(environment, unpacked[1])
if err != nil {
msg := "fheArrayEqRun failed to get rhs to verified ciphertexts"
logger.Error(msg, "err", err)
Expand Down
Loading

0 comments on commit ba51fe3

Please sign in to comment.