diff --git a/core/vm/contracts.go b/core/vm/contracts.go index d8f21dcd5..db6137722 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -67,33 +67,33 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{4}): &dataCopy{}, // Zama-specific contracts - common.BytesToAddress([]byte{65}): &fheAdd{}, - common.BytesToAddress([]byte{66}): &verifyCiphertext{}, - common.BytesToAddress([]byte{67}): &reencrypt{}, - common.BytesToAddress([]byte{68}): &fhePubKey{}, - common.BytesToAddress([]byte{70}): &fheLe{}, - common.BytesToAddress([]byte{71}): &fheSub{}, - common.BytesToAddress([]byte{72}): &fheMul{}, - common.BytesToAddress([]byte{73}): &fheLt{}, - common.BytesToAddress([]byte{74}): &fheRand{}, - common.BytesToAddress([]byte{75}): &optimisticRequire{}, - common.BytesToAddress([]byte{76}): &cast{}, - common.BytesToAddress([]byte{77}): &trivialEncrypt{}, - common.BytesToAddress([]byte{78}): &fheBitAnd{}, - common.BytesToAddress([]byte{79}): &fheBitOr{}, - common.BytesToAddress([]byte{80}): &fheBitXor{}, - common.BytesToAddress([]byte{81}): &fheEq{}, - common.BytesToAddress([]byte{82}): &fheGe{}, - common.BytesToAddress([]byte{83}): &fheGt{}, - common.BytesToAddress([]byte{84}): &fheShl{}, - common.BytesToAddress([]byte{85}): &fheShr{}, - common.BytesToAddress([]byte{86}): &fheNe{}, - common.BytesToAddress([]byte{87}): &fheMin{}, - common.BytesToAddress([]byte{88}): &fheMax{}, - common.BytesToAddress([]byte{89}): &fheNeg{}, - common.BytesToAddress([]byte{90}): &fheNot{}, - common.BytesToAddress([]byte{91}): &decrypt{}, - common.BytesToAddress([]byte{92}): &fheDiv{}, + common.BytesToAddress([]byte{65}): &fheAdd{}, // lib + common.BytesToAddress([]byte{66}): &verifyCiphertext{}, // lib + common.BytesToAddress([]byte{67}): &reencrypt{}, // lib + common.BytesToAddress([]byte{68}): &fhePubKey{}, // lib + common.BytesToAddress([]byte{70}): &fheLe{}, // lib + common.BytesToAddress([]byte{71}): &fheSub{}, // lib + common.BytesToAddress([]byte{72}): &fheMul{}, // lib + common.BytesToAddress([]byte{73}): &fheLt{}, // lib + common.BytesToAddress([]byte{74}): &fheRand{}, // lib + common.BytesToAddress([]byte{75}): &optimisticRequire{}, // lib + common.BytesToAddress([]byte{76}): &cast{}, // lib + common.BytesToAddress([]byte{77}): &trivialEncrypt{}, // lib + common.BytesToAddress([]byte{78}): &fheBitAnd{}, // lib + common.BytesToAddress([]byte{79}): &fheBitOr{}, // lib + common.BytesToAddress([]byte{80}): &fheBitXor{}, // lib + common.BytesToAddress([]byte{81}): &fheEq{}, // lib + common.BytesToAddress([]byte{82}): &fheGe{}, // lib + common.BytesToAddress([]byte{83}): &fheGt{}, // lib + common.BytesToAddress([]byte{84}): &fheShl{}, // lib + common.BytesToAddress([]byte{85}): &fheShr{}, // lib + common.BytesToAddress([]byte{86}): &fheNe{}, // lib + common.BytesToAddress([]byte{87}): &fheMin{}, // lib + common.BytesToAddress([]byte{88}): &fheMax{}, // lib + common.BytesToAddress([]byte{89}): &fheNeg{}, // lib + common.BytesToAddress([]byte{90}): &fheNot{}, // lib + common.BytesToAddress([]byte{91}): &decrypt{}, // lib + common.BytesToAddress([]byte{92}): &fheDiv{}, // lib common.BytesToAddress([]byte{93}): &fheLib{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -1476,6 +1476,38 @@ func writeResult(ct *tfheCiphertext, fileName string, logger Logger) { os.WriteFile("/tmp/"+fileName, ct.serialize(), 0644) } +func makeKeccakSignature(input string) uint32 { + return binary.BigEndian.Uint32(crypto.Keccak256([]byte(input))[0:4]) +} + +var signatureFheAdd = makeKeccakSignature("fheAdd(uint256,uint256,bytes1)") +var signatureFheSub = makeKeccakSignature("fheSub(uint256,uint256,bytes1)") +var signatureFheMul = makeKeccakSignature("fheMul(uint256,uint256,bytes1)") +var signatureFheLe = makeKeccakSignature("fheLe(uint256,uint256,bytes1)") +var signatureFheLt = makeKeccakSignature("fheLt(uint256,uint256,bytes1)") +var signatureFheEq = makeKeccakSignature("fheEq(uint256,uint256,bytes1)") +var signatureFheGe = makeKeccakSignature("fheGe(uint256,uint256,bytes1)") +var signatureFheGt = makeKeccakSignature("fheGt(uint256,uint256,bytes1)") +var signatureFheShl = makeKeccakSignature("fheShl(uint256,uint256,bytes1)") +var signatureFheShr = makeKeccakSignature("fheShr(uint256,uint256,bytes1)") +var signatureFheNe = makeKeccakSignature("fheNe(uint256,uint256,bytes1)") +var signatureFheMin = makeKeccakSignature("fheMin(uint256,uint256,bytes1)") +var signatureFheMax = makeKeccakSignature("fheMax(uint256,uint256,bytes1)") +var signatureFheNeg = makeKeccakSignature("fheNeg(uint256)") +var signatureFheNot = makeKeccakSignature("fheNot(uint256)") +var signatureFheDiv = makeKeccakSignature("fheDiv(uint256,uint256,bytes1)") +var signatureFheBitAnd = makeKeccakSignature("fheBitAnd(uint256,uint256,bytes1)") +var signatureFheBitOr = makeKeccakSignature("fheBitOr(uint256,uint256,bytes1)") +var signatureFheBitXor = makeKeccakSignature("fheBitXor(uint256,uint256,bytes1)") +var signatureFheRand = makeKeccakSignature("fheRand(bytes1)") +var signatureVerifyCiphertext = makeKeccakSignature("verifyCiphertext(bytes)") +var signatureReencrypt = makeKeccakSignature("reencrypt(uint256,uint256)") +var signatureFhePubKey = makeKeccakSignature("fhePubKey(bytes1)") +var signatureOptimisticRequire = makeKeccakSignature("optimisticRequire(uint256)") +var signatureCast = makeKeccakSignature("cast(uint256,bytes1)") +var signatureTrivialEncrypt = makeKeccakSignature("trivialEncrypt(uint256,bytes1)") +var signatureDecrypt = makeKeccakSignature("decrypt(uint256)") + type fheLib struct{} func (e *fheLib) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { @@ -1488,9 +1520,113 @@ func (e *fheLib) RequiredGas(accessibleState PrecompileAccessibleState, input [] signature := binary.BigEndian.Uint32(input[0:4]) switch signature { // first 4 bytes of keccak256('fheAdd(uint256,uint256,bytes1)') - case 0xf953e427: + case signatureFheAdd: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheAdd{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheSub(uint256,uint256,bytes1)') + case signatureFheSub: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheSub{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheMul(uint256,uint256,bytes1)') + case signatureFheMul: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheMul{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheLe(uint256,uint256,bytes1)') + case signatureFheLe: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheLe{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheLt(uint256,uint256,bytes1)') + case signatureFheLt: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheLt{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheEq(uint256,uint256,bytes1)') + case signatureFheEq: bwCompatBytes := input[4:minInt(69, len(input))] - return (*fheAdd)(nil).RequiredGas(accessibleState, bwCompatBytes) + return (&fheEq{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheGe(uint256,uint256,bytes1)') + case signatureFheGe: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheGe{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheGt(uint256,uint256,bytes1)') + case signatureFheGt: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheGt{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheShl(uint256,uint256,bytes1)') + case signatureFheShl: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheShl{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheShr(uint256,uint256,bytes1)') + case signatureFheShr: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheShr{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheNe(uint256,uint256,bytes1)') + case signatureFheNe: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheNe{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheMin(uint256,uint256,bytes1)') + case signatureFheMin: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheMin{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheMax(uint256,uint256,bytes1)') + case signatureFheMax: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheMax{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheNeg(uint256)') + case signatureFheNeg: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&fheNeg{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheNot(uint256)') + case signatureFheNot: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&fheNot{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheDiv(uint256,uint256,bytes1)') + case signatureFheDiv: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheDiv{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheBitAnd(uint256,uint256,bytes1)') + case signatureFheBitAnd: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheBitAnd{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheBitOr(uint256,uint256,bytes1)') + case signatureFheBitOr: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheBitOr{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheBitXor(uint256,uint256,bytes1)') + case signatureFheBitXor: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheBitXor{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fheRand(bytes1)') + case signatureFheRand: + bwCompatBytes := input[4:minInt(5, len(input))] + return (&fheRand{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('verifyCiphertext(bytes)') + case signatureVerifyCiphertext: + bwCompatBytes := input[4:] + return (&verifyCiphertext{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('reencrypt(uint256,uint256)') + case signatureReencrypt: + bwCompatBytes := input[4:minInt(68, len(input))] + return (&reencrypt{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('fhePubKey(bytes1)') + case signatureFhePubKey: + bwCompatBytes := input[4:minInt(5, len(input))] + return (&fhePubKey{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('optimisticRequire(uint256)') + case signatureOptimisticRequire: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&optimisticRequire{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('cast(uint256,bytes1)') + case signatureCast: + bwCompatBytes := input[4:minInt(37, len(input))] + return (&cast{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('trivialEncrypt(uint256,bytes1)') + case signatureTrivialEncrypt: + bwCompatBytes := input[4:minInt(37, len(input))] + return (&trivialEncrypt{}).RequiredGas(accessibleState, bwCompatBytes) + // first 4 bytes of keccak256('decrypt(uint256)') + case signatureDecrypt: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&decrypt{}).RequiredGas(accessibleState, bwCompatBytes) default: err := errors.New("precompile method not found") logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) @@ -1508,10 +1644,144 @@ func (e *fheLib) Run(accessibleState PrecompileAccessibleState, caller common.Ad signature := binary.BigEndian.Uint32(input[0:4]) switch signature { // first 4 bytes of keccak256('fheAdd(uint256,uint256,bytes1)') - case 0xf953e427: + case signatureFheAdd: bwCompatBytes := input[4:minInt(69, len(input))] // state of fheAdd struct is never needed or accessed so we use nil - return (*fheAdd)(nil).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + return (&fheAdd{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheSub(uint256,uint256,bytes1)') + case signatureFheSub: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheSub{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheMul(uint256,uint256,bytes1)') + case signatureFheMul: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheMul{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheLe(uint256,uint256,bytes1)') + case signatureFheLe: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheLe{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheLt(uint256,uint256,bytes1)') + case signatureFheLt: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheLt{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheEq(uint256,uint256,bytes1)') + case signatureFheEq: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheEq{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheGe(uint256,uint256,bytes1)') + case signatureFheGe: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheGe{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheGt(uint256,uint256,bytes1)') + case signatureFheGt: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheGt{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheShl(uint256,uint256,bytes1)') + case signatureFheShl: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheShl{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheShr(uint256,uint256,bytes1)') + case signatureFheShr: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheShr{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheNe(uint256,uint256,bytes1)') + case signatureFheNe: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheNe{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheMin(uint256,uint256,bytes1)') + case signatureFheMin: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheMin{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheMax(uint256,uint256,bytes1)') + case signatureFheMax: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheMax{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheNeg(uint256)') + case signatureFheNeg: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&fheNeg{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheNot(uint256)') + case signatureFheNot: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&fheNot{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheDiv(uint256,uint256,bytes1)') + case signatureFheDiv: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheDiv{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheBitAnd(uint256,uint256,bytes1)') + case signatureFheBitAnd: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheBitAnd{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheBitOr(uint256,uint256,bytes1)') + case signatureFheBitOr: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheBitOr{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheBitXor(uint256,uint256,bytes1)') + case signatureFheBitXor: + bwCompatBytes := input[4:minInt(69, len(input))] + return (&fheBitXor{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('fheRand(bytes1)') + case signatureFheRand: + bwCompatBytes := input[4:minInt(5, len(input))] + return (&fheRand{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('verifyCiphertext(bytes)') + case signatureVerifyCiphertext: + // first 32 bytes of the payload is offset, then 32 bytes are size of byte array + if len(input) <= 68 { + err := errors.New("verifyCiphertext(bytes) must contain at least 68 bytes for selector, byte offset and size") + logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) + return nil, err + } + bytesPaddingSize := 32 + bytesSizeSlotSize := 32 + // read only last 4 bytes of padded number for byte array size + sizeStart := 4 + bytesPaddingSize + bytesSizeSlotSize - 4 + sizeEnd := sizeStart + 4 + bytesSize := binary.BigEndian.Uint32(input[sizeStart:sizeEnd]) + bytesStart := 4 + bytesPaddingSize + bytesSizeSlotSize + bytesEnd := bytesStart + int(bytesSize) + bwCompatBytes := input[bytesStart:minInt(bytesEnd, len(input))] + return (&verifyCiphertext{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('reencrypt(uint256,uint256)') + case signatureReencrypt: + bwCompatBytes := input[4:minInt(68, len(input))] + precompileBytes, err := (&reencrypt{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + if err != nil { + return precompileBytes, err + } + // pad according to abi specification, first add offset to the dynamic bytes argument + outputBytes := make([]byte, 32, len(precompileBytes)+32) + outputBytes[31] = 0x20 + outputBytes = append(outputBytes, precompileBytes...) + return padArrayTo32Multiple(outputBytes), nil + // first 4 bytes of keccak256('fhePubKey(bytes1)') + case signatureFhePubKey: + bwCompatBytes := input[4:minInt(5, len(input))] + precompileBytes, err := (&fhePubKey{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + if err != nil { + return precompileBytes, err + } + // pad according to abi specification, first add offset to the dynamic bytes argument + outputBytes := make([]byte, 32, len(precompileBytes)+32) + outputBytes[31] = 0x20 + outputBytes = append(outputBytes, precompileBytes...) + return padArrayTo32Multiple(outputBytes), nil + // first 4 bytes of keccak256('optimisticRequire(uint256)') + case signatureOptimisticRequire: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&optimisticRequire{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('cast(uint256,bytes1)') + case signatureCast: + bwCompatBytes := input[4:minInt(37, len(input))] + return (&cast{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('trivialEncrypt(uint256,bytes1)') + case signatureTrivialEncrypt: + bwCompatBytes := input[4:minInt(37, len(input))] + return (&trivialEncrypt{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) + // first 4 bytes of keccak256('decrypt(uint256)') + case signatureDecrypt: + bwCompatBytes := input[4:minInt(36, len(input))] + return (&decrypt{}).Run(accessibleState, caller, addr, bwCompatBytes, readOnly) default: err := errors.New("precompile method not found") logger.Error("fheLib precompile error", "err", err, "input", hex.EncodeToString(input)) @@ -1706,15 +1976,29 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller // Return a memory with a layout that matches the `bytes` EVM type, namely: // - 32 byte integer in big-endian order as length // - the actual bytes in the `bytes` value +// - add zero byte padding until nearest multiple of 32 func toEVMBytes(input []byte) []byte { - len := uint64(len(input)) - lenBytes32 := uint256.NewInt(len).Bytes32() - ret := make([]byte, 0, len+32) + arrLen := uint64(len(input)) + lenBytes32 := uint256.NewInt(arrLen).Bytes32() + ret := make([]byte, 0, arrLen+32) ret = append(ret, lenBytes32[:]...) ret = append(ret, input...) return ret } +// apply padding to slice to the multiple of 32 +func padArrayTo32Multiple(input []byte) []byte { + modRes := len(input) % 32 + if modRes > 0 { + padding := 32 - modRes + for padding > 0 { + padding-- + input = append(input, 0x0) + } + } + return input +} + type reencrypt struct{} func (e *reencrypt) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 8d459c063..67f5bcff8 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "os" + "strings" "testing" "time" @@ -473,9 +474,8 @@ var scalarBytePadding = make([]byte, 31) func toLibPrecompileInput(method string, isScalar bool, hashes ...common.Hash) []byte { ret := make([]byte, 0) - state := crypto.NewKeccakState() - hashRes := crypto.HashData(state, []byte(method)) - signature := hashRes.Bytes()[0:4] + hashRes := crypto.Keccak256([]byte(method)) + signature := hashRes[0:4] ret = append(ret, signature...) for _, hash := range hashes { ret = append(ret, hash.Bytes()...) @@ -491,69 +491,1090 @@ func toLibPrecompileInput(method string, isScalar bool, hashes ...common.Hash) [ return ret } +func toLibPrecompileInputNoScalar(method string, hashes ...common.Hash) []byte { + ret := make([]byte, 0) + hashRes := crypto.Keccak256([]byte(method)) + signature := hashRes[0:4] + ret = append(ret, signature...) + for _, hash := range hashes { + ret = append(ret, hash.Bytes()...) + } + return ret +} + func VerifyCiphertext(t *testing.T, fheUintType fheUintType) { var value uint32 switch fheUintType { case FheUint8: - value = 2 + value = 2 + case FheUint16: + value = 4283 + case FheUint32: + value = 1333337 + } + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + compact := encryptAndSerializeCompact(value, fheUintType) + input := append(compact, byte(fheUintType)) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + ct := new(tfheCiphertext) + if err = ct.deserializeCompact(compact, fheUintType); err != nil { + t.Fatalf(err.Error()) + } + if common.BytesToHash(out) != ct.getHash() { + t.Fatalf("output hash in verifyCipertext is incorrect") + } + res := getVerifiedCiphertextFromEVM(state.interpreter, ct.getHash()) + if res == nil { + t.Fatalf("verifyCiphertext must have verified given ciphertext") + } +} + +func VerifyCiphertextBadType(t *testing.T, actualType fheUintType, metadataType fheUintType) { + var value uint32 + switch actualType { + case FheUint8: + value = 2 + case FheUint16: + value = 4283 + case FheUint32: + value = 1333337 + } + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + compact := encryptAndSerializeCompact(value, actualType) + input := append(compact, byte(metadataType)) + _, err := c.Run(state, addr, addr, input, readOnly) + if err == nil { + t.Fatalf("verifyCiphertext must have failed on type mismatch") + } + if len(state.interpreter.verifiedCiphertexts) != 0 { + t.Fatalf("verifyCiphertext mustn't have verified given ciphertext") + } +} + +func TrivialEncrypt(t *testing.T, fheUintType fheUintType) { + var value big.Int + switch fheUintType { + case FheUint8: + value = *big.NewInt(2) + case FheUint16: + value = *big.NewInt(4283) + case FheUint32: + value = *big.NewInt(1333337) + } + c := &trivialEncrypt{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + valueBytes := make([]byte, 32) + input := append(value.FillBytes(valueBytes), byte(fheUintType)) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + ct := new(tfheCiphertext).trivialEncrypt(value, fheUintType) + if common.BytesToHash(out) != ct.getHash() { + t.Fatalf("output hash in verifyCipertext is incorrect") + } + res := getVerifiedCiphertextFromEVM(state.interpreter, ct.getHash()) + if res == nil { + t.Fatalf("verifyCiphertext must have verified given ciphertext") + } +} + +func FheLibAdd(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs + rhs + c := &fheLib{} + signature := "fheAdd(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibSub(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs - rhs + c := &fheLib{} + signature := "fheSub(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibMul(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 3 + rhs = 2 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs * rhs + c := &fheLib{} + signature := "fheMul(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibLe(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheLib{} + signature := "fheLe(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + + // lhs <= rhs + input1 := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs <= lhs + input2 := toLibPrecompileInput(signature, false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + } +} + +func FheLibLt(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + + c := &fheLib{} + signature := "fheLt(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + + // lhs < rhs + input1 := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs < lhs + input2 := toLibPrecompileInput(signature, false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + } +} + +func FheLibEq(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheLib{} + signature := "fheLt(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + // lhs == rhs + input1 := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } +} + +func FheLibGe(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheLib{} + signature := "fheGe(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + // lhs >= rhs + input1 := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs >= lhs + input2 := toLibPrecompileInput(signature, false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } + } +} + +func FheLibGt(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + + c := &fheLib{} + signature := "fheGt(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + // lhs > rhs + input1 := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + // rhs > lhs + input2 := toLibPrecompileInput(signature, false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 0 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0) + } + } +} + +func FheLibShl(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 2 + case FheUint32: + lhs = 1333337 + rhs = 3 + } + expected := lhs << rhs + c := &fheLib{} + signature := "fheShl(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibShr(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 2 + case FheUint32: + lhs = 1333337 + rhs = 3 + } + expected := lhs >> rhs + c := &fheLib{} + signature := "fheShr(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibNe(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + c := &fheLib{} + signature := "fheNe(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + // lhs == rhs + input1 := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input1, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != 1 { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1) + } +} + +func FheLibMin(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + + c := &fheLib{} + signature := "fheMin(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != rhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), rhs) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + input2 := toLibPrecompileInput(signature, false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != rhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), rhs) + } + } +} + +func FheLibMax(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + + c := &fheLib{} + signature := "fheMax(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != lhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), lhs) + } + + // Inverting operands is only possible in the non scalar case as scalar + // operators expect the scalar to be on the rhs. + if !scalar { + input2 := toLibPrecompileInput(signature, false, rhsHash, lhsHash) + out, err = c.Run(state, addr, addr, input2, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res = getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err = res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != lhs { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), lhs) + } + } +} + +func FheLibNeg(t *testing.T, fheUintType fheUintType) { + var pt, expected uint64 + switch fheUintType { + case FheUint8: + pt = 2 + expected = uint64(-uint8(pt)) + case FheUint16: + pt = 4283 + expected = uint64(-uint16(pt)) + case FheUint32: + pt = 1333337 + expected = uint64(-uint32(pt)) + } + + c := &fheLib{} + signature := "fheNeg(uint256)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + ptHash := verifyCiphertextInTestMemory(state.interpreter, pt, depth, fheUintType).getHash() + + input := toLibPrecompileInputNoScalar(signature, ptHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibNot(t *testing.T, fheUintType fheUintType) { + var pt, expected uint64 + switch fheUintType { + case FheUint8: + pt = 2 + expected = uint64(^uint8(pt)) + case FheUint16: + pt = 4283 + expected = uint64(^uint16(pt)) + case FheUint32: + pt = 1333337 + expected = uint64(^uint32(pt)) + } + + c := &fheLib{} + signature := "fheNot(uint256)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + ptHash := verifyCiphertextInTestMemory(state.interpreter, pt, depth, fheUintType).getHash() + + input := toLibPrecompileInputNoScalar(signature, ptHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } +} + +func FheLibDiv(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 4 + rhs = 2 + case FheUint16: + lhs = 721 + rhs = 1000 + case FheUint32: + lhs = 137 + rhs = 17 + } + expected := lhs / rhs + c := &fheLib{} + signature := "fheDiv(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } + } else { + if err == nil { + t.Fatal("Non scalar multiplication should fail") + } + } +} + +func FheLibBitAnd(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs & rhs + c := &fheLib{} + signature := "fheBitAnd(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err == nil { + t.Fatalf("scalar bit and should have failed") + } + } else { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } + } +} + +func FheLibBitOr(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 + case FheUint16: + lhs = 4283 + rhs = 1337 + case FheUint32: + lhs = 1333337 + rhs = 133337 + } + expected := lhs | rhs + c := &fheLib{} + signature := "fheBitOr(uint256,uint256,bytes1)" + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + } + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err == nil { + t.Fatalf("scalar bit or should have failed") + } + } else { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } + } +} + +func FheLibBitXor(t *testing.T, fheUintType fheUintType, scalar bool) { + var lhs, rhs uint64 + switch fheUintType { + case FheUint8: + lhs = 2 + rhs = 1 case FheUint16: - value = 4283 + lhs = 4283 + rhs = 1337 case FheUint32: - value = 1333337 + lhs = 1333337 + rhs = 133337 } - c := &verifyCiphertext{} + expected := lhs ^ rhs + c := &fheLib{} + signature := "fheBitXor(uint256,uint256,bytes1)" depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false - compact := encryptAndSerializeCompact(value, fheUintType) - input := append(compact, byte(fheUintType)) - out, err := c.Run(state, addr, addr, input, readOnly) - if err != nil { - t.Fatalf(err.Error()) - } - ct := new(tfheCiphertext) - if err = ct.deserializeCompact(compact, fheUintType); err != nil { - t.Fatalf(err.Error()) - } - if common.BytesToHash(out) != ct.getHash() { - t.Fatalf("output hash in verifyCipertext is incorrect") + lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() + var rhsHash common.Hash + if scalar { + rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) + } else { + rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() } - res := getVerifiedCiphertextFromEVM(state.interpreter, ct.getHash()) - if res == nil { - t.Fatalf("verifyCiphertext must have verified given ciphertext") + input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + out, err := c.Run(state, addr, addr, input, readOnly) + if scalar { + if err == nil { + t.Fatalf("scalar bit xor should have failed") + } + } else { + if err != nil { + t.Fatalf(err.Error()) + } + res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) + if res == nil { + t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + } + decrypted, err := res.ciphertext.decrypt() + if err != nil || decrypted.Uint64() != expected { + t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) + } } } -func VerifyCiphertextBadType(t *testing.T, actualType fheUintType, metadataType fheUintType) { - var value uint32 - switch actualType { - case FheUint8: - value = 2 - case FheUint16: - value = 4283 - case FheUint32: - value = 1333337 - } - c := &verifyCiphertext{} +func FheLibRand(t *testing.T, fheUintType fheUintType) { + c := &fheLib{} + signature := "fheRand(bytes1)" depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false - compact := encryptAndSerializeCompact(value, actualType) - input := append(compact, byte(metadataType)) - _, err := c.Run(state, addr, addr, input, readOnly) - if err == nil { - t.Fatalf("verifyCiphertext must have failed on type mismatch") + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] + input := make([]byte, 0) + input = append(input, signatureBytes...) + input = append(input, byte(fheUintType)) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } else if len(out) != 32 { + t.Fatalf("fheRand expected output len of 32, got %v", len(out)) } - if len(state.interpreter.verifiedCiphertexts) != 0 { - t.Fatalf("verifyCiphertext mustn't have verified given ciphertext") + if len(state.interpreter.verifiedCiphertexts) != 1 { + t.Fatalf("fheRand expected 1 verified ciphertext") + } + + hash := common.BytesToHash(out) + _, err = state.interpreter.verifiedCiphertexts[hash].ciphertext.decrypt() + if err != nil { + t.Fatalf(err.Error()) } } -func TrivialEncrypt(t *testing.T, fheUintType fheUintType) { +func LibTrivialEncrypt(t *testing.T, fheUintType fheUintType) { var value big.Int switch fheUintType { case FheUint8: @@ -563,14 +1584,20 @@ func TrivialEncrypt(t *testing.T, fheUintType fheUintType) { case FheUint32: value = *big.NewInt(1333337) } - c := &trivialEncrypt{} + c := &fheLib{} + signature := "trivialEncrypt(uint256,bytes1)" + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false valueBytes := make([]byte, 32) - input := append(value.FillBytes(valueBytes), byte(fheUintType)) + input := make([]byte, 0) + input = append(input, signatureBytes...) + input = append(input, value.FillBytes(valueBytes)...) + input = append(input, byte(fheUintType)) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -585,46 +1612,143 @@ func TrivialEncrypt(t *testing.T, fheUintType fheUintType) { } } -func FheLibAdd(t *testing.T, fheUintType fheUintType, scalar bool) { - var lhs, rhs uint64 +func LibDecrypt(t *testing.T, fheUintType fheUintType) { + var value uint64 switch fheUintType { case FheUint8: - lhs = 2 - rhs = 1 + value = 2 case FheUint16: - lhs = 4283 - rhs = 1337 + value = 4283 case FheUint32: - lhs = 1333337 - rhs = 133337 + value = 1333337 } - expected := lhs + rhs c := &fheLib{} - signature := "fheAdd(uint256,uint256,bytes1)" + signature := "decrypt(uint256)" + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] depth := 1 state := newTestState() state.interpreter.evm.depth = depth addr := common.Address{} readOnly := false - lhsHash := verifyCiphertextInTestMemory(state.interpreter, lhs, depth, fheUintType).getHash() - var rhsHash common.Hash - if scalar { - rhsHash = common.BytesToHash(big.NewInt(int64(rhs)).Bytes()) - } else { - rhsHash = verifyCiphertextInTestMemory(state.interpreter, rhs, depth, fheUintType).getHash() + input := make([]byte, 0) + hash := verifyCiphertextInTestMemory(state.interpreter, value, depth, fheUintType).getHash() + input = append(input, signatureBytes...) + input = append(input, hash.Bytes()...) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } else if len(out) != 32 { + t.Fatalf("decrypt expected output len of 32, got %v", len(out)) } - input := toLibPrecompileInput(signature, scalar, lhsHash, rhsHash) + result := big.Int{} + result.SetBytes(out) + if result.Uint64() != value { + t.Fatalf("decrypt result not equal to value, result %v != value %v", result.Uint64(), value) + } +} + +func TestLibVerifyCiphertextInvalidType(t *testing.T) { + c := &fheLib{} + signature := "verifyCiphertext(bytes)" + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + invalidType := fheUintType(255) + input := make([]byte, 0) + input = append(input, signatureBytes...) + compact := encryptAndSerializeCompact(0, FheUint32) + input = append(input, compact...) + input = append(input, byte(invalidType)) + _, err := c.Run(state, addr, addr, input, readOnly) + if err == nil { + t.Fatalf("verifyCiphertext must have failed on invalid ciphertext type") + } + + if !strings.Contains(err.Error(), "ciphertext type is invalid") { + t.Fatalf("Unexpected test error: %s", err.Error()) + } +} + +func TestLibReencrypt(t *testing.T) { + c := &fheLib{} + signature := "reencrypt(uint256,uint256)" + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + state.interpreter.evm.EthCall = true + toEncrypt := 7 + fheUintType := FheUint8 + encCiphertext := verifyCiphertextInTestMemory(state.interpreter, uint64(toEncrypt), depth, fheUintType).getHash() + addr := common.Address{} + readOnly := false + input := make([]byte, 0) + input = append(input, signatureBytes...) + input = append(input, encCiphertext.Bytes()...) + // just append twice not to generate public key + input = append(input, encCiphertext.Bytes()...) + _, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf("Reencrypt error: %s", err.Error()) + } +} +func TestLibOneTrueOptimisticRequire(t *testing.T) { + var value uint64 = 1 + c := &fheLib{} + signature := "optimisticRequire(uint256)" + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + input := make([]byte, 0) + hash := verifyCiphertextInTestMemory(state.interpreter, value, depth, FheUint8).getHash() + input = append(input, signatureBytes...) + input = append(input, hash.Bytes()...) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) + } else if len(out) != 0 { + t.Fatalf("require expected output len of 0, got %v", len(out)) } - res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) - if res == nil { - t.Fatalf("output ciphertext is not found in verifiedCiphertexts") + // Call the interpreter with a single STOP opcode and expect that the optimistic require doesn't revert. + out, err = state.interpreter.Run(newStopOpcodeContract(), make([]byte, 0), readOnly) + if err != nil { + t.Fatalf(err.Error()) + } else if out != nil { + t.Fatalf("expected empty response") } - decrypted, err := res.ciphertext.decrypt() - if err != nil || decrypted.Uint64() != expected { - t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected) +} + +func TestLibCast(t *testing.T) { + c := &fheLib{} + signature := "cast(uint256,bytes1)" + hashRes := crypto.Keccak256([]byte(signature)) + signatureBytes := hashRes[0:4] + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + state.interpreter.evm.EthCall = true + toEncrypt := 7 + fheUintType := FheUint8 + encCiphertext := verifyCiphertextInTestMemory(state.interpreter, uint64(toEncrypt), depth, fheUintType).getHash() + addr := common.Address{} + readOnly := false + input := make([]byte, 0) + input = append(input, signatureBytes...) + input = append(input, encCiphertext.Bytes()...) + input = append(input, byte(FheUint32)) + _, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf("Reencrypt error: %s", err.Error()) } } @@ -1908,6 +3032,90 @@ func TestFheLibAdd8(t *testing.T) { FheLibAdd(t, FheUint8, false) } +func TestFheLibSub8(t *testing.T) { + FheLibSub(t, FheUint8, false) +} + +func TestFheLibMul8(t *testing.T) { + FheLibMul(t, FheUint8, false) +} + +func TestFheLibLe8(t *testing.T) { + FheLibLe(t, FheUint8, false) +} + +func TestFheLibLt8(t *testing.T) { + FheLibLt(t, FheUint8, false) +} + +func TestFheLibEq8(t *testing.T) { + FheLibEq(t, FheUint8, false) +} + +func TestFheLibGe8(t *testing.T) { + FheLibGe(t, FheUint8, false) +} + +func TestFheLibGt8(t *testing.T) { + FheLibGt(t, FheUint8, false) +} + +func TestFheLibShl8(t *testing.T) { + FheLibShl(t, FheUint8, false) +} + +func TestFheLibShr8(t *testing.T) { + FheLibShr(t, FheUint8, false) +} + +func TestFheLibNe8(t *testing.T) { + FheLibNe(t, FheUint8, false) +} + +func TestFheLibMin8(t *testing.T) { + FheLibMin(t, FheUint8, false) +} + +func TestFheLibMax8(t *testing.T) { + FheLibMax(t, FheUint8, false) +} + +func TestFheLibNeg8(t *testing.T) { + FheLibNeg(t, FheUint8) +} + +func TestFheLibNot8(t *testing.T) { + FheLibNot(t, FheUint8) +} + +func TestFheLibDiv8(t *testing.T) { + FheLibDiv(t, FheUint8, true) +} + +func TestFheLibBitAnd8(t *testing.T) { + FheLibBitAnd(t, FheUint8, false) +} + +func TestFheLibBitOr8(t *testing.T) { + FheLibBitOr(t, FheUint8, false) +} + +func TestFheLibBitXor8(t *testing.T) { + FheLibBitXor(t, FheUint8, false) +} + +func TestFheLibRand8(t *testing.T) { + FheLibRand(t, FheUint8) +} + +func TestFheLibTrivialEncrypt8(t *testing.T) { + LibTrivialEncrypt(t, FheUint8) +} + +func TestLibDecrypt8(t *testing.T) { + LibDecrypt(t, FheUint8) +} + func TestFheAdd8(t *testing.T) { FheAdd(t, FheUint8, false) }