diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/bug.md similarity index 50% rename from .github/ISSUE_TEMPLATE.md rename to .github/ISSUE_TEMPLATE/bug.md index 59285e456d71..c5a3654bde61 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -1,8 +1,10 @@ -Hi there, - -Please note that this is an issue tracker reserved for bug reports and feature requests. - -For general questions please use [discord](https://discord.gg/nthXNEv) or the Ethereum stack exchange at https://ethereum.stackexchange.com. +--- +name: Report a bug +about: Something with go-ethereum is not working as expected +title: '' +labels: 'type:bug' +assignees: '' +--- #### System information diff --git a/.github/ISSUE_TEMPLATE/feature.md b/.github/ISSUE_TEMPLATE/feature.md new file mode 100644 index 000000000000..aacd885f9e5e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature.md @@ -0,0 +1,17 @@ +--- +name: Request a feature +about: Report a missing feature - e.g. as a step before submitting a PR +title: '' +labels: 'type:feature' +assignees: '' +--- + +# Rationale + +Why should this feature exist? +What are the use-cases? + +# Implementation + +Do you have ideas regarding the implementation of this feature? +Are you willing to implement this feature? \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 000000000000..8f460ab558ec --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,9 @@ +--- +name: Ask a question +about: Something is unclear +title: '' +labels: 'type:docs' +assignees: '' +--- + +This should only be used in very rare cases e.g. if you are not 100% sure if something is a bug or asking a question that leads to improving the documentation. For general questions please use [discord](https://discord.gg/nthXNEv) or the Ethereum stack exchange at https://ethereum.stackexchange.com. diff --git a/.travis.yml b/.travis.yml index 245f52b362a7..1268c6d65724 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,69 +24,6 @@ jobs: script: - go run build/ci.go lint - - stage: build - os: linux - dist: xenial - go: 1.13.x - env: - - GO111MODULE=on - script: - - go run build/ci.go install - - go run build/ci.go test -coverage $TEST_PACKAGES - - - stage: build - os: linux - dist: xenial - go: 1.14.x - env: - - GO111MODULE=on - script: - - go run build/ci.go install - - go run build/ci.go test -coverage $TEST_PACKAGES - - # These are the latest Go versions. - - stage: build - os: linux - arch: amd64 - dist: xenial - go: 1.15.x - env: - - GO111MODULE=on - script: - - go run build/ci.go install - - go run build/ci.go test -coverage $TEST_PACKAGES - - - stage: build - if: type = pull_request - os: linux - arch: arm64 - dist: xenial - go: 1.15.x - env: - - GO111MODULE=on - script: - - go run build/ci.go install - - go run build/ci.go test -coverage $TEST_PACKAGES - - - stage: build - os: osx - osx_image: xcode11.3 - go: 1.15.x - env: - - GO111MODULE=on - script: - - echo "Increase the maximum number of open file descriptors on macOS" - - NOFILE=20480 - - sudo sysctl -w kern.maxfiles=$NOFILE - - sudo sysctl -w kern.maxfilesperproc=$NOFILE - - sudo launchctl limit maxfiles $NOFILE $NOFILE - - sudo launchctl limit maxfiles - - ulimit -S -n $NOFILE - - ulimit -n - - unset -f cd # workaround for https://github.com/travis-ci/travis-ci/issues/8703 - - go run build/ci.go install - - go run build/ci.go test -coverage $TEST_PACKAGES - # This builder does the Ubuntu PPA upload - stage: build if: type = push @@ -109,7 +46,7 @@ jobs: - python-paramiko script: - echo '|1|7SiYPr9xl3uctzovOTj4gMwAC1M=|t6ReES75Bo/PxlOPJ6/GsGbTrM0= ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEA0aKz5UTUndYgIGG7dQBV+HaeuEZJ2xPHo2DS2iSKvUL4xNMSAY4UguNW+pX56nAQmZKIZZ8MaEvSj6zMEDiq6HFfn5JcTlM80UwlnyKe8B8p7Nk06PPQLrnmQt5fh0HmEcZx+JU9TZsfCHPnX7MNz4ELfZE6cFsclClrKim3BHUIGq//t93DllB+h4O9LHjEUsQ1Sr63irDLSutkLJD6RXchjROXkNirlcNVHH/jwLWR5RcYilNX7S5bIkK8NlWPjsn/8Ua5O7I9/YoE97PpO6i73DTGLh5H9JN/SITwCKBkgSDWUt61uPK3Y11Gty7o2lWsBjhBUm2Y38CBsoGmBw==' >> ~/.ssh/known_hosts - - go run build/ci.go debsrc -goversion 1.15 -upload ethereum/ethereum -sftp-user geth-ci -signer "Go Ethereum Linux Builder " + - go run build/ci.go debsrc -upload ethereum/ethereum -sftp-user geth-ci -signer "Go Ethereum Linux Builder " # This builder does the Linux Azure uploads - stage: build @@ -129,23 +66,23 @@ jobs: - gcc-multilib script: # Build for the primary platforms that Trusty can manage - - go run build/ci.go install - - go run build/ci.go archive -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds - - go run build/ci.go install -arch 386 - - go run build/ci.go archive -arch 386 -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds + - go run build/ci.go install -dlgo + - go run build/ci.go archive -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds + - go run build/ci.go install -dlgo -arch 386 + - go run build/ci.go archive -arch 386 -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds # Switch over GCC to cross compilation (breaks 386, hence why do it here only) - sudo -E apt-get -yq --no-install-suggests --no-install-recommends --force-yes install gcc-arm-linux-gnueabi libc6-dev-armel-cross gcc-arm-linux-gnueabihf libc6-dev-armhf-cross gcc-aarch64-linux-gnu libc6-dev-arm64-cross - sudo ln -s /usr/include/asm-generic /usr/include/asm - - GOARM=5 go run build/ci.go install -arch arm -cc arm-linux-gnueabi-gcc - - GOARM=5 go run build/ci.go archive -arch arm -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds - - GOARM=6 go run build/ci.go install -arch arm -cc arm-linux-gnueabi-gcc - - GOARM=6 go run build/ci.go archive -arch arm -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds - - GOARM=7 go run build/ci.go install -arch arm -cc arm-linux-gnueabihf-gcc - - GOARM=7 go run build/ci.go archive -arch arm -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds - - go run build/ci.go install -arch arm64 -cc aarch64-linux-gnu-gcc - - go run build/ci.go archive -arch arm64 -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds + - GOARM=5 go run build/ci.go install -dlgo -arch arm -cc arm-linux-gnueabi-gcc + - GOARM=5 go run build/ci.go archive -arch arm -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds + - GOARM=6 go run build/ci.go install -dlgo -arch arm -cc arm-linux-gnueabi-gcc + - GOARM=6 go run build/ci.go archive -arch arm -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds + - GOARM=7 go run build/ci.go install -dlgo -arch arm -cc arm-linux-gnueabihf-gcc + - GOARM=7 go run build/ci.go archive -arch arm -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds + - go run build/ci.go install -dlgo -arch arm64 -cc aarch64-linux-gnu-gcc + - go run build/ci.go archive -arch arm64 -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds # This builder does the Linux Azure MIPS xgo uploads - stage: build @@ -163,19 +100,19 @@ jobs: script: - go run build/ci.go xgo --alltools -- --targets=linux/mips --ldflags '-extldflags "-static"' -v - for bin in build/bin/*-linux-mips; do mv -f "${bin}" "${bin/-linux-mips/}"; done - - go run build/ci.go archive -arch mips -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds + - go run build/ci.go archive -arch mips -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds - go run build/ci.go xgo --alltools -- --targets=linux/mipsle --ldflags '-extldflags "-static"' -v - for bin in build/bin/*-linux-mipsle; do mv -f "${bin}" "${bin/-linux-mipsle/}"; done - - go run build/ci.go archive -arch mipsle -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds + - go run build/ci.go archive -arch mipsle -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds - go run build/ci.go xgo --alltools -- --targets=linux/mips64 --ldflags '-extldflags "-static"' -v - for bin in build/bin/*-linux-mips64; do mv -f "${bin}" "${bin/-linux-mips64/}"; done - - go run build/ci.go archive -arch mips64 -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds + - go run build/ci.go archive -arch mips64 -type tar -signer LINUX_SIGNING_KEY signify SIGNIFY_KEY -upload gethstore/builds - go run build/ci.go xgo --alltools -- --targets=linux/mips64le --ldflags '-extldflags "-static"' -v - for bin in build/bin/*-linux-mips64le; do mv -f "${bin}" "${bin/-linux-mips64le/}"; done - - go run build/ci.go archive -arch mips64le -type tar -signer LINUX_SIGNING_KEY -upload gethstore/builds + - go run build/ci.go archive -arch mips64le -type tar -signer LINUX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds # This builder does the Android Maven and Azure uploads - stage: build @@ -202,7 +139,7 @@ jobs: git: submodules: false # avoid cloning ethereum/tests before_install: - - curl https://dl.google.com/go/go1.15.linux-amd64.tar.gz | tar -xz + - curl https://dl.google.com/go/go1.15.5.linux-amd64.tar.gz | tar -xz - export PATH=`pwd`/go/bin:$PATH - export GOROOT=`pwd`/go - export GOPATH=$HOME/go @@ -214,7 +151,7 @@ jobs: - mkdir -p $GOPATH/src/github.com/ethereum - ln -s `pwd` $GOPATH/src/github.com/ethereum/go-ethereum - - go run build/ci.go aar -signer ANDROID_SIGNING_KEY -deploy https://oss.sonatype.org -upload gethstore/builds + - go run build/ci.go aar -signer ANDROID_SIGNING_KEY -signify SIGNIFY_KEY -deploy https://oss.sonatype.org -upload gethstore/builds # This builder does the OSX Azure, iOS CocoaPods and iOS Azure uploads - stage: build @@ -229,8 +166,8 @@ jobs: git: submodules: false # avoid cloning ethereum/tests script: - - go run build/ci.go install - - go run build/ci.go archive -type tar -signer OSX_SIGNING_KEY -upload gethstore/builds + - go run build/ci.go install -dlgo + - go run build/ci.go archive -type tar -signer OSX_SIGNING_KEY -signify SIGNIFY_KEY -upload gethstore/builds # Build the iOS framework and upload it to CocoaPods and Azure - gem uninstall cocoapods -a -x @@ -245,7 +182,38 @@ jobs: # Workaround for https://github.com/golang/go/issues/23749 - export CGO_CFLAGS_ALLOW='-fmodules|-fblocks|-fobjc-arc' - - go run build/ci.go xcode -signer IOS_SIGNING_KEY -deploy trunk -upload gethstore/builds + - go run build/ci.go xcode -signer IOS_SIGNING_KEY -signify SIGNIFY_KEY -deploy trunk -upload gethstore/builds + + # These builders run the tests + - stage: build + os: linux + arch: amd64 + dist: xenial + go: 1.15.x + env: + - GO111MODULE=on + script: + - go run build/ci.go test -coverage $TEST_PACKAGES + + - stage: build + if: type = pull_request + os: linux + arch: arm64 + dist: xenial + go: 1.15.x + env: + - GO111MODULE=on + script: + - go run build/ci.go test -coverage $TEST_PACKAGES + + - stage: build + os: linux + dist: xenial + go: 1.14.x + env: + - GO111MODULE=on + script: + - go run build/ci.go test -coverage $TEST_PACKAGES # This builder does the Azure archive purges to avoid accumulating junk - stage: build diff --git a/Makefile b/Makefile index 67095f4d006e..ecee5f189435 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,9 @@ android: $(GORUN) build/ci.go aar --local @echo "Done building." @echo "Import \"$(GOBIN)/geth.aar\" to use the library." - + @echo "Import \"$(GOBIN)/geth-sources.jar\" to add javadocs" + @echo "For more info see https://stackoverflow.com/questions/20994336/android-studio-how-to-attach-javadoc" + ios: $(GORUN) build/ci.go xcode --local @echo "Done building." diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 38196ed25c88..5ca7c241db95 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -80,36 +80,56 @@ func (abi ABI) Pack(name string, args ...interface{}) ([]byte, error) { return append(method.ID, arguments...), nil } -// Unpack output in v according to the abi specification. -func (abi ABI) Unpack(v interface{}, name string, data []byte) (err error) { +func (abi ABI) getArguments(name string, data []byte) (Arguments, error) { // since there can't be naming collisions with contracts and events, // we need to decide whether we're calling a method or an event + var args Arguments if method, ok := abi.Methods[name]; ok { if len(data)%32 != 0 { - return fmt.Errorf("abi: improperly formatted output: %s - Bytes: [%+v]", string(data), data) + return nil, fmt.Errorf("abi: improperly formatted output: %s - Bytes: [%+v]", string(data), data) } - return method.Outputs.Unpack(v, data) + args = method.Outputs } if event, ok := abi.Events[name]; ok { - return event.Inputs.Unpack(v, data) + args = event.Inputs } - return fmt.Errorf("abi: could not locate named method or event") + if args == nil { + return nil, errors.New("abi: could not locate named method or event") + } + return args, nil +} + +// Unpack unpacks the output according to the abi specification. +func (abi ABI) Unpack(name string, data []byte) ([]interface{}, error) { + args, err := abi.getArguments(name, data) + if err != nil { + return nil, err + } + return args.Unpack(data) +} + +// UnpackIntoInterface unpacks the output in v according to the abi specification. +// It performs an additional copy. Please only use, if you want to unpack into a +// structure that does not strictly conform to the abi structure (e.g. has additional arguments) +func (abi ABI) UnpackIntoInterface(v interface{}, name string, data []byte) error { + args, err := abi.getArguments(name, data) + if err != nil { + return err + } + unpacked, err := args.Unpack(data) + if err != nil { + return err + } + return args.Copy(v, unpacked) } // UnpackIntoMap unpacks a log into the provided map[string]interface{}. func (abi ABI) UnpackIntoMap(v map[string]interface{}, name string, data []byte) (err error) { - // since there can't be naming collisions with contracts and events, - // we need to decide whether we're calling a method or an event - if method, ok := abi.Methods[name]; ok { - if len(data)%32 != 0 { - return fmt.Errorf("abi: improperly formatted output") - } - return method.Outputs.UnpackIntoMap(v, data) - } - if event, ok := abi.Events[name]; ok { - return event.Inputs.UnpackIntoMap(v, data) + args, err := abi.getArguments(name, data) + if err != nil { + return err } - return fmt.Errorf("abi: could not locate named method or event") + return args.UnpackIntoMap(v, data) } // UnmarshalJSON implements json.Unmarshaler interface. @@ -250,10 +270,10 @@ func UnpackRevert(data []byte) (string, error) { if !bytes.Equal(data[:4], revertSelector) { return "", errors.New("invalid data for unpacking") } - var reason string typ, _ := NewType("string", "", nil) - if err := (Arguments{{Type: typ}}).Unpack(&reason, data[4:]); err != nil { + unpacked, err := (Arguments{{Type: typ}}).Unpack(data[4:]) + if err != nil { return "", err } - return reason, nil + return unpacked[0].(string), nil } diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index f7d7f7aa620f..ad8acdf52229 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -181,18 +181,15 @@ func TestConstructor(t *testing.T) { if err != nil { t.Error(err) } - v := struct { - A *big.Int - B *big.Int - }{new(big.Int), new(big.Int)} - //abi.Unpack(&v, "", packed) - if err := abi.Constructor.Inputs.Unpack(&v, packed); err != nil { + unpacked, err := abi.Constructor.Inputs.Unpack(packed) + if err != nil { t.Error(err) } - if !reflect.DeepEqual(v.A, big.NewInt(1)) { + + if !reflect.DeepEqual(unpacked[0], big.NewInt(1)) { t.Error("Unable to pack/unpack from constructor") } - if !reflect.DeepEqual(v.B, big.NewInt(2)) { + if !reflect.DeepEqual(unpacked[1], big.NewInt(2)) { t.Error("Unable to pack/unpack from constructor") } } @@ -743,7 +740,7 @@ func TestUnpackEvent(t *testing.T) { } var ev ReceivedEvent - err = abi.Unpack(&ev, "received", data) + err = abi.UnpackIntoInterface(&ev, "received", data) if err != nil { t.Error(err) } @@ -752,7 +749,7 @@ func TestUnpackEvent(t *testing.T) { Sender common.Address } var receivedAddrEv ReceivedAddrEvent - err = abi.Unpack(&receivedAddrEv, "receivedAddr", data) + err = abi.UnpackIntoInterface(&receivedAddrEv, "receivedAddr", data) if err != nil { t.Error(err) } diff --git a/accounts/abi/argument.go b/accounts/abi/argument.go index aaef3bd41cdf..e6d52455965d 100644 --- a/accounts/abi/argument.go +++ b/accounts/abi/argument.go @@ -76,28 +76,20 @@ func (arguments Arguments) isTuple() bool { } // Unpack performs the operation hexdata -> Go format. -func (arguments Arguments) Unpack(v interface{}, data []byte) error { +func (arguments Arguments) Unpack(data []byte) ([]interface{}, error) { if len(data) == 0 { if len(arguments) != 0 { - return fmt.Errorf("abi: attempting to unmarshall an empty string while arguments are expected") + return nil, fmt.Errorf("abi: attempting to unmarshall an empty string while arguments are expected") } - return nil // Nothing to unmarshal, return - } - // make sure the passed value is arguments pointer - if reflect.Ptr != reflect.ValueOf(v).Kind() { - return fmt.Errorf("abi: Unpack(non-pointer %T)", v) - } - marshalledValues, err := arguments.UnpackValues(data) - if err != nil { - return err - } - if len(marshalledValues) == 0 { - return fmt.Errorf("abi: Unpack(no-values unmarshalled %T)", v) - } - if arguments.isTuple() { - return arguments.unpackTuple(v, marshalledValues) + // Nothing to unmarshal, return default variables + nonIndexedArgs := arguments.NonIndexed() + defaultVars := make([]interface{}, len(nonIndexedArgs)) + for index, arg := range nonIndexedArgs { + defaultVars[index] = reflect.New(arg.Type.GetType()) + } + return defaultVars, nil } - return arguments.unpackAtomic(v, marshalledValues[0]) + return arguments.UnpackValues(data) } // UnpackIntoMap performs the operation hexdata -> mapping of argument name to argument value. @@ -122,8 +114,26 @@ func (arguments Arguments) UnpackIntoMap(v map[string]interface{}, data []byte) return nil } -// unpackAtomic unpacks ( hexdata -> go ) a single value. -func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues interface{}) error { +// Copy performs the operation go format -> provided struct. +func (arguments Arguments) Copy(v interface{}, values []interface{}) error { + // make sure the passed value is arguments pointer + if reflect.Ptr != reflect.ValueOf(v).Kind() { + return fmt.Errorf("abi: Unpack(non-pointer %T)", v) + } + if len(values) == 0 { + if len(arguments) != 0 { + return fmt.Errorf("abi: attempting to copy no values while %d arguments are expected", len(arguments)) + } + return nil // Nothing to copy, return + } + if arguments.isTuple() { + return arguments.copyTuple(v, values) + } + return arguments.copyAtomic(v, values[0]) +} + +// unpackAtomic unpacks ( hexdata -> go ) a single value +func (arguments Arguments) copyAtomic(v interface{}, marshalledValues interface{}) error { dst := reflect.ValueOf(v).Elem() src := reflect.ValueOf(marshalledValues) @@ -133,8 +143,8 @@ func (arguments Arguments) unpackAtomic(v interface{}, marshalledValues interfac return set(dst, src) } -// unpackTuple unpacks ( hexdata -> go ) a batch of values. -func (arguments Arguments) unpackTuple(v interface{}, marshalledValues []interface{}) error { +// copyTuple copies a batch of values from marshalledValues to v. +func (arguments Arguments) copyTuple(v interface{}, marshalledValues []interface{}) error { value := reflect.ValueOf(v).Elem() nonIndexedArgs := arguments.NonIndexed() diff --git a/accounts/abi/bind/auth.go b/accounts/abi/bind/auth.go index c891b0a3e910..1190772676c5 100644 --- a/accounts/abi/bind/auth.go +++ b/accounts/abi/bind/auth.go @@ -21,6 +21,7 @@ import ( "errors" "io" "io/ioutil" + "math/big" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/external" @@ -28,11 +29,21 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/log" ) +// ErrNoChainID is returned whenever the user failed to specify a chain id. +var ErrNoChainID = errors.New("no chain id specified") + +// ErrNotAuthorized is returned when an account is not properly unlocked. +var ErrNotAuthorized = errors.New("not authorized to sign this account") + // NewTransactor is a utility method to easily create a transaction signer from // an encrypted json key stream and the associated passphrase. +// +// Deprecated: Use NewTransactorWithChainID instead. func NewTransactor(keyin io.Reader, passphrase string) (*TransactOpts, error) { + log.Warn("WARNING: NewTransactor has been deprecated in favour of NewTransactorWithChainID") json, err := ioutil.ReadAll(keyin) if err != nil { return nil, err @@ -45,13 +56,17 @@ func NewTransactor(keyin io.Reader, passphrase string) (*TransactOpts, error) { } // NewKeyStoreTransactor is a utility method to easily create a transaction signer from -// a decrypted key from a keystore. +// an decrypted key from a keystore. +// +// Deprecated: Use NewKeyStoreTransactorWithChainID instead. func NewKeyStoreTransactor(keystore *keystore.KeyStore, account accounts.Account) (*TransactOpts, error) { + log.Warn("WARNING: NewKeyStoreTransactor has been deprecated in favour of NewTransactorWithChainID") + signer := types.HomesteadSigner{} return &TransactOpts{ From: account.Address, - Signer: func(signer types.Signer, address common.Address, tx *types.Transaction) (*types.Transaction, error) { + Signer: func(address common.Address, tx *types.Transaction) (*types.Transaction, error) { if address != account.Address { - return nil, errors.New("not authorized to sign this account") + return nil, ErrNotAuthorized } signature, err := keystore.SignHash(account, signer.Hash(tx).Bytes()) if err != nil { @@ -64,13 +79,17 @@ func NewKeyStoreTransactor(keystore *keystore.KeyStore, account accounts.Account // NewKeyedTransactor is a utility method to easily create a transaction signer // from a single private key. +// +// Deprecated: Use NewKeyedTransactorWithChainID instead. func NewKeyedTransactor(key *ecdsa.PrivateKey) *TransactOpts { + log.Warn("WARNING: NewKeyedTransactor has been deprecated in favour of NewKeyedTransactorWithChainID") keyAddr := crypto.PubkeyToAddress(key.PublicKey) + signer := types.HomesteadSigner{} return &TransactOpts{ From: keyAddr, - Signer: func(signer types.Signer, address common.Address, tx *types.Transaction) (*types.Transaction, error) { + Signer: func(address common.Address, tx *types.Transaction) (*types.Transaction, error) { if address != keyAddr { - return nil, errors.New("not authorized to sign this account") + return nil, ErrNotAuthorized } signature, err := crypto.Sign(signer.Hash(tx).Bytes(), key) if err != nil { @@ -81,14 +100,73 @@ func NewKeyedTransactor(key *ecdsa.PrivateKey) *TransactOpts { } } +// NewTransactorWithChainID is a utility method to easily create a transaction signer from +// an encrypted json key stream and the associated passphrase. +func NewTransactorWithChainID(keyin io.Reader, passphrase string, chainID *big.Int) (*TransactOpts, error) { + json, err := ioutil.ReadAll(keyin) + if err != nil { + return nil, err + } + key, err := keystore.DecryptKey(json, passphrase) + if err != nil { + return nil, err + } + return NewKeyedTransactorWithChainID(key.PrivateKey, chainID) +} + +// NewKeyStoreTransactorWithChainID is a utility method to easily create a transaction signer from +// an decrypted key from a keystore. +func NewKeyStoreTransactorWithChainID(keystore *keystore.KeyStore, account accounts.Account, chainID *big.Int) (*TransactOpts, error) { + if chainID == nil { + return nil, ErrNoChainID + } + signer := types.NewEIP155Signer(chainID) + return &TransactOpts{ + From: account.Address, + Signer: func(address common.Address, tx *types.Transaction) (*types.Transaction, error) { + if address != account.Address { + return nil, ErrNotAuthorized + } + signature, err := keystore.SignHash(account, signer.Hash(tx).Bytes()) + if err != nil { + return nil, err + } + return tx.WithSignature(signer, signature) + }, + }, nil +} + +// NewKeyedTransactorWithChainID is a utility method to easily create a transaction signer +// from a single private key. +func NewKeyedTransactorWithChainID(key *ecdsa.PrivateKey, chainID *big.Int) (*TransactOpts, error) { + keyAddr := crypto.PubkeyToAddress(key.PublicKey) + if chainID == nil { + return nil, ErrNoChainID + } + signer := types.NewEIP155Signer(chainID) + return &TransactOpts{ + From: keyAddr, + Signer: func(address common.Address, tx *types.Transaction) (*types.Transaction, error) { + if address != keyAddr { + return nil, ErrNotAuthorized + } + signature, err := crypto.Sign(signer.Hash(tx).Bytes(), key) + if err != nil { + return nil, err + } + return tx.WithSignature(signer, signature) + }, + }, nil +} + // NewClefTransactor is a utility method to easily create a transaction signer // with a clef backend. func NewClefTransactor(clef *external.ExternalSigner, account accounts.Account) *TransactOpts { return &TransactOpts{ From: account.Address, - Signer: func(signer types.Signer, address common.Address, transaction *types.Transaction) (*types.Transaction, error) { + Signer: func(address common.Address, transaction *types.Transaction) (*types.Transaction, error) { if address != account.Address { - return nil, errors.New("not authorized to sign this account") + return nil, ErrNotAuthorized } return clef.SignTx(account, transaction, nil) // Clef enforces its own chain id }, diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index c7efca440b1f..6e87e037f05b 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -74,6 +74,7 @@ type SimulatedBackend struct { // NewSimulatedBackendWithDatabase creates a new binding backend based on the given database // and uses a simulated blockchain for testing purposes. +// A simulated backend always uses chainID 1337. func NewSimulatedBackendWithDatabase(database ethdb.Database, alloc core.GenesisAlloc, gasLimit uint64) *SimulatedBackend { genesis := core.Genesis{Config: params.AllEthashProtocolChanges, GasLimit: gasLimit, Alloc: alloc} genesis.MustCommit(database) @@ -91,6 +92,7 @@ func NewSimulatedBackendWithDatabase(database ethdb.Database, alloc core.Genesis // NewSimulatedBackend creates a new binding backend using a simulated blockchain // for testing purposes. +// A simulated backend always uses chainID 1337. func NewSimulatedBackend(alloc core.GenesisAlloc, gasLimit uint64) *SimulatedBackend { return NewSimulatedBackendWithDatabase(rawdb.NewMemoryDatabase(), alloc, gasLimit) } @@ -479,7 +481,7 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMs b.pendingState.RevertToSnapshot(snapshot) if err != nil { - if err == core.ErrIntrinsicGas { + if errors.Is(err, core.ErrIntrinsicGas) { return true, nil, nil // Special case, raise gas limit } return true, nil, err // Bail out @@ -542,10 +544,11 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call ethereum.CallM // Execute the call. msg := callMsg{call} - evmContext := core.NewEVMContext(msg, block.Header(), b.blockchain, nil) + txContext := core.NewEVMTxContext(msg) + evmContext := core.NewEVMBlockContext(block.Header(), b.blockchain, nil) // Create a new environment which holds all relevant information // about the transaction and calling mechanisms. - vmEnv := vm.NewEVM(evmContext, stateDB, b.config, vm.Config{}) + vmEnv := vm.NewEVM(evmContext, txContext, stateDB, b.config, vm.Config{}) gasPool := new(core.GasPool).AddGas(math.MaxUint64) return core.NewStateTransition(vmEnv, msg, gasPool).TransitionDb() diff --git a/accounts/abi/bind/backends/simulated_test.go b/accounts/abi/bind/backends/simulated_test.go index e2597cca0163..64ddf8bb2c3c 100644 --- a/accounts/abi/bind/backends/simulated_test.go +++ b/accounts/abi/bind/backends/simulated_test.go @@ -39,7 +39,7 @@ import ( func TestSimulatedBackend(t *testing.T) { var gasLimit uint64 = 8000029 key, _ := crypto.GenerateKey() // nolint: gosec - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) genAlloc := make(core.GenesisAlloc) genAlloc[auth.From] = core.GenesisAccount{Balance: big.NewInt(9223372036854775807)} @@ -411,7 +411,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { key, _ := crypto.GenerateKey() addr := crypto.PubkeyToAddress(key.PublicKey) - opts := bind.NewKeyedTransactor(key) + opts, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(params.Ether)}}, 10000000) defer sim.Close() @@ -888,7 +888,7 @@ func TestSimulatedBackend_PendingCodeAt(t *testing.T) { if err != nil { t.Errorf("could not get code at test addr: %v", err) } - auth := bind.NewKeyedTransactor(testKey) + auth, _ := bind.NewKeyedTransactorWithChainID(testKey, big.NewInt(1337)) contractAddr, tx, contract, err := bind.DeployContract(auth, parsed, common.FromHex(abiBin), sim) if err != nil { t.Errorf("could not deploy contract: %v tx: %v contract: %v", err, tx, contract) @@ -924,7 +924,7 @@ func TestSimulatedBackend_CodeAt(t *testing.T) { if err != nil { t.Errorf("could not get code at test addr: %v", err) } - auth := bind.NewKeyedTransactor(testKey) + auth, _ := bind.NewKeyedTransactorWithChainID(testKey, big.NewInt(1337)) contractAddr, tx, contract, err := bind.DeployContract(auth, parsed, common.FromHex(abiBin), sim) if err != nil { t.Errorf("could not deploy contract: %v tx: %v contract: %v", err, tx, contract) @@ -956,7 +956,7 @@ func TestSimulatedBackend_PendingAndCallContract(t *testing.T) { if err != nil { t.Errorf("could not get code at test addr: %v", err) } - contractAuth := bind.NewKeyedTransactor(testKey) + contractAuth, _ := bind.NewKeyedTransactorWithChainID(testKey, big.NewInt(1337)) addr, _, _, err := bind.DeployContract(contractAuth, parsed, common.FromHex(abiBin), sim) if err != nil { t.Errorf("could not deploy contract: %v", err) @@ -1043,7 +1043,7 @@ func TestSimulatedBackend_CallContractRevert(t *testing.T) { if err != nil { t.Errorf("could not get code at test addr: %v", err) } - contractAuth := bind.NewKeyedTransactor(testKey) + contractAuth, _ := bind.NewKeyedTransactorWithChainID(testKey, big.NewInt(1337)) addr, _, _, err := bind.DeployContract(contractAuth, parsed, common.FromHex(reverterBin), sim) if err != nil { t.Errorf("could not deploy contract: %v", err) diff --git a/accounts/abi/bind/base.go b/accounts/abi/bind/base.go index d9935e3de726..f5a6fe22fc24 100644 --- a/accounts/abi/bind/base.go +++ b/accounts/abi/bind/base.go @@ -32,7 +32,7 @@ import ( // SignerFn is a signer function callback when a contract requires a method to // sign the transaction before submission. -type SignerFn func(types.Signer, common.Address, *types.Transaction) (*types.Transaction, error) +type SignerFn func(common.Address, *types.Transaction) (*types.Transaction, error) // CallOpts is the collection of options to fine tune a contract call request. type CallOpts struct { @@ -117,11 +117,14 @@ func DeployContract(opts *TransactOpts, abi abi.ABI, bytecode []byte, backend Co // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. -func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string, params ...interface{}) error { +func (c *BoundContract) Call(opts *CallOpts, results *[]interface{}, method string, params ...interface{}) error { // Don't crash on a lazy user if opts == nil { opts = new(CallOpts) } + if results == nil { + results = new([]interface{}) + } // Pack the input, call and unpack the results input, err := c.abi.Pack(method, params...) if err != nil { @@ -149,7 +152,10 @@ func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string, } } else { output, err = c.caller.CallContract(ctx, msg, opts.BlockNumber) - if err == nil && len(output) == 0 { + if err != nil { + return err + } + if len(output) == 0 { // Make sure we have a contract to operate on, and bail out otherwise. if code, err = c.caller.CodeAt(ctx, c.address, opts.BlockNumber); err != nil { return err @@ -158,10 +164,14 @@ func (c *BoundContract) Call(opts *CallOpts, result interface{}, method string, } } } - if err != nil { + + if len(*results) == 0 { + res, err := c.abi.Unpack(method, output) + *results = res return err } - return c.abi.Unpack(result, method, output) + res := *results + return c.abi.UnpackIntoInterface(res[0], method, output) } // Transact invokes the (paid) contract method with params as input values. @@ -246,7 +256,7 @@ func (c *BoundContract) transact(opts *TransactOpts, contract *common.Address, i if opts.Signer == nil { return nil, errors.New("no signer to authorize the transaction with") } - signedTx, err := opts.Signer(types.HomesteadSigner{}, opts.From, rawTx) + signedTx, err := opts.Signer(opts.From, rawTx) if err != nil { return nil, err } @@ -339,7 +349,7 @@ func (c *BoundContract) WatchLogs(opts *WatchOpts, name string, query ...[]inter // UnpackLog unpacks a retrieved log into the provided output structure. func (c *BoundContract) UnpackLog(out interface{}, event string, log types.Log) error { if len(log.Data) > 0 { - if err := c.abi.Unpack(out, event, log.Data); err != nil { + if err := c.abi.UnpackIntoInterface(out, event, log.Data); err != nil { return err } } diff --git a/accounts/abi/bind/base_test.go b/accounts/abi/bind/base_test.go index 7d287850f44b..c4740f68b750 100644 --- a/accounts/abi/bind/base_test.go +++ b/accounts/abi/bind/base_test.go @@ -71,11 +71,10 @@ func TestPassingBlockNumber(t *testing.T) { }, }, }, mc, nil, nil) - var ret string blockNumber := big.NewInt(42) - bc.Call(&bind.CallOpts{BlockNumber: blockNumber}, &ret, "something") + bc.Call(&bind.CallOpts{BlockNumber: blockNumber}, nil, "something") if mc.callContractBlockNumber != blockNumber { t.Fatalf("CallContract() was not passed the block number") @@ -85,7 +84,7 @@ func TestPassingBlockNumber(t *testing.T) { t.Fatalf("CodeAt() was not passed the block number") } - bc.Call(&bind.CallOpts{}, &ret, "something") + bc.Call(&bind.CallOpts{}, nil, "something") if mc.callContractBlockNumber != nil { t.Fatalf("CallContract() was passed a block number when it should not have been") @@ -95,7 +94,7 @@ func TestPassingBlockNumber(t *testing.T) { t.Fatalf("CodeAt() was passed a block number when it should not have been") } - bc.Call(&bind.CallOpts{BlockNumber: blockNumber, Pending: true}, &ret, "something") + bc.Call(&bind.CallOpts{BlockNumber: blockNumber, Pending: true}, nil, "something") if !mc.pendingCallContractCalled { t.Fatalf("CallContract() was not passed the block number") diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index a5f08499d228..4a504516bbfc 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -296,7 +296,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -351,7 +351,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -397,7 +397,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -455,7 +455,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -503,7 +503,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -569,6 +569,45 @@ var bindTests = []struct { nil, nil, }, + { + `NonExistentStruct`, + ` + contract NonExistentStruct { + function Struct() public view returns(uint256 a, uint256 b) { + return (10, 10); + } + } + `, + []string{`6080604052348015600f57600080fd5b5060888061001e6000396000f3fe6080604052348015600f57600080fd5b506004361060285760003560e01c8063d5f6622514602d575b600080fd5b6033604c565b6040805192835260208301919091528051918290030190f35b600a809156fea264697066735822beefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef64736f6c6343decafe0033`}, + []string{`[{"inputs":[],"name":"Struct","outputs":[{"internalType":"uint256","name":"a","type":"uint256"},{"internalType":"uint256","name":"b","type":"uint256"}],"stateMutability":"pure","type":"function"}]`}, + ` + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/accounts/abi/bind/backends" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + `, + ` + // Create a simulator and wrap a non-deployed contract + + sim := backends.NewSimulatedBackend(core.GenesisAlloc{}, uint64(10000000000)) + defer sim.Close() + + nonexistent, err := NewNonExistentStruct(common.Address{}, sim) + if err != nil { + t.Fatalf("Failed to access non-existent contract: %v", err) + } + // Ensure that contract calls fail with the appropriate error + if res, err := nonexistent.Struct(nil); err == nil { + t.Fatalf("Call succeeded on non-existent contract: %v", res) + } else if (err != bind.ErrNoCode) { + t.Fatalf("Error mismatch: have %v, want %v", err, bind.ErrNoCode) + } + `, + nil, + nil, + nil, + nil, + }, // Tests that gas estimation works for contracts with weird gas mechanics too. { `FunkyGasPattern`, @@ -598,7 +637,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -648,7 +687,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -723,7 +762,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -817,7 +856,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -1007,7 +1046,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -1142,7 +1181,7 @@ var bindTests = []struct { ` key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -1284,7 +1323,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -1350,7 +1389,7 @@ var bindTests = []struct { ` // Initialize test accounts key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -1444,7 +1483,7 @@ var bindTests = []struct { sim := backends.NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(1000000000)}}, 10000000) defer sim.Close() - transactOpts := bind.NewKeyedTransactor(key) + transactOpts, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) _, _, _, err := DeployIdentifierCollision(transactOpts, sim) if err != nil { t.Fatalf("failed to deploy contract: %v", err) @@ -1506,7 +1545,7 @@ var bindTests = []struct { sim := backends.NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(1000000000)}}, 10000000) defer sim.Close() - transactOpts := bind.NewKeyedTransactor(key) + transactOpts, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) _, _, c1, err := DeployContractOne(transactOpts, sim) if err != nil { t.Fatal("Failed to deploy contract") @@ -1563,7 +1602,7 @@ var bindTests = []struct { ` // Generate a new random account and a funded simulator key, _ := crypto.GenerateKey() - auth := bind.NewKeyedTransactor(key) + auth, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) sim := backends.NewSimulatedBackend(core.GenesisAlloc{auth.From: {Balance: big.NewInt(10000000000)}}, 10000000) defer sim.Close() @@ -1632,7 +1671,7 @@ var bindTests = []struct { sim := backends.NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(1000000000)}}, 1000000) defer sim.Close() - opts := bind.NewKeyedTransactor(key) + opts, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) _, _, c, err := DeployNewFallbacks(opts, sim) if err != nil { t.Fatalf("Failed to deploy contract: %v", err) @@ -1696,11 +1735,11 @@ func TestGolangBindings(t *testing.T) { t.Skip("go sdk not found for testing") } // Create a temporary workspace for the test suite - ws, err := ioutil.TempDir("", "") + ws, err := ioutil.TempDir("", "binding-test") if err != nil { t.Fatalf("failed to create temporary workspace: %v", err) } - defer os.RemoveAll(ws) + //defer os.RemoveAll(ws) pkg := filepath.Join(ws, "bindtest") if err = os.MkdirAll(pkg, 0700); err != nil { diff --git a/accounts/abi/bind/template.go b/accounts/abi/bind/template.go index f19f1315ae88..351eabd258c8 100644 --- a/accounts/abi/bind/template.go +++ b/accounts/abi/bind/template.go @@ -261,7 +261,7 @@ var ( // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. - func (_{{$contract.Type}} *{{$contract.Type}}Raw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error { + func (_{{$contract.Type}} *{{$contract.Type}}Raw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { return _{{$contract.Type}}.Contract.{{$contract.Type}}Caller.contract.Call(opts, result, method, params...) } @@ -280,7 +280,7 @@ var ( // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. - func (_{{$contract.Type}} *{{$contract.Type}}CallerRaw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error { + func (_{{$contract.Type}} *{{$contract.Type}}CallerRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { return _{{$contract.Type}}.Contract.contract.Call(opts, result, method, params...) } @@ -300,19 +300,26 @@ var ( // // Solidity: {{.Original.String}} func (_{{$contract.Type}} *{{$contract.Type}}Caller) {{.Normalized.Name}}(opts *bind.CallOpts {{range .Normalized.Inputs}}, {{.Name}} {{bindtype .Type $structs}} {{end}}) ({{if .Structured}}struct{ {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}};{{end}} },{{else}}{{range .Normalized.Outputs}}{{bindtype .Type $structs}},{{end}}{{end}} error) { - {{if .Structured}}ret := new(struct{ - {{range .Normalized.Outputs}}{{.Name}} {{bindtype .Type $structs}} - {{end}} - }){{else}}var ( - {{range $i, $_ := .Normalized.Outputs}}ret{{$i}} = new({{bindtype .Type $structs}}) - {{end}} - ){{end}} - out := {{if .Structured}}ret{{else}}{{if eq (len .Normalized.Outputs) 1}}ret0{{else}}&[]interface{}{ - {{range $i, $_ := .Normalized.Outputs}}ret{{$i}}, - {{end}} - }{{end}}{{end}} - err := _{{$contract.Type}}.contract.Call(opts, out, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}}) - return {{if .Structured}}*ret,{{else}}{{range $i, $_ := .Normalized.Outputs}}*ret{{$i}},{{end}}{{end}} err + var out []interface{} + err := _{{$contract.Type}}.contract.Call(opts, &out, "{{.Original.Name}}" {{range .Normalized.Inputs}}, {{.Name}}{{end}}) + {{if .Structured}} + outstruct := new(struct{ {{range .Normalized.Outputs}} {{.Name}} {{bindtype .Type $structs}}; {{end}} }) + if err != nil { + return *outstruct, err + } + {{range $i, $t := .Normalized.Outputs}} + outstruct.{{.Name}} = out[{{$i}}].({{bindtype .Type $structs}}){{end}} + + return *outstruct, err + {{else}} + if err != nil { + return {{range $i, $_ := .Normalized.Outputs}}*new({{bindtype .Type $structs}}), {{end}} err + } + {{range $i, $t := .Normalized.Outputs}} + out{{$i}} := *abi.ConvertType(out[{{$i}}], new({{bindtype .Type $structs}})).(*{{bindtype .Type $structs}}){{end}} + + return {{range $i, $t := .Normalized.Outputs}}out{{$i}}, {{end}} err + {{end}} } // {{.Normalized.Name}} is a free data retrieval call binding the contract method 0x{{printf "%x" .Original.ID}}. @@ -537,6 +544,7 @@ var ( if err := _{{$contract.Type}}.contract.UnpackLog(event, "{{.Original.Name}}", log); err != nil { return nil, err } + event.Raw = log return event, nil } diff --git a/accounts/abi/event_test.go b/accounts/abi/event_test.go index 79504c28ce3f..3332f8a07216 100644 --- a/accounts/abi/event_test.go +++ b/accounts/abi/event_test.go @@ -147,10 +147,6 @@ func TestEventString(t *testing.T) { // TestEventMultiValueWithArrayUnpack verifies that array fields will be counted after parsing array. func TestEventMultiValueWithArrayUnpack(t *testing.T) { definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": false, "name":"value1", "type":"uint8[2]"},{"indexed": false, "name":"value2", "type":"uint8"}]}]` - type testStruct struct { - Value1 [2]uint8 - Value2 uint8 - } abi, err := JSON(strings.NewReader(definition)) require.NoError(t, err) var b bytes.Buffer @@ -158,10 +154,10 @@ func TestEventMultiValueWithArrayUnpack(t *testing.T) { for ; i <= 3; i++ { b.Write(packNum(reflect.ValueOf(i))) } - var rst testStruct - require.NoError(t, abi.Unpack(&rst, "test", b.Bytes())) - require.Equal(t, [2]uint8{1, 2}, rst.Value1) - require.Equal(t, uint8(3), rst.Value2) + unpacked, err := abi.Unpack("test", b.Bytes()) + require.NoError(t, err) + require.Equal(t, [2]uint8{1, 2}, unpacked[0]) + require.Equal(t, uint8(3), unpacked[1]) } func TestEventTupleUnpack(t *testing.T) { @@ -351,14 +347,14 @@ func unpackTestEventData(dest interface{}, hexData string, jsonEvent []byte, ass var e Event assert.NoError(json.Unmarshal(jsonEvent, &e), "Should be able to unmarshal event ABI") a := ABI{Events: map[string]Event{"e": e}} - return a.Unpack(dest, "e", data) + return a.UnpackIntoInterface(dest, "e", data) } // TestEventUnpackIndexed verifies that indexed field will be skipped by event decoder. func TestEventUnpackIndexed(t *testing.T) { definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": true, "name":"value1", "type":"uint8"},{"indexed": false, "name":"value2", "type":"uint8"}]}]` type testStruct struct { - Value1 uint8 + Value1 uint8 // indexed Value2 uint8 } abi, err := JSON(strings.NewReader(definition)) @@ -366,7 +362,7 @@ func TestEventUnpackIndexed(t *testing.T) { var b bytes.Buffer b.Write(packNum(reflect.ValueOf(uint8(8)))) var rst testStruct - require.NoError(t, abi.Unpack(&rst, "test", b.Bytes())) + require.NoError(t, abi.UnpackIntoInterface(&rst, "test", b.Bytes())) require.Equal(t, uint8(0), rst.Value1) require.Equal(t, uint8(8), rst.Value2) } @@ -375,7 +371,7 @@ func TestEventUnpackIndexed(t *testing.T) { func TestEventIndexedWithArrayUnpack(t *testing.T) { definition := `[{"name": "test", "type": "event", "inputs": [{"indexed": true, "name":"value1", "type":"uint8[2]"},{"indexed": false, "name":"value2", "type":"string"}]}]` type testStruct struct { - Value1 [2]uint8 + Value1 [2]uint8 // indexed Value2 string } abi, err := JSON(strings.NewReader(definition)) @@ -388,7 +384,7 @@ func TestEventIndexedWithArrayUnpack(t *testing.T) { b.Write(common.RightPadBytes([]byte(stringOut), 32)) var rst testStruct - require.NoError(t, abi.Unpack(&rst, "test", b.Bytes())) + require.NoError(t, abi.UnpackIntoInterface(&rst, "test", b.Bytes())) require.Equal(t, [2]uint8{0, 0}, rst.Value1) require.Equal(t, stringOut, rst.Value2) } diff --git a/accounts/abi/pack_test.go b/accounts/abi/pack_test.go index 284215a7d7b4..5c7cb1cc1a24 100644 --- a/accounts/abi/pack_test.go +++ b/accounts/abi/pack_test.go @@ -44,18 +44,7 @@ func TestPack(t *testing.T) { t.Fatalf("invalid ABI definition %s, %v", inDef, err) } var packed []byte - if reflect.TypeOf(test.unpacked).Kind() != reflect.Struct { - packed, err = inAbi.Pack("method", test.unpacked) - } else { - // if want is a struct we need to use the components. - elem := reflect.ValueOf(test.unpacked) - var values []interface{} - for i := 0; i < elem.NumField(); i++ { - field := elem.Field(i) - values = append(values, field.Interface()) - } - packed, err = inAbi.Pack("method", values...) - } + packed, err = inAbi.Pack("method", test.unpacked) if err != nil { t.Fatalf("test %d (%v) failed: %v", i, test.def, err) diff --git a/accounts/abi/packing_test.go b/accounts/abi/packing_test.go index 16b4dc43d75d..eae3b0df2056 100644 --- a/accounts/abi/packing_test.go +++ b/accounts/abi/packing_test.go @@ -620,7 +620,7 @@ var packUnpackTests = []packUnpackTest{ { def: `[{"type": "bytes32[]"}]`, - unpacked: []common.Hash{{1}, {2}}, + unpacked: [][32]byte{{1}, {2}}, packed: "0000000000000000000000000000000000000000000000000000000000000020" + "0000000000000000000000000000000000000000000000000000000000000002" + "0100000000000000000000000000000000000000000000000000000000000000" + @@ -722,7 +722,7 @@ var packUnpackTests = []packUnpackTest{ }, // struct outputs { - def: `[{"name":"int1","type":"int256"},{"name":"int2","type":"int256"}]`, + def: `[{"components": [{"name":"int1","type":"int256"},{"name":"int2","type":"int256"}], "type":"tuple"}]`, packed: "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002", unpacked: struct { @@ -731,28 +731,28 @@ var packUnpackTests = []packUnpackTest{ }{big.NewInt(1), big.NewInt(2)}, }, { - def: `[{"name":"int_one","type":"int256"}]`, + def: `[{"components": [{"name":"int_one","type":"int256"}], "type":"tuple"}]`, packed: "0000000000000000000000000000000000000000000000000000000000000001", unpacked: struct { IntOne *big.Int }{big.NewInt(1)}, }, { - def: `[{"name":"int__one","type":"int256"}]`, + def: `[{"components": [{"name":"int__one","type":"int256"}], "type":"tuple"}]`, packed: "0000000000000000000000000000000000000000000000000000000000000001", unpacked: struct { IntOne *big.Int }{big.NewInt(1)}, }, { - def: `[{"name":"int_one_","type":"int256"}]`, + def: `[{"components": [{"name":"int_one_","type":"int256"}], "type":"tuple"}]`, packed: "0000000000000000000000000000000000000000000000000000000000000001", unpacked: struct { IntOne *big.Int }{big.NewInt(1)}, }, { - def: `[{"name":"int_one","type":"int256"}, {"name":"intone","type":"int256"}]`, + def: `[{"components": [{"name":"int_one","type":"int256"}, {"name":"intone","type":"int256"}], "type":"tuple"}]`, packed: "0000000000000000000000000000000000000000000000000000000000000001" + "0000000000000000000000000000000000000000000000000000000000000002", unpacked: struct { @@ -831,11 +831,11 @@ var packUnpackTests = []packUnpackTest{ }, { // static tuple - def: `[{"name":"a","type":"int64"}, + def: `[{"components": [{"name":"a","type":"int64"}, {"name":"b","type":"int256"}, {"name":"c","type":"int256"}, {"name":"d","type":"bool"}, - {"name":"e","type":"bytes32[3][2]"}]`, + {"name":"e","type":"bytes32[3][2]"}], "type":"tuple"}]`, unpacked: struct { A int64 B *big.Int @@ -855,21 +855,22 @@ var packUnpackTests = []packUnpackTest{ "0500000000000000000000000000000000000000000000000000000000000000", // struct[e] array[1][2] }, { - def: `[{"name":"a","type":"string"}, + def: `[{"components": [{"name":"a","type":"string"}, {"name":"b","type":"int64"}, {"name":"c","type":"bytes"}, {"name":"d","type":"string[]"}, {"name":"e","type":"int256[]"}, - {"name":"f","type":"address[]"}]`, + {"name":"f","type":"address[]"}], "type":"tuple"}]`, unpacked: struct { - FieldA string `abi:"a"` // Test whether abi tag works - FieldB int64 `abi:"b"` - C []byte - D []string - E []*big.Int - F []common.Address + A string + B int64 + C []byte + D []string + E []*big.Int + F []common.Address }{"foobar", 1, []byte{1}, []string{"foo", "bar"}, []*big.Int{big.NewInt(1), big.NewInt(-1)}, []common.Address{{1}, {2}}}, - packed: "00000000000000000000000000000000000000000000000000000000000000c0" + // struct[a] offset + packed: "0000000000000000000000000000000000000000000000000000000000000020" + // struct a + "00000000000000000000000000000000000000000000000000000000000000c0" + // struct[a] offset "0000000000000000000000000000000000000000000000000000000000000001" + // struct[b] "0000000000000000000000000000000000000000000000000000000000000100" + // struct[c] offset "0000000000000000000000000000000000000000000000000000000000000140" + // struct[d] offset @@ -894,23 +895,24 @@ var packUnpackTests = []packUnpackTest{ "0000000000000000000000000200000000000000000000000000000000000000", // common.Address{2} }, { - def: `[{"components": [{"name": "a","type": "uint256"}, + def: `[{"components": [{ "type": "tuple","components": [{"name": "a","type": "uint256"}, {"name": "b","type": "uint256[]"}], "name": "a","type": "tuple"}, - {"name": "b","type": "uint256[]"}]`, + {"name": "b","type": "uint256[]"}], "type": "tuple"}]`, unpacked: struct { A struct { - FieldA *big.Int `abi:"a"` - B []*big.Int + A *big.Int + B []*big.Int } B []*big.Int }{ A: struct { - FieldA *big.Int `abi:"a"` // Test whether abi tag works for nested tuple - B []*big.Int + A *big.Int + B []*big.Int }{big.NewInt(1), []*big.Int{big.NewInt(1), big.NewInt(2)}}, B: []*big.Int{big.NewInt(1), big.NewInt(2)}}, - packed: "0000000000000000000000000000000000000000000000000000000000000040" + // a offset + packed: "0000000000000000000000000000000000000000000000000000000000000020" + // struct a + "0000000000000000000000000000000000000000000000000000000000000040" + // a offset "00000000000000000000000000000000000000000000000000000000000000e0" + // b offset "0000000000000000000000000000000000000000000000000000000000000001" + // a.a value "0000000000000000000000000000000000000000000000000000000000000040" + // a.b offset diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index f4812b06bf72..11248e07302e 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -24,6 +24,29 @@ import ( "strings" ) +// ConvertType converts an interface of a runtime type into a interface of the +// given type +// e.g. turn +// var fields []reflect.StructField +// fields = append(fields, reflect.StructField{ +// Name: "X", +// Type: reflect.TypeOf(new(big.Int)), +// Tag: reflect.StructTag("json:\"" + "x" + "\""), +// } +// into +// type TupleT struct { X *big.Int } +func ConvertType(in interface{}, proto interface{}) interface{} { + protoType := reflect.TypeOf(proto) + if reflect.TypeOf(in).ConvertibleTo(protoType) { + return reflect.ValueOf(in).Convert(protoType).Interface() + } + // Use set as a last ditch effort + if err := set(reflect.ValueOf(proto), reflect.ValueOf(in)); err != nil { + panic(err) + } + return proto +} + // indirect recursively dereferences the value until it either gets the value // or finds a big.Int func indirect(v reflect.Value) reflect.Value { @@ -119,6 +142,9 @@ func setSlice(dst, src reflect.Value) error { } func setArray(dst, src reflect.Value) error { + if src.Kind() == reflect.Ptr { + return set(dst, indirect(src)) + } array := reflect.New(dst.Type()).Elem() min := src.Len() if src.Len() > dst.Len() { diff --git a/accounts/abi/reflect_test.go b/accounts/abi/reflect_test.go index c425e6e54bff..cf13a79da84e 100644 --- a/accounts/abi/reflect_test.go +++ b/accounts/abi/reflect_test.go @@ -17,6 +17,7 @@ package abi import ( + "math/big" "reflect" "testing" ) @@ -189,3 +190,72 @@ func TestReflectNameToStruct(t *testing.T) { }) } } + +func TestConvertType(t *testing.T) { + // Test Basic Struct + type T struct { + X *big.Int + Y *big.Int + } + // Create on-the-fly structure + var fields []reflect.StructField + fields = append(fields, reflect.StructField{ + Name: "X", + Type: reflect.TypeOf(new(big.Int)), + Tag: "json:\"" + "x" + "\"", + }) + fields = append(fields, reflect.StructField{ + Name: "Y", + Type: reflect.TypeOf(new(big.Int)), + Tag: "json:\"" + "y" + "\"", + }) + val := reflect.New(reflect.StructOf(fields)) + val.Elem().Field(0).Set(reflect.ValueOf(big.NewInt(1))) + val.Elem().Field(1).Set(reflect.ValueOf(big.NewInt(2))) + // ConvertType + out := *ConvertType(val.Interface(), new(T)).(*T) + if out.X.Cmp(big.NewInt(1)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out.X, big.NewInt(1)) + } + if out.Y.Cmp(big.NewInt(2)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out.Y, big.NewInt(2)) + } + // Slice Type + val2 := reflect.MakeSlice(reflect.SliceOf(reflect.StructOf(fields)), 2, 2) + val2.Index(0).Field(0).Set(reflect.ValueOf(big.NewInt(1))) + val2.Index(0).Field(1).Set(reflect.ValueOf(big.NewInt(2))) + val2.Index(1).Field(0).Set(reflect.ValueOf(big.NewInt(3))) + val2.Index(1).Field(1).Set(reflect.ValueOf(big.NewInt(4))) + out2 := *ConvertType(val2.Interface(), new([]T)).(*[]T) + if out2[0].X.Cmp(big.NewInt(1)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out2[0].X, big.NewInt(1)) + } + if out2[0].Y.Cmp(big.NewInt(2)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out2[1].Y, big.NewInt(2)) + } + if out2[1].X.Cmp(big.NewInt(3)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out2[0].X, big.NewInt(1)) + } + if out2[1].Y.Cmp(big.NewInt(4)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out2[1].Y, big.NewInt(2)) + } + // Array Type + val3 := reflect.New(reflect.ArrayOf(2, reflect.StructOf(fields))) + val3.Elem().Index(0).Field(0).Set(reflect.ValueOf(big.NewInt(1))) + val3.Elem().Index(0).Field(1).Set(reflect.ValueOf(big.NewInt(2))) + val3.Elem().Index(1).Field(0).Set(reflect.ValueOf(big.NewInt(3))) + val3.Elem().Index(1).Field(1).Set(reflect.ValueOf(big.NewInt(4))) + out3 := *ConvertType(val3.Interface(), new([2]T)).(*[2]T) + if out3[0].X.Cmp(big.NewInt(1)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out3[0].X, big.NewInt(1)) + } + if out3[0].Y.Cmp(big.NewInt(2)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out3[1].Y, big.NewInt(2)) + } + if out3[1].X.Cmp(big.NewInt(3)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out3[0].X, big.NewInt(1)) + } + if out3[1].Y.Cmp(big.NewInt(4)) != 0 { + t.Errorf("ConvertType failed, got %v want %v", out3[1].Y, big.NewInt(2)) + } +} diff --git a/accounts/abi/type_test.go b/accounts/abi/type_test.go index 48df3aa383df..8c3aedca6a4d 100644 --- a/accounts/abi/type_test.go +++ b/accounts/abi/type_test.go @@ -255,7 +255,7 @@ func TestTypeCheck(t *testing.T) { {"bytes", nil, [2]byte{0, 1}, "abi: cannot use array as type slice as argument"}, {"bytes", nil, common.Hash{1}, "abi: cannot use array as type slice as argument"}, {"string", nil, "hello world", ""}, - {"string", nil, string(""), ""}, + {"string", nil, "", ""}, {"string", nil, []byte{}, "abi: cannot use slice as type string as argument"}, {"bytes32[]", nil, [][32]byte{{}}, ""}, {"function", nil, [24]byte{}, ""}, diff --git a/accounts/abi/unpack_test.go b/accounts/abi/unpack_test.go index 767d1540e60c..b88f77805bfa 100644 --- a/accounts/abi/unpack_test.go +++ b/accounts/abi/unpack_test.go @@ -44,15 +44,13 @@ func TestUnpack(t *testing.T) { if err != nil { t.Fatalf("invalid hex %s: %v", test.packed, err) } - outptr := reflect.New(reflect.TypeOf(test.unpacked)) - err = abi.Unpack(outptr.Interface(), "method", encb) + out, err := abi.Unpack("method", encb) if err != nil { t.Errorf("test %d (%v) failed: %v", i, test.def, err) return } - out := outptr.Elem().Interface() - if !reflect.DeepEqual(test.unpacked, out) { - t.Errorf("test %d (%v) failed: expected %v, got %v", i, test.def, test.unpacked, out) + if !reflect.DeepEqual(test.unpacked, ConvertType(out[0], test.unpacked)) { + t.Errorf("test %d (%v) failed: expected %v, got %v", i, test.def, test.unpacked, out[0]) } }) } @@ -221,7 +219,7 @@ func TestLocalUnpackTests(t *testing.T) { t.Fatalf("invalid hex %s: %v", test.enc, err) } outptr := reflect.New(reflect.TypeOf(test.want)) - err = abi.Unpack(outptr.Interface(), "method", encb) + err = abi.UnpackIntoInterface(outptr.Interface(), "method", encb) if err := test.checkError(err); err != nil { t.Errorf("test %d (%v) failed: %v", i, test.def, err) return @@ -234,7 +232,7 @@ func TestLocalUnpackTests(t *testing.T) { } } -func TestUnpackSetDynamicArrayOutput(t *testing.T) { +func TestUnpackIntoInterfaceSetDynamicArrayOutput(t *testing.T) { abi, err := JSON(strings.NewReader(`[{"constant":true,"inputs":[],"name":"testDynamicFixedBytes15","outputs":[{"name":"","type":"bytes15[]"}],"payable":false,"stateMutability":"view","type":"function"},{"constant":true,"inputs":[],"name":"testDynamicFixedBytes32","outputs":[{"name":"","type":"bytes32[]"}],"payable":false,"stateMutability":"view","type":"function"}]`)) if err != nil { t.Fatal(err) @@ -249,7 +247,7 @@ func TestUnpackSetDynamicArrayOutput(t *testing.T) { ) // test 32 - err = abi.Unpack(&out32, "testDynamicFixedBytes32", marshalledReturn32) + err = abi.UnpackIntoInterface(&out32, "testDynamicFixedBytes32", marshalledReturn32) if err != nil { t.Fatal(err) } @@ -266,7 +264,7 @@ func TestUnpackSetDynamicArrayOutput(t *testing.T) { } // test 15 - err = abi.Unpack(&out15, "testDynamicFixedBytes32", marshalledReturn15) + err = abi.UnpackIntoInterface(&out15, "testDynamicFixedBytes32", marshalledReturn15) if err != nil { t.Fatal(err) } @@ -367,7 +365,7 @@ func TestMethodMultiReturn(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { require := require.New(t) - err := abi.Unpack(tc.dest, "multi", data) + err := abi.UnpackIntoInterface(tc.dest, "multi", data) if tc.error == "" { require.Nil(err, "Should be able to unpack method outputs.") require.Equal(tc.expected, tc.dest) @@ -390,7 +388,7 @@ func TestMultiReturnWithArray(t *testing.T) { ret1, ret1Exp := new([3]uint64), [3]uint64{9, 9, 9} ret2, ret2Exp := new(uint64), uint64(8) - if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -414,7 +412,7 @@ func TestMultiReturnWithStringArray(t *testing.T) { ret2, ret2Exp := new(common.Address), common.HexToAddress("ab1257528b3782fb40d7ed5f72e624b744dffb2f") ret3, ret3Exp := new([2]string), [2]string{"Ethereum", "Hello, Ethereum!"} ret4, ret4Exp := new(bool), false - if err := abi.Unpack(&[]interface{}{ret1, ret2, ret3, ret4}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2, ret3, ret4}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -452,7 +450,7 @@ func TestMultiReturnWithStringSlice(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000065")) // output[1][1] value ret1, ret1Exp := new([]string), []string{"ethereum", "go-ethereum"} ret2, ret2Exp := new([]*big.Int), []*big.Int{big.NewInt(100), big.NewInt(101)} - if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -492,7 +490,7 @@ func TestMultiReturnWithDeeplyNestedArray(t *testing.T) { {{0x411, 0x412, 0x413}, {0x421, 0x422, 0x423}}, } ret2, ret2Exp := new(uint64), uint64(0x9876) - if err := abi.Unpack(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { + if err := abi.UnpackIntoInterface(&[]interface{}{ret1, ret2}, "multi", buff.Bytes()); err != nil { t.Fatal(err) } if !reflect.DeepEqual(*ret1, ret1Exp) { @@ -531,7 +529,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000a")) buff.Write(common.Hex2Bytes("0102000000000000000000000000000000000000000000000000000000000000")) - err = abi.Unpack(&mixedBytes, "mixedBytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&mixedBytes, "mixedBytes", buff.Bytes()) if err != nil { t.Error(err) } else { @@ -546,7 +544,7 @@ func TestUnmarshal(t *testing.T) { // marshal int var Int *big.Int - err = abi.Unpack(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + err = abi.UnpackIntoInterface(&Int, "int", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) if err != nil { t.Error(err) } @@ -557,7 +555,7 @@ func TestUnmarshal(t *testing.T) { // marshal bool var Bool bool - err = abi.Unpack(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) + err = abi.UnpackIntoInterface(&Bool, "bool", common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001")) if err != nil { t.Error(err) } @@ -574,7 +572,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(bytesOut) var Bytes []byte - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -590,7 +588,7 @@ func TestUnmarshal(t *testing.T) { bytesOut = common.RightPadBytes([]byte("hello"), 64) buff.Write(bytesOut) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -606,7 +604,7 @@ func TestUnmarshal(t *testing.T) { bytesOut = common.RightPadBytes([]byte("hello"), 64) buff.Write(bytesOut) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -616,7 +614,7 @@ func TestUnmarshal(t *testing.T) { } // marshal dynamic bytes output empty - err = abi.Unpack(&Bytes, "bytes", nil) + err = abi.UnpackIntoInterface(&Bytes, "bytes", nil) if err == nil { t.Error("expected error") } @@ -627,7 +625,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000005")) buff.Write(common.RightPadBytes([]byte("hello"), 32)) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err != nil { t.Error(err) } @@ -641,7 +639,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.RightPadBytes([]byte("hello"), 32)) var hash common.Hash - err = abi.Unpack(&hash, "fixed", buff.Bytes()) + err = abi.UnpackIntoInterface(&hash, "fixed", buff.Bytes()) if err != nil { t.Error(err) } @@ -654,12 +652,12 @@ func TestUnmarshal(t *testing.T) { // marshal error buff.Reset() buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000020")) - err = abi.Unpack(&Bytes, "bytes", buff.Bytes()) + err = abi.UnpackIntoInterface(&Bytes, "bytes", buff.Bytes()) if err == nil { t.Error("expected error") } - err = abi.Unpack(&Bytes, "multi", make([]byte, 64)) + err = abi.UnpackIntoInterface(&Bytes, "multi", make([]byte, 64)) if err == nil { t.Error("expected error") } @@ -670,7 +668,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003")) // marshal int array var intArray [3]*big.Int - err = abi.Unpack(&intArray, "intArraySingle", buff.Bytes()) + err = abi.UnpackIntoInterface(&intArray, "intArraySingle", buff.Bytes()) if err != nil { t.Error(err) } @@ -691,7 +689,7 @@ func TestUnmarshal(t *testing.T) { buff.Write(common.Hex2Bytes("0000000000000000000000000100000000000000000000000000000000000000")) var outAddr []common.Address - err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + err = abi.UnpackIntoInterface(&outAddr, "addressSliceSingle", buff.Bytes()) if err != nil { t.Fatal("didn't expect error:", err) } @@ -718,7 +716,7 @@ func TestUnmarshal(t *testing.T) { A []common.Address B []common.Address } - err = abi.Unpack(&outAddrStruct, "addressSliceDouble", buff.Bytes()) + err = abi.UnpackIntoInterface(&outAddrStruct, "addressSliceDouble", buff.Bytes()) if err != nil { t.Fatal("didn't expect error:", err) } @@ -746,7 +744,7 @@ func TestUnmarshal(t *testing.T) { buff.Reset() buff.Write(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000100")) - err = abi.Unpack(&outAddr, "addressSliceSingle", buff.Bytes()) + err = abi.UnpackIntoInterface(&outAddr, "addressSliceSingle", buff.Bytes()) if err == nil { t.Fatal("expected error:", err) } @@ -769,7 +767,7 @@ func TestUnpackTuple(t *testing.T) { B *big.Int }{new(big.Int), new(big.Int)} - err = abi.Unpack(&v, "tuple", buff.Bytes()) + err = abi.UnpackIntoInterface(&v, "tuple", buff.Bytes()) if err != nil { t.Error(err) } else { @@ -841,7 +839,7 @@ func TestUnpackTuple(t *testing.T) { A: big.NewInt(1), } - err = abi.Unpack(&ret, "tuple", buff.Bytes()) + err = abi.UnpackIntoInterface(&ret, "tuple", buff.Bytes()) if err != nil { t.Error(err) } diff --git a/accounts/accounts.go b/accounts/accounts.go index 7a14e4e3e5d5..08a1f0f2b193 100644 --- a/accounts/accounts.go +++ b/accounts/accounts.go @@ -21,7 +21,7 @@ import ( "fmt" "math/big" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" @@ -88,7 +88,7 @@ type Wallet interface { // to discover non zero accounts and automatically add them to list of tracked // accounts. // - // Note, self derivaton will increment the last component of the specified path + // Note, self derivation will increment the last component of the specified path // opposed to decending into a child path to allow discovering accounts starting // from non zero components. // diff --git a/accounts/hd.go b/accounts/hd.go index 75c47611061c..54acea3b261d 100644 --- a/accounts/hd.go +++ b/accounts/hd.go @@ -150,3 +150,31 @@ func (path *DerivationPath) UnmarshalJSON(b []byte) error { *path, err = ParseDerivationPath(dp) return err } + +// DefaultIterator creates a BIP-32 path iterator, which progresses by increasing the last component: +// i.e. m/44'/60'/0'/0/0, m/44'/60'/0'/0/1, m/44'/60'/0'/0/2, ... m/44'/60'/0'/0/N. +func DefaultIterator(base DerivationPath) func() DerivationPath { + path := make(DerivationPath, len(base)) + copy(path[:], base[:]) + // Set it back by one, so the first call gives the first result + path[len(path)-1]-- + return func() DerivationPath { + path[len(path)-1]++ + return path + } +} + +// LedgerLiveIterator creates a bip44 path iterator for Ledger Live. +// Ledger Live increments the third component rather than the fifth component +// i.e. m/44'/60'/0'/0/0, m/44'/60'/1'/0/0, m/44'/60'/2'/0/0, ... m/44'/60'/N'/0/0. +func LedgerLiveIterator(base DerivationPath) func() DerivationPath { + path := make(DerivationPath, len(base)) + copy(path[:], base[:]) + // Set it back by one, so the first call gives the first result + path[2]-- + return func() DerivationPath { + // ledgerLivePathIterator iterates on the third component + path[2]++ + return path + } +} diff --git a/accounts/hd_test.go b/accounts/hd_test.go index 3156a487ee7c..0743bbe66628 100644 --- a/accounts/hd_test.go +++ b/accounts/hd_test.go @@ -17,6 +17,7 @@ package accounts import ( + "fmt" "reflect" "testing" ) @@ -77,3 +78,41 @@ func TestHDPathParsing(t *testing.T) { } } } + +func testDerive(t *testing.T, next func() DerivationPath, expected []string) { + t.Helper() + for i, want := range expected { + if have := next(); fmt.Sprintf("%v", have) != want { + t.Errorf("step %d, have %v, want %v", i, have, want) + } + } +} + +func TestHdPathIteration(t *testing.T) { + testDerive(t, DefaultIterator(DefaultBaseDerivationPath), + []string{ + "m/44'/60'/0'/0/0", "m/44'/60'/0'/0/1", + "m/44'/60'/0'/0/2", "m/44'/60'/0'/0/3", + "m/44'/60'/0'/0/4", "m/44'/60'/0'/0/5", + "m/44'/60'/0'/0/6", "m/44'/60'/0'/0/7", + "m/44'/60'/0'/0/8", "m/44'/60'/0'/0/9", + }) + + testDerive(t, DefaultIterator(LegacyLedgerBaseDerivationPath), + []string{ + "m/44'/60'/0'/0", "m/44'/60'/0'/1", + "m/44'/60'/0'/2", "m/44'/60'/0'/3", + "m/44'/60'/0'/4", "m/44'/60'/0'/5", + "m/44'/60'/0'/6", "m/44'/60'/0'/7", + "m/44'/60'/0'/8", "m/44'/60'/0'/9", + }) + + testDerive(t, LedgerLiveIterator(DefaultBaseDerivationPath), + []string{ + "m/44'/60'/0'/0/0", "m/44'/60'/1'/0/0", + "m/44'/60'/2'/0/0", "m/44'/60'/3'/0/0", + "m/44'/60'/4'/0/0", "m/44'/60'/5'/0/0", + "m/44'/60'/6'/0/0", "m/44'/60'/7'/0/0", + "m/44'/60'/8'/0/0", "m/44'/60'/9'/0/0", + }) +} diff --git a/accounts/keystore/file_cache.go b/accounts/keystore/file_cache.go index 73ff6ae9ee6f..8b309321d370 100644 --- a/accounts/keystore/file_cache.go +++ b/accounts/keystore/file_cache.go @@ -32,7 +32,7 @@ import ( type fileCache struct { all mapset.Set // Set of all files from the keystore folder lastMod time.Time // Last time instance when a file was modified - mu sync.RWMutex + mu sync.Mutex } // scan performs a new scan on the given directory, compares against the already diff --git a/accounts/keystore/keystore_test.go b/accounts/keystore/keystore_test.go index 29c251d7c1a4..cb5de11c0ddb 100644 --- a/accounts/keystore/keystore_test.go +++ b/accounts/keystore/keystore_test.go @@ -336,7 +336,9 @@ func TestWalletNotifications(t *testing.T) { // Shut down the event collector and check events. sub.Unsubscribe() - <-updates + for ev := range updates { + events = append(events, walletEvent{ev, ev.Wallet.Accounts()[0]}) + } checkAccounts(t, live, ks.Wallets()) checkEvents(t, wantEvents, events) } diff --git a/accounts/keystore/passphrase.go b/accounts/keystore/passphrase.go index 89cdf0bfca14..dd4d7764e48b 100644 --- a/accounts/keystore/passphrase.go +++ b/accounts/keystore/passphrase.go @@ -230,7 +230,7 @@ func DecryptKey(keyjson []byte, auth string) (*Key, error) { key := crypto.ToECDSAUnsafe(keyBytes) return &Key{ - Id: uuid.UUID(keyId), + Id: keyId, Address: crypto.PubkeyToAddress(key.PublicKey), PrivateKey: key, }, nil diff --git a/accounts/keystore/wallet.go b/accounts/keystore/wallet.go index 498067d49730..1066095f6d07 100644 --- a/accounts/keystore/wallet.go +++ b/accounts/keystore/wallet.go @@ -19,7 +19,7 @@ package keystore import ( "math/big" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" @@ -58,7 +58,7 @@ func (w *keystoreWallet) Open(passphrase string) error { return nil } func (w *keystoreWallet) Close() error { return nil } // Accounts implements accounts.Wallet, returning an account list consisting of -// a single account that the plain kestore wallet contains. +// a single account that the plain keystore wallet contains. func (w *keystoreWallet) Accounts() []accounts.Account { return []accounts.Account{w.account} } @@ -93,12 +93,12 @@ func (w *keystoreWallet) signHash(account accounts.Account, hash []byte) ([]byte return w.keystore.SignHash(account, hash) } -// SignData signs keccak256(data). The mimetype parameter describes the type of data being signed +// SignData signs keccak256(data). The mimetype parameter describes the type of data being signed. func (w *keystoreWallet) SignData(account accounts.Account, mimeType string, data []byte) ([]byte, error) { return w.signHash(account, crypto.Keccak256(data)) } -// SignDataWithPassphrase signs keccak256(data). The mimetype parameter describes the type of data being signed +// SignDataWithPassphrase signs keccak256(data). The mimetype parameter describes the type of data being signed. func (w *keystoreWallet) SignDataWithPassphrase(account accounts.Account, passphrase, mimeType string, data []byte) ([]byte, error) { // Make sure the requested account is contained within if !w.Contains(account) { @@ -108,12 +108,14 @@ func (w *keystoreWallet) SignDataWithPassphrase(account accounts.Account, passph return w.keystore.SignHashWithPassphrase(account, passphrase, crypto.Keccak256(data)) } +// SignText implements accounts.Wallet, attempting to sign the hash of +// the given text with the given account. func (w *keystoreWallet) SignText(account accounts.Account, text []byte) ([]byte, error) { return w.signHash(account, accounts.TextHash(text)) } // SignTextWithPassphrase implements accounts.Wallet, attempting to sign the -// given hash with the given account using passphrase as extra authentication. +// hash of the given text with the given account using passphrase as extra authentication. func (w *keystoreWallet) SignTextWithPassphrase(account accounts.Account, passphrase string, text []byte) ([]byte, error) { // Make sure the requested account is contained within if !w.Contains(account) { diff --git a/accounts/scwallet/wallet.go b/accounts/scwallet/wallet.go index 80009fc5ebbf..6476646d7f07 100644 --- a/accounts/scwallet/wallet.go +++ b/accounts/scwallet/wallet.go @@ -33,7 +33,7 @@ import ( "sync" "time" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -637,7 +637,7 @@ func (w *Wallet) Derive(path accounts.DerivationPath, pin bool) (accounts.Accoun // to discover non zero accounts and automatically add them to list of tracked // accounts. // -// Note, self derivaton will increment the last component of the specified path +// Note, self derivation will increment the last component of the specified path // opposed to decending into a child path to allow discovering accounts starting // from non zero components. // diff --git a/accounts/usbwallet/ledger.go b/accounts/usbwallet/ledger.go index 64eae64f689a..71f0f9392fc4 100644 --- a/accounts/usbwallet/ledger.go +++ b/accounts/usbwallet/ledger.go @@ -162,7 +162,7 @@ func (w *ledgerDriver) SignTx(path accounts.DerivationPath, tx *types.Transactio return common.Address{}, nil, accounts.ErrWalletClosed } // Ensure the wallet is capable of signing the given transaction - if chainID != nil && w.version[0] <= 1 && w.version[2] <= 2 { + if chainID != nil && w.version[0] <= 1 && w.version[1] <= 0 && w.version[2] <= 2 { //lint:ignore ST1005 brand name displayed on the console return common.Address{}, nil, fmt.Errorf("Ledger v%d.%d.%d doesn't support signing this transaction, please update to v1.0.3 at least", w.version[0], w.version[1], w.version[2]) } diff --git a/accounts/usbwallet/wallet.go b/accounts/usbwallet/wallet.go index 993c599346a0..9f74e5554f14 100644 --- a/accounts/usbwallet/wallet.go +++ b/accounts/usbwallet/wallet.go @@ -25,7 +25,7 @@ import ( "sync" "time" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" @@ -493,7 +493,7 @@ func (w *wallet) Derive(path accounts.DerivationPath, pin bool) (accounts.Accoun // to discover non zero accounts and automatically add them to list of tracked // accounts. // -// Note, self derivaton will increment the last component of the specified path +// Note, self derivation will increment the last component of the specified path // opposed to decending into a child path to allow discovering accounts starting // from non zero components. // diff --git a/appveyor.yml b/appveyor.yml index 7d6bf87639a1..2bf67d45684d 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -24,13 +24,13 @@ environment: install: - git submodule update --init - rmdir C:\go /s /q - - appveyor DownloadFile https://dl.google.com/go/go1.15.windows-%GETH_ARCH%.zip - - 7z x go1.15.windows-%GETH_ARCH%.zip -y -oC:\ > NUL + - appveyor DownloadFile https://dl.google.com/go/go1.15.5.windows-%GETH_ARCH%.zip + - 7z x go1.15.5.windows-%GETH_ARCH%.zip -y -oC:\ > NUL - go version - gcc --version build_script: - - go run build\ci.go install + - go run build\ci.go install -dlgo after_build: - go run build\ci.go archive -type zip -signer WINDOWS_SIGNING_KEY -upload gethstore/builds diff --git a/build/checksums.txt b/build/checksums.txt index 39f855cd0cf4..a7a6a657e9bf 100644 --- a/build/checksums.txt +++ b/build/checksums.txt @@ -1,6 +1,17 @@ # This file contains sha256 checksums of optional build dependencies. -69438f7ed4f532154ffaf878f3dfd83747e7a00b70b3556eddabf7aaee28ac3a go1.15.src.tar.gz +890bba73c5e2b19ffb1180e385ea225059eb008eb91b694875dd86ea48675817 go1.15.6.src.tar.gz +940a73b45993a3bae5792cf324140dded34af97c548af4864d22fd6d49f3bd9f go1.15.6.darwin-amd64.tar.gz +ad187f02158b9a9013ef03f41d14aa69c402477f178825a3940280814bcbb755 go1.15.6.linux-386.tar.gz +3918e6cc85e7eaaa6f859f1bdbaac772e7a825b0eb423c63d3ae68b21f84b844 go1.15.6.linux-amd64.tar.gz +f87515b9744154ffe31182da9341d0a61eb0795551173d242c8cad209239e492 go1.15.6.linux-arm64.tar.gz +40ba9a57764e374195018ef37c38a5fbac9bbce908eab436370631a84bfc5788 go1.15.6.linux-armv6l.tar.gz +5872eff6746a0a5f304272b27cbe9ce186f468454e95749cce01e903fbfc0e17 go1.15.6.windows-386.zip +b7b3808bb072c2bab73175009187fd5a7f20ffe0a31739937003a14c5c4d9006 go1.15.6.windows-amd64.zip +9d9dd5c217c1392f1b2ed5e03e1c71bf4cf8553884e57a38e68fd37fdcfe31a8 go1.15.6.freebsd-386.tar.gz +609f065d855aed5a0b40ef0245aacbcc0b4b7882dc3b1e75ae50576cf25265ee go1.15.6.freebsd-amd64.tar.gz +d4174fc217e749ac049eacc8827df879689f2246ac230d04991ae7df336f7cb2 go1.15.6.linux-ppc64le.tar.gz +839cc6b67687d8bb7cb044e4a9a2eac0c090765cc8ec55ffe714dfb7cd51cf3a go1.15.6.linux-s390x.tar.gz d998a84eea42f2271aca792a7b027ca5c1edfcba229e8e5a844c9ac3f336df35 golangci-lint-1.27.0-linux-armv7.tar.gz bf781f05b0d393b4bf0a327d9e62926949a4f14d7774d950c4e009fc766ed1d4 golangci-lint.exe-1.27.0-windows-amd64.zip diff --git a/build/ci.go b/build/ci.go index ea708d5e7b2f..73d896162983 100644 --- a/build/ci.go +++ b/build/ci.go @@ -26,7 +26,7 @@ Available commands are: install [ -arch architecture ] [ -cc compiler ] [ packages... ] -- builds packages and executables test [ -coverage ] [ packages... ] -- runs the tests lint -- runs certain pre-selected linters - archive [ -arch architecture ] [ -type zip|tar ] [ -signer key-envvar ] [ -upload dest ] -- archives build artifacts + archive [ -arch architecture ] [ -type zip|tar ] [ -signer key-envvar ] [ -signify key-envvar ] [ -upload dest ] -- archives build artifacts importkeys -- imports signing keys from env debsrc [ -signer key-id ] [ -upload dest ] -- creates a debian source package nsis -- creates a Windows NSIS installer @@ -46,12 +46,11 @@ import ( "encoding/base64" "flag" "fmt" - "go/parser" - "go/token" "io/ioutil" "log" "os" "os/exec" + "path" "path/filepath" "regexp" "runtime" @@ -59,6 +58,7 @@ import ( "time" "github.com/cespare/cp" + "github.com/ethereum/go-ethereum/crypto/signify" "github.com/ethereum/go-ethereum/internal/build" "github.com/ethereum/go-ethereum/params" ) @@ -135,11 +135,11 @@ var ( // Note: artful is unsupported because it was officially deprecated on Launchpad. // Note: cosmic is unsupported because it was officially deprecated on Launchpad. // Note: disco is unsupported because it was officially deprecated on Launchpad. + // Note: eoan is unsupported because it was officially deprecated on Launchpad. debDistroGoBoots = map[string]string{ "trusty": "golang-1.11", "xenial": "golang-go", "bionic": "golang-go", - "eoan": "golang-go", "focal": "golang-go", "groovy": "golang-go", } @@ -148,6 +148,11 @@ var ( "golang-1.11": "/usr/lib/go-1.11", "golang-go": "/usr/lib/go", } + + // This is the version of go that will be downloaded by + // + // go run ci.go install -dlgo + dlgoVersion = "1.15.6" ) var GOBIN, _ = filepath.Abs(filepath.Join("build", "bin")) @@ -198,19 +203,19 @@ func main() { func doInstall(cmdline []string) { var ( + dlgo = flag.Bool("dlgo", false, "Download Go and build with it") arch = flag.String("arch", "", "Architecture to cross build for") cc = flag.String("cc", "", "C compiler to cross build with") ) flag.CommandLine.Parse(cmdline) env := build.Env() - // Check Go version. People regularly open issues about compilation + // Check local Go version. People regularly open issues about compilation // failure with outdated Go. This should save them the trouble. if !strings.Contains(runtime.Version(), "devel") { // Figure out the minor version number since we can't textually compare (1.10 < 1.9) var minor int fmt.Sscanf(strings.TrimPrefix(runtime.Version(), "go1."), "%d", &minor) - if minor < 13 { log.Println("You have Go version", runtime.Version()) log.Println("go-ethereum requires at least Go version 1.13 and cannot") @@ -218,90 +223,110 @@ func doInstall(cmdline []string) { os.Exit(1) } } - // Compile packages given as arguments, or everything if there are no arguments. - packages := []string{"./..."} - if flag.NArg() > 0 { - packages = flag.Args() + + // Choose which go command we're going to use. + var gobuild *exec.Cmd + if !*dlgo { + // Default behavior: use the go version which runs ci.go right now. + gobuild = goTool("build") + } else { + // Download of Go requested. This is for build environments where the + // installed version is too old and cannot be upgraded easily. + cachedir := filepath.Join("build", "cache") + goroot := downloadGo(runtime.GOARCH, runtime.GOOS, cachedir) + gobuild = localGoTool(goroot, "build") } - if *arch == "" || *arch == runtime.GOARCH { - goinstall := goTool("install", buildFlags(env)...) - if runtime.GOARCH == "arm64" { - goinstall.Args = append(goinstall.Args, "-p", "1") - } - goinstall.Args = append(goinstall.Args, "-trimpath") - goinstall.Args = append(goinstall.Args, "-v") - goinstall.Args = append(goinstall.Args, packages...) - build.MustRun(goinstall) - return + // Configure environment for cross build. + if *arch != "" || *arch != runtime.GOARCH { + gobuild.Env = append(gobuild.Env, "CGO_ENABLED=1") + gobuild.Env = append(gobuild.Env, "GOARCH="+*arch) } - // Seems we are cross compiling, work around forbidden GOBIN - goinstall := goToolArch(*arch, *cc, "install", buildFlags(env)...) - goinstall.Args = append(goinstall.Args, "-trimpath") - goinstall.Args = append(goinstall.Args, "-v") - goinstall.Args = append(goinstall.Args, []string{"-buildmode", "archive"}...) - goinstall.Args = append(goinstall.Args, packages...) - build.MustRun(goinstall) + // Configure C compiler. + if *cc != "" { + gobuild.Env = append(gobuild.Env, "CC="+*cc) + } else if os.Getenv("CC") != "" { + gobuild.Env = append(gobuild.Env, "CC="+os.Getenv("CC")) + } - if cmds, err := ioutil.ReadDir("cmd"); err == nil { - for _, cmd := range cmds { - pkgs, err := parser.ParseDir(token.NewFileSet(), filepath.Join(".", "cmd", cmd.Name()), nil, parser.PackageClauseOnly) - if err != nil { - log.Fatal(err) - } - for name := range pkgs { - if name == "main" { - gobuild := goToolArch(*arch, *cc, "build", buildFlags(env)...) - gobuild.Args = append(gobuild.Args, "-v") - gobuild.Args = append(gobuild.Args, []string{"-o", executablePath(cmd.Name())}...) - gobuild.Args = append(gobuild.Args, "."+string(filepath.Separator)+filepath.Join("cmd", cmd.Name())) - build.MustRun(gobuild) - break - } - } - } + // arm64 CI builders are memory-constrained and can't handle concurrent builds, + // better disable it. This check isn't the best, it should probably + // check for something in env instead. + if runtime.GOARCH == "arm64" { + gobuild.Args = append(gobuild.Args, "-p", "1") + } + + // Put the default settings in. + gobuild.Args = append(gobuild.Args, buildFlags(env)...) + + // We use -trimpath to avoid leaking local paths into the built executables. + gobuild.Args = append(gobuild.Args, "-trimpath") + + // Show packages during build. + gobuild.Args = append(gobuild.Args, "-v") + + // Now we choose what we're even building. + // Default: collect all 'main' packages in cmd/ and build those. + packages := flag.Args() + if len(packages) == 0 { + packages = build.FindMainPackages("./cmd") + } + + // Do the build! + for _, pkg := range packages { + args := make([]string, len(gobuild.Args)) + copy(args, gobuild.Args) + args = append(args, "-o", executablePath(path.Base(pkg))) + args = append(args, pkg) + build.MustRun(&exec.Cmd{Path: gobuild.Path, Args: args, Env: gobuild.Env}) } } +// buildFlags returns the go tool flags for building. func buildFlags(env build.Environment) (flags []string) { var ld []string if env.Commit != "" { ld = append(ld, "-X", "main.gitCommit="+env.Commit) ld = append(ld, "-X", "main.gitDate="+env.Date) } + // Strip DWARF on darwin. This used to be required for certain things, + // and there is no downside to this, so we just keep doing it. if runtime.GOOS == "darwin" { ld = append(ld, "-s") } - if len(ld) > 0 { flags = append(flags, "-ldflags", strings.Join(ld, " ")) } return flags } +// goTool returns the go tool. This uses the Go version which runs ci.go. func goTool(subcmd string, args ...string) *exec.Cmd { - return goToolArch(runtime.GOARCH, os.Getenv("CC"), subcmd, args...) + cmd := build.GoTool(subcmd, args...) + goToolSetEnv(cmd) + return cmd } -func goToolArch(arch string, cc string, subcmd string, args ...string) *exec.Cmd { - cmd := build.GoTool(subcmd, args...) - if arch == "" || arch == runtime.GOARCH { - cmd.Env = append(cmd.Env, "GOBIN="+GOBIN) - } else { - cmd.Env = append(cmd.Env, "CGO_ENABLED=1") - cmd.Env = append(cmd.Env, "GOARCH="+arch) - } - if cc != "" { - cmd.Env = append(cmd.Env, "CC="+cc) - } +// localGoTool returns the go tool from the given GOROOT. +func localGoTool(goroot string, subcmd string, args ...string) *exec.Cmd { + gotool := filepath.Join(goroot, "bin", "go") + cmd := exec.Command(gotool, subcmd) + goToolSetEnv(cmd) + cmd.Env = append(cmd.Env, "GOROOT="+goroot) + cmd.Args = append(cmd.Args, args...) + return cmd +} + +// goToolSetEnv forwards the build environment to the go tool. +func goToolSetEnv(cmd *exec.Cmd) { + cmd.Env = append(cmd.Env, "GOBIN="+GOBIN) for _, e := range os.Environ() { - if strings.HasPrefix(e, "GOBIN=") { + if strings.HasPrefix(e, "GOBIN=") || strings.HasPrefix(e, "CC=") { continue } cmd.Env = append(cmd.Env, e) } - return cmd } // Running The Tests @@ -363,7 +388,7 @@ func downloadLinter(cachedir string) string { if err := csdb.DownloadFile(url, archivePath); err != nil { log.Fatal(err) } - if err := build.ExtractTarballArchive(archivePath, cachedir); err != nil { + if err := build.ExtractArchive(archivePath, cachedir); err != nil { log.Fatal(err) } return filepath.Join(cachedir, base, "golangci-lint") @@ -372,11 +397,12 @@ func downloadLinter(cachedir string) string { // Release Packaging func doArchive(cmdline []string) { var ( - arch = flag.String("arch", runtime.GOARCH, "Architecture cross packaging") - atype = flag.String("type", "zip", "Type of archive to write (zip|tar)") - signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. LINUX_SIGNING_KEY)`) - upload = flag.String("upload", "", `Destination to upload the archives (usually "gethstore/builds")`) - ext string + arch = flag.String("arch", runtime.GOARCH, "Architecture cross packaging") + atype = flag.String("type", "zip", "Type of archive to write (zip|tar)") + signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. LINUX_SIGNING_KEY)`) + signify = flag.String("signify", "", `Environment variable holding the signify key (e.g. LINUX_SIGNIFY_KEY)`) + upload = flag.String("upload", "", `Destination to upload the archives (usually "gethstore/builds")`) + ext string ) flag.CommandLine.Parse(cmdline) switch *atype { @@ -403,7 +429,7 @@ func doArchive(cmdline []string) { log.Fatal(err) } for _, archive := range []string{geth, alltools} { - if err := archiveUpload(archive, *upload, *signer); err != nil { + if err := archiveUpload(archive, *upload, *signer, *signify); err != nil { log.Fatal(err) } } @@ -423,7 +449,7 @@ func archiveBasename(arch string, archiveVersion string) string { return platform + "-" + archiveVersion } -func archiveUpload(archive string, blobstore string, signer string) error { +func archiveUpload(archive string, blobstore string, signer string, signifyVar string) error { // If signing was requested, generate the signature files if signer != "" { key := getenvBase64(signer) @@ -431,6 +457,14 @@ func archiveUpload(archive string, blobstore string, signer string) error { return err } } + if signifyVar != "" { + key := os.Getenv(signifyVar) + untrustedComment := "verify with geth-release.pub" + trustedComment := fmt.Sprintf("%s (%s)", archive, time.Now().UTC().Format(time.RFC1123)) + if err := signify.SignFile(archive, archive+".sig", key, untrustedComment, trustedComment); err != nil { + return err + } + } // If uploading to Azure was requested, push the archive possibly with its signature if blobstore != "" { auth := build.AzureBlobstoreConfig{ @@ -446,6 +480,11 @@ func archiveUpload(archive string, blobstore string, signer string) error { return err } } + if signifyVar != "" { + if err := build.AzureBlobstoreUpload(archive+".sig", filepath.Base(archive+".sig"), auth); err != nil { + return err + } + } } return nil } @@ -469,13 +508,12 @@ func maybeSkipArchive(env build.Environment) { // Debian Packaging func doDebianSource(cmdline []string) { var ( - goversion = flag.String("goversion", "", `Go version to build with (will be included in the source package)`) - cachedir = flag.String("cachedir", "./build/cache", `Filesystem path to cache the downloaded Go bundles at`) - signer = flag.String("signer", "", `Signing key name, also used as package author`) - upload = flag.String("upload", "", `Where to upload the source package (usually "ethereum/ethereum")`) - sshUser = flag.String("sftp-user", "", `Username for SFTP upload (usually "geth-ci")`) - workdir = flag.String("workdir", "", `Output directory for packages (uses temp dir if unset)`) - now = time.Now() + cachedir = flag.String("cachedir", "./build/cache", `Filesystem path to cache the downloaded Go bundles at`) + signer = flag.String("signer", "", `Signing key name, also used as package author`) + upload = flag.String("upload", "", `Where to upload the source package (usually "ethereum/ethereum")`) + sshUser = flag.String("sftp-user", "", `Username for SFTP upload (usually "geth-ci")`) + workdir = flag.String("workdir", "", `Output directory for packages (uses temp dir if unset)`) + now = time.Now() ) flag.CommandLine.Parse(cmdline) *workdir = makeWorkdir(*workdir) @@ -490,10 +528,10 @@ func doDebianSource(cmdline []string) { } // Download and verify the Go source package. - gobundle := downloadGoSources(*goversion, *cachedir) + gobundle := downloadGoSources(*cachedir) // Download all the dependencies needed to build the sources and run the ci script - srcdepfetch := goTool("install", "-n", "./...") + srcdepfetch := goTool("mod", "download") srcdepfetch.Env = append(os.Environ(), "GOPATH="+filepath.Join(*workdir, "modgopath")) build.MustRun(srcdepfetch) @@ -509,7 +547,7 @@ func doDebianSource(cmdline []string) { pkgdir := stageDebianSource(*workdir, meta) // Add Go source code - if err := build.ExtractTarballArchive(gobundle, pkgdir); err != nil { + if err := build.ExtractArchive(gobundle, pkgdir); err != nil { log.Fatalf("Failed to extract Go sources: %v", err) } if err := os.Rename(filepath.Join(pkgdir, "go"), filepath.Join(pkgdir, ".go")); err != nil { @@ -541,9 +579,10 @@ func doDebianSource(cmdline []string) { } } -func downloadGoSources(version string, cachedir string) string { +// downloadGoSources downloads the Go source tarball. +func downloadGoSources(cachedir string) string { csdb := build.MustLoadChecksums("build/checksums.txt") - file := fmt.Sprintf("go%s.src.tar.gz", version) + file := fmt.Sprintf("go%s.src.tar.gz", dlgoVersion) url := "https://dl.google.com/go/" + file dst := filepath.Join(cachedir, file) if err := csdb.DownloadFile(url, dst); err != nil { @@ -552,6 +591,41 @@ func downloadGoSources(version string, cachedir string) string { return dst } +// downloadGo downloads the Go binary distribution and unpacks it into a temporary +// directory. It returns the GOROOT of the unpacked toolchain. +func downloadGo(goarch, goos, cachedir string) string { + if goarch == "arm" { + goarch = "armv6l" + } + + csdb := build.MustLoadChecksums("build/checksums.txt") + file := fmt.Sprintf("go%s.%s-%s", dlgoVersion, goos, goarch) + if goos == "windows" { + file += ".zip" + } else { + file += ".tar.gz" + } + url := "https://golang.org/dl/" + file + dst := filepath.Join(cachedir, file) + if err := csdb.DownloadFile(url, dst); err != nil { + log.Fatal(err) + } + + ucache, err := os.UserCacheDir() + if err != nil { + log.Fatal(err) + } + godir := filepath.Join(ucache, fmt.Sprintf("geth-go-%s-%s-%s", dlgoVersion, goos, goarch)) + if err := build.ExtractArchive(dst, godir); err != nil { + log.Fatal(err) + } + goroot, err := filepath.Abs(filepath.Join(godir, "go")) + if err != nil { + log.Fatal(err) + } + return goroot +} + func ppaUpload(workdir, ppa, sshUser string, files []string) { p := strings.Split(ppa, "/") if len(p) != 2 { @@ -747,6 +821,7 @@ func doWindowsInstaller(cmdline []string) { var ( arch = flag.String("arch", runtime.GOARCH, "Architecture for cross build packaging") signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. WINDOWS_SIGNING_KEY)`) + signify = flag.String("signify key", "", `Environment variable holding the signify signing key (e.g. WINDOWS_SIGNIFY_KEY)`) upload = flag.String("upload", "", `Destination to upload the archives (usually "gethstore/builds")`) workdir = flag.String("workdir", "", `Output directory for packages (uses temp dir if unset)`) ) @@ -808,7 +883,7 @@ func doWindowsInstaller(cmdline []string) { filepath.Join(*workdir, "geth.nsi"), ) // Sign and publish installer. - if err := archiveUpload(installer, *upload, *signer); err != nil { + if err := archiveUpload(installer, *upload, *signer, *signify); err != nil { log.Fatal(err) } } @@ -817,10 +892,11 @@ func doWindowsInstaller(cmdline []string) { func doAndroidArchive(cmdline []string) { var ( - local = flag.Bool("local", false, `Flag whether we're only doing a local build (skip Maven artifacts)`) - signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. ANDROID_SIGNING_KEY)`) - deploy = flag.String("deploy", "", `Destination to deploy the archive (usually "https://oss.sonatype.org")`) - upload = flag.String("upload", "", `Destination to upload the archive (usually "gethstore/builds")`) + local = flag.Bool("local", false, `Flag whether we're only doing a local build (skip Maven artifacts)`) + signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. ANDROID_SIGNING_KEY)`) + signify = flag.String("signify", "", `Environment variable holding the signify signing key (e.g. ANDROID_SIGNIFY_KEY)`) + deploy = flag.String("deploy", "", `Destination to deploy the archive (usually "https://oss.sonatype.org")`) + upload = flag.String("upload", "", `Destination to upload the archive (usually "gethstore/builds")`) ) flag.CommandLine.Parse(cmdline) env := build.Env() @@ -836,6 +912,7 @@ func doAndroidArchive(cmdline []string) { if *local { // If we're building locally, copy bundle to build dir and skip Maven os.Rename("geth.aar", filepath.Join(GOBIN, "geth.aar")) + os.Rename("geth-sources.jar", filepath.Join(GOBIN, "geth-sources.jar")) return } meta := newMavenMetadata(env) @@ -848,7 +925,7 @@ func doAndroidArchive(cmdline []string) { archive := "geth-" + archiveBasename("android", params.ArchiveVersion(env.Commit)) + ".aar" os.Rename("geth.aar", archive) - if err := archiveUpload(archive, *upload, *signer); err != nil { + if err := archiveUpload(archive, *upload, *signer, *signify); err != nil { log.Fatal(err) } // Sign and upload all the artifacts to Maven Central @@ -941,10 +1018,11 @@ func newMavenMetadata(env build.Environment) mavenMetadata { func doXCodeFramework(cmdline []string) { var ( - local = flag.Bool("local", false, `Flag whether we're only doing a local build (skip Maven artifacts)`) - signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. IOS_SIGNING_KEY)`) - deploy = flag.String("deploy", "", `Destination to deploy the archive (usually "trunk")`) - upload = flag.String("upload", "", `Destination to upload the archives (usually "gethstore/builds")`) + local = flag.Bool("local", false, `Flag whether we're only doing a local build (skip Maven artifacts)`) + signer = flag.String("signer", "", `Environment variable holding the signing key (e.g. IOS_SIGNING_KEY)`) + signify = flag.String("signify", "", `Environment variable holding the signify signing key (e.g. IOS_SIGNIFY_KEY)`) + deploy = flag.String("deploy", "", `Destination to deploy the archive (usually "trunk")`) + upload = flag.String("upload", "", `Destination to upload the archives (usually "gethstore/builds")`) ) flag.CommandLine.Parse(cmdline) env := build.Env() @@ -972,14 +1050,14 @@ func doXCodeFramework(cmdline []string) { maybeSkipArchive(env) // Sign and upload the framework to Azure - if err := archiveUpload(archive+".tar.gz", *upload, *signer); err != nil { + if err := archiveUpload(archive+".tar.gz", *upload, *signer, *signify); err != nil { log.Fatal(err) } // Prepare and upload a PodSpec to CocoaPods if *deploy != "" { meta := newPodMetadata(env, archive) build.Render("build/pod.podspec", "Geth.podspec", 0755, meta) - build.MustRunCommand("pod", *deploy, "push", "Geth.podspec", "--allow-warnings", "--verbose") + build.MustRunCommand("pod", *deploy, "push", "Geth.podspec", "--allow-warnings") } } diff --git a/cmd/abigen/main.go b/cmd/abigen/main.go index a74b0396d40f..7b3b35e4e54f 100644 --- a/cmd/abigen/main.go +++ b/cmd/abigen/main.go @@ -96,7 +96,7 @@ var ( } aliasFlag = cli.StringFlag{ Name: "alias", - Usage: "Comma separated aliases for function and event renaming, e.g. foo=bar", + Usage: "Comma separated aliases for function and event renaming, e.g. original1=alias1, original2=alias2", } ) diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go index f6e2a14c3bd2..6c9ff615a18b 100644 --- a/cmd/bootnode/main.go +++ b/cmd/bootnode/main.go @@ -44,7 +44,7 @@ func main() { natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:)") netrestrict = flag.String("netrestrict", "", "restrict network communication to the given IP networks (CIDR masks)") runv5 = flag.Bool("v5", false, "run a v5 topic discovery bootnode") - verbosity = flag.Int("verbosity", int(log.LvlInfo), "log verbosity (0-9)") + verbosity = flag.Int("verbosity", int(log.LvlInfo), "log verbosity (0-5)") vmodule = flag.String("vmodule", "", "log verbosity pattern") nodeKey *ecdsa.PrivateKey diff --git a/cmd/clef/extapi_changelog.md b/cmd/clef/extapi_changelog.md index dbc302631bc1..31554f079020 100644 --- a/cmd/clef/extapi_changelog.md +++ b/cmd/clef/extapi_changelog.md @@ -10,6 +10,64 @@ TL;DR: Given a version number MAJOR.MINOR.PATCH, increment the: Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. +### 6.1.0 + +The API-method `account_signGnosisSafeTx` was added. This method takes two parameters, +`[address, safeTx]`. The latter, `safeTx`, can be copy-pasted from the gnosis relay. For example: + +``` +{ + "jsonrpc": "2.0", + "method": "account_signGnosisSafeTx", + "params": ["0xfd1c4226bfD1c436672092F4eCbfC270145b7256", + { + "safe": "0x25a6c4BBd32B2424A9c99aEB0584Ad12045382B3", + "to": "0xB372a646f7F05Cc1785018dBDA7EBc734a2A20E2", + "value": "20000000000000000", + "data": null, + "operation": 0, + "gasToken": "0x0000000000000000000000000000000000000000", + "safeTxGas": 27845, + "baseGas": 0, + "gasPrice": "0", + "refundReceiver": "0x0000000000000000000000000000000000000000", + "nonce": 2, + "executionDate": null, + "submissionDate": "2020-09-15T21:54:49.617634Z", + "modified": "2020-09-15T21:54:49.617634Z", + "blockNumber": null, + "transactionHash": null, + "safeTxHash": "0x2edfbd5bc113ff18c0631595db32eb17182872d88d9bf8ee4d8c2dd5db6d95e2", + "executor": null, + "isExecuted": false, + "isSuccessful": null, + "ethGasPrice": null, + "gasUsed": null, + "fee": null, + "origin": null, + "dataDecoded": null, + "confirmationsRequired": null, + "confirmations": [ + { + "owner": "0xAd2e180019FCa9e55CADe76E4487F126Fd08DA34", + "submissionDate": "2020-09-15T21:54:49.663299Z", + "transactionHash": null, + "confirmationType": "CONFIRMATION", + "signature": "0x95a7250bb645f831c86defc847350e7faff815b2fb586282568e96cc859e39315876db20a2eed5f7a0412906ec5ab57652a6f645ad4833f345bda059b9da2b821c", + "signatureType": "EOA" + } + ], + "signatures": null + } + ], + "id": 67 +} +``` + +Not all fields are required, though. This method is really just a UX helper, which massages the +input to conform to the `EIP-712` [specification](https://docs.gnosis.io/safe/docs/contracts_tx_execution/#transaction-hash) +for the Gnosis Safe, and making the output be directly importable to by a relay service. + ### 6.0.0 diff --git a/cmd/clef/main.go b/cmd/clef/main.go index 8c8778c24980..273919cb410f 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -29,7 +29,6 @@ import ( "math/big" "os" "os/signal" - "os/user" "path/filepath" "runtime" "sort" @@ -55,7 +54,7 @@ import ( "github.com/ethereum/go-ethereum/signer/rules" "github.com/ethereum/go-ethereum/signer/storage" - colorable "github.com/mattn/go-colorable" + "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "gopkg.in/urfave/cli.v1" ) @@ -666,8 +665,8 @@ func signer(c *cli.Context) error { Version: "1.0"}, } if c.GlobalBool(utils.HTTPEnabledFlag.Name) { - vhosts := splitAndTrim(c.GlobalString(utils.HTTPVirtualHostsFlag.Name)) - cors := splitAndTrim(c.GlobalString(utils.HTTPCORSDomainFlag.Name)) + vhosts := utils.SplitAndTrim(c.GlobalString(utils.HTTPVirtualHostsFlag.Name)) + cors := utils.SplitAndTrim(c.GlobalString(utils.HTTPCORSDomainFlag.Name)) srv := rpc.NewServer() err := node.RegisterApisFromWhitelist(rpcAPI, []string{"account"}, srv, false) @@ -736,21 +735,11 @@ func signer(c *cli.Context) error { return nil } -// splitAndTrim splits input separated by a comma -// and trims excessive white space from the substrings. -func splitAndTrim(input string) []string { - result := strings.Split(input, ",") - for i, r := range result { - result[i] = strings.TrimSpace(r) - } - return result -} - // DefaultConfigDir is the default config directory to use for the vaults and other // persistence requirements. func DefaultConfigDir() string { // Try to place the data folder in the user's home dir - home := homeDir() + home := utils.HomeDir() if home != "" { if runtime.GOOS == "darwin" { return filepath.Join(home, "Library", "Signer") @@ -758,26 +747,15 @@ func DefaultConfigDir() string { appdata := os.Getenv("APPDATA") if appdata != "" { return filepath.Join(appdata, "Signer") - } else { - return filepath.Join(home, "AppData", "Roaming", "Signer") } - } else { - return filepath.Join(home, ".clef") + return filepath.Join(home, "AppData", "Roaming", "Signer") } + return filepath.Join(home, ".clef") } // As we cannot guess a stable location, return empty and handle later return "" } -func homeDir() string { - if home := os.Getenv("HOME"); home != "" { - return home - } - if usr, err := user.Current(); err == nil { - return usr.HomeDir - } - return "" -} func readMasterKey(ctx *cli.Context, ui core.UIClientAPI) ([]byte, error) { var ( file string diff --git a/cmd/devp2p/README.md b/cmd/devp2p/README.md new file mode 100644 index 000000000000..e1372d015899 --- /dev/null +++ b/cmd/devp2p/README.md @@ -0,0 +1,105 @@ +# The devp2p command + +The devp2p command line tool is a utility for low-level peer-to-peer debugging and +protocol development purposes. It can do many things. + +### ENR Decoding + +Use `devp2p enrdump ` to verify and display an Ethereum Node Record. + +### Node Key Management + +The `devp2p key ...` command family deals with node key files. + +Run `devp2p key generate mynode.key` to create a new node key in the `mynode.key` file. + +Run `devp2p key to-enode mynode.key -ip 127.0.0.1 -tcp 30303` to create an enode:// URL +corresponding to the given node key and address information. + +### Maintaining DNS Discovery Node Lists + +The devp2p command can create and publish DNS discovery node lists. + +Run `devp2p dns sign ` to update the signature of a DNS discovery tree. + +Run `devp2p dns sync ` to download a complete DNS discovery tree. + +Run `devp2p dns to-cloudflare ` to publish a tree to CloudFlare DNS. + +Run `devp2p dns to-route53 ` to publish a tree to Amazon Route53. + +You can find more information about these commands in the [DNS Discovery Setup Guide][dns-tutorial]. + +### Discovery v4 Utilities + +The `devp2p discv4 ...` command family deals with the [Node Discovery v4][discv4] +protocol. + +Run `devp2p discv4 ping ` to ping a node. + +Run `devp2p discv4 resolve ` to find the most recent node record of a node in +the DHT. + +Run `devp2p discv4 crawl ` to create or update a JSON node set. + +### Discovery v5 Utilities + +The `devp2p discv5 ...` command family deals with the [Node Discovery v5][discv5] +protocol. This protocol is currently under active development. + +Run `devp2p discv5 ping ` to ping a node. + +Run `devp2p discv5 resolve ` to find the most recent node record of a node in +the discv5 DHT. + +Run `devp2p discv5 listen` to run a Discovery v5 node. + +Run `devp2p discv5 crawl ` to create or update a JSON node set containing +discv5 nodes. + +### Discovery Test Suites + +The devp2p command also contains interactive test suites for Discovery v4 and Discovery +v5. + +To run these tests against your implementation, you need to set up a networking +environment where two separate UDP listening addresses are available on the same machine. +The two listening addresses must also be routed such that they are able to reach the node +you want to test. + +For example, if you want to run the test on your local host, and the node under test is +also on the local host, you need to assign two IP addresses (or a larger range) to your +loopback interface. On macOS, this can be done by executing the following command: + + sudo ifconfig lo0 add 127.0.0.2 + +You can now run either test suite as follows: Start the node under test first, ensuring +that it won't talk to the Internet (i.e. disable bootstrapping). An easy way to prevent +unintended connections to the global DHT is listening on `127.0.0.1`. + +Now get the ENR of your node and store it in the `NODE` environment variable. + +Start the test by running `devp2p discv5 test -listen1 127.0.0.1 -listen2 127.0.0.2 $NODE`. + +### Eth Protocol Test Suite + +The Eth Protocol test suite is a conformance test suite for the [eth protocol][eth]. + +To run the eth protocol test suite against your implementation, the node needs to be initialized as such: + +1. initialize the geth node with the `genesis.json` file contained in the `testdata` directory +2. import the `halfchain.rlp` file in the `testdata` directory +3. run geth with the following flags: +``` +geth --datadir --nodiscover --nat=none --networkid 19763 --verbosity 5 +``` + +Then, run the following command, replacing `` with the enode of the geth node: + ``` + devp2p rlpx eth-test cmd/devp2p/internal/ethtest/testdata/fullchain.rlp cmd/devp2p/internal/ethtest/testdata/genesis.json +``` + +[eth]: https://github.com/ethereum/devp2p/blob/master/caps/eth.md +[dns-tutorial]: https://geth.ethereum.org/docs/developers/dns-discovery-setup +[discv4]: https://github.com/ethereum/devp2p/tree/master/discv4.md +[discv5]: https://github.com/ethereum/devp2p/tree/master/discv5/discv5.md diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 99b0957ab37e..3b6dc09a1cc8 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -19,14 +19,12 @@ package main import ( "fmt" "net" - "os" "strings" "time" "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/internal/utesting" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" @@ -82,7 +80,13 @@ var ( Name: "test", Usage: "Runs tests against a node", Action: discv4Test, - Flags: []cli.Flag{remoteEnodeFlag, testPatternFlag, testListen1Flag, testListen2Flag}, + Flags: []cli.Flag{ + remoteEnodeFlag, + testPatternFlag, + testTAPFlag, + testListen1Flag, + testListen2Flag, + }, } ) @@ -113,20 +117,6 @@ var ( Usage: "Enode of the remote node under test", EnvVar: "REMOTE_ENODE", } - testPatternFlag = cli.StringFlag{ - Name: "run", - Usage: "Pattern of test suite(s) to run", - } - testListen1Flag = cli.StringFlag{ - Name: "listen1", - Usage: "IP address of the first tester", - Value: v4test.Listen1, - } - testListen2Flag = cli.StringFlag{ - Name: "listen2", - Usage: "IP address of the second tester", - Value: v4test.Listen2, - } ) func discv4Ping(ctx *cli.Context) error { @@ -213,6 +203,7 @@ func discv4Crawl(ctx *cli.Context) error { return nil } +// discv4Test runs the protocol test suite. func discv4Test(ctx *cli.Context) error { // Configure test package globals. if !ctx.IsSet(remoteEnodeFlag.Name) { @@ -221,18 +212,7 @@ func discv4Test(ctx *cli.Context) error { v4test.Remote = ctx.String(remoteEnodeFlag.Name) v4test.Listen1 = ctx.String(testListen1Flag.Name) v4test.Listen2 = ctx.String(testListen2Flag.Name) - - // Filter and run test cases. - tests := v4test.AllTests - if ctx.IsSet(testPatternFlag.Name) { - tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name)) - } - results := utesting.RunTests(tests, os.Stdout) - if fails := utesting.CountFailures(results); fails > 0 { - return fmt.Errorf("%v/%v tests passed.", len(tests)-fails, len(tests)) - } - fmt.Printf("%v/%v passed\n", len(tests), len(tests)) - return nil + return runTests(ctx, v4test.AllTests) } // startV4 starts an ephemeral discovery V4 node. @@ -286,7 +266,11 @@ func listen(ln *enode.LocalNode, addr string) *net.UDPConn { } usocket := socket.(*net.UDPConn) uaddr := socket.LocalAddr().(*net.UDPAddr) - ln.SetFallbackIP(net.IP{127, 0, 0, 1}) + if uaddr.IP.IsUnspecified() { + ln.SetFallbackIP(net.IP{127, 0, 0, 1}) + } else { + ln.SetFallbackIP(uaddr.IP) + } ln.SetFallbackUDP(uaddr.Port) return usocket } @@ -294,7 +278,11 @@ func listen(ln *enode.LocalNode, addr string) *net.UDPConn { func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { s := params.RinkebyBootnodes if ctx.IsSet(bootnodesFlag.Name) { - s = strings.Split(ctx.String(bootnodesFlag.Name), ",") + input := ctx.String(bootnodesFlag.Name) + if input == "" { + return nil, nil + } + s = strings.Split(input, ",") } nodes := make([]*enode.Node, len(s)) var err error diff --git a/cmd/devp2p/discv5cmd.go b/cmd/devp2p/discv5cmd.go index f871821ea220..e20d7c9cfae6 100644 --- a/cmd/devp2p/discv5cmd.go +++ b/cmd/devp2p/discv5cmd.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v5test" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/p2p/discover" "gopkg.in/urfave/cli.v1" @@ -33,6 +34,7 @@ var ( discv5PingCommand, discv5ResolveCommand, discv5CrawlCommand, + discv5TestCommand, discv5ListenCommand, }, } @@ -53,6 +55,17 @@ var ( Action: discv5Crawl, Flags: []cli.Flag{bootnodesFlag, crawlTimeoutFlag}, } + discv5TestCommand = cli.Command{ + Name: "test", + Usage: "Runs protocol tests against a node", + Action: discv5Test, + Flags: []cli.Flag{ + testPatternFlag, + testTAPFlag, + testListen1Flag, + testListen2Flag, + }, + } discv5ListenCommand = cli.Command{ Name: "listen", Usage: "Runs a node", @@ -103,6 +116,16 @@ func discv5Crawl(ctx *cli.Context) error { return nil } +// discv5Test runs the protocol test suite. +func discv5Test(ctx *cli.Context) error { + suite := &v5test.Suite{ + Dest: getNodeArg(ctx), + Listen1: ctx.String(testListen1Flag.Name), + Listen2: ctx.String(testListen2Flag.Name), + } + return runTests(ctx, suite.AllTests()) +} + func discv5Listen(ctx *cli.Context) error { disc := startV5(ctx) defer disc.Close() diff --git a/cmd/devp2p/dnscmd.go b/cmd/devp2p/dnscmd.go index 13110f21c64c..f56f0f34e456 100644 --- a/cmd/devp2p/dnscmd.go +++ b/cmd/devp2p/dnscmd.go @@ -30,7 +30,7 @@ import ( "github.com/ethereum/go-ethereum/console/prompt" "github.com/ethereum/go-ethereum/p2p/dnsdisc" "github.com/ethereum/go-ethereum/p2p/enode" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) var ( diff --git a/cmd/devp2p/internal/ethtest/chain.go b/cmd/devp2p/internal/ethtest/chain.go new file mode 100644 index 000000000000..d67387e80bf4 --- /dev/null +++ b/cmd/devp2p/internal/ethtest/chain.go @@ -0,0 +1,167 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "compress/gzip" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "math/big" + "os" + "strings" + + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" +) + +type Chain struct { + blocks []*types.Block + chainConfig *params.ChainConfig +} + +func (c *Chain) WriteTo(writer io.Writer) error { + for _, block := range c.blocks { + if err := rlp.Encode(writer, block); err != nil { + return err + } + } + + return nil +} + +// Len returns the length of the chain. +func (c *Chain) Len() int { + return len(c.blocks) +} + +// TD calculates the total difficulty of the chain. +func (c *Chain) TD(height int) *big.Int { // TODO later on channge scheme so that the height is included in range + sum := big.NewInt(0) + for _, block := range c.blocks[:height] { + sum.Add(sum, block.Difficulty()) + } + return sum +} + +// ForkID gets the fork id of the chain. +func (c *Chain) ForkID() forkid.ID { + return forkid.NewID(c.chainConfig, c.blocks[0].Hash(), uint64(c.Len())) +} + +// Shorten returns a copy chain of a desired height from the imported +func (c *Chain) Shorten(height int) *Chain { + blocks := make([]*types.Block, height) + copy(blocks, c.blocks[:height]) + + config := *c.chainConfig + return &Chain{ + blocks: blocks, + chainConfig: &config, + } +} + +// Head returns the chain head. +func (c *Chain) Head() *types.Block { + return c.blocks[c.Len()-1] +} + +func (c *Chain) GetHeaders(req GetBlockHeaders) (BlockHeaders, error) { + if req.Amount < 1 { + return nil, fmt.Errorf("no block headers requested") + } + + headers := make(BlockHeaders, req.Amount) + var blockNumber uint64 + + // range over blocks to check if our chain has the requested header + for _, block := range c.blocks { + if block.Hash() == req.Origin.Hash || block.Number().Uint64() == req.Origin.Number { + headers[0] = block.Header() + blockNumber = block.Number().Uint64() + } + } + if headers[0] == nil { + return nil, fmt.Errorf("no headers found for given origin number %v, hash %v", req.Origin.Number, req.Origin.Hash) + } + + if req.Reverse { + for i := 1; i < int(req.Amount); i++ { + blockNumber -= (1 - req.Skip) + headers[i] = c.blocks[blockNumber].Header() + + } + + return headers, nil + } + + for i := 1; i < int(req.Amount); i++ { + blockNumber += (1 + req.Skip) + headers[i] = c.blocks[blockNumber].Header() + } + + return headers, nil +} + +// loadChain takes the given chain.rlp file, and decodes and returns +// the blocks from the file. +func loadChain(chainfile string, genesis string) (*Chain, error) { + chainConfig, err := ioutil.ReadFile(genesis) + if err != nil { + return nil, err + } + var gen core.Genesis + if err := json.Unmarshal(chainConfig, &gen); err != nil { + return nil, err + } + gblock := gen.ToBlock(nil) + + // Load chain.rlp. + fh, err := os.Open(chainfile) + if err != nil { + return nil, err + } + defer fh.Close() + var reader io.Reader = fh + if strings.HasSuffix(chainfile, ".gz") { + if reader, err = gzip.NewReader(reader); err != nil { + return nil, err + } + } + stream := rlp.NewStream(reader, 0) + var blocks = make([]*types.Block, 1) + blocks[0] = gblock + for i := 0; ; i++ { + var b types.Block + if err := stream.Decode(&b); err == io.EOF { + break + } else if err != nil { + return nil, fmt.Errorf("at block index %d: %v", i, err) + } + if b.NumberU64() != uint64(i+1) { + return nil, fmt.Errorf("block at index %d has wrong number %d", i, b.NumberU64()) + } + blocks = append(blocks, &b) + } + + c := &Chain{blocks: blocks, chainConfig: gen.Config} + return c, nil +} diff --git a/cmd/devp2p/internal/ethtest/chain_test.go b/cmd/devp2p/internal/ethtest/chain_test.go new file mode 100644 index 000000000000..ac3907ce8dc7 --- /dev/null +++ b/cmd/devp2p/internal/ethtest/chain_test.go @@ -0,0 +1,150 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "path/filepath" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/stretchr/testify/assert" +) + +// TestEthProtocolNegotiation tests whether the test suite +// can negotiate the highest eth protocol in a status message exchange +func TestEthProtocolNegotiation(t *testing.T) { + var tests = []struct { + conn *Conn + caps []p2p.Cap + expected uint32 + }{ + { + conn: &Conn{}, + caps: []p2p.Cap{ + {Name: "eth", Version: 63}, + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + expected: uint32(65), + }, + { + conn: &Conn{}, + caps: []p2p.Cap{ + {Name: "eth", Version: 0}, + {Name: "eth", Version: 89}, + {Name: "eth", Version: 65}, + }, + expected: uint32(65), + }, + { + conn: &Conn{}, + caps: []p2p.Cap{ + {Name: "eth", Version: 63}, + {Name: "eth", Version: 64}, + {Name: "wrongProto", Version: 65}, + }, + expected: uint32(64), + }, + } + + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tt.conn.negotiateEthProtocol(tt.caps) + assert.Equal(t, tt.expected, uint32(tt.conn.ethProtocolVersion)) + }) + } +} + +// TestChain_GetHeaders tests whether the test suite can correctly +// respond to a GetBlockHeaders request from a node. +func TestChain_GetHeaders(t *testing.T) { + chainFile, err := filepath.Abs("./testdata/chain.rlp") + if err != nil { + t.Fatal(err) + } + genesisFile, err := filepath.Abs("./testdata/genesis.json") + if err != nil { + t.Fatal(err) + } + + chain, err := loadChain(chainFile, genesisFile) + if err != nil { + t.Fatal(err) + } + + var tests = []struct { + req GetBlockHeaders + expected BlockHeaders + }{ + { + req: GetBlockHeaders{ + Origin: hashOrNumber{ + Number: uint64(2), + }, + Amount: uint64(5), + Skip: 1, + Reverse: false, + }, + expected: BlockHeaders{ + chain.blocks[2].Header(), + chain.blocks[4].Header(), + chain.blocks[6].Header(), + chain.blocks[8].Header(), + chain.blocks[10].Header(), + }, + }, + { + req: GetBlockHeaders{ + Origin: hashOrNumber{ + Number: uint64(chain.Len() - 1), + }, + Amount: uint64(3), + Skip: 0, + Reverse: true, + }, + expected: BlockHeaders{ + chain.blocks[chain.Len()-1].Header(), + chain.blocks[chain.Len()-2].Header(), + chain.blocks[chain.Len()-3].Header(), + }, + }, + { + req: GetBlockHeaders{ + Origin: hashOrNumber{ + Hash: chain.Head().Hash(), + }, + Amount: uint64(1), + Skip: 0, + Reverse: false, + }, + expected: BlockHeaders{ + chain.Head().Header(), + }, + }, + } + + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + headers, err := chain.GetHeaders(tt.req) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, headers, tt.expected) + }) + } +} diff --git a/cmd/devp2p/internal/ethtest/large.go b/cmd/devp2p/internal/ethtest/large.go new file mode 100644 index 000000000000..deca00be5384 --- /dev/null +++ b/cmd/devp2p/internal/ethtest/large.go @@ -0,0 +1,80 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "crypto/rand" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" +) + +// largeNumber returns a very large big.Int. +func largeNumber(megabytes int) *big.Int { + buf := make([]byte, megabytes*1024*1024) + rand.Read(buf) + bigint := new(big.Int) + bigint.SetBytes(buf) + return bigint +} + +// largeBuffer returns a very large buffer. +func largeBuffer(megabytes int) []byte { + buf := make([]byte, megabytes*1024*1024) + rand.Read(buf) + return buf +} + +// largeString returns a very large string. +func largeString(megabytes int) string { + buf := make([]byte, megabytes*1024*1024) + rand.Read(buf) + return hexutil.Encode(buf) +} + +func largeBlock() *types.Block { + return types.NewBlockWithHeader(largeHeader()) +} + +// Returns a random hash +func randHash() common.Hash { + var h common.Hash + rand.Read(h[:]) + return h +} + +func largeHeader() *types.Header { + return &types.Header{ + MixDigest: randHash(), + ReceiptHash: randHash(), + TxHash: randHash(), + Nonce: types.BlockNonce{}, + Extra: []byte{}, + Bloom: types.Bloom{}, + GasUsed: 0, + Coinbase: common.Address{}, + GasLimit: 0, + UncleHash: randHash(), + Time: 1337, + ParentHash: randHash(), + Root: randHash(), + Number: largeNumber(2), + Difficulty: largeNumber(2), + } +} diff --git a/cmd/devp2p/internal/ethtest/suite.go b/cmd/devp2p/internal/ethtest/suite.go new file mode 100644 index 000000000000..5d0cdda720ab --- /dev/null +++ b/cmd/devp2p/internal/ethtest/suite.go @@ -0,0 +1,426 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "fmt" + "net" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/rlpx" + "github.com/stretchr/testify/assert" +) + +var pretty = spew.ConfigState{ + Indent: " ", + DisableCapacities: true, + DisablePointerAddresses: true, + SortKeys: true, +} + +var timeout = 20 * time.Second + +// Suite represents a structure used to test the eth +// protocol of a node(s). +type Suite struct { + Dest *enode.Node + + chain *Chain + fullChain *Chain +} + +// NewSuite creates and returns a new eth-test suite that can +// be used to test the given node against the given blockchain +// data. +func NewSuite(dest *enode.Node, chainfile string, genesisfile string) *Suite { + chain, err := loadChain(chainfile, genesisfile) + if err != nil { + panic(err) + } + return &Suite{ + Dest: dest, + chain: chain.Shorten(1000), + fullChain: chain, + } +} + +func (s *Suite) AllTests() []utesting.Test { + return []utesting.Test{ + {Name: "Status", Fn: s.TestStatus}, + {Name: "GetBlockHeaders", Fn: s.TestGetBlockHeaders}, + {Name: "Broadcast", Fn: s.TestBroadcast}, + {Name: "GetBlockBodies", Fn: s.TestGetBlockBodies}, + {Name: "TestLargeAnnounce", Fn: s.TestLargeAnnounce}, + {Name: "TestMaliciousHandshake", Fn: s.TestMaliciousHandshake}, + {Name: "TestMaliciousStatus", Fn: s.TestMaliciousStatus}, + {Name: "TestTransactions", Fn: s.TestTransaction}, + {Name: "TestMaliciousTransactions", Fn: s.TestMaliciousTx}, + } +} + +// TestStatus attempts to connect to the given node and exchange +// a status message with it, and then check to make sure +// the chain head is correct. +func (s *Suite) TestStatus(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + // get protoHandshake + conn.handshake(t) + // get status + switch msg := conn.statusExchange(t, s.chain, nil).(type) { + case *Status: + t.Logf("got status message: %s", pretty.Sdump(msg)) + default: + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } +} + +// TestMaliciousStatus sends a status package with a large total difficulty. +func (s *Suite) TestMaliciousStatus(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + // get protoHandshake + conn.handshake(t) + status := &Status{ + ProtocolVersion: uint32(conn.ethProtocolVersion), + NetworkID: s.chain.chainConfig.ChainID.Uint64(), + TD: largeNumber(2), + Head: s.chain.blocks[s.chain.Len()-1].Hash(), + Genesis: s.chain.blocks[0].Hash(), + ForkID: s.chain.ForkID(), + } + // get status + switch msg := conn.statusExchange(t, s.chain, status).(type) { + case *Status: + t.Logf("%+v\n", msg) + default: + t.Fatalf("expected status, got: %#v ", msg) + } + // wait for disconnect + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *Disconnect: + case *Error: + return + default: + t.Fatalf("expected disconnect, got: %s", pretty.Sdump(msg)) + } +} + +// TestGetBlockHeaders tests whether the given node can respond to +// a `GetBlockHeaders` request and that the response is accurate. +func (s *Suite) TestGetBlockHeaders(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + + conn.handshake(t) + conn.statusExchange(t, s.chain, nil) + + // get block headers + req := &GetBlockHeaders{ + Origin: hashOrNumber{ + Hash: s.chain.blocks[1].Hash(), + }, + Amount: 2, + Skip: 1, + Reverse: false, + } + + if err := conn.Write(req); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *BlockHeaders: + headers := msg + for _, header := range *headers { + num := header.Number.Uint64() + t.Logf("received header (%d): %s", num, pretty.Sdump(header)) + assert.Equal(t, s.chain.blocks[int(num)].Header(), header) + } + default: + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } +} + +// TestGetBlockBodies tests whether the given node can respond to +// a `GetBlockBodies` request and that the response is accurate. +func (s *Suite) TestGetBlockBodies(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + + conn.handshake(t) + conn.statusExchange(t, s.chain, nil) + // create block bodies request + req := &GetBlockBodies{s.chain.blocks[54].Hash(), s.chain.blocks[75].Hash()} + if err := conn.Write(req); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *BlockBodies: + t.Logf("received %d block bodies", len(*msg)) + default: + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } +} + +// TestBroadcast tests whether a block announcement is correctly +// propagated to the given node's peer(s). +func (s *Suite) TestBroadcast(t *utesting.T) { + sendConn, receiveConn := s.setupConnection(t), s.setupConnection(t) + nextBlock := len(s.chain.blocks) + blockAnnouncement := &NewBlock{ + Block: s.fullChain.blocks[nextBlock], + TD: s.fullChain.TD(nextBlock + 1), + } + s.testAnnounce(t, sendConn, receiveConn, blockAnnouncement) + // update test suite chain + s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock]) + // wait for client to update its chain + if err := receiveConn.waitForBlock(s.chain.Head()); err != nil { + t.Fatal(err) + } +} + +// TestMaliciousHandshake tries to send malicious data during the handshake. +func (s *Suite) TestMaliciousHandshake(t *utesting.T) { + conn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + // write hello to client + pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:] + handshakes := []*Hello{ + { + Version: 5, + Caps: []p2p.Cap{ + {Name: largeString(2), Version: 64}, + }, + ID: pub0, + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: append(pub0, byte(0)), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: append(pub0, pub0...), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: largeBuffer(2), + }, + { + Version: 5, + Caps: []p2p.Cap{ + {Name: largeString(2), Version: 64}, + }, + ID: largeBuffer(2), + }, + } + for i, handshake := range handshakes { + t.Logf("Testing malicious handshake %v\n", i) + // Init the handshake + if err := conn.Write(handshake); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // check that the peer disconnected + timeout := 20 * time.Second + // Discard one hello + for i := 0; i < 2; i++ { + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *Disconnect: + case *Error: + case *Hello: + // Hello's are send concurrently, so ignore them + continue + default: + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } + } + // Dial for the next round + conn, err = s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + } +} + +// TestLargeAnnounce tests the announcement mechanism with a large block. +func (s *Suite) TestLargeAnnounce(t *utesting.T) { + nextBlock := len(s.chain.blocks) + blocks := []*NewBlock{ + { + Block: largeBlock(), + TD: s.fullChain.TD(nextBlock + 1), + }, + { + Block: s.fullChain.blocks[nextBlock], + TD: largeNumber(2), + }, + { + Block: largeBlock(), + TD: largeNumber(2), + }, + { + Block: s.fullChain.blocks[nextBlock], + TD: s.fullChain.TD(nextBlock + 1), + }, + } + + for i, blockAnnouncement := range blocks[0:3] { + t.Logf("Testing malicious announcement: %v\n", i) + sendConn := s.setupConnection(t) + if err := sendConn.Write(blockAnnouncement); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // Invalid announcement, check that peer disconnected + switch msg := sendConn.ReadAndServe(s.chain, timeout).(type) { + case *Disconnect: + case *Error: + break + default: + t.Fatalf("unexpected: %s wanted disconnect", pretty.Sdump(msg)) + } + } + // Test the last block as a valid block + sendConn := s.setupConnection(t) + receiveConn := s.setupConnection(t) + s.testAnnounce(t, sendConn, receiveConn, blocks[3]) + // update test suite chain + s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock]) + // wait for client to update its chain + if err := receiveConn.waitForBlock(s.fullChain.blocks[nextBlock]); err != nil { + t.Fatal(err) + } +} + +func (s *Suite) testAnnounce(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) { + // Announce the block. + if err := sendConn.Write(blockAnnouncement); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + s.waitAnnounce(t, receiveConn, blockAnnouncement) +} + +func (s *Suite) waitAnnounce(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) { + timeout := 20 * time.Second + switch msg := conn.ReadAndServe(s.chain, timeout).(type) { + case *NewBlock: + t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block)) + assert.Equal(t, + blockAnnouncement.Block.Header(), msg.Block.Header(), + "wrong block header in announcement", + ) + assert.Equal(t, + blockAnnouncement.TD, msg.TD, + "wrong TD in announcement", + ) + case *NewBlockHashes: + hashes := *msg + t.Logf("received NewBlockHashes message: %s", pretty.Sdump(hashes)) + assert.Equal(t, + blockAnnouncement.Block.Hash(), hashes[0].Hash, + "wrong block hash in announcement", + ) + default: + t.Fatalf("unexpected: %s", pretty.Sdump(msg)) + } +} + +func (s *Suite) setupConnection(t *utesting.T) *Conn { + // create conn + sendConn, err := s.dial() + if err != nil { + t.Fatalf("could not dial: %v", err) + } + sendConn.handshake(t) + sendConn.statusExchange(t, s.chain, nil) + return sendConn +} + +// dial attempts to dial the given node and perform a handshake, +// returning the created Conn if successful. +func (s *Suite) dial() (*Conn, error) { + var conn Conn + + fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP())) + if err != nil { + return nil, err + } + conn.Conn = rlpx.NewConn(fd, s.Dest.Pubkey()) + + // do encHandshake + conn.ourKey, _ = crypto.GenerateKey() + _, err = conn.Handshake(conn.ourKey) + if err != nil { + return nil, err + } + + return &conn, nil +} + +func (s *Suite) TestTransaction(t *utesting.T) { + tests := []*types.Transaction{ + getNextTxFromChain(t, s), + unknownTx(t, s), + } + for i, tx := range tests { + t.Logf("Testing tx propagation: %v\n", i) + sendSuccessfulTx(t, s, tx) + } +} + +func (s *Suite) TestMaliciousTx(t *utesting.T) { + tests := []*types.Transaction{ + getOldTxFromChain(t, s), + invalidNonceTx(t, s), + hugeAmount(t, s), + hugeGasPrice(t, s), + hugeData(t, s), + } + for i, tx := range tests { + t.Logf("Testing malicious tx propagation: %v\n", i) + sendFailingTx(t, s, tx) + } +} diff --git a/cmd/devp2p/internal/ethtest/testdata/chain.rlp b/cmd/devp2p/internal/ethtest/testdata/chain.rlp new file mode 100644 index 000000000000..5ebc2f3bb788 Binary files /dev/null and b/cmd/devp2p/internal/ethtest/testdata/chain.rlp differ diff --git a/cmd/devp2p/internal/ethtest/testdata/genesis.json b/cmd/devp2p/internal/ethtest/testdata/genesis.json new file mode 100644 index 000000000000..d8b5d225024e --- /dev/null +++ b/cmd/devp2p/internal/ethtest/testdata/genesis.json @@ -0,0 +1,26 @@ +{ + "config": { + "chainId": 19763, + "homesteadBlock": 0, + "eip150Block": 0, + "eip155Block": 0, + "eip158Block": 0, + "byzantiumBlock": 0, + "ethash": {} + }, + "nonce": "0xdeadbeefdeadbeef", + "timestamp": "0x0", + "extraData": "0x0000000000000000000000000000000000000000000000000000000000000000", + "gasLimit": "0x80000000", + "difficulty": "0x20000", + "mixHash": "0x0000000000000000000000000000000000000000000000000000000000000000", + "coinbase": "0x0000000000000000000000000000000000000000", + "alloc": { + "71562b71999873db5b286df957af199ec94617f7": { + "balance": "0xffffffffffffffffffffffffff" + } + }, + "number": "0x0", + "gasUsed": "0x0", + "parentHash": "0x0000000000000000000000000000000000000000000000000000000000000000" +} diff --git a/cmd/devp2p/internal/ethtest/testdata/halfchain.rlp b/cmd/devp2p/internal/ethtest/testdata/halfchain.rlp new file mode 100644 index 000000000000..1a820734e10c Binary files /dev/null and b/cmd/devp2p/internal/ethtest/testdata/halfchain.rlp differ diff --git a/cmd/devp2p/internal/ethtest/transaction.go b/cmd/devp2p/internal/ethtest/transaction.go new file mode 100644 index 000000000000..4aaab8bf97d4 --- /dev/null +++ b/cmd/devp2p/internal/ethtest/transaction.go @@ -0,0 +1,180 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" +) + +//var faucetAddr = common.HexToAddress("0x71562b71999873DB5b286dF957af199Ec94617F7") +var faucetKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + +func sendSuccessfulTx(t *utesting.T, s *Suite, tx *types.Transaction) { + sendConn := s.setupConnection(t) + t.Logf("sending tx: %v %v %v\n", tx.Hash().String(), tx.GasPrice(), tx.Gas()) + // Send the transaction + if err := sendConn.Write(Transactions([]*types.Transaction{tx})); err != nil { + t.Fatal(err) + } + time.Sleep(100 * time.Millisecond) + recvConn := s.setupConnection(t) + // Wait for the transaction announcement + switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { + case *Transactions: + recTxs := *msg + if len(recTxs) < 1 { + t.Fatalf("received transactions do not match send: %v", recTxs) + } + if tx.Hash() != recTxs[len(recTxs)-1].Hash() { + t.Fatalf("received transactions do not match send: got %v want %v", recTxs, tx) + } + case *NewPooledTransactionHashes: + txHashes := *msg + if len(txHashes) < 1 { + t.Fatalf("received transactions do not match send: %v", txHashes) + } + if tx.Hash() != txHashes[len(txHashes)-1] { + t.Fatalf("wrong announcement received, wanted %v got %v", tx, txHashes) + } + default: + t.Fatalf("unexpected message in sendSuccessfulTx: %s", pretty.Sdump(msg)) + } +} + +func sendFailingTx(t *utesting.T, s *Suite, tx *types.Transaction) { + sendConn, recvConn := s.setupConnection(t), s.setupConnection(t) + // Wait for a transaction announcement + switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { + case *NewPooledTransactionHashes: + break + default: + t.Logf("unexpected message, logging: %v", pretty.Sdump(msg)) + } + // Send the transaction + if err := sendConn.Write(Transactions([]*types.Transaction{tx})); err != nil { + t.Fatal(err) + } + // Wait for another transaction announcement + switch msg := recvConn.ReadAndServe(s.chain, timeout).(type) { + case *Transactions: + t.Fatalf("Received unexpected transaction announcement: %v", msg) + case *NewPooledTransactionHashes: + t.Fatalf("Received unexpected pooledTx announcement: %v", msg) + case *Error: + // Transaction should not be announced -> wait for timeout + return + default: + t.Fatalf("unexpected message in sendFailingTx: %s", pretty.Sdump(msg)) + } +} + +func unknownTx(t *utesting.T, s *Suite) *types.Transaction { + tx := getNextTxFromChain(t, s) + var to common.Address + if tx.To() != nil { + to = *tx.To() + } + txNew := types.NewTransaction(tx.Nonce()+1, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data()) + return signWithFaucet(t, txNew) +} + +func getNextTxFromChain(t *utesting.T, s *Suite) *types.Transaction { + // Get a new transaction + var tx *types.Transaction + for _, blocks := range s.fullChain.blocks[s.chain.Len():] { + txs := blocks.Transactions() + if txs.Len() != 0 { + tx = txs[0] + break + } + } + if tx == nil { + t.Fatal("could not find transaction") + } + return tx +} + +func getOldTxFromChain(t *utesting.T, s *Suite) *types.Transaction { + var tx *types.Transaction + for _, blocks := range s.fullChain.blocks[:s.chain.Len()-1] { + txs := blocks.Transactions() + if txs.Len() != 0 { + tx = txs[0] + break + } + } + if tx == nil { + t.Fatal("could not find transaction") + } + return tx +} + +func invalidNonceTx(t *utesting.T, s *Suite) *types.Transaction { + tx := getNextTxFromChain(t, s) + var to common.Address + if tx.To() != nil { + to = *tx.To() + } + txNew := types.NewTransaction(tx.Nonce()-2, to, tx.Value(), tx.Gas(), tx.GasPrice(), tx.Data()) + return signWithFaucet(t, txNew) +} + +func hugeAmount(t *utesting.T, s *Suite) *types.Transaction { + tx := getNextTxFromChain(t, s) + amount := largeNumber(2) + var to common.Address + if tx.To() != nil { + to = *tx.To() + } + txNew := types.NewTransaction(tx.Nonce(), to, amount, tx.Gas(), tx.GasPrice(), tx.Data()) + return signWithFaucet(t, txNew) +} + +func hugeGasPrice(t *utesting.T, s *Suite) *types.Transaction { + tx := getNextTxFromChain(t, s) + gasPrice := largeNumber(2) + var to common.Address + if tx.To() != nil { + to = *tx.To() + } + txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), gasPrice, tx.Data()) + return signWithFaucet(t, txNew) +} + +func hugeData(t *utesting.T, s *Suite) *types.Transaction { + tx := getNextTxFromChain(t, s) + var to common.Address + if tx.To() != nil { + to = *tx.To() + } + txNew := types.NewTransaction(tx.Nonce(), to, tx.Value(), tx.Gas(), tx.GasPrice(), largeBuffer(2)) + return signWithFaucet(t, txNew) +} + +func signWithFaucet(t *utesting.T, tx *types.Transaction) *types.Transaction { + signer := types.HomesteadSigner{} + signedTx, err := types.SignTx(tx, signer, faucetKey) + if err != nil { + t.Fatalf("could not sign tx: %v\n", err) + } + return signedTx +} diff --git a/cmd/devp2p/internal/ethtest/types.go b/cmd/devp2p/internal/ethtest/types.go new file mode 100644 index 000000000000..f768d61d5ae7 --- /dev/null +++ b/cmd/devp2p/internal/ethtest/types.go @@ -0,0 +1,395 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethtest + +import ( + "crypto/ecdsa" + "fmt" + "io" + "math/big" + "reflect" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/rlpx" + "github.com/ethereum/go-ethereum/rlp" +) + +type Message interface { + Code() int +} + +type Error struct { + err error +} + +func (e *Error) Unwrap() error { return e.err } +func (e *Error) Error() string { return e.err.Error() } +func (e *Error) Code() int { return -1 } +func (e *Error) String() string { return e.Error() } + +func errorf(format string, args ...interface{}) *Error { + return &Error{fmt.Errorf(format, args...)} +} + +// Hello is the RLP structure of the protocol handshake. +type Hello struct { + Version uint64 + Name string + Caps []p2p.Cap + ListenPort uint64 + ID []byte // secp256k1 public key + + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` +} + +func (h Hello) Code() int { return 0x00 } + +// Disconnect is the RLP structure for a disconnect message. +type Disconnect struct { + Reason p2p.DiscReason +} + +func (d Disconnect) Code() int { return 0x01 } + +type Ping struct{} + +func (p Ping) Code() int { return 0x02 } + +type Pong struct{} + +func (p Pong) Code() int { return 0x03 } + +// Status is the network packet for the status message for eth/64 and later. +type Status struct { + ProtocolVersion uint32 + NetworkID uint64 + TD *big.Int + Head common.Hash + Genesis common.Hash + ForkID forkid.ID +} + +func (s Status) Code() int { return 16 } + +// NewBlockHashes is the network packet for the block announcements. +type NewBlockHashes []struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced +} + +func (nbh NewBlockHashes) Code() int { return 17 } + +type Transactions []*types.Transaction + +func (t Transactions) Code() int { return 18 } + +// GetBlockHeaders represents a block header query. +type GetBlockHeaders struct { + Origin hashOrNumber // Block from which to retrieve headers + Amount uint64 // Maximum number of headers to retrieve + Skip uint64 // Blocks to skip between consecutive headers + Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +} + +func (g GetBlockHeaders) Code() int { return 19 } + +type BlockHeaders []*types.Header + +func (bh BlockHeaders) Code() int { return 20 } + +// GetBlockBodies represents a GetBlockBodies request +type GetBlockBodies []common.Hash + +func (gbb GetBlockBodies) Code() int { return 21 } + +// BlockBodies is the network packet for block content distribution. +type BlockBodies []*types.Body + +func (bb BlockBodies) Code() int { return 22 } + +// NewBlock is the network packet for the block propagation message. +type NewBlock struct { + Block *types.Block + TD *big.Int +} + +func (nb NewBlock) Code() int { return 23 } + +// NewPooledTransactionHashes is the network packet for the tx hash propagation message. +type NewPooledTransactionHashes [][32]byte + +func (nb NewPooledTransactionHashes) Code() int { return 24 } + +// HashOrNumber is a combined field for specifying an origin block. +type hashOrNumber struct { + Hash common.Hash // Block hash from which to retrieve headers (excludes Number) + Number uint64 // Block hash from which to retrieve headers (excludes Hash) +} + +// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the +// two contained union fields. +func (hn *hashOrNumber) EncodeRLP(w io.Writer) error { + if hn.Hash == (common.Hash{}) { + return rlp.Encode(w, hn.Number) + } + if hn.Number != 0 { + return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number) + } + return rlp.Encode(w, hn.Hash) +} + +// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents +// into either a block hash or a block number. +func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error { + _, size, _ := s.Kind() + origin, err := s.Raw() + if err == nil { + switch { + case size == 32: + err = rlp.DecodeBytes(origin, &hn.Hash) + case size <= 8: + err = rlp.DecodeBytes(origin, &hn.Number) + default: + err = fmt.Errorf("invalid input size %d for origin", size) + } + } + return err +} + +// Conn represents an individual connection with a peer +type Conn struct { + *rlpx.Conn + ourKey *ecdsa.PrivateKey + ethProtocolVersion uint +} + +func (c *Conn) Read() Message { + code, rawData, _, err := c.Conn.Read() + if err != nil { + return errorf("could not read from connection: %v", err) + } + + var msg Message + switch int(code) { + case (Hello{}).Code(): + msg = new(Hello) + case (Ping{}).Code(): + msg = new(Ping) + case (Pong{}).Code(): + msg = new(Pong) + case (Disconnect{}).Code(): + msg = new(Disconnect) + case (Status{}).Code(): + msg = new(Status) + case (GetBlockHeaders{}).Code(): + msg = new(GetBlockHeaders) + case (BlockHeaders{}).Code(): + msg = new(BlockHeaders) + case (GetBlockBodies{}).Code(): + msg = new(GetBlockBodies) + case (BlockBodies{}).Code(): + msg = new(BlockBodies) + case (NewBlock{}).Code(): + msg = new(NewBlock) + case (NewBlockHashes{}).Code(): + msg = new(NewBlockHashes) + case (Transactions{}).Code(): + msg = new(Transactions) + case (NewPooledTransactionHashes{}).Code(): + msg = new(NewPooledTransactionHashes) + default: + return errorf("invalid message code: %d", code) + } + + if err := rlp.DecodeBytes(rawData, msg); err != nil { + return errorf("could not rlp decode message: %v", err) + } + return msg +} + +// ReadAndServe serves GetBlockHeaders requests while waiting +// on another message from the node. +func (c *Conn) ReadAndServe(chain *Chain, timeout time.Duration) Message { + start := time.Now() + for time.Since(start) < timeout { + timeout := time.Now().Add(10 * time.Second) + c.SetReadDeadline(timeout) + switch msg := c.Read().(type) { + case *Ping: + c.Write(&Pong{}) + case *GetBlockHeaders: + req := *msg + headers, err := chain.GetHeaders(req) + if err != nil { + return errorf("could not get headers for inbound header request: %v", err) + } + + if err := c.Write(headers); err != nil { + return errorf("could not write to connection: %v", err) + } + default: + return msg + } + } + return errorf("no message received within %v", timeout) +} + +func (c *Conn) Write(msg Message) error { + payload, err := rlp.EncodeToBytes(msg) + if err != nil { + return err + } + _, err = c.Conn.Write(uint64(msg.Code()), payload) + return err +} + +// handshake checks to make sure a `HELLO` is received. +func (c *Conn) handshake(t *utesting.T) Message { + defer c.SetDeadline(time.Time{}) + c.SetDeadline(time.Now().Add(10 * time.Second)) + + // write hello to client + pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:] + ourHandshake := &Hello{ + Version: 5, + Caps: []p2p.Cap{ + {Name: "eth", Version: 64}, + {Name: "eth", Version: 65}, + }, + ID: pub0, + } + if err := c.Write(ourHandshake); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + // read hello from client + switch msg := c.Read().(type) { + case *Hello: + // set snappy if version is at least 5 + if msg.Version >= 5 { + c.SetSnappy(true) + } + c.negotiateEthProtocol(msg.Caps) + if c.ethProtocolVersion == 0 { + t.Fatalf("unexpected eth protocol version") + } + return msg + default: + t.Fatalf("bad handshake: %#v", msg) + return nil + } +} + +// negotiateEthProtocol sets the Conn's eth protocol version +// to highest advertised capability from peer +func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) { + var highestEthVersion uint + for _, capability := range caps { + if capability.Name != "eth" { + continue + } + if capability.Version > highestEthVersion && capability.Version <= 65 { + highestEthVersion = capability.Version + } + } + c.ethProtocolVersion = highestEthVersion +} + +// statusExchange performs a `Status` message exchange with the given +// node. +func (c *Conn) statusExchange(t *utesting.T, chain *Chain, status *Status) Message { + defer c.SetDeadline(time.Time{}) + c.SetDeadline(time.Now().Add(20 * time.Second)) + + // read status message from client + var message Message +loop: + for { + switch msg := c.Read().(type) { + case *Status: + if msg.Head != chain.blocks[chain.Len()-1].Hash() { + t.Fatalf("wrong head block in status: %s", msg.Head.String()) + } + if msg.TD.Cmp(chain.TD(chain.Len())) != 0 { + t.Fatalf("wrong TD in status: %v", msg.TD) + } + if !reflect.DeepEqual(msg.ForkID, chain.ForkID()) { + t.Fatalf("wrong fork ID in status: %v", msg.ForkID) + } + message = msg + break loop + case *Disconnect: + t.Fatalf("disconnect received: %v", msg.Reason) + case *Ping: + c.Write(&Pong{}) // TODO (renaynay): in the future, this should be an error + // (PINGs should not be a response upon fresh connection) + default: + t.Fatalf("bad status message: %s", pretty.Sdump(msg)) + } + } + // make sure eth protocol version is set for negotiation + if c.ethProtocolVersion == 0 { + t.Fatalf("eth protocol version must be set in Conn") + } + if status == nil { + // write status message to client + status = &Status{ + ProtocolVersion: uint32(c.ethProtocolVersion), + NetworkID: chain.chainConfig.ChainID.Uint64(), + TD: chain.TD(chain.Len()), + Head: chain.blocks[chain.Len()-1].Hash(), + Genesis: chain.blocks[0].Hash(), + ForkID: chain.ForkID(), + } + } + + if err := c.Write(*status); err != nil { + t.Fatalf("could not write to connection: %v", err) + } + + return message +} + +// waitForBlock waits for confirmation from the client that it has +// imported the given block. +func (c *Conn) waitForBlock(block *types.Block) error { + defer c.SetReadDeadline(time.Time{}) + + timeout := time.Now().Add(20 * time.Second) + c.SetReadDeadline(timeout) + for { + req := &GetBlockHeaders{Origin: hashOrNumber{Hash: block.Hash()}, Amount: 1} + if err := c.Write(req); err != nil { + return err + } + switch msg := c.Read().(type) { + case *BlockHeaders: + if len(*msg) > 0 { + return nil + } + time.Sleep(100 * time.Millisecond) + default: + return fmt.Errorf("invalid message: %s", pretty.Sdump(msg)) + } + } +} diff --git a/cmd/devp2p/internal/v5test/discv5tests.go b/cmd/devp2p/internal/v5test/discv5tests.go new file mode 100644 index 000000000000..7866498f7376 --- /dev/null +++ b/cmd/devp2p/internal/v5test/discv5tests.go @@ -0,0 +1,377 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package v5test + +import ( + "bytes" + "net" + "sync" + "time" + + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/netutil" +) + +// Suite is the discv5 test suite. +type Suite struct { + Dest *enode.Node + Listen1, Listen2 string // listening addresses +} + +func (s *Suite) listen1(log logger) (*conn, net.PacketConn) { + c := newConn(s.Dest, log) + l := c.listen(s.Listen1) + return c, l +} + +func (s *Suite) listen2(log logger) (*conn, net.PacketConn, net.PacketConn) { + c := newConn(s.Dest, log) + l1, l2 := c.listen(s.Listen1), c.listen(s.Listen2) + return c, l1, l2 +} + +func (s *Suite) AllTests() []utesting.Test { + return []utesting.Test{ + {Name: "Ping", Fn: s.TestPing}, + {Name: "PingLargeRequestID", Fn: s.TestPingLargeRequestID}, + {Name: "PingMultiIP", Fn: s.TestPingMultiIP}, + {Name: "PingHandshakeInterrupted", Fn: s.TestPingHandshakeInterrupted}, + {Name: "TalkRequest", Fn: s.TestTalkRequest}, + {Name: "FindnodeZeroDistance", Fn: s.TestFindnodeZeroDistance}, + {Name: "FindnodeResults", Fn: s.TestFindnodeResults}, + } +} + +// This test sends PING and expects a PONG response. +func (s *Suite) TestPing(t *utesting.T) { + conn, l1 := s.listen1(t) + defer conn.close() + + ping := &v5wire.Ping{ReqID: conn.nextReqID()} + switch resp := conn.reqresp(l1, ping).(type) { + case *v5wire.Pong: + checkPong(t, resp, ping, l1) + default: + t.Fatal("expected PONG, got", resp.Name()) + } +} + +func checkPong(t *utesting.T, pong *v5wire.Pong, ping *v5wire.Ping, c net.PacketConn) { + if !bytes.Equal(pong.ReqID, ping.ReqID) { + t.Fatalf("wrong request ID %x in PONG, want %x", pong.ReqID, ping.ReqID) + } + if !pong.ToIP.Equal(laddr(c).IP) { + t.Fatalf("wrong destination IP %v in PONG, want %v", pong.ToIP, laddr(c).IP) + } + if int(pong.ToPort) != laddr(c).Port { + t.Fatalf("wrong destination port %v in PONG, want %v", pong.ToPort, laddr(c).Port) + } +} + +// This test sends PING with a 9-byte request ID, which isn't allowed by the spec. +// The remote node should not respond. +func (s *Suite) TestPingLargeRequestID(t *utesting.T) { + conn, l1 := s.listen1(t) + defer conn.close() + + ping := &v5wire.Ping{ReqID: make([]byte, 9)} + switch resp := conn.reqresp(l1, ping).(type) { + case *v5wire.Pong: + t.Errorf("PONG response with unknown request ID %x", resp.ReqID) + case *readError: + if resp.err == v5wire.ErrInvalidReqID { + t.Error("response with oversized request ID") + } else if !netutil.IsTimeout(resp.err) { + t.Error(resp) + } + } +} + +// In this test, a session is established from one IP as usual. The session is then reused +// on another IP, which shouldn't work. The remote node should respond with WHOAREYOU for +// the attempt from a different IP. +func (s *Suite) TestPingMultiIP(t *utesting.T) { + conn, l1, l2 := s.listen2(t) + defer conn.close() + + // Create the session on l1. + ping := &v5wire.Ping{ReqID: conn.nextReqID()} + resp := conn.reqresp(l1, ping) + if resp.Kind() != v5wire.PongMsg { + t.Fatal("expected PONG, got", resp) + } + checkPong(t, resp.(*v5wire.Pong), ping, l1) + + // Send on l2. This reuses the session because there is only one codec. + ping2 := &v5wire.Ping{ReqID: conn.nextReqID()} + conn.write(l2, ping2, nil) + switch resp := conn.read(l2).(type) { + case *v5wire.Pong: + t.Fatalf("remote responded to PING from %v for session on IP %v", laddr(l2).IP, laddr(l1).IP) + case *v5wire.Whoareyou: + t.Logf("got WHOAREYOU for new session as expected") + resp.Node = s.Dest + conn.write(l2, ping2, resp) + default: + t.Fatal("expected WHOAREYOU, got", resp) + } + + // Catch the PONG on l2. + switch resp := conn.read(l2).(type) { + case *v5wire.Pong: + checkPong(t, resp, ping2, l2) + default: + t.Fatal("expected PONG, got", resp) + } + + // Try on l1 again. + ping3 := &v5wire.Ping{ReqID: conn.nextReqID()} + conn.write(l1, ping3, nil) + switch resp := conn.read(l1).(type) { + case *v5wire.Pong: + t.Fatalf("remote responded to PING from %v for session on IP %v", laddr(l1).IP, laddr(l2).IP) + case *v5wire.Whoareyou: + t.Logf("got WHOAREYOU for new session as expected") + default: + t.Fatal("expected WHOAREYOU, got", resp) + } +} + +// This test starts a handshake, but doesn't finish it and sends a second ordinary message +// packet instead of a handshake message packet. The remote node should respond with +// another WHOAREYOU challenge for the second packet. +func (s *Suite) TestPingHandshakeInterrupted(t *utesting.T) { + conn, l1 := s.listen1(t) + defer conn.close() + + // First PING triggers challenge. + ping := &v5wire.Ping{ReqID: conn.nextReqID()} + conn.write(l1, ping, nil) + switch resp := conn.read(l1).(type) { + case *v5wire.Whoareyou: + t.Logf("got WHOAREYOU for PING") + default: + t.Fatal("expected WHOAREYOU, got", resp) + } + + // Send second PING. + ping2 := &v5wire.Ping{ReqID: conn.nextReqID()} + switch resp := conn.reqresp(l1, ping2).(type) { + case *v5wire.Pong: + checkPong(t, resp, ping2, l1) + default: + t.Fatal("expected WHOAREYOU, got", resp) + } +} + +// This test sends TALKREQ and expects an empty TALKRESP response. +func (s *Suite) TestTalkRequest(t *utesting.T) { + conn, l1 := s.listen1(t) + defer conn.close() + + // Non-empty request ID. + id := conn.nextReqID() + resp := conn.reqresp(l1, &v5wire.TalkRequest{ReqID: id, Protocol: "test-protocol"}) + switch resp := resp.(type) { + case *v5wire.TalkResponse: + if !bytes.Equal(resp.ReqID, id) { + t.Fatalf("wrong request ID %x in TALKRESP, want %x", resp.ReqID, id) + } + if len(resp.Message) > 0 { + t.Fatalf("non-empty message %x in TALKRESP", resp.Message) + } + default: + t.Fatal("expected TALKRESP, got", resp.Name()) + } + + // Empty request ID. + resp = conn.reqresp(l1, &v5wire.TalkRequest{Protocol: "test-protocol"}) + switch resp := resp.(type) { + case *v5wire.TalkResponse: + if len(resp.ReqID) > 0 { + t.Fatalf("wrong request ID %x in TALKRESP, want empty byte array", resp.ReqID) + } + if len(resp.Message) > 0 { + t.Fatalf("non-empty message %x in TALKRESP", resp.Message) + } + default: + t.Fatal("expected TALKRESP, got", resp.Name()) + } +} + +// This test checks that the remote node returns itself for FINDNODE with distance zero. +func (s *Suite) TestFindnodeZeroDistance(t *utesting.T) { + conn, l1 := s.listen1(t) + defer conn.close() + + nodes, err := conn.findnode(l1, []uint{0}) + if err != nil { + t.Fatal(err) + } + if len(nodes) != 1 { + t.Fatalf("remote returned more than one node for FINDNODE [0]") + } + if nodes[0].ID() != conn.remote.ID() { + t.Errorf("ID of response node is %v, want %v", nodes[0].ID(), conn.remote.ID()) + } +} + +// In this test, multiple nodes ping the node under test. After waiting for them to be +// accepted into the remote table, the test checks that they are returned by FINDNODE. +func (s *Suite) TestFindnodeResults(t *utesting.T) { + // Create bystanders. + nodes := make([]*bystander, 5) + added := make(chan enode.ID, len(nodes)) + for i := range nodes { + nodes[i] = newBystander(t, s, added) + defer nodes[i].close() + } + + // Get them added to the remote table. + timeout := 60 * time.Second + timeoutCh := time.After(timeout) + for count := 0; count < len(nodes); { + select { + case id := <-added: + t.Logf("bystander node %v added to remote table", id) + count++ + case <-timeoutCh: + t.Errorf("remote added %d bystander nodes in %v, need %d to continue", count, timeout, len(nodes)) + t.Logf("this can happen if the node has a non-empty table from previous runs") + return + } + } + t.Logf("all %d bystander nodes were added", len(nodes)) + + // Collect our nodes by distance. + var dists []uint + expect := make(map[enode.ID]*enode.Node) + for _, bn := range nodes { + n := bn.conn.localNode.Node() + expect[n.ID()] = n + d := uint(enode.LogDist(n.ID(), s.Dest.ID())) + if !containsUint(dists, d) { + dists = append(dists, d) + } + } + + // Send FINDNODE for all distances. + conn, l1 := s.listen1(t) + defer conn.close() + foundNodes, err := conn.findnode(l1, dists) + if err != nil { + t.Fatal(err) + } + t.Logf("remote returned %d nodes for distance list %v", len(foundNodes), dists) + for _, n := range foundNodes { + delete(expect, n.ID()) + } + if len(expect) > 0 { + t.Errorf("missing %d nodes in FINDNODE result", len(expect)) + t.Logf("this can happen if the test is run multiple times in quick succession") + t.Logf("and the remote node hasn't removed dead nodes from previous runs yet") + } else { + t.Logf("all %d expected nodes were returned", len(nodes)) + } +} + +// A bystander is a node whose only purpose is filling a spot in the remote table. +type bystander struct { + dest *enode.Node + conn *conn + l net.PacketConn + + addedCh chan enode.ID + done sync.WaitGroup +} + +func newBystander(t *utesting.T, s *Suite, added chan enode.ID) *bystander { + conn, l := s.listen1(t) + conn.setEndpoint(l) // bystander nodes need IP/port to get pinged + bn := &bystander{ + conn: conn, + l: l, + dest: s.Dest, + addedCh: added, + } + bn.done.Add(1) + go bn.loop() + return bn +} + +// id returns the node ID of the bystander. +func (bn *bystander) id() enode.ID { + return bn.conn.localNode.ID() +} + +// close shuts down loop. +func (bn *bystander) close() { + bn.conn.close() + bn.done.Wait() +} + +// loop answers packets from the remote node until quit. +func (bn *bystander) loop() { + defer bn.done.Done() + + var ( + lastPing time.Time + wasAdded bool + ) + for { + // Ping the remote node. + if !wasAdded && time.Since(lastPing) > 10*time.Second { + bn.conn.reqresp(bn.l, &v5wire.Ping{ + ReqID: bn.conn.nextReqID(), + ENRSeq: bn.dest.Seq(), + }) + lastPing = time.Now() + } + // Answer packets. + switch p := bn.conn.read(bn.l).(type) { + case *v5wire.Ping: + bn.conn.write(bn.l, &v5wire.Pong{ + ReqID: p.ReqID, + ENRSeq: bn.conn.localNode.Seq(), + ToIP: bn.dest.IP(), + ToPort: uint16(bn.dest.UDP()), + }, nil) + wasAdded = true + bn.notifyAdded() + case *v5wire.Findnode: + bn.conn.write(bn.l, &v5wire.Nodes{ReqID: p.ReqID, Total: 1}, nil) + wasAdded = true + bn.notifyAdded() + case *v5wire.TalkRequest: + bn.conn.write(bn.l, &v5wire.TalkResponse{ReqID: p.ReqID}, nil) + case *readError: + if !netutil.IsTemporaryError(p.err) { + bn.conn.logf("shutting down: %v", p.err) + return + } + } + } +} + +func (bn *bystander) notifyAdded() { + if bn.addedCh != nil { + bn.addedCh <- bn.id() + bn.addedCh = nil + } +} diff --git a/cmd/devp2p/internal/v5test/framework.go b/cmd/devp2p/internal/v5test/framework.go new file mode 100644 index 000000000000..9eac37520f7b --- /dev/null +++ b/cmd/devp2p/internal/v5test/framework.go @@ -0,0 +1,263 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package v5test + +import ( + "bytes" + "crypto/ecdsa" + "encoding/binary" + "fmt" + "net" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +// readError represents an error during packet reading. +// This exists to facilitate type-switching on the result of conn.read. +type readError struct { + err error +} + +func (p *readError) Kind() byte { return 99 } +func (p *readError) Name() string { return fmt.Sprintf("error: %v", p.err) } +func (p *readError) Error() string { return p.err.Error() } +func (p *readError) Unwrap() error { return p.err } +func (p *readError) RequestID() []byte { return nil } +func (p *readError) SetRequestID([]byte) {} + +// readErrorf creates a readError with the given text. +func readErrorf(format string, args ...interface{}) *readError { + return &readError{fmt.Errorf(format, args...)} +} + +// This is the response timeout used in tests. +const waitTime = 300 * time.Millisecond + +// conn is a connection to the node under test. +type conn struct { + localNode *enode.LocalNode + localKey *ecdsa.PrivateKey + remote *enode.Node + remoteAddr *net.UDPAddr + listeners []net.PacketConn + + log logger + codec *v5wire.Codec + lastRequest v5wire.Packet + lastChallenge *v5wire.Whoareyou + idCounter uint32 +} + +type logger interface { + Logf(string, ...interface{}) +} + +// newConn sets up a connection to the given node. +func newConn(dest *enode.Node, log logger) *conn { + key, err := crypto.GenerateKey() + if err != nil { + panic(err) + } + db, err := enode.OpenDB("") + if err != nil { + panic(err) + } + ln := enode.NewLocalNode(db, key) + + return &conn{ + localKey: key, + localNode: ln, + remote: dest, + remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()}, + codec: v5wire.NewCodec(ln, key, mclock.System{}), + log: log, + } +} + +func (tc *conn) setEndpoint(c net.PacketConn) { + tc.localNode.SetStaticIP(laddr(c).IP) + tc.localNode.SetFallbackUDP(laddr(c).Port) +} + +func (tc *conn) listen(ip string) net.PacketConn { + l, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", ip)) + if err != nil { + panic(err) + } + tc.listeners = append(tc.listeners, l) + return l +} + +// close shuts down all listeners and the local node. +func (tc *conn) close() { + for _, l := range tc.listeners { + l.Close() + } + tc.localNode.Database().Close() +} + +// nextReqID creates a request id. +func (tc *conn) nextReqID() []byte { + id := make([]byte, 4) + tc.idCounter++ + binary.BigEndian.PutUint32(id, tc.idCounter) + return id +} + +// reqresp performs a request/response interaction on the given connection. +// The request is retried if a handshake is requested. +func (tc *conn) reqresp(c net.PacketConn, req v5wire.Packet) v5wire.Packet { + reqnonce := tc.write(c, req, nil) + switch resp := tc.read(c).(type) { + case *v5wire.Whoareyou: + if resp.Nonce != reqnonce { + return readErrorf("wrong nonce %x in WHOAREYOU (want %x)", resp.Nonce[:], reqnonce[:]) + } + resp.Node = tc.remote + tc.write(c, req, resp) + return tc.read(c) + default: + return resp + } +} + +// findnode sends a FINDNODE request and waits for its responses. +func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) { + var ( + findnode = &v5wire.Findnode{ReqID: tc.nextReqID(), Distances: dists} + reqnonce = tc.write(c, findnode, nil) + first = true + total uint8 + results []*enode.Node + ) + for n := 1; n > 0; { + switch resp := tc.read(c).(type) { + case *v5wire.Whoareyou: + // Handle handshake. + if resp.Nonce == reqnonce { + resp.Node = tc.remote + tc.write(c, findnode, resp) + } else { + return nil, fmt.Errorf("unexpected WHOAREYOU (nonce %x), waiting for NODES", resp.Nonce[:]) + } + case *v5wire.Ping: + // Handle ping from remote. + tc.write(c, &v5wire.Pong{ + ReqID: resp.ReqID, + ENRSeq: tc.localNode.Seq(), + }, nil) + case *v5wire.Nodes: + // Got NODES! Check request ID. + if !bytes.Equal(resp.ReqID, findnode.ReqID) { + return nil, fmt.Errorf("NODES response has wrong request id %x", resp.ReqID) + } + // Check total count. It should be greater than one + // and needs to be the same across all responses. + if first { + if resp.Total == 0 || resp.Total > 6 { + return nil, fmt.Errorf("invalid NODES response 'total' %d (not in (0,7))", resp.Total) + } + total = resp.Total + n = int(total) - 1 + first = false + } else { + n-- + if resp.Total != total { + return nil, fmt.Errorf("invalid NODES response 'total' %d (!= %d)", resp.Total, total) + } + } + // Check nodes. + nodes, err := checkRecords(resp.Nodes) + if err != nil { + return nil, fmt.Errorf("invalid node in NODES response: %v", err) + } + results = append(results, nodes...) + default: + return nil, fmt.Errorf("expected NODES, got %v", resp) + } + } + return results, nil +} + +// write sends a packet on the given connection. +func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou) v5wire.Nonce { + packet, nonce, err := tc.codec.Encode(tc.remote.ID(), tc.remoteAddr.String(), p, challenge) + if err != nil { + panic(fmt.Errorf("can't encode %v packet: %v", p.Name(), err)) + } + if _, err := c.WriteTo(packet, tc.remoteAddr); err != nil { + tc.logf("Can't send %s: %v", p.Name(), err) + } else { + tc.logf(">> %s", p.Name()) + } + return nonce +} + +// read waits for an incoming packet on the given connection. +func (tc *conn) read(c net.PacketConn) v5wire.Packet { + buf := make([]byte, 1280) + if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil { + return &readError{err} + } + n, fromAddr, err := c.ReadFrom(buf) + if err != nil { + return &readError{err} + } + _, _, p, err := tc.codec.Decode(buf[:n], fromAddr.String()) + if err != nil { + return &readError{err} + } + tc.logf("<< %s", p.Name()) + return p +} + +// logf prints to the test log. +func (tc *conn) logf(format string, args ...interface{}) { + if tc.log != nil { + tc.log.Logf("(%s) %s", tc.localNode.ID().TerminalString(), fmt.Sprintf(format, args...)) + } +} + +func laddr(c net.PacketConn) *net.UDPAddr { + return c.LocalAddr().(*net.UDPAddr) +} + +func checkRecords(records []*enr.Record) ([]*enode.Node, error) { + nodes := make([]*enode.Node, len(records)) + for i := range records { + n, err := enode.New(enode.ValidSchemes, records[i]) + if err != nil { + return nil, err + } + nodes[i] = n + } + return nodes, nil +} + +func containsUint(ints []uint, x uint) bool { + for i := range ints { + if ints[i] == x { + return true + } + } + return false +} diff --git a/cmd/devp2p/main.go b/cmd/devp2p/main.go index 6064d17c7070..4a4e905a424e 100644 --- a/cmd/devp2p/main.go +++ b/cmd/devp2p/main.go @@ -63,6 +63,7 @@ func init() { discv5Command, dnsCommand, nodesetCommand, + rlpxCommand, } } @@ -80,7 +81,7 @@ func commandHasFlag(ctx *cli.Context, flag cli.Flag) bool { // getNodeArg handles the common case of a single node descriptor argument. func getNodeArg(ctx *cli.Context) *enode.Node { - if ctx.NArg() != 1 { + if ctx.NArg() < 1 { exit("missing node as command-line argument") } n, err := parseNode(ctx.Args()[0]) diff --git a/cmd/devp2p/nodesetcmd.go b/cmd/devp2p/nodesetcmd.go index 228d3319e3bb..ba97405abcd6 100644 --- a/cmd/devp2p/nodesetcmd.go +++ b/cmd/devp2p/nodesetcmd.go @@ -95,6 +95,7 @@ var filterFlags = map[string]nodeFilterC{ "-min-age": {1, minAgeFilter}, "-eth-network": {1, ethFilter}, "-les-server": {0, lesFilter}, + "-snap": {0, snapFilter}, } func parseFilters(args []string) ([]nodeFilter, error) { @@ -104,15 +105,15 @@ func parseFilters(args []string) ([]nodeFilter, error) { if !ok { return nil, fmt.Errorf("invalid filter %q", args[0]) } - if len(args) < fc.narg { - return nil, fmt.Errorf("filter %q wants %d arguments, have %d", args[0], fc.narg, len(args)) + if len(args)-1 < fc.narg { + return nil, fmt.Errorf("filter %q wants %d arguments, have %d", args[0], fc.narg, len(args)-1) } - filter, err := fc.fn(args[1:]) + filter, err := fc.fn(args[1 : 1+fc.narg]) if err != nil { return nil, fmt.Errorf("%s: %v", args[0], err) } filters = append(filters, filter) - args = args[fc.narg+1:] + args = args[1+fc.narg:] } return filters, nil } @@ -191,3 +192,13 @@ func lesFilter(args []string) (nodeFilter, error) { } return f, nil } + +func snapFilter(args []string) (nodeFilter, error) { + f := func(n nodeJSON) bool { + var snap struct { + _ []rlp.RawValue `rlp:"tail"` + } + return n.N.Load(enr.WithEntry("snap", &snap)) == nil + } + return f, nil +} diff --git a/cmd/devp2p/rlpxcmd.go b/cmd/devp2p/rlpxcmd.go new file mode 100644 index 000000000000..d90eb4687cad --- /dev/null +++ b/cmd/devp2p/rlpxcmd.go @@ -0,0 +1,99 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "fmt" + "net" + + "github.com/ethereum/go-ethereum/cmd/devp2p/internal/ethtest" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/rlpx" + "github.com/ethereum/go-ethereum/rlp" + "gopkg.in/urfave/cli.v1" +) + +var ( + rlpxCommand = cli.Command{ + Name: "rlpx", + Usage: "RLPx Commands", + Subcommands: []cli.Command{ + rlpxPingCommand, + rlpxEthTestCommand, + }, + } + rlpxPingCommand = cli.Command{ + Name: "ping", + Usage: "ping ", + Action: rlpxPing, + } + rlpxEthTestCommand = cli.Command{ + Name: "eth-test", + Usage: "Runs tests against a node", + ArgsUsage: " ", + Action: rlpxEthTest, + Flags: []cli.Flag{ + testPatternFlag, + testTAPFlag, + }, + } +) + +func rlpxPing(ctx *cli.Context) error { + n := getNodeArg(ctx) + fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP())) + if err != nil { + return err + } + conn := rlpx.NewConn(fd, n.Pubkey()) + ourKey, _ := crypto.GenerateKey() + _, err = conn.Handshake(ourKey) + if err != nil { + return err + } + code, data, _, err := conn.Read() + if err != nil { + return err + } + switch code { + case 0: + var h ethtest.Hello + if err := rlp.DecodeBytes(data, &h); err != nil { + return fmt.Errorf("invalid handshake: %v", err) + } + fmt.Printf("%+v\n", h) + case 1: + var msg []p2p.DiscReason + if rlp.DecodeBytes(data, &msg); len(msg) == 0 { + return fmt.Errorf("invalid disconnect message") + } + return fmt.Errorf("received disconnect message: %v", msg[0]) + default: + return fmt.Errorf("invalid message code %d, expected handshake (code zero)", code) + } + return nil +} + +// rlpxEthTest runs the eth protocol test suite. +func rlpxEthTest(ctx *cli.Context) error { + if ctx.NArg() < 3 { + exit("missing path to chain.rlp as command-line argument") + } + suite := ethtest.NewSuite(getNodeArg(ctx), ctx.Args()[1], ctx.Args()[2]) + return runTests(ctx, suite.AllTests()) +} diff --git a/cmd/devp2p/runtest.go b/cmd/devp2p/runtest.go new file mode 100644 index 000000000000..4168f8555bfb --- /dev/null +++ b/cmd/devp2p/runtest.go @@ -0,0 +1,69 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "os" + + "github.com/ethereum/go-ethereum/cmd/devp2p/internal/v4test" + "github.com/ethereum/go-ethereum/internal/utesting" + "github.com/ethereum/go-ethereum/log" + "gopkg.in/urfave/cli.v1" +) + +var ( + testPatternFlag = cli.StringFlag{ + Name: "run", + Usage: "Pattern of test suite(s) to run", + } + testTAPFlag = cli.BoolFlag{ + Name: "tap", + Usage: "Output TAP", + } + // These two are specific to the discovery tests. + testListen1Flag = cli.StringFlag{ + Name: "listen1", + Usage: "IP address of the first tester", + Value: v4test.Listen1, + } + testListen2Flag = cli.StringFlag{ + Name: "listen2", + Usage: "IP address of the second tester", + Value: v4test.Listen2, + } +) + +func runTests(ctx *cli.Context, tests []utesting.Test) error { + // Filter test cases. + if ctx.IsSet(testPatternFlag.Name) { + tests = utesting.MatchTests(tests, ctx.String(testPatternFlag.Name)) + } + // Disable logging unless explicitly enabled. + if !ctx.GlobalIsSet("verbosity") && !ctx.GlobalIsSet("vmodule") { + log.Root().SetHandler(log.DiscardHandler()) + } + // Run the tests. + var run = utesting.RunTests + if ctx.Bool(testTAPFlag.Name) { + run = utesting.RunTAP + } + results := run(tests, os.Stdout) + if utesting.CountFailures(results) > 0 { + os.Exit(1) + } + return nil +} diff --git a/cmd/evm/compiler.go b/cmd/evm/compiler.go index c019a2fe70b7..40ad9313c514 100644 --- a/cmd/evm/compiler.go +++ b/cmd/evm/compiler.go @@ -23,7 +23,7 @@ import ( "github.com/ethereum/go-ethereum/cmd/evm/internal/compiler" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) var compileCommand = cli.Command{ diff --git a/cmd/evm/disasm.go b/cmd/evm/disasm.go index 5bc743aa88f7..68a09cbf50cd 100644 --- a/cmd/evm/disasm.go +++ b/cmd/evm/disasm.go @@ -23,7 +23,7 @@ import ( "strings" "github.com/ethereum/go-ethereum/core/asm" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) var disasmCommand = cli.Command{ diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 75586d588b0f..171443e15625 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -110,7 +110,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, txIndex = 0 ) gaspool.AddGas(pre.Env.GasLimit) - vmContext := vm.Context{ + vmContext := vm.BlockContext{ CanTransfer: core.CanTransfer, Transfer: core.Transfer, Coinbase: pre.Env.Coinbase, @@ -119,7 +119,6 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, Difficulty: pre.Env.Difficulty, GasLimit: pre.Env.GasLimit, GetHash: getHash, - // GasPrice and Origin needs to be set per transaction } // If DAO is supported/enabled, we need to handle it here. In geth 'proper', it's // done in StateProcessor.Process(block, ...), right before transactions are applied. @@ -143,10 +142,19 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, vmConfig.Tracer = tracer vmConfig.Debug = (tracer != nil) statedb.Prepare(tx.Hash(), blockHash, txIndex) - vmContext.GasPrice = msg.GasPrice() - vmContext.Origin = msg.From() + txContext := core.NewEVMTxContext(msg) - evm := vm.NewEVM(vmContext, statedb, chainConfig, vmConfig) + evm := vm.NewEVM(vmContext, txContext, statedb, chainConfig, vmConfig) + if chainConfig.IsYoloV2(vmContext.BlockNumber) { + statedb.AddAddressToAccessList(msg.From()) + if dst := msg.To(); dst != nil { + statedb.AddAddressToAccessList(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range evm.ActivePrecompiles() { + statedb.AddAddressToAccessList(addr) + } + } snapshot := statedb.Snapshot() // (ret []byte, usedGas uint64, failed bool, err error) msgResult, err := core.ApplyMessage(evm, msg, gaspool) @@ -175,7 +183,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, receipt.GasUsed = msgResult.UsedGas // if the transaction created a contract, store the creation address in the receipt. if msg.To() == nil { - receipt.ContractAddress = crypto.CreateAddress(evm.Context.Origin, tx.Nonce()) + receipt.ContractAddress = crypto.CreateAddress(evm.TxContext.Origin, tx.Nonce()) } // Set the receipt logs and create a bloom for filtering receipt.Logs = statedb.GetLogs(tx.Hash()) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index d0be6ca1e1e0..4063767cb8f4 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -38,7 +38,7 @@ import ( "github.com/ethereum/go-ethereum/core/vm/runtime" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) var runCommand = cli.Command{ diff --git a/cmd/evm/staterunner.go b/cmd/evm/staterunner.go index f9a6b06b8fe4..c4df936c75ff 100644 --- a/cmd/evm/staterunner.go +++ b/cmd/evm/staterunner.go @@ -28,7 +28,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/tests" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) var stateTestCommand = cli.Command{ diff --git a/cmd/faucet/faucet.go b/cmd/faucet/faucet.go index 7dc5536eba13..008cb142960f 100644 --- a/cmd/faucet/faucet.go +++ b/cmd/faucet/faucet.go @@ -43,6 +43,7 @@ import ( "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts/keystore" + "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" @@ -82,6 +83,8 @@ var ( noauthFlag = flag.Bool("noauth", false, "Enables funding requests without authentication") logFlag = flag.Int("loglevel", 3, "Log level to use for Ethereum and the faucet") + + twitterBearerToken = flag.String("twitter.token", "", "Twitter bearer token to authenticate with the twitter API") ) var ( @@ -241,6 +244,7 @@ func newFaucet(genesis *core.Genesis, port int, enodes []*discv5.Node, network u cfg.SyncMode = downloader.LightSync cfg.NetworkId = network cfg.Genesis = genesis + utils.SetDNSDiscoveryDefaults(&cfg, genesis.ToBlock(nil).Hash()) lesBackend, err := les.New(stack, &cfg) if err != nil { return nil, fmt.Errorf("Failed to register the Ethereum service: %w", err) @@ -441,6 +445,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) { } // Retrieve the Ethereum address to fund, the requesting user and a profile picture var ( + id string username string avatar string address common.Address @@ -460,11 +465,13 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) { } continue case strings.HasPrefix(msg.URL, "https://twitter.com/"): - username, avatar, address, err = authTwitter(msg.URL) + id, username, avatar, address, err = authTwitter(msg.URL, *twitterBearerToken) case strings.HasPrefix(msg.URL, "https://www.facebook.com/"): username, avatar, address, err = authFacebook(msg.URL) + id = username case *noauthFlag: username, avatar, address, err = authNoAuth(msg.URL) + id = username default: //lint:ignore ST1005 This error is to be displayed in the browser err = errors.New("Something funky happened, please open an issue at https://github.com/ethereum/go-ethereum/issues") @@ -484,7 +491,7 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) { fund bool timeout time.Time ) - if timeout = f.timeouts[username]; time.Now().After(timeout) { + if timeout = f.timeouts[id]; time.Now().After(timeout) { // User wasn't funded recently, create the funding transaction amount := new(big.Int).Mul(big.NewInt(int64(*payoutFlag)), ether) amount = new(big.Int).Mul(amount, new(big.Int).Exp(big.NewInt(5), big.NewInt(int64(msg.Tier)), nil)) @@ -509,16 +516,16 @@ func (f *faucet) apiHandler(w http.ResponseWriter, r *http.Request) { } continue } - f.reqs = append(f.reqs, &request{ + f.reqs = append([]*request{{ Avatar: avatar, Account: address, Time: time.Now(), Tx: signed, - }) + }}, f.reqs...) timeout := time.Duration(*minutesFlag*int(math.Pow(3, float64(msg.Tier)))) * time.Minute grace := timeout / 288 // 24h timeout => 5m grace - f.timeouts[username] = time.Now().Add(timeout - grace) + f.timeouts[id] = time.Now().Add(timeout - grace) fund = true } f.lock.Unlock() @@ -682,23 +689,32 @@ func sendSuccess(conn *websocket.Conn, msg string) error { } // authTwitter tries to authenticate a faucet request using Twitter posts, returning -// the username, avatar URL and Ethereum address to fund on success. -func authTwitter(url string) (string, string, common.Address, error) { +// the uniqueness identifier (user id/username), username, avatar URL and Ethereum address to fund on success. +func authTwitter(url string, token string) (string, string, string, common.Address, error) { // Ensure the user specified a meaningful URL, no fancy nonsense parts := strings.Split(url, "/") if len(parts) < 4 || parts[len(parts)-2] != "status" { //lint:ignore ST1005 This error is to be displayed in the browser - return "", "", common.Address{}, errors.New("Invalid Twitter status URL") + return "", "", "", common.Address{}, errors.New("Invalid Twitter status URL") + } + + // Twitter's API isn't really friendly with direct links. + // It is restricted to 300 queries / 15 minute with an app api key. + // Anything more will require read only authorization from the users and that we want to avoid. + + // If twitter bearer token is provided, use the twitter api + if token != "" { + return authTwitterWithToken(parts[len(parts)-1], token) } - // Twitter's API isn't really friendly with direct links. Still, we don't - // want to do ask read permissions from users, so just load the public posts + + // Twiter API token isn't provided so we just load the public posts // and scrape it for the Ethereum address and profile URL. We need to load // the mobile page though since the main page loads tweet contents via JS. url = strings.Replace(url, "https://twitter.com/", "https://mobile.twitter.com/", 1) res, err := http.Get(url) if err != nil { - return "", "", common.Address{}, err + return "", "", "", common.Address{}, err } defer res.Body.Close() @@ -706,31 +722,87 @@ func authTwitter(url string) (string, string, common.Address, error) { parts = strings.Split(res.Request.URL.String(), "/") if len(parts) < 4 || parts[len(parts)-2] != "status" { //lint:ignore ST1005 This error is to be displayed in the browser - return "", "", common.Address{}, errors.New("Invalid Twitter status URL") + return "", "", "", common.Address{}, errors.New("Invalid Twitter status URL") } username := parts[len(parts)-3] body, err := ioutil.ReadAll(res.Body) if err != nil { - return "", "", common.Address{}, err + return "", "", "", common.Address{}, err } address := common.HexToAddress(string(regexp.MustCompile("0x[0-9a-fA-F]{40}").Find(body))) if address == (common.Address{}) { //lint:ignore ST1005 This error is to be displayed in the browser - return "", "", common.Address{}, errors.New("No Ethereum address found to fund") + return "", "", "", common.Address{}, errors.New("No Ethereum address found to fund") } var avatar string if parts = regexp.MustCompile("src=\"([^\"]+twimg.com/profile_images[^\"]+)\"").FindStringSubmatch(string(body)); len(parts) == 2 { avatar = parts[1] } - return username + "@twitter", avatar, address, nil + return username + "@twitter", username, avatar, address, nil +} + +// authTwitterWithToken tries to authenticate a faucet request using Twitter's API, returning +// the uniqueness identifier (user id/username), username, avatar URL and Ethereum address to fund on success. +func authTwitterWithToken(tweetID string, token string) (string, string, string, common.Address, error) { + // Strip any query parameters from the tweet id + sanitizedTweetID := strings.Split(tweetID, "?")[0] + + // Ensure numeric tweetID + if !regexp.MustCompile("^[0-9]+$").MatchString(sanitizedTweetID) { + return "", "", "", common.Address{}, errors.New("Invalid Tweet URL") + } + + // Query the tweet details from Twitter + url := fmt.Sprintf("https://api.twitter.com/2/tweets/%s?expansions=author_id&user.fields=profile_image_url", sanitizedTweetID) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return "", "", "", common.Address{}, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", "", common.Address{}, err + } + defer res.Body.Close() + + var result struct { + Data struct { + AuthorID string `json:"author_id"` + ID string `json:"id"` + Text string `json:"text"` + } `json:"data"` + Includes struct { + Users []struct { + ProfileImageURL string `json:"profile_image_url"` + Username string `json:"username"` + ID string `json:"id"` + Name string `json:"name"` + } `json:"users"` + } `json:"includes"` + } + + err = json.NewDecoder(res.Body).Decode(&result) + if err != nil { + return "", "", "", common.Address{}, err + } + + address := common.HexToAddress(regexp.MustCompile("0x[0-9a-fA-F]{40}").FindString(result.Data.Text)) + if address == (common.Address{}) { + //lint:ignore ST1005 This error is to be displayed in the browser + return "", "", "", common.Address{}, errors.New("No Ethereum address found to fund") + } + return result.Data.AuthorID + "@twitter", result.Includes.Users[0].Username, result.Includes.Users[0].ProfileImageURL, address, nil } // authFacebook tries to authenticate a faucet request using Facebook posts, // returning the username, avatar URL and Ethereum address to fund on success. func authFacebook(url string) (string, string, common.Address, error) { // Ensure the user specified a meaningful URL, no fancy nonsense - parts := strings.Split(url, "/") + parts := strings.Split(strings.Split(url, "?")[0], "/") + if parts[len(parts)-1] == "" { + parts = parts[0 : len(parts)-1] + } if len(parts) < 4 || parts[len(parts)-2] != "posts" { //lint:ignore ST1005 This error is to be displayed in the browser return "", "", common.Address{}, errors.New("Invalid Facebook post URL") diff --git a/cmd/geth/accountcmd_test.go b/cmd/geth/accountcmd_test.go index 6213e5195dea..04f55e9e7a2b 100644 --- a/cmd/geth/accountcmd_test.go +++ b/cmd/geth/accountcmd_test.go @@ -43,13 +43,13 @@ func tmpDatadirWithKeystore(t *testing.T) string { } func TestAccountListEmpty(t *testing.T) { - geth := runGeth(t, "account", "list") + geth := runGeth(t, "--nousb", "account", "list") geth.ExpectExit() } func TestAccountList(t *testing.T) { datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, "account", "list", "--datadir", datadir) + geth := runGeth(t, "--nousb", "account", "list", "--datadir", datadir) defer geth.ExpectExit() if runtime.GOOS == "windows" { geth.Expect(` @@ -102,6 +102,7 @@ func TestAccountImport(t *testing.T) { }, } for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { t.Parallel() importAccountWithExpect(t, test.key, test.output) @@ -138,7 +139,7 @@ Fatal: Passwords do not match func TestAccountUpdate(t *testing.T) { datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, "account", "update", + geth := runGeth(t, "--nousb", "account", "update", "--datadir", datadir, "--lightkdf", "f466859ead1932d743d622cb74fc058882e8648a") defer geth.ExpectExit() @@ -153,7 +154,7 @@ Repeat password: {{.InputLine "foobar2"}} } func TestWalletImport(t *testing.T) { - geth := runGeth(t, "wallet", "import", "--lightkdf", "testdata/guswallet.json") + geth := runGeth(t, "--nousb", "wallet", "import", "--lightkdf", "testdata/guswallet.json") defer geth.ExpectExit() geth.Expect(` !! Unsupported terminal, password will be echoed. @@ -168,7 +169,7 @@ Address: {d4584b5f6229b7be90727b0fc8c6b91bb427821f} } func TestWalletImportBadPassword(t *testing.T) { - geth := runGeth(t, "wallet", "import", "--lightkdf", "testdata/guswallet.json") + geth := runGeth(t, "--nousb", "wallet", "import", "--lightkdf", "testdata/guswallet.json") defer geth.ExpectExit() geth.Expect(` !! Unsupported terminal, password will be echoed. @@ -178,11 +179,8 @@ Fatal: could not decrypt key with given password } func TestUnlockFlag(t *testing.T) { - datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", - "js", "testdata/empty.js") + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "js", "testdata/empty.js") geth.Expect(` Unlocking account f466859ead1932d743d622cb74fc058882e8648a | Attempt 1/3 !! Unsupported terminal, password will be echoed. @@ -202,10 +200,9 @@ Password: {{.InputLine "foobar"}} } func TestUnlockFlagWrongPassword(t *testing.T) { - datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--unlock", "f466859ead1932d743d622cb74fc058882e8648a") + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "js", "testdata/empty.js") + defer geth.ExpectExit() geth.Expect(` Unlocking account f466859ead1932d743d622cb74fc058882e8648a | Attempt 1/3 @@ -221,11 +218,9 @@ Fatal: Failed to unlock account f466859ead1932d743d622cb74fc058882e8648a (could // https://github.com/ethereum/go-ethereum/issues/1785 func TestUnlockFlagMultiIndex(t *testing.T) { - datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--unlock", "0,2", - "js", "testdata/empty.js") + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "--unlock", "0,2", "js", "testdata/empty.js") + geth.Expect(` Unlocking account 0 | Attempt 1/3 !! Unsupported terminal, password will be echoed. @@ -248,11 +243,9 @@ Password: {{.InputLine "foobar"}} } func TestUnlockFlagPasswordFile(t *testing.T) { - datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--password", "testdata/passwords.txt", "--unlock", "0,2", - "js", "testdata/empty.js") + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "--password", "testdata/passwords.txt", "--unlock", "0,2", "js", "testdata/empty.js") + geth.ExpectExit() wantMessages := []string{ @@ -268,10 +261,9 @@ func TestUnlockFlagPasswordFile(t *testing.T) { } func TestUnlockFlagPasswordFileWrongPassword(t *testing.T) { - datadir := tmpDatadirWithKeystore(t) - geth := runGeth(t, - "--datadir", datadir, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--password", "testdata/wrong-passwords.txt", "--unlock", "0,2") + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "--password", + "testdata/wrong-passwords.txt", "--unlock", "0,2") defer geth.ExpectExit() geth.Expect(` Fatal: Failed to unlock account 0 (could not decrypt key with given password) @@ -280,9 +272,9 @@ Fatal: Failed to unlock account 0 (could not decrypt key with given password) func TestUnlockFlagAmbiguous(t *testing.T) { store := filepath.Join("..", "..", "accounts", "keystore", "testdata", "dupes") - geth := runGeth(t, - "--keystore", store, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "--keystore", + store, "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "js", "testdata/empty.js") defer geth.ExpectExit() @@ -318,9 +310,10 @@ In order to avoid this warning, you need to remove the following duplicate key f func TestUnlockFlagAmbiguousWrongPassword(t *testing.T) { store := filepath.Join("..", "..", "accounts", "keystore", "testdata", "dupes") - geth := runGeth(t, - "--keystore", store, "--nat", "none", "--nodiscover", "--maxpeers", "0", "--port", "0", - "--unlock", "f466859ead1932d743d622cb74fc058882e8648a") + geth := runMinimalGeth(t, "--port", "0", "--ipcdisable", "--datadir", tmpDatadirWithKeystore(t), + "--unlock", "f466859ead1932d743d622cb74fc058882e8648a", "--keystore", + store, "--unlock", "f466859ead1932d743d622cb74fc058882e8648a") + defer geth.ExpectExit() // Helper for the expect template, returns absolute keystore path. diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index b98597e307e6..6418f909579a 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -163,7 +163,7 @@ The export-preimages command export hash preimages to an RLP encoded stream`, utils.RinkebyFlag, utils.TxLookupLimitFlag, utils.GoerliFlag, - utils.YoloV1Flag, + utils.YoloV2Flag, utils.LegacyTestnetFlag, }, Category: "BLOCKCHAIN COMMANDS", @@ -213,7 +213,7 @@ Use "ethereum dump 0" to dump the genesis block.`, utils.RopstenFlag, utils.RinkebyFlag, utils.GoerliFlag, - utils.YoloV1Flag, + utils.YoloV2Flag, utils.LegacyTestnetFlag, utils.SyncModeFlag, }, diff --git a/cmd/geth/config.go b/cmd/geth/config.go index 6b51843aa456..bf1fc55b172a 100644 --- a/cmd/geth/config.go +++ b/cmd/geth/config.go @@ -24,7 +24,7 @@ import ( "reflect" "unicode" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/eth" diff --git a/cmd/geth/consolecmd.go b/cmd/geth/consolecmd.go index e2f733f844a4..cbecbe0a5fa7 100644 --- a/cmd/geth/consolecmd.go +++ b/cmd/geth/consolecmd.go @@ -136,8 +136,8 @@ func remoteConsole(ctx *cli.Context) error { path = filepath.Join(path, "rinkeby") } else if ctx.GlobalBool(utils.GoerliFlag.Name) { path = filepath.Join(path, "goerli") - } else if ctx.GlobalBool(utils.YoloV1Flag.Name) { - path = filepath.Join(path, "yolo-v1") + } else if ctx.GlobalBool(utils.YoloV2Flag.Name) { + path = filepath.Join(path, "yolo-v2") } } endpoint = fmt.Sprintf("%s/geth.ipc", path) diff --git a/cmd/geth/consolecmd_test.go b/cmd/geth/consolecmd_test.go index 6c100e18d9ee..b0555c45d7aa 100644 --- a/cmd/geth/consolecmd_test.go +++ b/cmd/geth/consolecmd_test.go @@ -35,16 +35,25 @@ const ( httpAPIs = "eth:1.0 net:1.0 rpc:1.0 web3:1.0" ) +// spawns geth with the given command line args, using a set of flags to minimise +// memory and disk IO. If the args don't set --datadir, the +// child g gets a temporary data directory. +func runMinimalGeth(t *testing.T, args ...string) *testgeth { + // --ropsten to make the 'writing genesis to disk' faster (no accounts) + // --networkid=1337 to avoid cache bump + // --syncmode=full to avoid allocating fast sync bloom + allArgs := []string{"--ropsten", "--nousb", "--networkid", "1337", "--syncmode=full", "--port", "0", + "--nat", "none", "--nodiscover", "--maxpeers", "0", "--cache", "64"} + return runGeth(t, append(allArgs, args...)...) +} + // Tests that a node embedded within a console can be started up properly and // then terminated by closing the input stream. func TestConsoleWelcome(t *testing.T) { coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" // Start a geth console, make sure it's cleaned up and terminate the console - geth := runGeth(t, - "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", - "--etherbase", coinbase, - "console") + geth := runMinimalGeth(t, "--etherbase", coinbase, "console") // Gather all the infos the welcome message needs to contain geth.SetTemplateFunc("goos", func() string { return runtime.GOOS }) @@ -66,16 +75,20 @@ at block: 0 ({{niltime}}) datadir: {{.Datadir}} modules: {{apis}} +To exit, press ctrl-d > {{.InputLine "exit"}} `) geth.ExpectExit() } // Tests that a console can be attached to a running node via various means. -func TestIPCAttachWelcome(t *testing.T) { +func TestAttachWelcome(t *testing.T) { + var ( + ipc string + httpPort string + wsPort string + ) // Configure the instance for IPC attachment - coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" - var ipc string if runtime.GOOS == "windows" { ipc = `\\.\pipe\geth` + strconv.Itoa(trulyRandInt(100000, 999999)) } else { @@ -83,51 +96,28 @@ func TestIPCAttachWelcome(t *testing.T) { defer os.RemoveAll(ws) ipc = filepath.Join(ws, "geth.ipc") } - geth := runGeth(t, - "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", - "--etherbase", coinbase, "--ipcpath", ipc) - - defer func() { - geth.Interrupt() - geth.ExpectExit() - }() - - waitForEndpoint(t, ipc, 3*time.Second) - testAttachWelcome(t, geth, "ipc:"+ipc, ipcAPIs) - -} - -func TestHTTPAttachWelcome(t *testing.T) { - coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" - port := strconv.Itoa(trulyRandInt(1024, 65536)) // Yeah, sometimes this will fail, sorry :P - geth := runGeth(t, - "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", - "--etherbase", coinbase, "--http", "--http.port", port) - defer func() { - geth.Interrupt() - geth.ExpectExit() - }() - - endpoint := "http://127.0.0.1:" + port - waitForEndpoint(t, endpoint, 3*time.Second) - testAttachWelcome(t, geth, endpoint, httpAPIs) -} - -func TestWSAttachWelcome(t *testing.T) { - coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" - port := strconv.Itoa(trulyRandInt(1024, 65536)) // Yeah, sometimes this will fail, sorry :P - - geth := runGeth(t, - "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", - "--etherbase", coinbase, "--ws", "--ws.port", port) - defer func() { - geth.Interrupt() - geth.ExpectExit() - }() - - endpoint := "ws://127.0.0.1:" + port - waitForEndpoint(t, endpoint, 3*time.Second) - testAttachWelcome(t, geth, endpoint, httpAPIs) + // And HTTP + WS attachment + p := trulyRandInt(1024, 65533) // Yeah, sometimes this will fail, sorry :P + httpPort = strconv.Itoa(p) + wsPort = strconv.Itoa(p + 1) + geth := runMinimalGeth(t, "--etherbase", "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182", + "--ipcpath", ipc, + "--http", "--http.port", httpPort, + "--ws", "--ws.port", wsPort) + t.Run("ipc", func(t *testing.T) { + waitForEndpoint(t, ipc, 3*time.Second) + testAttachWelcome(t, geth, "ipc:"+ipc, ipcAPIs) + }) + t.Run("http", func(t *testing.T) { + endpoint := "http://127.0.0.1:" + httpPort + waitForEndpoint(t, endpoint, 3*time.Second) + testAttachWelcome(t, geth, endpoint, httpAPIs) + }) + t.Run("ws", func(t *testing.T) { + endpoint := "ws://127.0.0.1:" + wsPort + waitForEndpoint(t, endpoint, 3*time.Second) + testAttachWelcome(t, geth, endpoint, httpAPIs) + }) } func testAttachWelcome(t *testing.T, geth *testgeth, endpoint, apis string) { @@ -159,6 +149,7 @@ at block: 0 ({{niltime}}){{if ipc}} datadir: {{datadir}}{{end}} modules: {{apis}} +To exit, press ctrl-d > {{.InputLine "exit" }} `) attach.ExpectExit() diff --git a/cmd/geth/dao_test.go b/cmd/geth/dao_test.go index 6c36771e9726..df7f14fdb831 100644 --- a/cmd/geth/dao_test.go +++ b/cmd/geth/dao_test.go @@ -115,10 +115,10 @@ func testDAOForkBlockNewChain(t *testing.T, test int, genesis string, expectBloc if err := ioutil.WriteFile(json, []byte(genesis), 0600); err != nil { t.Fatalf("test %d: failed to write genesis file: %v", test, err) } - runGeth(t, "--datadir", datadir, "init", json).WaitExit() + runGeth(t, "--datadir", datadir, "--nousb", "--networkid", "1337", "init", json).WaitExit() } else { // Force chain initialization - args := []string{"--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--ipcdisable", "--datadir", datadir} + args := []string{"--port", "0", "--nousb", "--networkid", "1337", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--ipcdisable", "--datadir", datadir} runGeth(t, append(args, []string{"--exec", "2+2", "console"}...)...).WaitExit() } // Retrieve the DAO config flag from the database diff --git a/cmd/geth/genesis_test.go b/cmd/geth/genesis_test.go index ee3991acd1bc..0651c32cad00 100644 --- a/cmd/geth/genesis_test.go +++ b/cmd/geth/genesis_test.go @@ -84,7 +84,7 @@ func TestCustomGenesis(t *testing.T) { runGeth(t, "--nousb", "--datadir", datadir, "init", json).WaitExit() // Query the custom genesis block - geth := runGeth(t, "--nousb", + geth := runGeth(t, "--nousb", "--networkid", "1337", "--syncmode=full", "--datadir", datadir, "--maxpeers", "0", "--port", "0", "--nodiscover", "--nat", "none", "--ipcdisable", "--exec", tt.query, "console") diff --git a/cmd/geth/les_test.go b/cmd/geth/les_test.go index 259d4a806737..d2f63ac7bde5 100644 --- a/cmd/geth/les_test.go +++ b/cmd/geth/les_test.go @@ -2,7 +2,12 @@ package main import ( "context" + "fmt" + "os" "path/filepath" + "runtime" + "strings" + "sync/atomic" "testing" "time" @@ -95,24 +100,57 @@ func (g *gethrpc) waitSynced() { } } +// ipcEndpoint resolves an IPC endpoint based on a configured value, taking into +// account the set data folders as well as the designated platform we're currently +// running on. +func ipcEndpoint(ipcPath, datadir string) string { + // On windows we can only use plain top-level pipes + if runtime.GOOS == "windows" { + if strings.HasPrefix(ipcPath, `\\.\pipe\`) { + return ipcPath + } + return `\\.\pipe\` + ipcPath + } + // Resolve names into the data directory full paths otherwise + if filepath.Base(ipcPath) == ipcPath { + if datadir == "" { + return filepath.Join(os.TempDir(), ipcPath) + } + return filepath.Join(datadir, ipcPath) + } + return ipcPath +} + +// nextIPC ensures that each ipc pipe gets a unique name. +// On linux, it works well to use ipc pipes all over the filesystem (in datadirs), +// but windows require pipes to sit in "\\.\pipe\". Therefore, to run several +// nodes simultaneously, we need to distinguish between them, which we do by +// the pipe filename instead of folder. +var nextIPC = uint32(0) + func startGethWithIpc(t *testing.T, name string, args ...string) *gethrpc { - g := &gethrpc{name: name} - args = append([]string{"--networkid=42", "--port=0", "--nousb"}, args...) + ipcName := fmt.Sprintf("geth-%d.ipc", atomic.AddUint32(&nextIPC, 1)) + args = append([]string{"--networkid=42", "--port=0", "--nousb", "--ipcpath", ipcName}, args...) t.Logf("Starting %v with rpc: %v", name, args) - g.geth = runGeth(t, args...) + + g := &gethrpc{ + name: name, + geth: runGeth(t, args...), + } // wait before we can attach to it. TODO: probe for it properly time.Sleep(1 * time.Second) var err error - ipcpath := filepath.Join(g.geth.Datadir, "geth.ipc") - g.rpc, err = rpc.Dial(ipcpath) - if err != nil { - t.Fatalf("%v rpc connect: %v", name, err) + ipcpath := ipcEndpoint(ipcName, g.geth.Datadir) + if g.rpc, err = rpc.Dial(ipcpath); err != nil { + t.Fatalf("%v rpc connect to %v: %v", name, ipcpath, err) } return g } func initGeth(t *testing.T) string { - g := runGeth(t, "--nousb", "--networkid=42", "init", "./testdata/clique.json") + args := []string{"--nousb", "--networkid=42", "init", "./testdata/clique.json"} + t.Logf("Initializing geth: %v ", args) + g := runGeth(t, args...) datadir := g.Datadir g.WaitExit() return datadir @@ -120,15 +158,16 @@ func initGeth(t *testing.T) string { func startLightServer(t *testing.T) *gethrpc { datadir := initGeth(t) - runGeth(t, "--nousb", "--datadir", datadir, "--password", "./testdata/password.txt", "account", "import", "./testdata/key.prv").WaitExit() + t.Logf("Importing keys to geth") + runGeth(t, "--nousb", "--datadir", datadir, "--password", "./testdata/password.txt", "account", "import", "./testdata/key.prv", "--lightkdf").WaitExit() account := "0x02f0d131f1f97aef08aec6e3291b957d9efe7105" - server := startGethWithIpc(t, "lightserver", "--allow-insecure-unlock", "--datadir", datadir, "--password", "./testdata/password.txt", "--unlock", account, "--mine", "--light.serve=100", "--light.maxpeers=1", "--nodiscover", "--nat=extip:127.0.0.1") + server := startGethWithIpc(t, "lightserver", "--allow-insecure-unlock", "--datadir", datadir, "--password", "./testdata/password.txt", "--unlock", account, "--mine", "--light.serve=100", "--light.maxpeers=1", "--nodiscover", "--nat=extip:127.0.0.1", "--verbosity=4") return server } func startClient(t *testing.T, name string) *gethrpc { datadir := initGeth(t) - return startGethWithIpc(t, name, "--datadir", datadir, "--nodiscover", "--syncmode=light", "--nat=extip:127.0.0.1") + return startGethWithIpc(t, name, "--datadir", datadir, "--nodiscover", "--syncmode=light", "--nat=extip:127.0.0.1", "--verbosity=4") } func TestPriorityClient(t *testing.T) { @@ -151,7 +190,7 @@ func TestPriorityClient(t *testing.T) { prioCli := startClient(t, "prioCli") defer prioCli.killAndWait() // 3_000_000_000 once we move to Go 1.13 - tokens := 3000000000 + tokens := uint64(3000000000) lightServer.callRPC(nil, "les_addBalance", prioCli.getNodeInfo().ID, tokens) prioCli.addPeer(lightServer) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 677a19eed2cd..e587a5f86dd6 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -42,7 +42,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/node" gopsutil "github.com/shirou/gopsutil/mem" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) const ( @@ -113,6 +113,7 @@ var ( utils.CacheGCFlag, utils.CacheSnapshotFlag, utils.CacheNoPrefetchFlag, + utils.CachePreimagesFlag, utils.ListenPortFlag, utils.MaxPeersFlag, utils.MaxPendingPeersFlag, @@ -144,7 +145,7 @@ var ( utils.RopstenFlag, utils.RinkebyFlag, utils.GoerliFlag, - utils.YoloV1Flag, + utils.YoloV2Flag, utils.VMEnableDebugFlag, utils.NetworkIdFlag, utils.EthStatsURLFlag, @@ -188,8 +189,8 @@ var ( utils.IPCDisabledFlag, utils.IPCPathFlag, utils.InsecureUnlockAllowedFlag, - utils.RPCGlobalGasCap, - utils.RPCGlobalTxFeeCap, + utils.RPCGlobalGasCapFlag, + utils.RPCGlobalTxFeeCapFlag, } whisperFlags = []cli.Flag{ @@ -241,11 +242,10 @@ func init() { makecacheCommand, makedagCommand, versionCommand, + versionCheckCommand, licenseCommand, // See config.go dumpConfigCommand, - // See retesteth.go - retestethCommand, // See cmd/utils/flags_legacy.go utils.ShowDeprecated, } diff --git a/cmd/geth/misccmd.go b/cmd/geth/misccmd.go index 0e7ee965133c..b347d31d97e1 100644 --- a/cmd/geth/misccmd.go +++ b/cmd/geth/misccmd.go @@ -25,12 +25,23 @@ import ( "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/consensus/ethash" - "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/params" "gopkg.in/urfave/cli.v1" ) var ( + VersionCheckUrlFlag = cli.StringFlag{ + Name: "check.url", + Usage: "URL to use when checking vulnerabilities", + Value: "https://geth.ethereum.org/docs/vulnerabilities/vulnerabilities.json", + } + VersionCheckVersionFlag = cli.StringFlag{ + Name: "check.version", + Usage: "Version to check", + Value: fmt.Sprintf("Geth/v%v/%v-%v/%v", + params.VersionWithCommit(gitCommit, gitDate), + runtime.GOOS, runtime.GOARCH, runtime.Version()), + } makecacheCommand = cli.Command{ Action: utils.MigrateFlags(makecache), Name: "makecache", @@ -65,6 +76,21 @@ Regular users do not need to execute it. Category: "MISCELLANEOUS COMMANDS", Description: ` The output of this command is supposed to be machine-readable. +`, + } + versionCheckCommand = cli.Command{ + Action: utils.MigrateFlags(versionCheck), + Flags: []cli.Flag{ + VersionCheckUrlFlag, + VersionCheckVersionFlag, + }, + Name: "version-check", + Usage: "Checks (online) whether the current version suffers from any known security vulnerabilities", + ArgsUsage: "", + Category: "MISCELLANEOUS COMMANDS", + Description: ` +The version-check command fetches vulnerability-information from https://geth.ethereum.org/docs/vulnerabilities/vulnerabilities.json, +and displays information about any security vulnerabilities that affect the currently executing version. `, } licenseCommand = cli.Command{ @@ -116,7 +142,6 @@ func version(ctx *cli.Context) error { fmt.Println("Git Commit Date:", gitDate) } fmt.Println("Architecture:", runtime.GOARCH) - fmt.Println("Protocol Versions:", eth.ProtocolVersions) fmt.Println("Go Version:", runtime.Version()) fmt.Println("Operating System:", runtime.GOOS) fmt.Printf("GOPATH=%s\n", os.Getenv("GOPATH")) diff --git a/cmd/geth/retesteth.go b/cmd/geth/retesteth.go deleted file mode 100644 index 1d4c15d1e60a..000000000000 --- a/cmd/geth/retesteth.go +++ /dev/null @@ -1,927 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package main - -import ( - "bytes" - "context" - "fmt" - "math/big" - "os" - "os/signal" - "strings" - "time" - - "github.com/ethereum/go-ethereum/cmd/utils" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/common/math" - "github.com/ethereum/go-ethereum/consensus" - "github.com/ethereum/go-ethereum/consensus/ethash" - "github.com/ethereum/go-ethereum/consensus/misc" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/node" - "github.com/ethereum/go-ethereum/params" - "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/rpc" - "github.com/ethereum/go-ethereum/trie" - - cli "gopkg.in/urfave/cli.v1" -) - -var ( - rpcPortFlag = cli.IntFlag{ - Name: "rpcport", - Usage: "HTTP-RPC server listening port", - Value: node.DefaultHTTPPort, - } - retestethCommand = cli.Command{ - Action: utils.MigrateFlags(retesteth), - Name: "retesteth", - Usage: "Launches geth in retesteth mode", - ArgsUsage: "", - Flags: []cli.Flag{rpcPortFlag}, - Category: "MISCELLANEOUS COMMANDS", - Description: `Launches geth in retesteth mode (no database, no network, only retesteth RPC interface)`, - } -) - -type RetestethTestAPI interface { - SetChainParams(ctx context.Context, chainParams ChainParams) (bool, error) - MineBlocks(ctx context.Context, number uint64) (bool, error) - ModifyTimestamp(ctx context.Context, interval uint64) (bool, error) - ImportRawBlock(ctx context.Context, rawBlock hexutil.Bytes) (common.Hash, error) - RewindToBlock(ctx context.Context, number uint64) (bool, error) - GetLogHash(ctx context.Context, txHash common.Hash) (common.Hash, error) -} - -type RetestethEthAPI interface { - SendRawTransaction(ctx context.Context, rawTx hexutil.Bytes) (common.Hash, error) - BlockNumber(ctx context.Context) (uint64, error) - GetBlockByNumber(ctx context.Context, blockNr math.HexOrDecimal64, fullTx bool) (map[string]interface{}, error) - GetBlockByHash(ctx context.Context, blockHash common.Hash, fullTx bool) (map[string]interface{}, error) - GetBalance(ctx context.Context, address common.Address, blockNr math.HexOrDecimal64) (*math.HexOrDecimal256, error) - GetCode(ctx context.Context, address common.Address, blockNr math.HexOrDecimal64) (hexutil.Bytes, error) - GetTransactionCount(ctx context.Context, address common.Address, blockNr math.HexOrDecimal64) (uint64, error) -} - -type RetestethDebugAPI interface { - AccountRange(ctx context.Context, - blockHashOrNumber *math.HexOrDecimal256, txIndex uint64, - addressHash *math.HexOrDecimal256, maxResults uint64, - ) (AccountRangeResult, error) - StorageRangeAt(ctx context.Context, - blockHashOrNumber *math.HexOrDecimal256, txIndex uint64, - address common.Address, - begin *math.HexOrDecimal256, maxResults uint64, - ) (StorageRangeResult, error) -} - -type RetestWeb3API interface { - ClientVersion(ctx context.Context) (string, error) -} - -type RetestethAPI struct { - ethDb ethdb.Database - db state.Database - chainConfig *params.ChainConfig - author common.Address - extraData []byte - genesisHash common.Hash - engine *NoRewardEngine - blockchain *core.BlockChain - txMap map[common.Address]map[uint64]*types.Transaction // Sender -> Nonce -> Transaction - txSenders map[common.Address]struct{} // Set of transaction senders - blockInterval uint64 -} - -type ChainParams struct { - SealEngine string `json:"sealEngine"` - Params CParamsParams `json:"params"` - Genesis CParamsGenesis `json:"genesis"` - Accounts map[common.Address]CParamsAccount `json:"accounts"` -} - -type CParamsParams struct { - AccountStartNonce math.HexOrDecimal64 `json:"accountStartNonce"` - HomesteadForkBlock *math.HexOrDecimal64 `json:"homesteadForkBlock"` - EIP150ForkBlock *math.HexOrDecimal64 `json:"EIP150ForkBlock"` - EIP158ForkBlock *math.HexOrDecimal64 `json:"EIP158ForkBlock"` - DaoHardforkBlock *math.HexOrDecimal64 `json:"daoHardforkBlock"` - ByzantiumForkBlock *math.HexOrDecimal64 `json:"byzantiumForkBlock"` - ConstantinopleForkBlock *math.HexOrDecimal64 `json:"constantinopleForkBlock"` - ConstantinopleFixForkBlock *math.HexOrDecimal64 `json:"constantinopleFixForkBlock"` - IstanbulBlock *math.HexOrDecimal64 `json:"istanbulForkBlock"` - ChainID *math.HexOrDecimal256 `json:"chainID"` - MaximumExtraDataSize math.HexOrDecimal64 `json:"maximumExtraDataSize"` - TieBreakingGas bool `json:"tieBreakingGas"` - MinGasLimit math.HexOrDecimal64 `json:"minGasLimit"` - MaxGasLimit math.HexOrDecimal64 `json:"maxGasLimit"` - GasLimitBoundDivisor math.HexOrDecimal64 `json:"gasLimitBoundDivisor"` - MinimumDifficulty math.HexOrDecimal256 `json:"minimumDifficulty"` - DifficultyBoundDivisor math.HexOrDecimal256 `json:"difficultyBoundDivisor"` - DurationLimit math.HexOrDecimal256 `json:"durationLimit"` - BlockReward math.HexOrDecimal256 `json:"blockReward"` - NetworkID math.HexOrDecimal256 `json:"networkID"` -} - -type CParamsGenesis struct { - Nonce math.HexOrDecimal64 `json:"nonce"` - Difficulty *math.HexOrDecimal256 `json:"difficulty"` - MixHash *math.HexOrDecimal256 `json:"mixHash"` - Author common.Address `json:"author"` - Timestamp math.HexOrDecimal64 `json:"timestamp"` - ParentHash common.Hash `json:"parentHash"` - ExtraData hexutil.Bytes `json:"extraData"` - GasLimit math.HexOrDecimal64 `json:"gasLimit"` -} - -type CParamsAccount struct { - Balance *math.HexOrDecimal256 `json:"balance"` - Precompiled *CPAccountPrecompiled `json:"precompiled"` - Code hexutil.Bytes `json:"code"` - Storage map[string]string `json:"storage"` - Nonce *math.HexOrDecimal64 `json:"nonce"` -} - -type CPAccountPrecompiled struct { - Name string `json:"name"` - StartingBlock math.HexOrDecimal64 `json:"startingBlock"` - Linear *CPAPrecompiledLinear `json:"linear"` -} - -type CPAPrecompiledLinear struct { - Base uint64 `json:"base"` - Word uint64 `json:"word"` -} - -type AccountRangeResult struct { - AddressMap map[common.Hash]common.Address `json:"addressMap"` - NextKey common.Hash `json:"nextKey"` -} - -type StorageRangeResult struct { - Complete bool `json:"complete"` - Storage map[common.Hash]SRItem `json:"storage"` -} - -type SRItem struct { - Key string `json:"key"` - Value string `json:"value"` -} - -type NoRewardEngine struct { - inner consensus.Engine - rewardsOn bool -} - -func (e *NoRewardEngine) Author(header *types.Header) (common.Address, error) { - return e.inner.Author(header) -} - -func (e *NoRewardEngine) VerifyHeader(chain consensus.ChainHeaderReader, header *types.Header, seal bool) error { - return e.inner.VerifyHeader(chain, header, seal) -} - -func (e *NoRewardEngine) VerifyHeaders(chain consensus.ChainHeaderReader, headers []*types.Header, seals []bool) (chan<- struct{}, <-chan error) { - return e.inner.VerifyHeaders(chain, headers, seals) -} - -func (e *NoRewardEngine) VerifyUncles(chain consensus.ChainReader, block *types.Block) error { - return e.inner.VerifyUncles(chain, block) -} - -func (e *NoRewardEngine) VerifySeal(chain consensus.ChainHeaderReader, header *types.Header) error { - return e.inner.VerifySeal(chain, header) -} - -func (e *NoRewardEngine) Prepare(chain consensus.ChainHeaderReader, header *types.Header) error { - return e.inner.Prepare(chain, header) -} - -func (e *NoRewardEngine) accumulateRewards(config *params.ChainConfig, state *state.StateDB, header *types.Header, uncles []*types.Header) { - // Simply touch miner and uncle coinbase accounts - reward := big.NewInt(0) - for _, uncle := range uncles { - state.AddBalance(uncle.Coinbase, reward) - } - state.AddBalance(header.Coinbase, reward) -} - -func (e *NoRewardEngine) Finalize(chain consensus.ChainHeaderReader, header *types.Header, statedb *state.StateDB, txs []*types.Transaction, - uncles []*types.Header) { - if e.rewardsOn { - e.inner.Finalize(chain, header, statedb, txs, uncles) - } else { - e.accumulateRewards(chain.Config(), statedb, header, uncles) - header.Root = statedb.IntermediateRoot(chain.Config().IsEIP158(header.Number)) - } -} - -func (e *NoRewardEngine) FinalizeAndAssemble(chain consensus.ChainHeaderReader, header *types.Header, statedb *state.StateDB, txs []*types.Transaction, - uncles []*types.Header, receipts []*types.Receipt) (*types.Block, error) { - if e.rewardsOn { - return e.inner.FinalizeAndAssemble(chain, header, statedb, txs, uncles, receipts) - } else { - e.accumulateRewards(chain.Config(), statedb, header, uncles) - header.Root = statedb.IntermediateRoot(chain.Config().IsEIP158(header.Number)) - - // Header seems complete, assemble into a block and return - return types.NewBlock(header, txs, uncles, receipts, new(trie.Trie)), nil - } -} - -func (e *NoRewardEngine) Seal(chain consensus.ChainHeaderReader, block *types.Block, results chan<- *types.Block, stop <-chan struct{}) error { - return e.inner.Seal(chain, block, results, stop) -} - -func (e *NoRewardEngine) SealHash(header *types.Header) common.Hash { - return e.inner.SealHash(header) -} - -func (e *NoRewardEngine) CalcDifficulty(chain consensus.ChainHeaderReader, time uint64, parent *types.Header) *big.Int { - return e.inner.CalcDifficulty(chain, time, parent) -} - -func (e *NoRewardEngine) APIs(chain consensus.ChainHeaderReader) []rpc.API { - return e.inner.APIs(chain) -} - -func (e *NoRewardEngine) Close() error { - return e.inner.Close() -} - -func (api *RetestethAPI) SetChainParams(ctx context.Context, chainParams ChainParams) (bool, error) { - // Clean up - if api.blockchain != nil { - api.blockchain.Stop() - } - if api.engine != nil { - api.engine.Close() - } - if api.ethDb != nil { - api.ethDb.Close() - } - ethDb := rawdb.NewMemoryDatabase() - accounts := make(core.GenesisAlloc) - for address, account := range chainParams.Accounts { - balance := big.NewInt(0) - if account.Balance != nil { - balance.Set((*big.Int)(account.Balance)) - } - var nonce uint64 - if account.Nonce != nil { - nonce = uint64(*account.Nonce) - } - if account.Precompiled == nil || account.Balance != nil { - storage := make(map[common.Hash]common.Hash) - for k, v := range account.Storage { - storage[common.HexToHash(k)] = common.HexToHash(v) - } - accounts[address] = core.GenesisAccount{ - Balance: balance, - Code: account.Code, - Nonce: nonce, - Storage: storage, - } - } - } - chainId := big.NewInt(1) - if chainParams.Params.ChainID != nil { - chainId.Set((*big.Int)(chainParams.Params.ChainID)) - } - var ( - homesteadBlock *big.Int - daoForkBlock *big.Int - eip150Block *big.Int - eip155Block *big.Int - eip158Block *big.Int - byzantiumBlock *big.Int - constantinopleBlock *big.Int - petersburgBlock *big.Int - istanbulBlock *big.Int - ) - if chainParams.Params.HomesteadForkBlock != nil { - homesteadBlock = big.NewInt(int64(*chainParams.Params.HomesteadForkBlock)) - } - if chainParams.Params.DaoHardforkBlock != nil { - daoForkBlock = big.NewInt(int64(*chainParams.Params.DaoHardforkBlock)) - } - if chainParams.Params.EIP150ForkBlock != nil { - eip150Block = big.NewInt(int64(*chainParams.Params.EIP150ForkBlock)) - } - if chainParams.Params.EIP158ForkBlock != nil { - eip158Block = big.NewInt(int64(*chainParams.Params.EIP158ForkBlock)) - eip155Block = eip158Block - } - if chainParams.Params.ByzantiumForkBlock != nil { - byzantiumBlock = big.NewInt(int64(*chainParams.Params.ByzantiumForkBlock)) - } - if chainParams.Params.ConstantinopleForkBlock != nil { - constantinopleBlock = big.NewInt(int64(*chainParams.Params.ConstantinopleForkBlock)) - } - if chainParams.Params.ConstantinopleFixForkBlock != nil { - petersburgBlock = big.NewInt(int64(*chainParams.Params.ConstantinopleFixForkBlock)) - } - if constantinopleBlock != nil && petersburgBlock == nil { - petersburgBlock = big.NewInt(100000000000) - } - if chainParams.Params.IstanbulBlock != nil { - istanbulBlock = big.NewInt(int64(*chainParams.Params.IstanbulBlock)) - } - - genesis := &core.Genesis{ - Config: ¶ms.ChainConfig{ - ChainID: chainId, - HomesteadBlock: homesteadBlock, - DAOForkBlock: daoForkBlock, - DAOForkSupport: true, - EIP150Block: eip150Block, - EIP155Block: eip155Block, - EIP158Block: eip158Block, - ByzantiumBlock: byzantiumBlock, - ConstantinopleBlock: constantinopleBlock, - PetersburgBlock: petersburgBlock, - IstanbulBlock: istanbulBlock, - }, - Nonce: uint64(chainParams.Genesis.Nonce), - Timestamp: uint64(chainParams.Genesis.Timestamp), - ExtraData: chainParams.Genesis.ExtraData, - GasLimit: uint64(chainParams.Genesis.GasLimit), - Difficulty: big.NewInt(0).Set((*big.Int)(chainParams.Genesis.Difficulty)), - Mixhash: common.BigToHash((*big.Int)(chainParams.Genesis.MixHash)), - Coinbase: chainParams.Genesis.Author, - ParentHash: chainParams.Genesis.ParentHash, - Alloc: accounts, - } - chainConfig, genesisHash, err := core.SetupGenesisBlock(ethDb, genesis) - if err != nil { - return false, err - } - fmt.Printf("Chain config: %v\n", chainConfig) - - var inner consensus.Engine - switch chainParams.SealEngine { - case "NoProof", "NoReward": - inner = ethash.NewFaker() - case "Ethash": - inner = ethash.New(ethash.Config{ - CacheDir: "ethash", - CachesInMem: 2, - CachesOnDisk: 3, - CachesLockMmap: false, - DatasetsInMem: 1, - DatasetsOnDisk: 2, - DatasetsLockMmap: false, - }, nil, false) - default: - return false, fmt.Errorf("unrecognised seal engine: %s", chainParams.SealEngine) - } - engine := &NoRewardEngine{inner: inner, rewardsOn: chainParams.SealEngine != "NoReward"} - - blockchain, err := core.NewBlockChain(ethDb, nil, chainConfig, engine, vm.Config{}, nil, nil) - if err != nil { - return false, err - } - - api.chainConfig = chainConfig - api.genesisHash = genesisHash - api.author = chainParams.Genesis.Author - api.extraData = chainParams.Genesis.ExtraData - api.ethDb = ethDb - api.engine = engine - api.blockchain = blockchain - api.db = state.NewDatabase(api.ethDb) - api.txMap = make(map[common.Address]map[uint64]*types.Transaction) - api.txSenders = make(map[common.Address]struct{}) - api.blockInterval = 0 - return true, nil -} - -func (api *RetestethAPI) SendRawTransaction(ctx context.Context, rawTx hexutil.Bytes) (common.Hash, error) { - tx := new(types.Transaction) - if err := rlp.DecodeBytes(rawTx, tx); err != nil { - // Return nil is not by mistake - some tests include sending transaction where gasLimit overflows uint64 - return common.Hash{}, nil - } - signer := types.MakeSigner(api.chainConfig, big.NewInt(int64(api.currentNumber()))) - sender, err := types.Sender(signer, tx) - if err != nil { - return common.Hash{}, err - } - if nonceMap, ok := api.txMap[sender]; ok { - nonceMap[tx.Nonce()] = tx - } else { - nonceMap = make(map[uint64]*types.Transaction) - nonceMap[tx.Nonce()] = tx - api.txMap[sender] = nonceMap - } - api.txSenders[sender] = struct{}{} - return tx.Hash(), nil -} - -func (api *RetestethAPI) MineBlocks(ctx context.Context, number uint64) (bool, error) { - for i := 0; i < int(number); i++ { - if err := api.mineBlock(); err != nil { - return false, err - } - } - fmt.Printf("Mined %d blocks\n", number) - return true, nil -} - -func (api *RetestethAPI) currentNumber() uint64 { - if current := api.blockchain.CurrentBlock(); current != nil { - return current.NumberU64() - } - return 0 -} - -func (api *RetestethAPI) mineBlock() error { - number := api.currentNumber() - parentHash := rawdb.ReadCanonicalHash(api.ethDb, number) - parent := rawdb.ReadBlock(api.ethDb, parentHash, number) - var timestamp uint64 - if api.blockInterval == 0 { - timestamp = uint64(time.Now().Unix()) - } else { - timestamp = parent.Time() + api.blockInterval - } - gasLimit := core.CalcGasLimit(parent, 9223372036854775807, 9223372036854775807) - header := &types.Header{ - ParentHash: parent.Hash(), - Number: big.NewInt(int64(number + 1)), - GasLimit: gasLimit, - Extra: api.extraData, - Time: timestamp, - } - header.Coinbase = api.author - if api.engine != nil { - api.engine.Prepare(api.blockchain, header) - } - // If we are care about TheDAO hard-fork check whether to override the extra-data or not - if daoBlock := api.chainConfig.DAOForkBlock; daoBlock != nil { - // Check whether the block is among the fork extra-override range - limit := new(big.Int).Add(daoBlock, params.DAOForkExtraRange) - if header.Number.Cmp(daoBlock) >= 0 && header.Number.Cmp(limit) < 0 { - // Depending whether we support or oppose the fork, override differently - if api.chainConfig.DAOForkSupport { - header.Extra = common.CopyBytes(params.DAOForkBlockExtra) - } else if bytes.Equal(header.Extra, params.DAOForkBlockExtra) { - header.Extra = []byte{} // If miner opposes, don't let it use the reserved extra-data - } - } - } - statedb, err := api.blockchain.StateAt(parent.Root()) - if err != nil { - return err - } - if api.chainConfig.DAOForkSupport && api.chainConfig.DAOForkBlock != nil && api.chainConfig.DAOForkBlock.Cmp(header.Number) == 0 { - misc.ApplyDAOHardFork(statedb) - } - gasPool := new(core.GasPool).AddGas(header.GasLimit) - txCount := 0 - var txs []*types.Transaction - var receipts []*types.Receipt - var blockFull = gasPool.Gas() < params.TxGas - for address := range api.txSenders { - if blockFull { - break - } - m := api.txMap[address] - for nonce := statedb.GetNonce(address); ; nonce++ { - if tx, ok := m[nonce]; ok { - // Try to apply transactions to the state - statedb.Prepare(tx.Hash(), common.Hash{}, txCount) - snap := statedb.Snapshot() - - receipt, err := core.ApplyTransaction( - api.chainConfig, - api.blockchain, - &api.author, - gasPool, - statedb, - header, tx, &header.GasUsed, *api.blockchain.GetVMConfig(), - ) - if err != nil { - statedb.RevertToSnapshot(snap) - break - } - txs = append(txs, tx) - receipts = append(receipts, receipt) - delete(m, nonce) - if len(m) == 0 { - // Last tx for the sender - delete(api.txMap, address) - delete(api.txSenders, address) - } - txCount++ - if gasPool.Gas() < params.TxGas { - blockFull = true - break - } - } else { - break // Gap in the nonces - } - } - } - block, err := api.engine.FinalizeAndAssemble(api.blockchain, header, statedb, txs, []*types.Header{}, receipts) - if err != nil { - return err - } - return api.importBlock(block) -} - -func (api *RetestethAPI) importBlock(block *types.Block) error { - if _, err := api.blockchain.InsertChain([]*types.Block{block}); err != nil { - return err - } - fmt.Printf("Imported block %d, head is %d\n", block.NumberU64(), api.currentNumber()) - return nil -} - -func (api *RetestethAPI) ModifyTimestamp(ctx context.Context, interval uint64) (bool, error) { - api.blockInterval = interval - return true, nil -} - -func (api *RetestethAPI) ImportRawBlock(ctx context.Context, rawBlock hexutil.Bytes) (common.Hash, error) { - block := new(types.Block) - if err := rlp.DecodeBytes(rawBlock, block); err != nil { - return common.Hash{}, err - } - fmt.Printf("Importing block %d with parent hash: %x, genesisHash: %x\n", block.NumberU64(), block.ParentHash(), api.genesisHash) - if err := api.importBlock(block); err != nil { - return common.Hash{}, err - } - return block.Hash(), nil -} - -func (api *RetestethAPI) RewindToBlock(ctx context.Context, newHead uint64) (bool, error) { - if err := api.blockchain.SetHead(newHead); err != nil { - return false, err - } - // When we rewind, the transaction pool should be cleaned out. - api.txMap = make(map[common.Address]map[uint64]*types.Transaction) - api.txSenders = make(map[common.Address]struct{}) - return true, nil -} - -var emptyListHash common.Hash = common.HexToHash("0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") - -func (api *RetestethAPI) GetLogHash(ctx context.Context, txHash common.Hash) (common.Hash, error) { - receipt, _, _, _ := rawdb.ReadReceipt(api.ethDb, txHash, api.chainConfig) - if receipt == nil { - return emptyListHash, nil - } else { - if logListRlp, err := rlp.EncodeToBytes(receipt.Logs); err != nil { - return common.Hash{}, err - } else { - return common.BytesToHash(crypto.Keccak256(logListRlp)), nil - } - } -} - -func (api *RetestethAPI) BlockNumber(ctx context.Context) (uint64, error) { - return api.currentNumber(), nil -} - -func (api *RetestethAPI) GetBlockByNumber(ctx context.Context, blockNr math.HexOrDecimal64, fullTx bool) (map[string]interface{}, error) { - block := api.blockchain.GetBlockByNumber(uint64(blockNr)) - if block != nil { - response, err := RPCMarshalBlock(block, true, fullTx) - if err != nil { - return nil, err - } - response["author"] = response["miner"] - response["totalDifficulty"] = (*hexutil.Big)(api.blockchain.GetTd(block.Hash(), uint64(blockNr))) - return response, err - } - return nil, fmt.Errorf("block %d not found", blockNr) -} - -func (api *RetestethAPI) GetBlockByHash(ctx context.Context, blockHash common.Hash, fullTx bool) (map[string]interface{}, error) { - block := api.blockchain.GetBlockByHash(blockHash) - if block != nil { - response, err := RPCMarshalBlock(block, true, fullTx) - if err != nil { - return nil, err - } - response["author"] = response["miner"] - response["totalDifficulty"] = (*hexutil.Big)(api.blockchain.GetTd(block.Hash(), block.Number().Uint64())) - return response, err - } - return nil, fmt.Errorf("block 0x%x not found", blockHash) -} - -func (api *RetestethAPI) AccountRange(ctx context.Context, - blockHashOrNumber *math.HexOrDecimal256, txIndex uint64, - addressHash *math.HexOrDecimal256, maxResults uint64, -) (AccountRangeResult, error) { - var ( - header *types.Header - block *types.Block - ) - if (*big.Int)(blockHashOrNumber).Cmp(big.NewInt(math.MaxInt64)) > 0 { - blockHash := common.BigToHash((*big.Int)(blockHashOrNumber)) - header = api.blockchain.GetHeaderByHash(blockHash) - block = api.blockchain.GetBlockByHash(blockHash) - //fmt.Printf("Account range: %x, txIndex %d, start: %x, maxResults: %d\n", blockHash, txIndex, common.BigToHash((*big.Int)(addressHash)), maxResults) - } else { - blockNumber := (*big.Int)(blockHashOrNumber).Uint64() - header = api.blockchain.GetHeaderByNumber(blockNumber) - block = api.blockchain.GetBlockByNumber(blockNumber) - //fmt.Printf("Account range: %d, txIndex %d, start: %x, maxResults: %d\n", blockNumber, txIndex, common.BigToHash((*big.Int)(addressHash)), maxResults) - } - parentHeader := api.blockchain.GetHeaderByHash(header.ParentHash) - var root common.Hash - var statedb *state.StateDB - var err error - if parentHeader == nil || int(txIndex) >= len(block.Transactions()) { - root = header.Root - statedb, err = api.blockchain.StateAt(root) - if err != nil { - return AccountRangeResult{}, err - } - } else { - root = parentHeader.Root - statedb, err = api.blockchain.StateAt(root) - if err != nil { - return AccountRangeResult{}, err - } - // Recompute transactions up to the target index. - signer := types.MakeSigner(api.blockchain.Config(), block.Number()) - for idx, tx := range block.Transactions() { - // Assemble the transaction call message and return if the requested offset - msg, _ := tx.AsMessage(signer) - context := core.NewEVMContext(msg, block.Header(), api.blockchain, nil) - // Not yet the searched for transaction, execute on top of the current state - vmenv := vm.NewEVM(context, statedb, api.blockchain.Config(), vm.Config{}) - if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())); err != nil { - return AccountRangeResult{}, fmt.Errorf("transaction %#x failed: %v", tx.Hash(), err) - } - // Ensure any modifications are committed to the state - // Only delete empty objects if EIP158/161 (a.k.a Spurious Dragon) is in effect - root = statedb.IntermediateRoot(vmenv.ChainConfig().IsEIP158(block.Number())) - if idx == int(txIndex) { - // This is to make sure root can be opened by OpenTrie - root, err = statedb.Commit(api.chainConfig.IsEIP158(block.Number())) - if err != nil { - return AccountRangeResult{}, err - } - break - } - } - } - accountTrie, err := statedb.Database().OpenTrie(root) - if err != nil { - return AccountRangeResult{}, err - } - it := trie.NewIterator(accountTrie.NodeIterator(common.BigToHash((*big.Int)(addressHash)).Bytes())) - result := AccountRangeResult{AddressMap: make(map[common.Hash]common.Address)} - for i := 0; i < int(maxResults) && it.Next(); i++ { - if preimage := accountTrie.GetKey(it.Key); preimage != nil { - result.AddressMap[common.BytesToHash(it.Key)] = common.BytesToAddress(preimage) - } - } - //fmt.Printf("Number of entries returned: %d\n", len(result.AddressMap)) - // Add the 'next key' so clients can continue downloading. - if it.Next() { - next := common.BytesToHash(it.Key) - result.NextKey = next - } - return result, nil -} - -func (api *RetestethAPI) GetBalance(ctx context.Context, address common.Address, blockNr math.HexOrDecimal64) (*math.HexOrDecimal256, error) { - //fmt.Printf("GetBalance %x, block %d\n", address, blockNr) - header := api.blockchain.GetHeaderByNumber(uint64(blockNr)) - statedb, err := api.blockchain.StateAt(header.Root) - if err != nil { - return nil, err - } - return (*math.HexOrDecimal256)(statedb.GetBalance(address)), nil -} - -func (api *RetestethAPI) GetCode(ctx context.Context, address common.Address, blockNr math.HexOrDecimal64) (hexutil.Bytes, error) { - header := api.blockchain.GetHeaderByNumber(uint64(blockNr)) - statedb, err := api.blockchain.StateAt(header.Root) - if err != nil { - return nil, err - } - return statedb.GetCode(address), nil -} - -func (api *RetestethAPI) GetTransactionCount(ctx context.Context, address common.Address, blockNr math.HexOrDecimal64) (uint64, error) { - header := api.blockchain.GetHeaderByNumber(uint64(blockNr)) - statedb, err := api.blockchain.StateAt(header.Root) - if err != nil { - return 0, err - } - return statedb.GetNonce(address), nil -} - -func (api *RetestethAPI) StorageRangeAt(ctx context.Context, - blockHashOrNumber *math.HexOrDecimal256, txIndex uint64, - address common.Address, - begin *math.HexOrDecimal256, maxResults uint64, -) (StorageRangeResult, error) { - var ( - header *types.Header - block *types.Block - ) - if (*big.Int)(blockHashOrNumber).Cmp(big.NewInt(math.MaxInt64)) > 0 { - blockHash := common.BigToHash((*big.Int)(blockHashOrNumber)) - header = api.blockchain.GetHeaderByHash(blockHash) - block = api.blockchain.GetBlockByHash(blockHash) - //fmt.Printf("Storage range: %x, txIndex %d, addr: %x, start: %x, maxResults: %d\n", - // blockHash, txIndex, address, common.BigToHash((*big.Int)(begin)), maxResults) - } else { - blockNumber := (*big.Int)(blockHashOrNumber).Uint64() - header = api.blockchain.GetHeaderByNumber(blockNumber) - block = api.blockchain.GetBlockByNumber(blockNumber) - //fmt.Printf("Storage range: %d, txIndex %d, addr: %x, start: %x, maxResults: %d\n", - // blockNumber, txIndex, address, common.BigToHash((*big.Int)(begin)), maxResults) - } - parentHeader := api.blockchain.GetHeaderByHash(header.ParentHash) - var root common.Hash - var statedb *state.StateDB - var err error - if parentHeader == nil || int(txIndex) >= len(block.Transactions()) { - root = header.Root - statedb, err = api.blockchain.StateAt(root) - if err != nil { - return StorageRangeResult{}, err - } - } else { - root = parentHeader.Root - statedb, err = api.blockchain.StateAt(root) - if err != nil { - return StorageRangeResult{}, err - } - // Recompute transactions up to the target index. - signer := types.MakeSigner(api.blockchain.Config(), block.Number()) - for idx, tx := range block.Transactions() { - // Assemble the transaction call message and return if the requested offset - msg, _ := tx.AsMessage(signer) - context := core.NewEVMContext(msg, block.Header(), api.blockchain, nil) - // Not yet the searched for transaction, execute on top of the current state - vmenv := vm.NewEVM(context, statedb, api.blockchain.Config(), vm.Config{}) - if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())); err != nil { - return StorageRangeResult{}, fmt.Errorf("transaction %#x failed: %v", tx.Hash(), err) - } - // Ensure any modifications are committed to the state - // Only delete empty objects if EIP158/161 (a.k.a Spurious Dragon) is in effect - _ = statedb.IntermediateRoot(vmenv.ChainConfig().IsEIP158(block.Number())) - if idx == int(txIndex) { - // This is to make sure root can be opened by OpenTrie - _, err = statedb.Commit(vmenv.ChainConfig().IsEIP158(block.Number())) - if err != nil { - return StorageRangeResult{}, err - } - } - } - } - storageTrie := statedb.StorageTrie(address) - it := trie.NewIterator(storageTrie.NodeIterator(common.BigToHash((*big.Int)(begin)).Bytes())) - result := StorageRangeResult{Storage: make(map[common.Hash]SRItem)} - for i := 0; /*i < int(maxResults) && */ it.Next(); i++ { - if preimage := storageTrie.GetKey(it.Key); preimage != nil { - key := (*math.HexOrDecimal256)(big.NewInt(0).SetBytes(preimage)) - v, _, err := rlp.SplitString(it.Value) - if err != nil { - return StorageRangeResult{}, err - } - value := (*math.HexOrDecimal256)(big.NewInt(0).SetBytes(v)) - ks, _ := key.MarshalText() - vs, _ := value.MarshalText() - if len(ks)%2 != 0 { - ks = append(append(append([]byte{}, ks[:2]...), byte('0')), ks[2:]...) - } - if len(vs)%2 != 0 { - vs = append(append(append([]byte{}, vs[:2]...), byte('0')), vs[2:]...) - } - result.Storage[common.BytesToHash(it.Key)] = SRItem{ - Key: string(ks), - Value: string(vs), - } - } - } - if it.Next() { - result.Complete = false - } else { - result.Complete = true - } - return result, nil -} - -func (api *RetestethAPI) ClientVersion(ctx context.Context) (string, error) { - return "Geth-" + params.VersionWithCommit(gitCommit, gitDate), nil -} - -// splitAndTrim splits input separated by a comma -// and trims excessive white space from the substrings. -func splitAndTrim(input string) []string { - result := strings.Split(input, ",") - for i, r := range result { - result[i] = strings.TrimSpace(r) - } - return result -} - -func retesteth(ctx *cli.Context) error { - log.Info("Welcome to retesteth!") - // register signer API with server - var ( - extapiURL string - ) - apiImpl := &RetestethAPI{} - var testApi RetestethTestAPI = apiImpl - var ethApi RetestethEthAPI = apiImpl - var debugApi RetestethDebugAPI = apiImpl - var web3Api RetestWeb3API = apiImpl - rpcAPI := []rpc.API{ - { - Namespace: "test", - Public: true, - Service: testApi, - Version: "1.0", - }, - { - Namespace: "eth", - Public: true, - Service: ethApi, - Version: "1.0", - }, - { - Namespace: "debug", - Public: true, - Service: debugApi, - Version: "1.0", - }, - { - Namespace: "web3", - Public: true, - Service: web3Api, - Version: "1.0", - }, - } - vhosts := splitAndTrim(ctx.GlobalString(utils.HTTPVirtualHostsFlag.Name)) - cors := splitAndTrim(ctx.GlobalString(utils.HTTPCORSDomainFlag.Name)) - - // register apis and create handler stack - srv := rpc.NewServer() - err := node.RegisterApisFromWhitelist(rpcAPI, []string{"test", "eth", "debug", "web3"}, srv, false) - if err != nil { - utils.Fatalf("Could not register RPC apis: %w", err) - } - handler := node.NewHTTPHandlerStack(srv, cors, vhosts) - - // start http server - var RetestethHTTPTimeouts = rpc.HTTPTimeouts{ - ReadTimeout: 120 * time.Second, - WriteTimeout: 120 * time.Second, - IdleTimeout: 120 * time.Second, - } - httpEndpoint := fmt.Sprintf("%s:%d", ctx.GlobalString(utils.HTTPListenAddrFlag.Name), ctx.Int(rpcPortFlag.Name)) - httpServer, _, err := node.StartHTTPEndpoint(httpEndpoint, RetestethHTTPTimeouts, handler) - if err != nil { - utils.Fatalf("Could not start RPC api: %v", err) - } - extapiURL = fmt.Sprintf("http://%s", httpEndpoint) - log.Info("HTTP endpoint opened", "url", extapiURL) - - defer func() { - // Don't bother imposing a timeout here. - httpServer.Shutdown(context.Background()) - log.Info("HTTP endpoint closed", "url", httpEndpoint) - }() - - abortChan := make(chan os.Signal, 11) - signal.Notify(abortChan, os.Interrupt) - - sig := <-abortChan - log.Info("Exiting...", "signal", sig) - return nil -} diff --git a/cmd/geth/retesteth_copypaste.go b/cmd/geth/retesteth_copypaste.go deleted file mode 100644 index e2795af7f968..000000000000 --- a/cmd/geth/retesteth_copypaste.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package main - -import ( - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/core/types" -) - -// RPCTransaction represents a transaction that will serialize to the RPC representation of a transaction -type RPCTransaction struct { - BlockHash common.Hash `json:"blockHash"` - BlockNumber *hexutil.Big `json:"blockNumber"` - From common.Address `json:"from"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice *hexutil.Big `json:"gasPrice"` - Hash common.Hash `json:"hash"` - Input hexutil.Bytes `json:"input"` - Nonce hexutil.Uint64 `json:"nonce"` - To *common.Address `json:"to"` - TransactionIndex hexutil.Uint `json:"transactionIndex"` - Value *hexutil.Big `json:"value"` - V *hexutil.Big `json:"v"` - R *hexutil.Big `json:"r"` - S *hexutil.Big `json:"s"` -} - -// newRPCTransaction returns a transaction that will serialize to the RPC -// representation, with the given location metadata set (if available). -func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber uint64, index uint64) *RPCTransaction { - var signer types.Signer = types.FrontierSigner{} - if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) - } - from, _ := types.Sender(signer, tx) - v, r, s := tx.RawSignatureValues() - - result := &RPCTransaction{ - From: from, - Gas: hexutil.Uint64(tx.Gas()), - GasPrice: (*hexutil.Big)(tx.GasPrice()), - Hash: tx.Hash(), - Input: hexutil.Bytes(tx.Data()), - Nonce: hexutil.Uint64(tx.Nonce()), - To: tx.To(), - Value: (*hexutil.Big)(tx.Value()), - V: (*hexutil.Big)(v), - R: (*hexutil.Big)(r), - S: (*hexutil.Big)(s), - } - if blockHash != (common.Hash{}) { - result.BlockHash = blockHash - result.BlockNumber = (*hexutil.Big)(new(big.Int).SetUint64(blockNumber)) - result.TransactionIndex = hexutil.Uint(index) - } - return result -} - -// newRPCTransactionFromBlockIndex returns a transaction that will serialize to the RPC representation. -func newRPCTransactionFromBlockIndex(b *types.Block, index uint64) *RPCTransaction { - txs := b.Transactions() - if index >= uint64(len(txs)) { - return nil - } - return newRPCTransaction(txs[index], b.Hash(), b.NumberU64(), index) -} - -// newRPCTransactionFromBlockHash returns a transaction that will serialize to the RPC representation. -func newRPCTransactionFromBlockHash(b *types.Block, hash common.Hash) *RPCTransaction { - for idx, tx := range b.Transactions() { - if tx.Hash() == hash { - return newRPCTransactionFromBlockIndex(b, uint64(idx)) - } - } - return nil -} - -// RPCMarshalBlock converts the given block to the RPC output which depends on fullTx. If inclTx is true transactions are -// returned. When fullTx is true the returned block contains full transaction details, otherwise it will only contain -// transaction hashes. -func RPCMarshalBlock(b *types.Block, inclTx bool, fullTx bool) (map[string]interface{}, error) { - head := b.Header() // copies the header once - fields := map[string]interface{}{ - "number": (*hexutil.Big)(head.Number), - "hash": b.Hash(), - "parentHash": head.ParentHash, - "nonce": head.Nonce, - "mixHash": head.MixDigest, - "sha3Uncles": head.UncleHash, - "logsBloom": head.Bloom, - "stateRoot": head.Root, - "miner": head.Coinbase, - "difficulty": (*hexutil.Big)(head.Difficulty), - "extraData": hexutil.Bytes(head.Extra), - "size": hexutil.Uint64(b.Size()), - "gasLimit": hexutil.Uint64(head.GasLimit), - "gasUsed": hexutil.Uint64(head.GasUsed), - "timestamp": hexutil.Uint64(head.Time), - "transactionsRoot": head.TxHash, - "receiptsRoot": head.ReceiptHash, - } - - if inclTx { - formatTx := func(tx *types.Transaction) (interface{}, error) { - return tx.Hash(), nil - } - if fullTx { - formatTx = func(tx *types.Transaction) (interface{}, error) { - return newRPCTransactionFromBlockHash(b, tx.Hash()), nil - } - } - txs := b.Transactions() - transactions := make([]interface{}, len(txs)) - var err error - for i, tx := range txs { - if transactions[i], err = formatTx(tx); err != nil { - return nil, err - } - } - fields["transactions"] = transactions - } - - uncles := b.Uncles() - uncleHashes := make([]common.Hash, len(uncles)) - for i, uncle := range uncles { - uncleHashes[i] = uncle.Hash() - } - fields["uncles"] = uncleHashes - - return fields, nil -} diff --git a/cmd/geth/run_test.go b/cmd/geth/run_test.go index f7b735b84c14..79b892c59bc2 100644 --- a/cmd/geth/run_test.go +++ b/cmd/geth/run_test.go @@ -70,12 +70,12 @@ func runGeth(t *testing.T, args ...string) *testgeth { tt := &testgeth{} tt.TestCmd = cmdtest.NewTestCmd(t, tt) for i, arg := range args { - switch { - case arg == "-datadir" || arg == "--datadir": + switch arg { + case "--datadir": if i < len(args)-1 { tt.Datadir = args[i+1] } - case arg == "-etherbase" || arg == "--etherbase": + case "--etherbase": if i < len(args)-1 { tt.Etherbase = args[i+1] } @@ -84,7 +84,7 @@ func runGeth(t *testing.T, args ...string) *testgeth { if tt.Datadir == "" { tt.Datadir = tmpdir(t) tt.Cleanup = func() { os.RemoveAll(tt.Datadir) } - args = append([]string{"-datadir", tt.Datadir}, args...) + args = append([]string{"--datadir", tt.Datadir}, args...) // Remove the temporary datadir if something fails below. defer func() { if t.Failed() { diff --git a/cmd/geth/testdata/vcheck/data.json b/cmd/geth/testdata/vcheck/data.json new file mode 100644 index 000000000000..e7ee2bf7e44c --- /dev/null +++ b/cmd/geth/testdata/vcheck/data.json @@ -0,0 +1,61 @@ +[ + { + "name": "CorruptedDAG", + "uid": "GETH-2020-01", + "summary": "Mining nodes will generate erroneous PoW on epochs > `385`.", + "description": "A mining flaw could cause miners to erroneously calculate PoW, due to an index overflow, if DAG size is exceeding the maximum 32 bit unsigned value.\n\nThis occurred on the ETC chain on 2020-11-06. This is likely to trigger for ETH mainnet around block `11550000`/epoch `385`, slated to occur early January 2021.\n\nThis issue is relevant only for miners, non-mining nodes are unaffected, since non-mining nodes use a smaller verification cache instead of a full DAG.", + "links": [ + "https://github.com/ethereum/go-ethereum/pull/21793", + "https://blog.ethereum.org/2020/11/12/geth_security_release/", + "https://github.com/ethereum/go-ethereum/commit/567d41d9363706b4b13ce0903804e8acf214af49" + ], + "introduced": "v1.6.0", + "fixed": "v1.9.24", + "published": "2020-11-12", + "severity": "Medium", + "check": "Geth\\/v1\\.(6|7|8)\\..*|Geth\\/v1\\.9\\.2(1|2|3)-.*" + }, + { + "name": "GoCrash", + "uid": "GETH-2020-02", + "summary": "A denial-of-service issue can be used to crash Geth nodes during block processing, due to an underlying bug in Go (CVE-2020-28362) versions < `1.15.5`, or `<1.14.12`", + "description": "The DoS issue can be used to crash all Geth nodes during block processing, the effects of which would be that a major part of the Ethereum network went offline.\n\nOutside of Go-Ethereum, the issue is most likely relevant for all forks of Geth (such as TurboGeth or ETC’s core-geth) which is built with versions of Go which contains the vulnerability.", + "links": [ + "https://blog.ethereum.org/2020/11/12/geth_security_release/", + "https://groups.google.com/g/golang-announce/c/NpBGTTmKzpM", + "https://github.com/golang/go/issues/42552" + ], + "fixed": "v1.9.24", + "published": "2020-11-12", + "severity": "Critical", + "check": "Geth.*\\/go1\\.(11(.*)|12(.*)|13(.*)|14|14\\.(\\d|10|11|)|15|15\\.[0-4])$" + }, + { + "name": "ShallowCopy", + "uid": "GETH-2020-03", + "summary": "A consensus flaw in Geth, related to `datacopy` precompile", + "description": "Geth erroneously performed a 'shallow' copy when the precompiled `datacopy` (at `0x00...04`) was invoked. An attacker could deploy a contract that uses the shallow copy to corrupt the contents of the `RETURNDATA`, thus causing a consensus failure.", + "links": [ + "https://blog.ethereum.org/2020/11/12/geth_security_release/" + ], + "introduced": "v1.9.7", + "fixed": "v1.9.17", + "published": "2020-11-12", + "severity": "Critical", + "check": "Geth\\/v1\\.9\\.(7|8|9|10|11|12|13|14|15|16).*$" + }, + { + "name": "GethCrash", + "uid": "GETH-2020-04", + "summary": "A denial-of-service issue can be used to crash Geth nodes during block processing", + "description": "Full details to be disclosed at a later date", + "links": [ + "https://blog.ethereum.org/2020/11/12/geth_security_release/" + ], + "introduced": "v1.9.16", + "fixed": "v1.9.18", + "published": "2020-11-12", + "severity": "Critical", + "check": "Geth\\/v1\\.9.(16|17).*$" + } +] diff --git a/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.1 b/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.1 new file mode 100644 index 000000000000..f9066d4fe0b1 --- /dev/null +++ b/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.1 @@ -0,0 +1,4 @@ +untrusted comment: signature from minisign secret key +RWQkliYstQBOKFQFQTjmCd6TPw07VZyWFSB3v4+1BM1kv8eHLE5FDy2OkPEqtdaL53xftlrHoJQie0uCcovdlSV8kpyxiLrxEQ0= +trusted comment: timestamp:1605618622 file:vulnerabilities.json +osAPs4QPdDkmiWQxqeMIzYv/b+ZGxJ+19Sbrk1Cpq4t2gHBT+lqFtwL3OCzKWWyjGRTmHfsVGBYpzEdPRQ0/BQ== diff --git a/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.2 b/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.2 new file mode 100644 index 000000000000..a89a83d21a57 --- /dev/null +++ b/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.2 @@ -0,0 +1,4 @@ +untrusted comment: Here's a comment +RWQkliYstQBOKFQFQTjmCd6TPw07VZyWFSB3v4+1BM1kv8eHLE5FDy2OkPEqtdaL53xftlrHoJQie0uCcovdlSV8kpyxiLrxEQ0= +trusted comment: Here's a trusted comment +3CnkIuz9MEDa7uNyGZAbKZhuirwfiqm7E1uQHrd2SiO4Y8+Akw9vs052AyKw0s5nhbYHCZE2IMQdHNjKwxEGAQ== diff --git a/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.3 b/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.3 new file mode 100644 index 000000000000..6fd33b19a3cd --- /dev/null +++ b/cmd/geth/testdata/vcheck/minisig-sigs/vulnerabilities.json.minisig.3 @@ -0,0 +1,4 @@ +untrusted comment: One more (untrusted) comment +RWQkliYstQBOKFQFQTjmCd6TPw07VZyWFSB3v4+1BM1kv8eHLE5FDy2OkPEqtdaL53xftlrHoJQie0uCcovdlSV8kpyxiLrxEQ0= +trusted comment: Here's a trusted comment +3CnkIuz9MEDa7uNyGZAbKZhuirwfiqm7E1uQHrd2SiO4Y8+Akw9vs052AyKw0s5nhbYHCZE2IMQdHNjKwxEGAQ== diff --git a/cmd/geth/testdata/vcheck/minisign.pub b/cmd/geth/testdata/vcheck/minisign.pub new file mode 100644 index 000000000000..183dce5f6b5d --- /dev/null +++ b/cmd/geth/testdata/vcheck/minisign.pub @@ -0,0 +1,2 @@ +untrusted comment: minisign public key 284E00B52C269624 +RWQkliYstQBOKOdtClfgC3IypIPX6TAmoEi7beZ4gyR3wsaezvqOMWsp diff --git a/cmd/geth/testdata/vcheck/minisign.sec b/cmd/geth/testdata/vcheck/minisign.sec new file mode 100644 index 000000000000..5c50715b2096 --- /dev/null +++ b/cmd/geth/testdata/vcheck/minisign.sec @@ -0,0 +1,2 @@ +untrusted comment: minisign encrypted secret key +RWRTY0Iyz8kmPMKrqk6DCtlO9a33akKiaOQG1aLolqDxs52qvPoAAAACAAAAAAAAAEAAAAAArEiggdvyn6+WzTprirLtgiYQoU+ihz/HyGgjhuF+Pz2ddMduyCO+xjCHeq+vgVVW039fbsI8hW6LRGJZLBKV5/jdxCXAVVQE7qTQ6xpEdO0z8Z731/pV1hlspQXG2PNd16NMtwd9dWw= diff --git a/cmd/geth/testdata/vcheck/signify-sigs/data.json.sig b/cmd/geth/testdata/vcheck/signify-sigs/data.json.sig new file mode 100644 index 000000000000..3d5fcacf9ae2 --- /dev/null +++ b/cmd/geth/testdata/vcheck/signify-sigs/data.json.sig @@ -0,0 +1,2 @@ +untrusted comment: verify with ./signifykey.pub +RWSKLNhZb0KdAbhRUhW2LQZXdnwttu2SYhM9EuC4mMgOJB85h7/YIPupf8/ldTs4N8e9Y/fhgdY40q5LQpt5IFC62fq0v8U1/w8= diff --git a/cmd/geth/testdata/vcheck/signifykey.pub b/cmd/geth/testdata/vcheck/signifykey.pub new file mode 100644 index 000000000000..328f973ab4dc --- /dev/null +++ b/cmd/geth/testdata/vcheck/signifykey.pub @@ -0,0 +1,2 @@ +untrusted comment: signify public key +RWSKLNhZb0KdATtRT7mZC/bybI3t3+Hv/O2i3ye04Dq9fnT9slpZ1a2/ diff --git a/cmd/geth/testdata/vcheck/signifykey.sec b/cmd/geth/testdata/vcheck/signifykey.sec new file mode 100644 index 000000000000..3279a2e58b55 --- /dev/null +++ b/cmd/geth/testdata/vcheck/signifykey.sec @@ -0,0 +1,2 @@ +untrusted comment: signify secret key +RWRCSwAAACpLQDLawSQCtI7eAVIvaiHzjTsTyJsfV5aKLNhZb0KdAWeICXJGa93/bHAcsY6jUh9I8RdEcDWEoGxmaXZC+IdVBPxDpkix9fBRGEUdKWHi3dOfqME0YRzErWI5AVg3cRw= diff --git a/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.1 b/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.1 new file mode 100644 index 000000000000..f9066d4fe0b1 --- /dev/null +++ b/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.1 @@ -0,0 +1,4 @@ +untrusted comment: signature from minisign secret key +RWQkliYstQBOKFQFQTjmCd6TPw07VZyWFSB3v4+1BM1kv8eHLE5FDy2OkPEqtdaL53xftlrHoJQie0uCcovdlSV8kpyxiLrxEQ0= +trusted comment: timestamp:1605618622 file:vulnerabilities.json +osAPs4QPdDkmiWQxqeMIzYv/b+ZGxJ+19Sbrk1Cpq4t2gHBT+lqFtwL3OCzKWWyjGRTmHfsVGBYpzEdPRQ0/BQ== diff --git a/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.2 b/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.2 new file mode 100644 index 000000000000..a89a83d21a57 --- /dev/null +++ b/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.2 @@ -0,0 +1,4 @@ +untrusted comment: Here's a comment +RWQkliYstQBOKFQFQTjmCd6TPw07VZyWFSB3v4+1BM1kv8eHLE5FDy2OkPEqtdaL53xftlrHoJQie0uCcovdlSV8kpyxiLrxEQ0= +trusted comment: Here's a trusted comment +3CnkIuz9MEDa7uNyGZAbKZhuirwfiqm7E1uQHrd2SiO4Y8+Akw9vs052AyKw0s5nhbYHCZE2IMQdHNjKwxEGAQ== diff --git a/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.3 b/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.3 new file mode 100644 index 000000000000..6fd33b19a3cd --- /dev/null +++ b/cmd/geth/testdata/vcheck/sigs/vulnerabilities.json.minisig.3 @@ -0,0 +1,4 @@ +untrusted comment: One more (untrusted) comment +RWQkliYstQBOKFQFQTjmCd6TPw07VZyWFSB3v4+1BM1kv8eHLE5FDy2OkPEqtdaL53xftlrHoJQie0uCcovdlSV8kpyxiLrxEQ0= +trusted comment: Here's a trusted comment +3CnkIuz9MEDa7uNyGZAbKZhuirwfiqm7E1uQHrd2SiO4Y8+Akw9vs052AyKw0s5nhbYHCZE2IMQdHNjKwxEGAQ== diff --git a/cmd/geth/testdata/vcheck/vulnerabilities.json b/cmd/geth/testdata/vcheck/vulnerabilities.json new file mode 100644 index 000000000000..36509f95a98a --- /dev/null +++ b/cmd/geth/testdata/vcheck/vulnerabilities.json @@ -0,0 +1,70 @@ +[ + { + "name": "CorruptedDAG", + "uid": "GETH-2020-01", + "summary": "Mining nodes will generate erroneous PoW on epochs > `385`.", + "description": "A mining flaw could cause miners to erroneously calculate PoW, due to an index overflow, if DAG size is exceeding the maximum 32 bit unsigned value.\n\nThis occurred on the ETC chain on 2020-11-06. This is likely to trigger for ETH mainnet around block `11550000`/epoch `385`, slated to occur early January 2021.\n\nThis issue is relevant only for miners, non-mining nodes are unaffected, since non-mining nodes use a smaller verification cache instead of a full DAG.", + "links": [ + "https://github.com/ethereum/go-ethereum/pull/21793", + "https://blog.ethereum.org/2020/11/12/geth_security_release/", + "https://github.com/ethereum/go-ethereum/commit/567d41d9363706b4b13ce0903804e8acf214af49", + "https://github.com/ethereum/go-ethereum/security/advisories/GHSA-v592-xf75-856p" + ], + "introduced": "v1.6.0", + "fixed": "v1.9.24", + "published": "2020-11-12", + "severity": "Medium", + "CVE": "CVE-2020-26240", + "check": "Geth\\/v1\\.(6|7|8)\\..*|Geth\\/v1\\.9\\.\\d-.*|Geth\\/v1\\.9\\.1.*|Geth\\/v1\\.9\\.2(0|1|2|3)-.*" + }, + { + "name": "Denial of service due to Go CVE-2020-28362", + "uid": "GETH-2020-02", + "summary": "A denial-of-service issue can be used to crash Geth nodes during block processing, due to an underlying bug in Go (CVE-2020-28362) versions < `1.15.5`, or `<1.14.12`", + "description": "The DoS issue can be used to crash all Geth nodes during block processing, the effects of which would be that a major part of the Ethereum network went offline.\n\nOutside of Go-Ethereum, the issue is most likely relevant for all forks of Geth (such as TurboGeth or ETC’s core-geth) which is built with versions of Go which contains the vulnerability.", + "links": [ + "https://blog.ethereum.org/2020/11/12/geth_security_release/", + "https://groups.google.com/g/golang-announce/c/NpBGTTmKzpM", + "https://github.com/golang/go/issues/42552", + "https://github.com/ethereum/go-ethereum/security/advisories/GHSA-m6gx-rhvj-fh52" + ], + "introduced": "v0.0.0", + "fixed": "v1.9.24", + "published": "2020-11-12", + "severity": "Critical", + "CVE": "CVE-2020-28362", + "check": "Geth.*\\/go1\\.(11(.*)|12(.*)|13(.*)|14|14\\.(\\d|10|11|)|15|15\\.[0-4])$" + }, + { + "name": "ShallowCopy", + "uid": "GETH-2020-03", + "summary": "A consensus flaw in Geth, related to `datacopy` precompile", + "description": "Geth erroneously performed a 'shallow' copy when the precompiled `datacopy` (at `0x00...04`) was invoked. An attacker could deploy a contract that uses the shallow copy to corrupt the contents of the `RETURNDATA`, thus causing a consensus failure.", + "links": [ + "https://blog.ethereum.org/2020/11/12/geth_security_release/", + "https://github.com/ethereum/go-ethereum/security/advisories/GHSA-69v6-xc2j-r2jf" + ], + "introduced": "v1.9.7", + "fixed": "v1.9.17", + "published": "2020-11-12", + "severity": "Critical", + "CVE": "CVE-2020-26241", + "check": "Geth\\/v1\\.9\\.(7|8|9|10|11|12|13|14|15|16).*$" + }, + { + "name": "GethCrash", + "uid": "GETH-2020-04", + "summary": "A denial-of-service issue can be used to crash Geth nodes during block processing", + "description": "Full details to be disclosed at a later date", + "links": [ + "https://blog.ethereum.org/2020/11/12/geth_security_release/", + "https://github.com/ethereum/go-ethereum/security/advisories/GHSA-jm5c-rv3w-w83m" + ], + "introduced": "v1.9.16", + "fixed": "v1.9.18", + "published": "2020-11-12", + "severity": "Critical", + "CVE": "CVE-2020-26242", + "check": "Geth\\/v1\\.9.(16|17).*$" + } +] diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 334a729c24d8..0e70451ed385 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -25,7 +25,7 @@ import ( "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/internal/debug" "github.com/ethereum/go-ethereum/internal/flags" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) // AppHelpFlagGroups is the application flags, grouped by functionality. @@ -42,7 +42,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.NetworkIdFlag, utils.GoerliFlag, utils.RinkebyFlag, - utils.YoloV1Flag, + utils.YoloV2Flag, utils.RopstenFlag, utils.SyncModeFlag, utils.ExitWhenSyncedFlag, @@ -114,6 +114,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.CacheGCFlag, utils.CacheSnapshotFlag, utils.CacheNoPrefetchFlag, + utils.CachePreimagesFlag, }, }, { @@ -144,8 +145,8 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.GraphQLEnabledFlag, utils.GraphQLCORSDomainFlag, utils.GraphQLVirtualHostsFlag, - utils.RPCGlobalGasCap, - utils.RPCGlobalTxFeeCap, + utils.RPCGlobalGasCapFlag, + utils.RPCGlobalTxFeeCapFlag, utils.JSpathFlag, utils.ExecFlag, utils.PreloadJSFlag, diff --git a/cmd/geth/version_check.go b/cmd/geth/version_check.go new file mode 100644 index 000000000000..2101a69e9886 --- /dev/null +++ b/cmd/geth/version_check.go @@ -0,0 +1,169 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "regexp" + "strings" + + "github.com/ethereum/go-ethereum/log" + "github.com/jedisct1/go-minisign" + "gopkg.in/urfave/cli.v1" +) + +var gethPubKeys []string = []string{ + //@holiman, minisign public key FB1D084D39BAEC24 + "RWQk7Lo5TQgd+wxBNZM+Zoy+7UhhMHaWKzqoes9tvSbFLJYZhNTbrIjx", + //minisign public key 138B1CA303E51687 + "RWSHFuUDoxyLEzjszuWZI1xStS66QTyXFFZG18uDfO26CuCsbckX1e9J", + //minisign public key FD9813B2D2098484 + "RWSEhAnSshOY/b+GmaiDkObbCWefsAoavjoLcPjBo1xn71yuOH5I+Lts", +} + +type vulnJson struct { + Name string + Uid string + Summary string + Description string + Links []string + Introduced string + Fixed string + Published string + Severity string + Check string + CVE string +} + +func versionCheck(ctx *cli.Context) error { + url := ctx.String(VersionCheckUrlFlag.Name) + version := ctx.String(VersionCheckVersionFlag.Name) + log.Info("Checking vulnerabilities", "version", version, "url", url) + return checkCurrent(url, version) +} + +func checkCurrent(url, current string) error { + var ( + data []byte + sig []byte + err error + ) + if data, err = fetch(url); err != nil { + return fmt.Errorf("could not retrieve data: %w", err) + } + if sig, err = fetch(fmt.Sprintf("%v.minisig", url)); err != nil { + return fmt.Errorf("could not retrieve signature: %w", err) + } + if err = verifySignature(gethPubKeys, data, sig); err != nil { + return err + } + var vulns []vulnJson + if err = json.Unmarshal(data, &vulns); err != nil { + return err + } + allOk := true + for _, vuln := range vulns { + r, err := regexp.Compile(vuln.Check) + if err != nil { + return err + } + if r.MatchString(current) { + allOk = false + fmt.Printf("## Vulnerable to %v (%v)\n\n", vuln.Uid, vuln.Name) + fmt.Printf("Severity: %v\n", vuln.Severity) + fmt.Printf("Summary : %v\n", vuln.Summary) + fmt.Printf("Fixed in: %v\n", vuln.Fixed) + if len(vuln.CVE) > 0 { + fmt.Printf("CVE: %v\n", vuln.CVE) + } + if len(vuln.Links) > 0 { + fmt.Printf("References:\n") + for _, ref := range vuln.Links { + fmt.Printf("\t- %v\n", ref) + } + } + fmt.Println() + } + } + if allOk { + fmt.Println("No vulnerabilities found") + } + return nil +} + +// fetch makes an HTTP request to the given url and returns the response body +func fetch(url string) ([]byte, error) { + if filep := strings.TrimPrefix(url, "file://"); filep != url { + return ioutil.ReadFile(filep) + } + res, err := http.Get(url) + if err != nil { + return nil, err + } + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, err + } + return body, nil +} + +// verifySignature checks that the sigData is a valid signature of the given +// data, for pubkey GethPubkey +func verifySignature(pubkeys []string, data, sigdata []byte) error { + sig, err := minisign.DecodeSignature(string(sigdata)) + if err != nil { + return err + } + // find the used key + var key *minisign.PublicKey + for _, pubkey := range pubkeys { + pub, err := minisign.NewPublicKey(pubkey) + if err != nil { + // our pubkeys should be parseable + return err + } + if pub.KeyId != sig.KeyId { + continue + } + key = &pub + break + } + if key == nil { + log.Info("Signing key not trusted", "keyid", keyID(sig.KeyId), "error", err) + return errors.New("signature could not be verified") + } + if ok, err := key.Verify(data, sig); !ok || err != nil { + log.Info("Verification failed error", "keyid", keyID(key.KeyId), "error", err) + return errors.New("signature could not be verified") + } + return nil +} + +// keyID turns a binary minisign key ID into a hex string. +// Note: key IDs are printed in reverse byte order. +func keyID(id [8]byte) string { + var rev [8]byte + for i := range id { + rev[len(rev)-1-i] = id[i] + } + return fmt.Sprintf("%X", rev) +} diff --git a/cmd/geth/version_check_test.go b/cmd/geth/version_check_test.go new file mode 100644 index 000000000000..0f056d1967d4 --- /dev/null +++ b/cmd/geth/version_check_test.go @@ -0,0 +1,130 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "regexp" + "strconv" + "strings" + "testing" +) + +func TestVerification(t *testing.T) { + // Signatures generated with `minisign` + t.Run("minisig", func(t *testing.T) { + // For this test, the pubkey is in testdata/minisign.pub + // (the privkey is `minisign.sec`, if we want to expand this test. Password 'test' ) + pub := "RWQkliYstQBOKOdtClfgC3IypIPX6TAmoEi7beZ4gyR3wsaezvqOMWsp" + testVerification(t, pub, "./testdata/vcheck/minisig-sigs/") + }) + // Signatures generated with `signify-openbsd` + t.Run("signify-openbsd", func(t *testing.T) { + t.Skip("This currently fails, minisign expects 4 lines of data, signify provides only 2") + // For this test, the pubkey is in testdata/signifykey.pub + // (the privkey is `signifykey.sec`, if we want to expand this test. Password 'test' ) + pub := "RWSKLNhZb0KdATtRT7mZC/bybI3t3+Hv/O2i3ye04Dq9fnT9slpZ1a2/" + testVerification(t, pub, "./testdata/vcheck/signify-sigs/") + }) +} + +func testVerification(t *testing.T, pubkey, sigdir string) { + // Data to verify + data, err := ioutil.ReadFile("./testdata/vcheck/data.json") + if err != nil { + t.Fatal(err) + } + // Signatures, with and without comments, both trusted and untrusted + files, err := ioutil.ReadDir(sigdir) + if err != nil { + t.Fatal(err) + } + for _, f := range files { + sig, err := ioutil.ReadFile(filepath.Join(sigdir, f.Name())) + if err != nil { + t.Fatal(err) + } + err = verifySignature([]string{pubkey}, data, sig) + if err != nil { + t.Fatal(err) + } + } +} + +func versionUint(v string) int { + mustInt := func(s string) int { + a, err := strconv.Atoi(s) + if err != nil { + panic(v) + } + return a + } + components := strings.Split(strings.TrimPrefix(v, "v"), ".") + a := mustInt(components[0]) + b := mustInt(components[1]) + c := mustInt(components[2]) + return a*100*100 + b*100 + c +} + +// TestMatching can be used to check that the regexps are correct +func TestMatching(t *testing.T) { + data, _ := ioutil.ReadFile("./testdata/vcheck/vulnerabilities.json") + var vulns []vulnJson + if err := json.Unmarshal(data, &vulns); err != nil { + t.Fatal(err) + } + check := func(version string) { + vFull := fmt.Sprintf("Geth/%v-unstable-15339cf1-20201204/linux-amd64/go1.15.4", version) + for _, vuln := range vulns { + r, err := regexp.Compile(vuln.Check) + vulnIntro := versionUint(vuln.Introduced) + vulnFixed := versionUint(vuln.Fixed) + current := versionUint(version) + if err != nil { + t.Fatal(err) + } + if vuln.Name == "Denial of service due to Go CVE-2020-28362" { + // this one is not tied to geth-versions + continue + } + if vulnIntro <= current && vulnFixed > current { + // Should be vulnerable + if !r.MatchString(vFull) { + t.Errorf("Should be vulnerable, version %v, intro: %v, fixed: %v %v %v", + version, vuln.Introduced, vuln.Fixed, vuln.Name, vuln.Check) + } + } else { + if r.MatchString(vFull) { + t.Errorf("Should not be flagged vulnerable, version %v, intro: %v, fixed: %v %v %d %d %d", + version, vuln.Introduced, vuln.Fixed, vuln.Name, vulnIntro, current, vulnFixed) + } + } + + } + } + for major := 1; major < 2; major++ { + for minor := 0; minor < 30; minor++ { + for patch := 0; patch < 30; patch++ { + vShort := fmt.Sprintf("v%d.%d.%d", major, minor, patch) + check(vShort) + } + } + } +} diff --git a/cmd/puppeth/genesis.go b/cmd/puppeth/genesis.go index b3e1709dbf4f..46268ec11bc4 100644 --- a/cmd/puppeth/genesis.go +++ b/cmd/puppeth/genesis.go @@ -152,7 +152,7 @@ func newAlethGenesisSpec(network string, genesis *core.Genesis) (*alethGenesisSp spec.Genesis.Author = genesis.Coinbase spec.Genesis.Timestamp = (hexutil.Uint64)(genesis.Timestamp) spec.Genesis.ParentHash = genesis.ParentHash - spec.Genesis.ExtraData = (hexutil.Bytes)(genesis.ExtraData) + spec.Genesis.ExtraData = genesis.ExtraData spec.Genesis.GasLimit = (hexutil.Uint64)(genesis.GasLimit) for address, account := range genesis.Alloc { @@ -430,7 +430,7 @@ func newParityChainSpec(network string, genesis *core.Genesis, bootnodes []strin spec.Genesis.Author = genesis.Coinbase spec.Genesis.Timestamp = (hexutil.Uint64)(genesis.Timestamp) spec.Genesis.ParentHash = genesis.ParentHash - spec.Genesis.ExtraData = (hexutil.Bytes)(genesis.ExtraData) + spec.Genesis.ExtraData = genesis.ExtraData spec.Genesis.GasLimit = (hexutil.Uint64)(genesis.GasLimit) spec.Accounts = make(map[common.UnprefixedAddress]*parityChainSpecAccount) diff --git a/cmd/puppeth/module_faucet.go b/cmd/puppeth/module_faucet.go index 987bed14aa57..2527e137f2c6 100644 --- a/cmd/puppeth/module_faucet.go +++ b/cmd/puppeth/module_faucet.go @@ -46,6 +46,7 @@ ENTRYPOINT [ \ "--faucet.name", "{{.FaucetName}}", "--faucet.amount", "{{.FaucetAmount}}", "--faucet.minutes", "{{.FaucetMinutes}}", "--faucet.tiers", "{{.FaucetTiers}}", \ "--account.json", "/account.json", "--account.pass", "/account.pass" \ {{if .CaptchaToken}}, "--captcha.token", "{{.CaptchaToken}}", "--captcha.secret", "{{.CaptchaSecret}}"{{end}}{{if .NoAuth}}, "--noauth"{{end}} \ + {{if .TwitterToken}}, "--twitter.token", "{{.TwitterToken}}", ]` // faucetComposefile is the docker-compose.yml file required to deploy and maintain @@ -71,6 +72,7 @@ services: - FAUCET_TIERS={{.FaucetTiers}} - CAPTCHA_TOKEN={{.CaptchaToken}} - CAPTCHA_SECRET={{.CaptchaSecret}} + - TWITTER_TOKEN={{.TwitterToken}} - NO_AUTH={{.NoAuth}}{{if .VHost}} - VIRTUAL_HOST={{.VHost}} - VIRTUAL_PORT=8080{{end}} @@ -103,6 +105,7 @@ func deployFaucet(client *sshClient, network string, bootnodes []string, config "FaucetMinutes": config.minutes, "FaucetTiers": config.tiers, "NoAuth": config.noauth, + "TwitterToken": config.twitterToken, }) files[filepath.Join(workdir, "Dockerfile")] = dockerfile.Bytes() @@ -120,6 +123,7 @@ func deployFaucet(client *sshClient, network string, bootnodes []string, config "FaucetMinutes": config.minutes, "FaucetTiers": config.tiers, "NoAuth": config.noauth, + "TwitterToken": config.twitterToken, }) files[filepath.Join(workdir, "docker-compose.yaml")] = composefile.Bytes() @@ -152,6 +156,7 @@ type faucetInfos struct { noauth bool captchaToken string captchaSecret string + twitterToken string } // Report converts the typed struct into a plain string->string map, containing @@ -165,6 +170,7 @@ func (info *faucetInfos) Report() map[string]string { "Funding cooldown (base tier)": fmt.Sprintf("%d mins", info.minutes), "Funding tiers": strconv.Itoa(info.tiers), "Captha protection": fmt.Sprintf("%v", info.captchaToken != ""), + "Using Twitter API": fmt.Sprintf("%v", info.twitterToken != ""), "Ethstats username": info.node.ethstats, } if info.noauth { @@ -243,5 +249,6 @@ func checkFaucet(client *sshClient, network string) (*faucetInfos, error) { captchaToken: infos.envvars["CAPTCHA_TOKEN"], captchaSecret: infos.envvars["CAPTCHA_SECRET"], noauth: infos.envvars["NO_AUTH"] == "true", + twitterToken: infos.envvars["TWITTER_TOKEN"], }, nil } diff --git a/cmd/puppeth/wizard_faucet.go b/cmd/puppeth/wizard_faucet.go index 9f753ad68bb9..47e05cd9c106 100644 --- a/cmd/puppeth/wizard_faucet.go +++ b/cmd/puppeth/wizard_faucet.go @@ -102,6 +102,29 @@ func (w *wizard) deployFaucet() { infos.captchaSecret = w.readPassword() } } + + // Accessing the twitter api requires a bearer token, request it + if infos.twitterToken != "" { + fmt.Println() + fmt.Println("Reuse previous twitter API Bearer token (y/n)? (default = yes)") + if !w.readDefaultYesNo(true) { + infos.twitterToken = "" + } + } + if infos.twitterToken == "" { + // No previous twitter token (or old one discarded) + fmt.Println() + fmt.Println("Enable twitter API (y/n)? (default = no)") + if !w.readDefaultYesNo(false) { + log.Warn("The faucet will fallback to using direct calls") + } else { + // Twitter api explicitly requested, read the bearer token + fmt.Println() + fmt.Printf("What is the twitter API Bearer token?\n") + infos.twitterToken = w.readString() + } + } + // Figure out where the user wants to store the persistent data fmt.Println() if infos.node.datadir == "" { diff --git a/cmd/puppeth/wizard_genesis.go b/cmd/puppeth/wizard_genesis.go index 40327d25d226..2d014e83bca4 100644 --- a/cmd/puppeth/wizard_genesis.go +++ b/cmd/puppeth/wizard_genesis.go @@ -236,8 +236,8 @@ func (w *wizard) manageGenesis() { w.conf.Genesis.Config.IstanbulBlock = w.readDefaultBigInt(w.conf.Genesis.Config.IstanbulBlock) fmt.Println() - fmt.Printf("Which block should YOLOv1 come into effect? (default = %v)\n", w.conf.Genesis.Config.YoloV1Block) - w.conf.Genesis.Config.YoloV1Block = w.readDefaultBigInt(w.conf.Genesis.Config.YoloV1Block) + fmt.Printf("Which block should YOLOv2 come into effect? (default = %v)\n", w.conf.Genesis.Config.YoloV2Block) + w.conf.Genesis.Config.YoloV2Block = w.readDefaultBigInt(w.conf.Genesis.Config.YoloV2Block) out, _ := json.MarshalIndent(w.conf.Genesis.Config, "", " ") fmt.Printf("Chain configuration updated:\n\n%s\n", out) diff --git a/cmd/utils/customflags.go b/cmd/utils/customflags.go index 17dcba5f74fe..0a72e80349d4 100644 --- a/cmd/utils/customflags.go +++ b/cmd/utils/customflags.go @@ -192,14 +192,14 @@ func GlobalBig(ctx *cli.Context, name string) *big.Int { // Note, it has limitations, e.g. ~someuser/tmp will not be expanded func expandPath(p string) string { if strings.HasPrefix(p, "~/") || strings.HasPrefix(p, "~\\") { - if home := homeDir(); home != "" { + if home := HomeDir(); home != "" { p = home + p[1:] } } return path.Clean(os.ExpandEnv(p)) } -func homeDir() string { +func HomeDir() string { if home := os.Getenv("HOME"); home != "" { return home } diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index d47b3a2f66a2..c51d7916ca80 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -64,7 +64,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/params" pcsclite "github.com/gballet/go-libpcsclite" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) func init() { @@ -135,9 +135,9 @@ var ( Name: "goerli", Usage: "Görli network: pre-configured proof-of-authority test network", } - YoloV1Flag = cli.BoolFlag{ - Name: "yolov1", - Usage: "YOLOv1 network: pre-configured proof-of-authority shortlived test network.", + YoloV2Flag = cli.BoolFlag{ + Name: "yolov2", + Usage: "YOLOv2 network: pre-configured proof-of-authority shortlived test network.", } RinkebyFlag = cli.BoolFlag{ Name: "rinkeby", @@ -162,7 +162,7 @@ var ( DocRootFlag = DirectoryFlag{ Name: "docroot", Usage: "Document Root for HTTPClient file scheme", - Value: DirectoryString(homeDir()), + Value: DirectoryString(HomeDir()), } ExitWhenSyncedFlag = cli.BoolFlag{ Name: "exitwhensynced", @@ -187,7 +187,7 @@ var ( defaultSyncMode = eth.DefaultConfig.SyncMode SyncModeFlag = TextMarshalerFlag{ Name: "syncmode", - Usage: `Blockchain sync mode ("fast", "full", or "light")`, + Usage: `Blockchain sync mode ("fast", "full", "snap" or "light")`, Value: &defaultSyncMode, } GCModeFlag = cli.StringFlag{ @@ -383,6 +383,10 @@ var ( Name: "cache.noprefetch", Usage: "Disable heuristic state prefetch during block import (less CPU and disk IO, more time waiting for data)", } + CachePreimagesFlag = cli.BoolTFlag{ + Name: "cache.preimages", + Usage: "Enable recording the SHA3/keccak preimages of trie keys (default: true)", + } // Miner settings MiningEnabledFlag = cli.BoolFlag{ Name: "mine", @@ -454,12 +458,12 @@ var ( Name: "allow-insecure-unlock", Usage: "Allow insecure account unlocking when account-related RPCs are exposed by http", } - RPCGlobalGasCap = cli.Uint64Flag{ + RPCGlobalGasCapFlag = cli.Uint64Flag{ Name: "rpc.gascap", Usage: "Sets a cap on gas that can be used in eth_call/estimateGas (0=infinite)", Value: eth.DefaultConfig.RPCGasCap, } - RPCGlobalTxFeeCap = cli.Float64Flag{ + RPCGlobalTxFeeCapFlag = cli.Float64Flag{ Name: "rpc.txfeecap", Usage: "Sets a cap on transaction fee (in ether) that can be sent via the RPC APIs (0 = no cap)", Value: eth.DefaultConfig.RPCTxFeeCap, @@ -744,8 +748,8 @@ func MakeDataDir(ctx *cli.Context) string { if ctx.GlobalBool(GoerliFlag.Name) { return filepath.Join(path, "goerli") } - if ctx.GlobalBool(YoloV1Flag.Name) { - return filepath.Join(path, "yolo-v1") + if ctx.GlobalBool(YoloV2Flag.Name) { + return filepath.Join(path, "yolo-v2") } return path } @@ -793,9 +797,9 @@ func setBootstrapNodes(ctx *cli.Context, cfg *p2p.Config) { switch { case ctx.GlobalIsSet(BootnodesFlag.Name) || ctx.GlobalIsSet(LegacyBootnodesV4Flag.Name): if ctx.GlobalIsSet(LegacyBootnodesV4Flag.Name) { - urls = splitAndTrim(ctx.GlobalString(LegacyBootnodesV4Flag.Name)) + urls = SplitAndTrim(ctx.GlobalString(LegacyBootnodesV4Flag.Name)) } else { - urls = splitAndTrim(ctx.GlobalString(BootnodesFlag.Name)) + urls = SplitAndTrim(ctx.GlobalString(BootnodesFlag.Name)) } case ctx.GlobalBool(LegacyTestnetFlag.Name) || ctx.GlobalBool(RopstenFlag.Name): urls = params.RopstenBootnodes @@ -803,8 +807,8 @@ func setBootstrapNodes(ctx *cli.Context, cfg *p2p.Config) { urls = params.RinkebyBootnodes case ctx.GlobalBool(GoerliFlag.Name): urls = params.GoerliBootnodes - case ctx.GlobalBool(YoloV1Flag.Name): - urls = params.YoloV1Bootnodes + case ctx.GlobalBool(YoloV2Flag.Name): + urls = params.YoloV2Bootnodes case cfg.BootstrapNodes != nil: return // already set, don't apply defaults. } @@ -829,9 +833,9 @@ func setBootstrapNodesV5(ctx *cli.Context, cfg *p2p.Config) { switch { case ctx.GlobalIsSet(BootnodesFlag.Name) || ctx.GlobalIsSet(LegacyBootnodesV5Flag.Name): if ctx.GlobalIsSet(LegacyBootnodesV5Flag.Name) { - urls = splitAndTrim(ctx.GlobalString(LegacyBootnodesV5Flag.Name)) + urls = SplitAndTrim(ctx.GlobalString(LegacyBootnodesV5Flag.Name)) } else { - urls = splitAndTrim(ctx.GlobalString(BootnodesFlag.Name)) + urls = SplitAndTrim(ctx.GlobalString(BootnodesFlag.Name)) } case ctx.GlobalBool(RopstenFlag.Name): urls = params.RopstenBootnodes @@ -839,8 +843,8 @@ func setBootstrapNodesV5(ctx *cli.Context, cfg *p2p.Config) { urls = params.RinkebyBootnodes case ctx.GlobalBool(GoerliFlag.Name): urls = params.GoerliBootnodes - case ctx.GlobalBool(YoloV1Flag.Name): - urls = params.YoloV1Bootnodes + case ctx.GlobalBool(YoloV2Flag.Name): + urls = params.YoloV2Bootnodes case cfg.BootstrapNodesV5 != nil: return // already set, don't apply defaults. } @@ -877,13 +881,12 @@ func setNAT(ctx *cli.Context, cfg *p2p.Config) { } } -// splitAndTrim splits input separated by a comma +// SplitAndTrim splits input separated by a comma // and trims excessive white space from the substrings. -func splitAndTrim(input string) (ret []string) { +func SplitAndTrim(input string) (ret []string) { l := strings.Split(input, ",") for _, r := range l { - r = strings.TrimSpace(r) - if len(r) > 0 { + if r = strings.TrimSpace(r); r != "" { ret = append(ret, r) } } @@ -917,27 +920,27 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { } if ctx.GlobalIsSet(LegacyRPCCORSDomainFlag.Name) { - cfg.HTTPCors = splitAndTrim(ctx.GlobalString(LegacyRPCCORSDomainFlag.Name)) + cfg.HTTPCors = SplitAndTrim(ctx.GlobalString(LegacyRPCCORSDomainFlag.Name)) log.Warn("The flag --rpccorsdomain is deprecated and will be removed in the future, please use --http.corsdomain") } if ctx.GlobalIsSet(HTTPCORSDomainFlag.Name) { - cfg.HTTPCors = splitAndTrim(ctx.GlobalString(HTTPCORSDomainFlag.Name)) + cfg.HTTPCors = SplitAndTrim(ctx.GlobalString(HTTPCORSDomainFlag.Name)) } if ctx.GlobalIsSet(LegacyRPCApiFlag.Name) { - cfg.HTTPModules = splitAndTrim(ctx.GlobalString(LegacyRPCApiFlag.Name)) + cfg.HTTPModules = SplitAndTrim(ctx.GlobalString(LegacyRPCApiFlag.Name)) log.Warn("The flag --rpcapi is deprecated and will be removed in the future, please use --http.api") } if ctx.GlobalIsSet(HTTPApiFlag.Name) { - cfg.HTTPModules = splitAndTrim(ctx.GlobalString(HTTPApiFlag.Name)) + cfg.HTTPModules = SplitAndTrim(ctx.GlobalString(HTTPApiFlag.Name)) } if ctx.GlobalIsSet(LegacyRPCVirtualHostsFlag.Name) { - cfg.HTTPVirtualHosts = splitAndTrim(ctx.GlobalString(LegacyRPCVirtualHostsFlag.Name)) + cfg.HTTPVirtualHosts = SplitAndTrim(ctx.GlobalString(LegacyRPCVirtualHostsFlag.Name)) log.Warn("The flag --rpcvhosts is deprecated and will be removed in the future, please use --http.vhosts") } if ctx.GlobalIsSet(HTTPVirtualHostsFlag.Name) { - cfg.HTTPVirtualHosts = splitAndTrim(ctx.GlobalString(HTTPVirtualHostsFlag.Name)) + cfg.HTTPVirtualHosts = SplitAndTrim(ctx.GlobalString(HTTPVirtualHostsFlag.Name)) } } @@ -945,10 +948,10 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { // command line flags, returning empty if the GraphQL endpoint is disabled. func setGraphQL(ctx *cli.Context, cfg *node.Config) { if ctx.GlobalIsSet(GraphQLCORSDomainFlag.Name) { - cfg.GraphQLCors = splitAndTrim(ctx.GlobalString(GraphQLCORSDomainFlag.Name)) + cfg.GraphQLCors = SplitAndTrim(ctx.GlobalString(GraphQLCORSDomainFlag.Name)) } if ctx.GlobalIsSet(GraphQLVirtualHostsFlag.Name) { - cfg.GraphQLVirtualHosts = splitAndTrim(ctx.GlobalString(GraphQLVirtualHostsFlag.Name)) + cfg.GraphQLVirtualHosts = SplitAndTrim(ctx.GlobalString(GraphQLVirtualHostsFlag.Name)) } } @@ -974,19 +977,19 @@ func setWS(ctx *cli.Context, cfg *node.Config) { } if ctx.GlobalIsSet(LegacyWSAllowedOriginsFlag.Name) { - cfg.WSOrigins = splitAndTrim(ctx.GlobalString(LegacyWSAllowedOriginsFlag.Name)) + cfg.WSOrigins = SplitAndTrim(ctx.GlobalString(LegacyWSAllowedOriginsFlag.Name)) log.Warn("The flag --wsorigins is deprecated and will be removed in the future, please use --ws.origins") } if ctx.GlobalIsSet(WSAllowedOriginsFlag.Name) { - cfg.WSOrigins = splitAndTrim(ctx.GlobalString(WSAllowedOriginsFlag.Name)) + cfg.WSOrigins = SplitAndTrim(ctx.GlobalString(WSAllowedOriginsFlag.Name)) } if ctx.GlobalIsSet(LegacyWSApiFlag.Name) { - cfg.WSModules = splitAndTrim(ctx.GlobalString(LegacyWSApiFlag.Name)) + cfg.WSModules = SplitAndTrim(ctx.GlobalString(LegacyWSApiFlag.Name)) log.Warn("The flag --wsapi is deprecated and will be removed in the future, please use --ws.api") } if ctx.GlobalIsSet(WSApiFlag.Name) { - cfg.WSModules = splitAndTrim(ctx.GlobalString(WSApiFlag.Name)) + cfg.WSModules = SplitAndTrim(ctx.GlobalString(WSApiFlag.Name)) } } @@ -1270,8 +1273,8 @@ func setDataDir(ctx *cli.Context, cfg *node.Config) { cfg.DataDir = filepath.Join(node.DefaultDataDir(), "rinkeby") case ctx.GlobalBool(GoerliFlag.Name) && cfg.DataDir == node.DefaultDataDir(): cfg.DataDir = filepath.Join(node.DefaultDataDir(), "goerli") - case ctx.GlobalBool(YoloV1Flag.Name) && cfg.DataDir == node.DefaultDataDir(): - cfg.DataDir = filepath.Join(node.DefaultDataDir(), "yolo-v1") + case ctx.GlobalBool(YoloV2Flag.Name) && cfg.DataDir == node.DefaultDataDir(): + cfg.DataDir = filepath.Join(node.DefaultDataDir(), "yolo-v2") } } @@ -1484,14 +1487,13 @@ func SetShhConfig(ctx *cli.Context, stack *node.Node) { // SetEthConfig applies eth-related command line flags to the config. func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { // Avoid conflicting network flags - CheckExclusive(ctx, DeveloperFlag, LegacyTestnetFlag, RopstenFlag, RinkebyFlag, GoerliFlag, YoloV1Flag) + CheckExclusive(ctx, DeveloperFlag, LegacyTestnetFlag, RopstenFlag, RinkebyFlag, GoerliFlag, YoloV2Flag) CheckExclusive(ctx, LegacyLightServFlag, LightServeFlag, SyncModeFlag, "light") CheckExclusive(ctx, DeveloperFlag, ExternalSignerFlag) // Can't use both ephemeral unlocked and external signer CheckExclusive(ctx, GCModeFlag, "archive", TxLookupLimitFlag) - // todo(rjl493456442) make it available for les server - // Ancient tx indices pruning is not available for les server now - // since light client relies on the server for transaction status query. - CheckExclusive(ctx, LegacyLightServFlag, LightServeFlag, TxLookupLimitFlag) + if (ctx.GlobalIsSet(LegacyLightServFlag.Name) || ctx.GlobalIsSet(LightServeFlag.Name)) && ctx.GlobalIsSet(TxLookupLimitFlag.Name) { + log.Warn("LES server cannot serve old transaction status and cannot connect below les/4 protocol version if transaction lookup index is limited") + } var ks *keystore.KeyStore if keystores := stack.AccountManager().Backends(keystore.KeyStoreType); len(keystores) > 0 { ks = keystores[0].(*keystore.KeyStore) @@ -1527,6 +1529,12 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { if ctx.GlobalIsSet(CacheNoPrefetchFlag.Name) { cfg.NoPrefetch = ctx.GlobalBool(CacheNoPrefetchFlag.Name) } + // Read the value from the flag no matter if it's set or not. + cfg.Preimages = ctx.GlobalBool(CachePreimagesFlag.Name) + if cfg.NoPruning && !cfg.Preimages { + cfg.Preimages = true + log.Info("Enabling recording of key preimages since archive mode is used") + } if ctx.GlobalIsSet(TxLookupLimitFlag.Name) { cfg.TxLookupLimit = ctx.GlobalUint64(TxLookupLimitFlag.Name) } @@ -1546,8 +1554,14 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { cfg.SnapshotCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheSnapshotFlag.Name) / 100 } if !ctx.GlobalIsSet(SnapshotFlag.Name) { - cfg.TrieCleanCache += cfg.SnapshotCache - cfg.SnapshotCache = 0 // Disabled + // If snap-sync is requested, this flag is also required + if cfg.SyncMode == downloader.SnapSync { + log.Info("Snap sync requested, enabling --snapshot") + ctx.Set(SnapshotFlag.Name, "true") + } else { + cfg.TrieCleanCache += cfg.SnapshotCache + cfg.SnapshotCache = 0 // Disabled + } } if ctx.GlobalIsSet(DocRootFlag.Name) { cfg.DocRoot = ctx.GlobalString(DocRootFlag.Name) @@ -1564,26 +1578,27 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { if ctx.GlobalIsSet(EVMInterpreterFlag.Name) { cfg.EVMInterpreter = ctx.GlobalString(EVMInterpreterFlag.Name) } - if ctx.GlobalIsSet(RPCGlobalGasCap.Name) { - cfg.RPCGasCap = ctx.GlobalUint64(RPCGlobalGasCap.Name) + if ctx.GlobalIsSet(RPCGlobalGasCapFlag.Name) { + cfg.RPCGasCap = ctx.GlobalUint64(RPCGlobalGasCapFlag.Name) } if cfg.RPCGasCap != 0 { log.Info("Set global gas cap", "cap", cfg.RPCGasCap) } else { log.Info("Global gas cap disabled") } - if ctx.GlobalIsSet(RPCGlobalTxFeeCap.Name) { - cfg.RPCTxFeeCap = ctx.GlobalFloat64(RPCGlobalTxFeeCap.Name) + if ctx.GlobalIsSet(RPCGlobalTxFeeCapFlag.Name) { + cfg.RPCTxFeeCap = ctx.GlobalFloat64(RPCGlobalTxFeeCapFlag.Name) } - if ctx.GlobalIsSet(DNSDiscoveryFlag.Name) { + if ctx.GlobalIsSet(NoDiscoverFlag.Name) { + cfg.EthDiscoveryURLs, cfg.SnapDiscoveryURLs = []string{}, []string{} + } else if ctx.GlobalIsSet(DNSDiscoveryFlag.Name) { urls := ctx.GlobalString(DNSDiscoveryFlag.Name) if urls == "" { - cfg.DiscoveryURLs = []string{} + cfg.EthDiscoveryURLs = []string{} } else { - cfg.DiscoveryURLs = splitAndTrim(urls) + cfg.EthDiscoveryURLs = SplitAndTrim(urls) } } - // Override any default configs for hard coded networks. switch { case ctx.GlobalBool(LegacyTestnetFlag.Name) || ctx.GlobalBool(RopstenFlag.Name): @@ -1591,24 +1606,24 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { cfg.NetworkId = 3 } cfg.Genesis = core.DefaultRopstenGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.RopstenGenesisHash) + SetDNSDiscoveryDefaults(cfg, params.RopstenGenesisHash) case ctx.GlobalBool(RinkebyFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 4 } cfg.Genesis = core.DefaultRinkebyGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.RinkebyGenesisHash) + SetDNSDiscoveryDefaults(cfg, params.RinkebyGenesisHash) case ctx.GlobalBool(GoerliFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 5 } cfg.Genesis = core.DefaultGoerliGenesisBlock() - setDNSDiscoveryDefaults(cfg, params.GoerliGenesisHash) - case ctx.GlobalBool(YoloV1Flag.Name): + SetDNSDiscoveryDefaults(cfg, params.GoerliGenesisHash) + case ctx.GlobalBool(YoloV2Flag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { - cfg.NetworkId = 133519467574833 // "yolov1" + cfg.NetworkId = 133519467574834 // "yolov2" } - cfg.Genesis = core.DefaultYoloV1GenesisBlock() + cfg.Genesis = core.DefaultYoloV2GenesisBlock() case ctx.GlobalBool(DeveloperFlag.Name): if !ctx.GlobalIsSet(NetworkIdFlag.Name) { cfg.NetworkId = 1337 @@ -1657,24 +1672,28 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { } default: if cfg.NetworkId == 1 { - setDNSDiscoveryDefaults(cfg, params.MainnetGenesisHash) + SetDNSDiscoveryDefaults(cfg, params.MainnetGenesisHash) } } } -// setDNSDiscoveryDefaults configures DNS discovery with the given URL if +// SetDNSDiscoveryDefaults configures DNS discovery with the given URL if // no URLs are set. -func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) { - if cfg.DiscoveryURLs != nil { +func SetDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) { + if cfg.EthDiscoveryURLs != nil { return // already set through flags/config } - protocol := "all" if cfg.SyncMode == downloader.LightSync { protocol = "les" } if url := params.KnownDNSNetwork(genesis, protocol); url != "" { - cfg.DiscoveryURLs = []string{url} + cfg.EthDiscoveryURLs = []string{url} + } + if cfg.SyncMode == downloader.SnapSync { + if url := params.KnownDNSNetwork(genesis, "snap"); url != "" { + cfg.SnapDiscoveryURLs = []string{url} + } } } @@ -1686,19 +1705,18 @@ func RegisterEthService(stack *node.Node, cfg *eth.Config) ethapi.Backend { Fatalf("Failed to register the Ethereum service: %v", err) } return backend.ApiBackend - } else { - backend, err := eth.New(stack, cfg) + } + backend, err := eth.New(stack, cfg) + if err != nil { + Fatalf("Failed to register the Ethereum service: %v", err) + } + if cfg.LightServ > 0 { + _, err := les.NewLesServer(stack, backend, cfg) if err != nil { - Fatalf("Failed to register the Ethereum service: %v", err) + Fatalf("Failed to create the LES server: %v", err) } - if cfg.LightServ > 0 { - _, err := les.NewLesServer(stack, backend, cfg) - if err != nil { - Fatalf("Failed to create the LES server: %v", err) - } - } - return backend.APIBackend } + return backend.APIBackend } // RegisterEthStatsService configures the Ethereum Stats daemon and adds it to @@ -1792,8 +1810,8 @@ func MakeGenesis(ctx *cli.Context) *core.Genesis { genesis = core.DefaultRinkebyGenesisBlock() case ctx.GlobalBool(GoerliFlag.Name): genesis = core.DefaultGoerliGenesisBlock() - case ctx.GlobalBool(YoloV1Flag.Name): - genesis = core.DefaultYoloV1GenesisBlock() + case ctx.GlobalBool(YoloV2Flag.Name): + genesis = core.DefaultYoloV2GenesisBlock() case ctx.GlobalBool(DeveloperFlag.Name): Fatalf("Developer chains are ephemeral") } @@ -1836,6 +1854,11 @@ func MakeChain(ctx *cli.Context, stack *node.Node, readOnly bool) (chain *core.B TrieDirtyDisabled: ctx.GlobalString(GCModeFlag.Name) == "archive", TrieTimeLimit: eth.DefaultConfig.TrieTimeout, SnapshotLimit: eth.DefaultConfig.SnapshotCache, + Preimages: ctx.GlobalBool(CachePreimagesFlag.Name), + } + if cache.TrieDirtyDisabled && !cache.Preimages { + cache.Preimages = true + log.Info("Enabling recording of key preimages since archive mode is used") } if !ctx.GlobalIsSet(SnapshotFlag.Name) { cache.SnapshotLimit = 0 // Disabled diff --git a/common/bitutil/compress_fuzz.go b/common/bitutil/compress_fuzz.go index 1b87f50edc9b..714bbcd131d5 100644 --- a/common/bitutil/compress_fuzz.go +++ b/common/bitutil/compress_fuzz.go @@ -24,7 +24,7 @@ import "bytes" // invocations. func Fuzz(data []byte) int { if len(data) == 0 { - return -1 + return 0 } if data[0]%2 == 0 { return fuzzEncode(data[1:]) @@ -39,7 +39,7 @@ func fuzzEncode(data []byte) int { if !bytes.Equal(data, proc) { panic("content mismatch") } - return 0 + return 1 } // fuzzDecode implements a go-fuzz fuzzer method to test the bit decoding and @@ -52,5 +52,5 @@ func fuzzDecode(data []byte) int { if comp := bitsetEncodeBytes(blob); !bytes.Equal(comp, data) { panic("content mismatch") } - return 0 + return 1 } diff --git a/common/bytes.go b/common/bytes.go index 634041804d0b..7827bb572e13 100644 --- a/common/bytes.go +++ b/common/bytes.go @@ -17,28 +17,9 @@ // Package common contains various helper functions. package common -import "encoding/hex" - -// ToHex returns the hex representation of b, prefixed with '0x'. -// For empty slices, the return value is "0x0". -// -// Deprecated: use hexutil.Encode instead. -func ToHex(b []byte) string { - hex := Bytes2Hex(b) - if len(hex) == 0 { - hex = "0" - } - return "0x" + hex -} - -// ToHexArray creates a array of hex-string based on []byte -func ToHexArray(b [][]byte) []string { - r := make([]string, len(b)) - for i := range b { - r[i] = ToHex(b[i]) - } - return r -} +import ( + "encoding/hex" +) // FromHex returns the bytes represented by the hexadecimal string s. // s may be prefixed with "0x". diff --git a/common/hexutil/json_test.go b/common/hexutil/json_test.go index 8a6b8643a1e9..ed7d6fad1a8e 100644 --- a/common/hexutil/json_test.go +++ b/common/hexutil/json_test.go @@ -88,7 +88,7 @@ func TestUnmarshalBytes(t *testing.T) { if !checkError(t, test.input, err, test.wantErr) { continue } - if !bytes.Equal(test.want.([]byte), []byte(v)) { + if !bytes.Equal(test.want.([]byte), v) { t.Errorf("input %s: value mismatch: got %x, want %x", test.input, &v, test.want) continue } diff --git a/common/math/big.go b/common/math/big.go index 17a57df9dc78..1af5b4d879d6 100644 --- a/common/math/big.go +++ b/common/math/big.go @@ -67,6 +67,40 @@ func (i *HexOrDecimal256) MarshalText() ([]byte, error) { return []byte(fmt.Sprintf("%#x", (*big.Int)(i))), nil } +// Decimal256 unmarshals big.Int as a decimal string. When unmarshalling, +// it however accepts either "0x"-prefixed (hex encoded) or non-prefixed (decimal) +type Decimal256 big.Int + +// NewHexOrDecimal256 creates a new Decimal256 +func NewDecimal256(x int64) *Decimal256 { + b := big.NewInt(x) + d := Decimal256(*b) + return &d +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (i *Decimal256) UnmarshalText(input []byte) error { + bigint, ok := ParseBig256(string(input)) + if !ok { + return fmt.Errorf("invalid hex or decimal integer %q", input) + } + *i = Decimal256(*bigint) + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (i *Decimal256) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + +// String implements Stringer. +func (i *Decimal256) String() string { + if i == nil { + return "0" + } + return fmt.Sprintf("%#d", (*big.Int)(i)) +} + // ParseBig256 parses s as a 256 bit integer in decimal or hexadecimal syntax. // Leading zeros are accepted. The empty string parses as zero. func ParseBig256(s string) (*big.Int, bool) { diff --git a/common/types.go b/common/types.go index cdcc6c20ad53..d920e8b1f117 100644 --- a/common/types.go +++ b/common/types.go @@ -17,6 +17,7 @@ package common import ( + "bytes" "database/sql/driver" "encoding/hex" "encoding/json" @@ -84,10 +85,34 @@ func (h Hash) String() string { return h.Hex() } -// Format implements fmt.Formatter, forcing the byte slice to be formatted as is, -// without going through the stringer interface used for logging. +// Format implements fmt.Formatter. +// Hash supports the %v, %s, %v, %x, %X and %d format verbs. func (h Hash) Format(s fmt.State, c rune) { - fmt.Fprintf(s, "%"+string(c), h[:]) + hexb := make([]byte, 2+len(h)*2) + copy(hexb, "0x") + hex.Encode(hexb[2:], h[:]) + + switch c { + case 'x', 'X': + if !s.Flag('#') { + hexb = hexb[2:] + } + if c == 'X' { + hexb = bytes.ToUpper(hexb) + } + fallthrough + case 'v', 's': + s.Write(hexb) + case 'q': + q := []byte{'"'} + s.Write(q) + s.Write(hexb) + s.Write(q) + case 'd': + fmt.Fprint(s, ([len(h)]byte)(h)) + default: + fmt.Fprintf(s, "%%!%c(hash=%x)", c, h) + } } // UnmarshalText parses a hash in hex syntax. @@ -208,39 +233,72 @@ func (a Address) Hash() Hash { return BytesToHash(a[:]) } // Hex returns an EIP55-compliant hex string representation of the address. func (a Address) Hex() string { - unchecksummed := hex.EncodeToString(a[:]) + return string(a.checksumHex()) +} + +// String implements fmt.Stringer. +func (a Address) String() string { + return a.Hex() +} + +func (a *Address) checksumHex() []byte { + buf := a.hex() + + // compute checksum sha := sha3.NewLegacyKeccak256() - sha.Write([]byte(unchecksummed)) + sha.Write(buf[2:]) hash := sha.Sum(nil) - - result := []byte(unchecksummed) - for i := 0; i < len(result); i++ { - hashByte := hash[i/2] + for i := 2; i < len(buf); i++ { + hashByte := hash[(i-2)/2] if i%2 == 0 { hashByte = hashByte >> 4 } else { hashByte &= 0xf } - if result[i] > '9' && hashByte > 7 { - result[i] -= 32 + if buf[i] > '9' && hashByte > 7 { + buf[i] -= 32 } } - return "0x" + string(result) + return buf[:] } -// String implements fmt.Stringer. -func (a Address) String() string { - return a.Hex() +func (a Address) hex() []byte { + var buf [len(a)*2 + 2]byte + copy(buf[:2], "0x") + hex.Encode(buf[2:], a[:]) + return buf[:] } -// Format implements fmt.Formatter, forcing the byte slice to be formatted as is, -// without going through the stringer interface used for logging. +// Format implements fmt.Formatter. +// Address supports the %v, %s, %v, %x, %X and %d format verbs. func (a Address) Format(s fmt.State, c rune) { - fmt.Fprintf(s, "%"+string(c), a[:]) + switch c { + case 'v', 's': + s.Write(a.checksumHex()) + case 'q': + q := []byte{'"'} + s.Write(q) + s.Write(a.checksumHex()) + s.Write(q) + case 'x', 'X': + // %x disables the checksum. + hex := a.hex() + if !s.Flag('#') { + hex = hex[2:] + } + if c == 'X' { + hex = bytes.ToUpper(hex) + } + s.Write(hex) + case 'd': + fmt.Fprint(s, ([len(a)]byte)(a)) + default: + fmt.Fprintf(s, "%%!%c(address=%x)", c, a) + } } // SetBytes sets the address to the value of b. -// If b is larger than len(a) it will panic. +// If b is larger than len(a), b will be cropped from the left. func (a *Address) SetBytes(b []byte) { if len(b) > len(a) { b = b[len(b)-AddressLength:] diff --git a/common/types_test.go b/common/types_test.go index fffd673c6ee2..318e985f870b 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -17,8 +17,10 @@ package common import ( + "bytes" "database/sql/driver" "encoding/json" + "fmt" "math/big" "reflect" "strings" @@ -371,3 +373,167 @@ func TestAddress_Value(t *testing.T) { }) } } + +func TestAddress_Format(t *testing.T) { + b := []byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + } + var addr Address + addr.SetBytes(b) + + tests := []struct { + name string + out string + want string + }{ + { + name: "println", + out: fmt.Sprintln(addr), + want: "0xB26f2b342AAb24BCF63ea218c6A9274D30Ab9A15\n", + }, + { + name: "print", + out: fmt.Sprint(addr), + want: "0xB26f2b342AAb24BCF63ea218c6A9274D30Ab9A15", + }, + { + name: "printf-s", + out: func() string { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "%s", addr) + return buf.String() + }(), + want: "0xB26f2b342AAb24BCF63ea218c6A9274D30Ab9A15", + }, + { + name: "printf-q", + out: fmt.Sprintf("%q", addr), + want: `"0xB26f2b342AAb24BCF63ea218c6A9274D30Ab9A15"`, + }, + { + name: "printf-x", + out: fmt.Sprintf("%x", addr), + want: "b26f2b342aab24bcf63ea218c6a9274d30ab9a15", + }, + { + name: "printf-X", + out: fmt.Sprintf("%X", addr), + want: "B26F2B342AAB24BCF63EA218C6A9274D30AB9A15", + }, + { + name: "printf-#x", + out: fmt.Sprintf("%#x", addr), + want: "0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15", + }, + { + name: "printf-v", + out: fmt.Sprintf("%v", addr), + want: "0xB26f2b342AAb24BCF63ea218c6A9274D30Ab9A15", + }, + // The original default formatter for byte slice + { + name: "printf-d", + out: fmt.Sprintf("%d", addr), + want: "[178 111 43 52 42 171 36 188 246 62 162 24 198 169 39 77 48 171 154 21]", + }, + // Invalid format char. + { + name: "printf-t", + out: fmt.Sprintf("%t", addr), + want: "%!t(address=b26f2b342aab24bcf63ea218c6a9274d30ab9a15)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.out != tt.want { + t.Errorf("%s does not render as expected:\n got %s\nwant %s", tt.name, tt.out, tt.want) + } + }) + } +} + +func TestHash_Format(t *testing.T) { + var hash Hash + hash.SetBytes([]byte{ + 0xb2, 0x6f, 0x2b, 0x34, 0x2a, 0xab, 0x24, 0xbc, 0xf6, 0x3e, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0xa2, 0x18, 0xc6, 0xa9, 0x27, 0x4d, 0x30, 0xab, 0x9a, 0x15, + 0x10, 0x00, + }) + + tests := []struct { + name string + out string + want string + }{ + { + name: "println", + out: fmt.Sprintln(hash), + want: "0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000\n", + }, + { + name: "print", + out: fmt.Sprint(hash), + want: "0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000", + }, + { + name: "printf-s", + out: func() string { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "%s", hash) + return buf.String() + }(), + want: "0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000", + }, + { + name: "printf-q", + out: fmt.Sprintf("%q", hash), + want: `"0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000"`, + }, + { + name: "printf-x", + out: fmt.Sprintf("%x", hash), + want: "b26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000", + }, + { + name: "printf-X", + out: fmt.Sprintf("%X", hash), + want: "B26F2B342AAB24BCF63EA218C6A9274D30AB9A15A218C6A9274D30AB9A151000", + }, + { + name: "printf-#x", + out: fmt.Sprintf("%#x", hash), + want: "0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000", + }, + { + name: "printf-#X", + out: fmt.Sprintf("%#X", hash), + want: "0XB26F2B342AAB24BCF63EA218C6A9274D30AB9A15A218C6A9274D30AB9A151000", + }, + { + name: "printf-v", + out: fmt.Sprintf("%v", hash), + want: "0xb26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000", + }, + // The original default formatter for byte slice + { + name: "printf-d", + out: fmt.Sprintf("%d", hash), + want: "[178 111 43 52 42 171 36 188 246 62 162 24 198 169 39 77 48 171 154 21 162 24 198 169 39 77 48 171 154 21 16 0]", + }, + // Invalid format char. + { + name: "printf-t", + out: fmt.Sprintf("%t", hash), + want: "%!t(hash=b26f2b342aab24bcf63ea218c6a9274d30ab9a15a218c6a9274d30ab9a151000)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.out != tt.want { + t.Errorf("%s does not render as expected:\n got %s\nwant %s", tt.name, tt.out, tt.want) + } + }) + } +} diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index 02f2451133e4..6c667804d8df 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -520,7 +520,7 @@ func (c *Clique) Prepare(chain consensus.ChainHeaderReader, header *types.Header c.lock.RUnlock() } // Set the correct difficulty - header.Difficulty = CalcDifficulty(snap, c.signer) + header.Difficulty = calcDifficulty(snap, c.signer) // Ensure the extra data has all its components if len(header.Extra) < extraVanity { @@ -561,9 +561,8 @@ func (c *Clique) Finalize(chain consensus.ChainHeaderReader, header *types.Heade // FinalizeAndAssemble implements consensus.Engine, ensuring no uncles are set, // nor block rewards given, and returns the final block. func (c *Clique) FinalizeAndAssemble(chain consensus.ChainHeaderReader, header *types.Header, state *state.StateDB, txs []*types.Transaction, uncles []*types.Header, receipts []*types.Receipt) (*types.Block, error) { - // No block rewards in PoA, so the state remains as is and uncles are dropped - header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number)) - header.UncleHash = types.CalcUncleHash(nil) + // Finalize block + c.Finalize(chain, header, state, txs, uncles) // Assemble and return the final block for sealing return types.NewBlock(header, txs, nil, receipts, new(trie.Trie)), nil @@ -652,20 +651,18 @@ func (c *Clique) Seal(chain consensus.ChainHeaderReader, block *types.Block, res } // CalcDifficulty is the difficulty adjustment algorithm. It returns the difficulty -// that a new block should have based on the previous blocks in the chain and the -// current signer. +// that a new block should have: +// * DIFF_NOTURN(2) if BLOCK_NUMBER % SIGNER_COUNT != SIGNER_INDEX +// * DIFF_INTURN(1) if BLOCK_NUMBER % SIGNER_COUNT == SIGNER_INDEX func (c *Clique) CalcDifficulty(chain consensus.ChainHeaderReader, time uint64, parent *types.Header) *big.Int { snap, err := c.snapshot(chain, parent.Number.Uint64(), parent.Hash(), nil) if err != nil { return nil } - return CalcDifficulty(snap, c.signer) + return calcDifficulty(snap, c.signer) } -// CalcDifficulty is the difficulty adjustment algorithm. It returns the difficulty -// that a new block should have based on the previous blocks in the chain and the -// current signer. -func CalcDifficulty(snap *Snapshot, signer common.Address) *big.Int { +func calcDifficulty(snap *Snapshot, signer common.Address) *big.Int { if snap.inturn(snap.Number+1, signer) { return new(big.Int).Set(diffInTurn) } diff --git a/consensus/clique/snapshot_test.go b/consensus/clique/snapshot_test.go index 3890fc51dd58..039ba919bf8d 100644 --- a/consensus/clique/snapshot_test.go +++ b/consensus/clique/snapshot_test.go @@ -423,7 +423,7 @@ func TestClique(t *testing.T) { }) // Iterate through the blocks and seal them individually for j, block := range blocks { - // Geth the header and prepare it for signing + // Get the header and prepare it for signing header := block.Header() if j > 0 { header.ParentHash = blocks[j-1].Hash() diff --git a/consensus/ethash/algorithm.go b/consensus/ethash/algorithm.go index d6c871092ed3..80379597e241 100644 --- a/consensus/ethash/algorithm.go +++ b/consensus/ethash/algorithm.go @@ -151,10 +151,12 @@ func generateCache(dest []uint32, epoch uint64, seed []byte) { logFn("Generated ethash verification cache", "elapsed", common.PrettyDuration(elapsed)) }() // Convert our destination slice to a byte buffer - header := *(*reflect.SliceHeader)(unsafe.Pointer(&dest)) - header.Len *= 4 - header.Cap *= 4 - cache := *(*[]byte)(unsafe.Pointer(&header)) + var cache []byte + cacheHdr := (*reflect.SliceHeader)(unsafe.Pointer(&cache)) + dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dest)) + cacheHdr.Data = dstHdr.Data + cacheHdr.Len = dstHdr.Len * 4 + cacheHdr.Cap = dstHdr.Cap * 4 // Calculate the number of theoretical rows (we'll store in one buffer nonetheless) size := uint64(len(cache)) @@ -283,10 +285,12 @@ func generateDataset(dest []uint32, epoch uint64, cache []uint32) { swapped := !isLittleEndian() // Convert our destination slice to a byte buffer - header := *(*reflect.SliceHeader)(unsafe.Pointer(&dest)) - header.Len *= 4 - header.Cap *= 4 - dataset := *(*[]byte)(unsafe.Pointer(&header)) + var dataset []byte + datasetHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dataset)) + destHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dest)) + datasetHdr.Data = destHdr.Data + datasetHdr.Len = destHdr.Len * 4 + datasetHdr.Cap = destHdr.Cap * 4 // Generate the dataset on many goroutines since it takes a while threads := runtime.NumCPU() @@ -295,7 +299,7 @@ func generateDataset(dest []uint32, epoch uint64, cache []uint32) { var pend sync.WaitGroup pend.Add(threads) - var progress uint32 + var progress uint64 for i := 0; i < threads; i++ { go func(id int) { defer pend.Done() @@ -304,23 +308,23 @@ func generateDataset(dest []uint32, epoch uint64, cache []uint32) { keccak512 := makeHasher(sha3.NewLegacyKeccak512()) // Calculate the data segment this thread should generate - batch := uint32((size + hashBytes*uint64(threads) - 1) / (hashBytes * uint64(threads))) - first := uint32(id) * batch + batch := (size + hashBytes*uint64(threads) - 1) / (hashBytes * uint64(threads)) + first := uint64(id) * batch limit := first + batch - if limit > uint32(size/hashBytes) { - limit = uint32(size / hashBytes) + if limit > size/hashBytes { + limit = size / hashBytes } // Calculate the dataset segment - percent := uint32(size / hashBytes / 100) + percent := size / hashBytes / 100 for index := first; index < limit; index++ { - item := generateDatasetItem(cache, index, keccak512) + item := generateDatasetItem(cache, uint32(index), keccak512) if swapped { swap(item) } copy(dataset[index*hashBytes:], item) - if status := atomic.AddUint32(&progress, 1); status%percent == 0 { - logger.Info("Generating DAG in progress", "percentage", uint64(status*100)/(size/hashBytes), "elapsed", common.PrettyDuration(time.Since(start))) + if status := atomic.AddUint64(&progress, 1); status%percent == 0 { + logger.Info("Generating DAG in progress", "percentage", (status*100)/(size/hashBytes), "elapsed", common.PrettyDuration(time.Since(start))) } } }(i) diff --git a/consensus/ethash/consensus.go b/consensus/ethash/consensus.go index bdc02098afdf..ae0905ee3ae0 100644 --- a/consensus/ethash/consensus.go +++ b/consensus/ethash/consensus.go @@ -485,6 +485,11 @@ func calcDifficultyFrontier(time uint64, parent *types.Header) *big.Int { return diff } +// Exported for fuzzing +var FrontierDifficultyCalulator = calcDifficultyFrontier +var HomesteadDifficultyCalulator = calcDifficultyHomestead +var DynamicDifficultyCalculator = makeDifficultyCalculator + // VerifySeal implements consensus.Engine, checking whether the given block satisfies // the PoW difficulty requirements. func (ethash *Ethash) VerifySeal(chain consensus.ChainHeaderReader, header *types.Header) error { @@ -579,9 +584,8 @@ func (ethash *Ethash) Finalize(chain consensus.ChainHeaderReader, header *types. // FinalizeAndAssemble implements consensus.Engine, accumulating the block and // uncle rewards, setting the final state and assembling the block. func (ethash *Ethash) FinalizeAndAssemble(chain consensus.ChainHeaderReader, header *types.Header, state *state.StateDB, txs []*types.Transaction, uncles []*types.Header, receipts []*types.Receipt) (*types.Block, error) { - // Accumulate any block and uncle rewards and commit the final state root - accumulateRewards(chain.Config(), state, header, uncles) - header.Root = state.IntermediateRoot(chain.Config().IsEIP158(header.Number)) + // Finalize block + ethash.Finalize(chain, header, state, txs, uncles) // Header seems complete, assemble into a block and return return types.NewBlock(header, txs, uncles, receipts, new(trie.Trie)), nil diff --git a/consensus/ethash/consensus_test.go b/consensus/ethash/consensus_test.go index 675737d9e1aa..6f6dc79fd86b 100644 --- a/consensus/ethash/consensus_test.go +++ b/consensus/ethash/consensus_test.go @@ -17,12 +17,15 @@ package ethash import ( + "encoding/binary" "encoding/json" "math/big" + "math/rand" "os" "path/filepath" "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/params" @@ -84,3 +87,102 @@ func TestCalcDifficulty(t *testing.T) { } } } + +func randSlice(min, max uint32) []byte { + var b = make([]byte, 4) + rand.Read(b) + a := binary.LittleEndian.Uint32(b) + size := min + a%(max-min) + out := make([]byte, size) + rand.Read(out) + return out +} + +func TestDifficultyCalculators(t *testing.T) { + rand.Seed(2) + for i := 0; i < 5000; i++ { + // 1 to 300 seconds diff + var timeDelta = uint64(1 + rand.Uint32()%3000) + diffBig := big.NewInt(0).SetBytes(randSlice(2, 10)) + if diffBig.Cmp(params.MinimumDifficulty) < 0 { + diffBig.Set(params.MinimumDifficulty) + } + //rand.Read(difficulty) + header := &types.Header{ + Difficulty: diffBig, + Number: new(big.Int).SetUint64(rand.Uint64() % 50_000_000), + Time: rand.Uint64() - timeDelta, + } + if rand.Uint32()&1 == 0 { + header.UncleHash = types.EmptyUncleHash + } + bombDelay := new(big.Int).SetUint64(rand.Uint64() % 50_000_000) + for i, pair := range []struct { + bigFn func(time uint64, parent *types.Header) *big.Int + u256Fn func(time uint64, parent *types.Header) *big.Int + }{ + {FrontierDifficultyCalulator, CalcDifficultyFrontierU256}, + {HomesteadDifficultyCalulator, CalcDifficultyHomesteadU256}, + {DynamicDifficultyCalculator(bombDelay), MakeDifficultyCalculatorU256(bombDelay)}, + } { + time := header.Time + timeDelta + want := pair.bigFn(time, header) + have := pair.u256Fn(time, header) + if want.BitLen() > 256 { + continue + } + if want.Cmp(have) != 0 { + t.Fatalf("pair %d: want %x have %x\nparent.Number: %x\np.Time: %x\nc.Time: %x\nBombdelay: %v\n", i, want, have, + header.Number, header.Time, time, bombDelay) + } + } + } +} + +func BenchmarkDifficultyCalculator(b *testing.B) { + x1 := makeDifficultyCalculator(big.NewInt(1000000)) + x2 := MakeDifficultyCalculatorU256(big.NewInt(1000000)) + h := &types.Header{ + ParentHash: common.Hash{}, + UncleHash: types.EmptyUncleHash, + Difficulty: big.NewInt(0xffffff), + Number: big.NewInt(500000), + Time: 1000000, + } + b.Run("big-frontier", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + calcDifficultyFrontier(1000014, h) + } + }) + b.Run("u256-frontier", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + CalcDifficultyFrontierU256(1000014, h) + } + }) + b.Run("big-homestead", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + calcDifficultyHomestead(1000014, h) + } + }) + b.Run("u256-homestead", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + CalcDifficultyHomesteadU256(1000014, h) + } + }) + b.Run("big-generic", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + x1(1000014, h) + } + }) + b.Run("u256-generic", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + x2(1000014, h) + } + }) +} diff --git a/consensus/ethash/difficulty.go b/consensus/ethash/difficulty.go new file mode 100644 index 000000000000..59c4ac74191f --- /dev/null +++ b/consensus/ethash/difficulty.go @@ -0,0 +1,193 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package ethash + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/holiman/uint256" +) + +const ( + // frontierDurationLimit is for Frontier: + // The decision boundary on the blocktime duration used to determine + // whether difficulty should go up or down. + frontierDurationLimit = 13 + // minimumDifficulty The minimum that the difficulty may ever be. + minimumDifficulty = 131072 + // expDiffPeriod is the exponential difficulty period + expDiffPeriodUint = 100000 + // difficultyBoundDivisorBitShift is the bound divisor of the difficulty (2048), + // This constant is the right-shifts to use for the division. + difficultyBoundDivisor = 11 +) + +// CalcDifficultyFrontierU256 is the difficulty adjustment algorithm. It returns the +// difficulty that a new block should have when created at time given the parent +// block's time and difficulty. The calculation uses the Frontier rules. +func CalcDifficultyFrontierU256(time uint64, parent *types.Header) *big.Int { + /* + Algorithm + block_diff = pdiff + pdiff / 2048 * (1 if time - ptime < 13 else -1) + int(2^((num // 100000) - 2)) + + Where: + - pdiff = parent.difficulty + - ptime = parent.time + - time = block.timestamp + - num = block.number + */ + + pDiff := uint256.NewInt() + pDiff.SetFromBig(parent.Difficulty) // pDiff: pdiff + adjust := pDiff.Clone() + adjust.Rsh(adjust, difficultyBoundDivisor) // adjust: pDiff / 2048 + + if time-parent.Time < frontierDurationLimit { + pDiff.Add(pDiff, adjust) + } else { + pDiff.Sub(pDiff, adjust) + } + if pDiff.LtUint64(minimumDifficulty) { + pDiff.SetUint64(minimumDifficulty) + } + // 'pdiff' now contains: + // pdiff + pdiff / 2048 * (1 if time - ptime < 13 else -1) + + if periodCount := (parent.Number.Uint64() + 1) / expDiffPeriodUint; periodCount > 1 { + // diff = diff + 2^(periodCount - 2) + expDiff := adjust.SetOne() + expDiff.Lsh(expDiff, uint(periodCount-2)) // expdiff: 2 ^ (periodCount -2) + pDiff.Add(pDiff, expDiff) + } + return pDiff.ToBig() +} + +// CalcDifficultyHomesteadU256 is the difficulty adjustment algorithm. It returns +// the difficulty that a new block should have when created at time given the +// parent block's time and difficulty. The calculation uses the Homestead rules. +func CalcDifficultyHomesteadU256(time uint64, parent *types.Header) *big.Int { + /* + https://github.com/ethereum/EIPs/blob/master/EIPS/eip-2.md + Algorithm: + block_diff = pdiff + pdiff / 2048 * max(1 - (time - ptime) / 10, -99) + 2 ^ int((num / 100000) - 2)) + + Our modification, to use unsigned ints: + block_diff = pdiff - pdiff / 2048 * max((time - ptime) / 10 - 1, 99) + 2 ^ int((num / 100000) - 2)) + + Where: + - pdiff = parent.difficulty + - ptime = parent.time + - time = block.timestamp + - num = block.number + */ + + pDiff := uint256.NewInt() + pDiff.SetFromBig(parent.Difficulty) // pDiff: pdiff + adjust := pDiff.Clone() + adjust.Rsh(adjust, difficultyBoundDivisor) // adjust: pDiff / 2048 + + x := (time - parent.Time) / 10 // (time - ptime) / 10) + var neg = true + if x == 0 { + x = 1 + neg = false + } else if x >= 100 { + x = 99 + } else { + x = x - 1 + } + z := new(uint256.Int).SetUint64(x) + adjust.Mul(adjust, z) // adjust: (pdiff / 2048) * max((time - ptime) / 10 - 1, 99) + if neg { + pDiff.Sub(pDiff, adjust) // pdiff - pdiff / 2048 * max((time - ptime) / 10 - 1, 99) + } else { + pDiff.Add(pDiff, adjust) // pdiff + pdiff / 2048 * max((time - ptime) / 10 - 1, 99) + } + if pDiff.LtUint64(minimumDifficulty) { + pDiff.SetUint64(minimumDifficulty) + } + // for the exponential factor, a.k.a "the bomb" + // diff = diff + 2^(periodCount - 2) + if periodCount := (1 + parent.Number.Uint64()) / expDiffPeriodUint; periodCount > 1 { + expFactor := adjust.Lsh(adjust.SetOne(), uint(periodCount-2)) + pDiff.Add(pDiff, expFactor) + } + return pDiff.ToBig() +} + +// MakeDifficultyCalculatorU256 creates a difficultyCalculator with the given bomb-delay. +// the difficulty is calculated with Byzantium rules, which differs from Homestead in +// how uncles affect the calculation +func MakeDifficultyCalculatorU256(bombDelay *big.Int) func(time uint64, parent *types.Header) *big.Int { + // Note, the calculations below looks at the parent number, which is 1 below + // the block number. Thus we remove one from the delay given + bombDelayFromParent := bombDelay.Uint64() - 1 + return func(time uint64, parent *types.Header) *big.Int { + /* + https://github.com/ethereum/EIPs/issues/100 + pDiff = parent.difficulty + BLOCK_DIFF_FACTOR = 9 + a = pDiff + (pDiff // BLOCK_DIFF_FACTOR) * adj_factor + b = min(parent.difficulty, MIN_DIFF) + child_diff = max(a,b ) + */ + x := (time - parent.Time) / 9 // (block_timestamp - parent_timestamp) // 9 + c := uint64(1) // if parent.unclehash == emptyUncleHashHash + if parent.UncleHash != types.EmptyUncleHash { + c = 2 + } + xNeg := x >= c + if xNeg { + // x is now _negative_ adjustment factor + x = x - c // - ( (t-p)/p -( 2 or 1) ) + } else { + x = c - x // (2 or 1) - (t-p)/9 + } + if x > 99 { + x = 99 // max(x, 99) + } + // parent_diff + (parent_diff / 2048 * max((2 if len(parent.uncles) else 1) - ((timestamp - parent.timestamp) // 9), -99)) + y := new(uint256.Int) + y.SetFromBig(parent.Difficulty) // y: p_diff + pDiff := y.Clone() // pdiff: p_diff + z := new(uint256.Int).SetUint64(x) //z : +-adj_factor (either pos or negative) + y.Rsh(y, difficultyBoundDivisor) // y: p__diff / 2048 + z.Mul(y, z) // z: (p_diff / 2048 ) * (+- adj_factor) + + if xNeg { + y.Sub(pDiff, z) // y: parent_diff + parent_diff/2048 * adjustment_factor + } else { + y.Add(pDiff, z) // y: parent_diff + parent_diff/2048 * adjustment_factor + } + // minimum difficulty can ever be (before exponential factor) + if y.LtUint64(minimumDifficulty) { + y.SetUint64(minimumDifficulty) + } + // calculate a fake block number for the ice-age delay + // Specification: https://eips.ethereum.org/EIPS/eip-1234 + var pNum = parent.Number.Uint64() + if pNum >= bombDelayFromParent { + if fakeBlockNumber := pNum - bombDelayFromParent; fakeBlockNumber >= 2*expDiffPeriodUint { + z.SetOne() + z.Lsh(z, uint(fakeBlockNumber/expDiffPeriodUint-2)) + y.Add(z, y) + } + } + return y.ToBig() + } +} diff --git a/consensus/ethash/ethash.go b/consensus/ethash/ethash.go index aa3f002c0dac..550d99893d89 100644 --- a/consensus/ethash/ethash.go +++ b/consensus/ethash/ethash.go @@ -33,7 +33,7 @@ import ( "time" "unsafe" - mmap "github.com/edsrzf/mmap-go" + "github.com/edsrzf/mmap-go" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" diff --git a/console/bridge.go b/console/bridge.go index 9303496b287a..21ef0e8e7b90 100644 --- a/console/bridge.go +++ b/console/bridge.go @@ -144,15 +144,14 @@ func (b *bridge) OpenWallet(call jsre.Call) (goja.Value, error) { if val, err = openWallet(goja.Null(), wallet, passwd); err != nil { if !strings.HasSuffix(err.Error(), scwallet.ErrPINNeeded.Error()) { return nil, err - } else { - // PIN input requested, fetch from the user and call open again - input, err := b.prompter.PromptPassword("Please enter current PIN: ") - if err != nil { - return nil, err - } - if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(input)); err != nil { - return nil, err - } + } + // PIN input requested, fetch from the user and call open again + input, err := b.prompter.PromptPassword("Please enter current PIN: ") + if err != nil { + return nil, err + } + if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(input)); err != nil { + return nil, err } } @@ -353,14 +352,14 @@ func (b *bridge) SleepBlocks(call jsre.Call) (goja.Value, error) { } // Poll the current block number until either it or a timeout is reached. - var ( - deadline = time.Now().Add(time.Duration(sleep) * time.Second) - lastNumber = ^hexutil.Uint64(0) - ) + deadline := time.Now().Add(time.Duration(sleep) * time.Second) + var lastNumber hexutil.Uint64 + if err := b.client.Call(&lastNumber, "eth_blockNumber"); err != nil { + return nil, err + } for time.Now().Before(deadline) { var number hexutil.Uint64 - err := b.client.Call(&number, "eth_blockNumber") - if err != nil { + if err := b.client.Call(&number, "eth_blockNumber"); err != nil { return nil, err } if number != lastNumber { diff --git a/console/console.go b/console/console.go index 1dcad3065e6e..ae9f28da0486 100644 --- a/console/console.go +++ b/console/console.go @@ -324,6 +324,7 @@ func (c *Console) Welcome() { sort.Strings(modules) message += " modules: " + strings.Join(modules, " ") + "\n" } + message += "\nTo exit, press ctrl-d" fmt.Fprintln(c.printer, message) } @@ -372,7 +373,7 @@ func (c *Console) Interactive() { return case err := <-inputErr: - if err == liner.ErrPromptAborted && indents > 0 { + if err == liner.ErrPromptAborted { // When prompting for multi-line input, the first Ctrl-C resets // the multi-line state. prompt, indents, input = c.prompt, 0, "" diff --git a/contracts/checkpointoracle/contract/oracle.go b/contracts/checkpointoracle/contract/oracle.go index 998ccb93c264..a4a308f5c524 100644 --- a/contracts/checkpointoracle/contract/oracle.go +++ b/contracts/checkpointoracle/contract/oracle.go @@ -27,10 +27,17 @@ var ( ) // CheckpointOracleABI is the input ABI used to generate the binding from. -const CheckpointOracleABI = "[{\"constant\":true,\"inputs\":[],\"name\":\"GetAllAdmin\",\"outputs\":[{\"name\":\"\",\"type\":\"address[]\"}],\"payable\":false,\"stateMutability\":\"view\",\"type\":\"function\"},{\"constant\":true,\"inputs\":[],\"name\":\"GetLatestCheckpoint\",\"outputs\":[{\"name\":\"\",\"type\":\"uint64\"},{\"name\":\"\",\"type\":\"bytes32\"},{\"name\":\"\",\"type\":\"uint256\"}],\"payable\":false,\"stateMutability\":\"view\",\"type\":\"function\"},{\"constant\":false,\"inputs\":[{\"name\":\"_recentNumber\",\"type\":\"uint256\"},{\"name\":\"_recentHash\",\"type\":\"bytes32\"},{\"name\":\"_hash\",\"type\":\"bytes32\"},{\"name\":\"_sectionIndex\",\"type\":\"uint64\"},{\"name\":\"v\",\"type\":\"uint8[]\"},{\"name\":\"r\",\"type\":\"bytes32[]\"},{\"name\":\"s\",\"type\":\"bytes32[]\"}],\"name\":\"SetCheckpoint\",\"outputs\":[{\"name\":\"\",\"type\":\"bool\"}],\"payable\":false,\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"name\":\"_adminlist\",\"type\":\"address[]\"},{\"name\":\"_sectionSize\",\"type\":\"uint256\"},{\"name\":\"_processConfirms\",\"type\":\"uint256\"},{\"name\":\"_threshold\",\"type\":\"uint256\"}],\"payable\":false,\"stateMutability\":\"nonpayable\",\"type\":\"constructor\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"name\":\"index\",\"type\":\"uint64\"},{\"indexed\":false,\"name\":\"checkpointHash\",\"type\":\"bytes32\"},{\"indexed\":false,\"name\":\"v\",\"type\":\"uint8\"},{\"indexed\":false,\"name\":\"r\",\"type\":\"bytes32\"},{\"indexed\":false,\"name\":\"s\",\"type\":\"bytes32\"}],\"name\":\"NewCheckpointVote\",\"type\":\"event\"}]" +const CheckpointOracleABI = "[{\"inputs\":[{\"internalType\":\"address[]\",\"name\":\"_adminlist\",\"type\":\"address[]\"},{\"internalType\":\"uint256\",\"name\":\"_sectionSize\",\"type\":\"uint256\"},{\"internalType\":\"uint256\",\"name\":\"_processConfirms\",\"type\":\"uint256\"},{\"internalType\":\"uint256\",\"name\":\"_threshold\",\"type\":\"uint256\"}],\"stateMutability\":\"nonpayable\",\"type\":\"constructor\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"uint64\",\"name\":\"index\",\"type\":\"uint64\"},{\"indexed\":false,\"internalType\":\"bytes32\",\"name\":\"checkpointHash\",\"type\":\"bytes32\"},{\"indexed\":false,\"internalType\":\"uint8\",\"name\":\"v\",\"type\":\"uint8\"},{\"indexed\":false,\"internalType\":\"bytes32\",\"name\":\"r\",\"type\":\"bytes32\"},{\"indexed\":false,\"internalType\":\"bytes32\",\"name\":\"s\",\"type\":\"bytes32\"}],\"name\":\"NewCheckpointVote\",\"type\":\"event\"},{\"inputs\":[],\"name\":\"GetAllAdmin\",\"outputs\":[{\"internalType\":\"address[]\",\"name\":\"\",\"type\":\"address[]\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"GetLatestCheckpoint\",\"outputs\":[{\"internalType\":\"uint64\",\"name\":\"\",\"type\":\"uint64\"},{\"internalType\":\"bytes32\",\"name\":\"\",\"type\":\"bytes32\"},{\"internalType\":\"uint256\",\"name\":\"\",\"type\":\"uint256\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"uint256\",\"name\":\"_recentNumber\",\"type\":\"uint256\"},{\"internalType\":\"bytes32\",\"name\":\"_recentHash\",\"type\":\"bytes32\"},{\"internalType\":\"bytes32\",\"name\":\"_hash\",\"type\":\"bytes32\"},{\"internalType\":\"uint64\",\"name\":\"_sectionIndex\",\"type\":\"uint64\"},{\"internalType\":\"uint8[]\",\"name\":\"v\",\"type\":\"uint8[]\"},{\"internalType\":\"bytes32[]\",\"name\":\"r\",\"type\":\"bytes32[]\"},{\"internalType\":\"bytes32[]\",\"name\":\"s\",\"type\":\"bytes32[]\"}],\"name\":\"SetCheckpoint\",\"outputs\":[{\"internalType\":\"bool\",\"name\":\"\",\"type\":\"bool\"}],\"stateMutability\":\"nonpayable\",\"type\":\"function\"}]" + +// CheckpointOracleFuncSigs maps the 4-byte function signature to its string representation. +var CheckpointOracleFuncSigs = map[string]string{ + "45848dfc": "GetAllAdmin()", + "4d6a304c": "GetLatestCheckpoint()", + "d459fc46": "SetCheckpoint(uint256,bytes32,bytes32,uint64,uint8[],bytes32[],bytes32[])", +} // CheckpointOracleBin is the compiled bytecode used for deploying new contracts. -const CheckpointOracleBin = `0x608060405234801561001057600080fd5b506040516108153803806108158339818101604052608081101561003357600080fd5b81019080805164010000000081111561004b57600080fd5b8201602081018481111561005e57600080fd5b815185602082028301116401000000008211171561007b57600080fd5b505060208201516040830151606090930151919450925060005b84518110156101415760016000808784815181106100af57fe5b60200260200101516001600160a01b03166001600160a01b0316815260200190815260200160002060006101000a81548160ff02191690831515021790555060018582815181106100fc57fe5b60209081029190910181015182546001808201855560009485529290932090920180546001600160a01b0319166001600160a01b039093169290921790915501610095565b50600592909255600655600755506106b78061015e6000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c806345848dfc146100465780634d6a304c1461009e578063d459fc46146100cf575b600080fd5b61004e6102b0565b60408051602080825283518183015283519192839290830191858101910280838360005b8381101561008a578181015183820152602001610072565b505050509050019250505060405180910390f35b6100a661034f565b6040805167ffffffffffffffff9094168452602084019290925282820152519081900360600190f35b61029c600480360360e08110156100e557600080fd5b81359160208101359160408201359167ffffffffffffffff6060820135169181019060a08101608082013564010000000081111561012257600080fd5b82018360208201111561013457600080fd5b8035906020019184602083028401116401000000008311171561015657600080fd5b91908080602002602001604051908101604052809392919081815260200183836020028082843760009201919091525092959493602081019350359150506401000000008111156101a657600080fd5b8201836020820111156101b857600080fd5b803590602001918460208302840111640100000000831117156101da57600080fd5b919080806020026020016040519081016040528093929190818152602001838360200280828437600092019190915250929594936020810193503591505064010000000081111561022a57600080fd5b82018360208201111561023c57600080fd5b8035906020019184602083028401116401000000008311171561025e57600080fd5b91908080602002602001604051908101604052809392919081815260200183836020028082843760009201919091525092955061036a945050505050565b604080519115158252519081900360200190f35b6060806001805490506040519080825280602002602001820160405280156102e2578160200160208202803883390190505b50905060005b60015481101561034957600181815481106102ff57fe5b9060005260206000200160009054906101000a90046001600160a01b031682828151811061032957fe5b6001600160a01b03909216602092830291909101909101526001016102e8565b50905090565b60025460045460035467ffffffffffffffff90921691909192565b3360009081526020819052604081205460ff1661038657600080fd5b8688401461039357600080fd5b82518451146103a157600080fd5b81518451146103af57600080fd5b6006546005548660010167ffffffffffffffff1602014310156103d457506000610677565b60025467ffffffffffffffff90811690861610156103f457506000610677565b60025467ffffffffffffffff8681169116148015610426575067ffffffffffffffff8516151580610426575060035415155b1561043357506000610677565b8561044057506000610677565b60408051601960f81b6020808301919091526000602183018190523060601b60228401526001600160c01b031960c08a901b166036840152603e8084018b905284518085039091018152605e909301909352815191012090805b86518110156106715760006001848984815181106104b457fe5b60200260200101518985815181106104c857fe5b60200260200101518986815181106104dc57fe5b602002602001015160405160008152602001604052604051808581526020018460ff1660ff1681526020018381526020018281526020019450505050506020604051602081039080840390855afa15801561053b573d6000803e3d6000fd5b505060408051601f1901516001600160a01b03811660009081526020819052919091205490925060ff16905061057057600080fd5b826001600160a01b0316816001600160a01b03161161058e57600080fd5b8092508867ffffffffffffffff167fce51ffa16246bcaf0899f6504f473cd0114f430f566cef71ab7e03d3dde42a418b8a85815181106105ca57fe5b60200260200101518a86815181106105de57fe5b60200260200101518a87815181106105f257fe5b6020026020010151604051808581526020018460ff1660ff16815260200183815260200182815260200194505050505060405180910390a260075482600101106106685750505060048790555050436003556002805467ffffffffffffffff191667ffffffffffffffff86161790556001610677565b5060010161049a565b50600080fd5b97965050505050505056fea265627a7a723058207f6a191ce575596a2f1e907c8c0a01003d16b69fb2c4f432d10878e8c0a99a0264736f6c634300050a0032` +var CheckpointOracleBin = "0x608060405234801561001057600080fd5b506040516108703803806108708339818101604052608081101561003357600080fd5b810190808051604051939291908464010000000082111561005357600080fd5b90830190602082018581111561006857600080fd5b825186602082028301116401000000008211171561008557600080fd5b82525081516020918201928201910280838360005b838110156100b257818101518382015260200161009a565b50505050919091016040908152602083015190830151606090930151909450919250600090505b84518110156101855760016000808784815181106100f357fe5b60200260200101516001600160a01b03166001600160a01b0316815260200190815260200160002060006101000a81548160ff021916908315150217905550600185828151811061014057fe5b60209081029190910181015182546001808201855560009485529290932090920180546001600160a01b0319166001600160a01b0390931692909217909155016100d9565b50600592909255600655600755506106ce806101a26000396000f3fe608060405234801561001057600080fd5b50600436106100415760003560e01c806345848dfc146100465780634d6a304c1461009e578063d459fc46146100cf575b600080fd5b61004e6102b0565b60408051602080825283518183015283519192839290830191858101910280838360005b8381101561008a578181015183820152602001610072565b505050509050019250505060405180910390f35b6100a6610365565b6040805167ffffffffffffffff9094168452602084019290925282820152519081900360600190f35b61029c600480360360e08110156100e557600080fd5b81359160208101359160408201359167ffffffffffffffff6060820135169181019060a08101608082013564010000000081111561012257600080fd5b82018360208201111561013457600080fd5b8035906020019184602083028401116401000000008311171561015657600080fd5b91908080602002602001604051908101604052809392919081815260200183836020028082843760009201919091525092959493602081019350359150506401000000008111156101a657600080fd5b8201836020820111156101b857600080fd5b803590602001918460208302840111640100000000831117156101da57600080fd5b919080806020026020016040519081016040528093929190818152602001838360200280828437600092019190915250929594936020810193503591505064010000000081111561022a57600080fd5b82018360208201111561023c57600080fd5b8035906020019184602083028401116401000000008311171561025e57600080fd5b919080806020026020016040519081016040528093929190818152602001838360200280828437600092019190915250929550610380945050505050565b604080519115158252519081900360200190f35b600154606090819067ffffffffffffffff811180156102ce57600080fd5b506040519080825280602002602001820160405280156102f8578160200160208202803683370190505b50905060005b60015481101561035f576001818154811061031557fe5b9060005260206000200160009054906101000a90046001600160a01b031682828151811061033f57fe5b6001600160a01b03909216602092830291909101909101526001016102fe565b50905090565b60025460045460035467ffffffffffffffff90921691909192565b3360009081526020819052604081205460ff1661039c57600080fd5b868840146103a957600080fd5b82518451146103b757600080fd5b81518451146103c557600080fd5b6006546005548660010167ffffffffffffffff1602014310156103ea5750600061068d565b60025467ffffffffffffffff908116908616101561040a5750600061068d565b60025467ffffffffffffffff868116911614801561043c575067ffffffffffffffff851615158061043c575060035415155b156104495750600061068d565b856104565750600061068d565b60408051601960f81b6020808301919091526000602183018190523060601b60228401526001600160c01b031960c08a901b166036840152603e8084018b905284518085039091018152605e909301909352815191012090805b86518110156106875760006001848984815181106104ca57fe5b60200260200101518985815181106104de57fe5b60200260200101518986815181106104f257fe5b602002602001015160405160008152602001604052604051808581526020018460ff1660ff1681526020018381526020018281526020019450505050506020604051602081039080840390855afa158015610551573d6000803e3d6000fd5b505060408051601f1901516001600160a01b03811660009081526020819052919091205490925060ff16905061058657600080fd5b826001600160a01b0316816001600160a01b0316116105a457600080fd5b8092508867ffffffffffffffff167fce51ffa16246bcaf0899f6504f473cd0114f430f566cef71ab7e03d3dde42a418b8a85815181106105e057fe5b60200260200101518a86815181106105f457fe5b60200260200101518a878151811061060857fe5b6020026020010151604051808581526020018460ff1660ff16815260200183815260200182815260200194505050505060405180910390a2600754826001011061067e5750505060048790555050436003556002805467ffffffffffffffff191667ffffffffffffffff8616179055600161068d565b506001016104b0565b50600080fd5b97965050505050505056fea26469706673582212202ddf9eda76bf59c0fc65584c0b22d84ecef2c703765de60439596d6ac34c2b7264736f6c634300060b0033" // DeployCheckpointOracle deploys a new Ethereum contract, binding an instance of CheckpointOracle to it. func DeployCheckpointOracle(auth *bind.TransactOpts, backend bind.ContractBackend, _adminlist []common.Address, _sectionSize *big.Int, _processConfirms *big.Int, _threshold *big.Int) (common.Address, *types.Transaction, *CheckpointOracle, error) { @@ -38,6 +45,7 @@ func DeployCheckpointOracle(auth *bind.TransactOpts, backend bind.ContractBacken if err != nil { return common.Address{}, nil, nil, err } + address, tx, contract, err := bind.DeployContract(auth, parsed, common.FromHex(CheckpointOracleBin), backend, _adminlist, _sectionSize, _processConfirms, _threshold) if err != nil { return common.Address{}, nil, nil, err @@ -153,7 +161,7 @@ func bindCheckpointOracle(address common.Address, caller bind.ContractCaller, tr // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. -func (_CheckpointOracle *CheckpointOracleRaw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error { +func (_CheckpointOracle *CheckpointOracleRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { return _CheckpointOracle.Contract.CheckpointOracleCaller.contract.Call(opts, result, method, params...) } @@ -172,7 +180,7 @@ func (_CheckpointOracle *CheckpointOracleRaw) Transact(opts *bind.TransactOpts, // sets the output to result. The result type might be a single field for simple // returns, a slice of interfaces for anonymous returns and a struct for named // returns. -func (_CheckpointOracle *CheckpointOracleCallerRaw) Call(opts *bind.CallOpts, result interface{}, method string, params ...interface{}) error { +func (_CheckpointOracle *CheckpointOracleCallerRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { return _CheckpointOracle.Contract.contract.Call(opts, result, method, params...) } @@ -189,58 +197,64 @@ func (_CheckpointOracle *CheckpointOracleTransactorRaw) Transact(opts *bind.Tran // GetAllAdmin is a free data retrieval call binding the contract method 0x45848dfc. // -// Solidity: function GetAllAdmin() constant returns(address[]) +// Solidity: function GetAllAdmin() view returns(address[]) func (_CheckpointOracle *CheckpointOracleCaller) GetAllAdmin(opts *bind.CallOpts) ([]common.Address, error) { - var ( - ret0 = new([]common.Address) - ) - out := ret0 - err := _CheckpointOracle.contract.Call(opts, out, "GetAllAdmin") - return *ret0, err + var out []interface{} + err := _CheckpointOracle.contract.Call(opts, &out, "GetAllAdmin") + + if err != nil { + return *new([]common.Address), err + } + + out0 := *abi.ConvertType(out[0], new([]common.Address)).(*[]common.Address) + + return out0, err + } // GetAllAdmin is a free data retrieval call binding the contract method 0x45848dfc. // -// Solidity: function GetAllAdmin() constant returns(address[]) +// Solidity: function GetAllAdmin() view returns(address[]) func (_CheckpointOracle *CheckpointOracleSession) GetAllAdmin() ([]common.Address, error) { return _CheckpointOracle.Contract.GetAllAdmin(&_CheckpointOracle.CallOpts) } // GetAllAdmin is a free data retrieval call binding the contract method 0x45848dfc. // -// Solidity: function GetAllAdmin() constant returns(address[]) +// Solidity: function GetAllAdmin() view returns(address[]) func (_CheckpointOracle *CheckpointOracleCallerSession) GetAllAdmin() ([]common.Address, error) { return _CheckpointOracle.Contract.GetAllAdmin(&_CheckpointOracle.CallOpts) } // GetLatestCheckpoint is a free data retrieval call binding the contract method 0x4d6a304c. // -// Solidity: function GetLatestCheckpoint() constant returns(uint64, bytes32, uint256) +// Solidity: function GetLatestCheckpoint() view returns(uint64, bytes32, uint256) func (_CheckpointOracle *CheckpointOracleCaller) GetLatestCheckpoint(opts *bind.CallOpts) (uint64, [32]byte, *big.Int, error) { - var ( - ret0 = new(uint64) - ret1 = new([32]byte) - ret2 = new(*big.Int) - ) - out := &[]interface{}{ - ret0, - ret1, - ret2, + var out []interface{} + err := _CheckpointOracle.contract.Call(opts, &out, "GetLatestCheckpoint") + + if err != nil { + return *new(uint64), *new([32]byte), *new(*big.Int), err } - err := _CheckpointOracle.contract.Call(opts, out, "GetLatestCheckpoint") - return *ret0, *ret1, *ret2, err + + out0 := *abi.ConvertType(out[0], new(uint64)).(*uint64) + out1 := *abi.ConvertType(out[1], new([32]byte)).(*[32]byte) + out2 := *abi.ConvertType(out[2], new(*big.Int)).(**big.Int) + + return out0, out1, out2, err + } // GetLatestCheckpoint is a free data retrieval call binding the contract method 0x4d6a304c. // -// Solidity: function GetLatestCheckpoint() constant returns(uint64, bytes32, uint256) +// Solidity: function GetLatestCheckpoint() view returns(uint64, bytes32, uint256) func (_CheckpointOracle *CheckpointOracleSession) GetLatestCheckpoint() (uint64, [32]byte, *big.Int, error) { return _CheckpointOracle.Contract.GetLatestCheckpoint(&_CheckpointOracle.CallOpts) } // GetLatestCheckpoint is a free data retrieval call binding the contract method 0x4d6a304c. // -// Solidity: function GetLatestCheckpoint() constant returns(uint64, bytes32, uint256) +// Solidity: function GetLatestCheckpoint() view returns(uint64, bytes32, uint256) func (_CheckpointOracle *CheckpointOracleCallerSession) GetLatestCheckpoint() (uint64, [32]byte, *big.Int, error) { return _CheckpointOracle.Contract.GetLatestCheckpoint(&_CheckpointOracle.CallOpts) } diff --git a/contracts/checkpointoracle/contract/oracle.sol b/contracts/checkpointoracle/contract/oracle.sol index 010644727374..65bac09d28bf 100644 --- a/contracts/checkpointoracle/contract/oracle.sol +++ b/contracts/checkpointoracle/contract/oracle.sol @@ -1,4 +1,4 @@ -pragma solidity ^0.5.10; +pragma solidity ^0.6.0; /** * @title CheckpointOracle diff --git a/contracts/checkpointoracle/oracle.go b/contracts/checkpointoracle/oracle.go index 1f273272ab6c..7f3127d0b8b0 100644 --- a/contracts/checkpointoracle/oracle.go +++ b/contracts/checkpointoracle/oracle.go @@ -65,7 +65,7 @@ func (oracle *CheckpointOracle) LookupCheckpointEvents(blockLogs [][]*types.Log, if err != nil { continue } - if event.Index == section && common.Hash(event.CheckpointHash) == hash { + if event.Index == section && event.CheckpointHash == hash { votes = append(votes, event) } } diff --git a/contracts/checkpointoracle/oracle_test.go b/contracts/checkpointoracle/oracle_test.go index 817954d11abe..12184819297c 100644 --- a/contracts/checkpointoracle/oracle_test.go +++ b/contracts/checkpointoracle/oracle_test.go @@ -178,7 +178,7 @@ func TestCheckpointRegister(t *testing.T) { contractBackend := backends.NewSimulatedBackend(core.GenesisAlloc{accounts[0].addr: {Balance: big.NewInt(1000000000)}, accounts[1].addr: {Balance: big.NewInt(1000000000)}, accounts[2].addr: {Balance: big.NewInt(1000000000)}}, 10000000) defer contractBackend.Close() - transactOpts := bind.NewKeyedTransactor(accounts[0].key) + transactOpts, _ := bind.NewKeyedTransactorWithChainID(accounts[0].key, big.NewInt(1337)) // 3 trusted signers, threshold 2 contractAddr, _, c, err := contract.DeployCheckpointOracle(transactOpts, contractBackend, []common.Address{accounts[0].addr, accounts[1].addr, accounts[2].addr}, sectionSize, processConfirms, big.NewInt(2)) diff --git a/core/block_validator.go b/core/block_validator.go index 831209393552..8dbd0f75520e 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -62,7 +62,7 @@ func (v *BlockValidator) ValidateBody(block *types.Block) error { if hash := types.CalcUncleHash(block.Uncles()); hash != header.UncleHash { return fmt.Errorf("uncle root hash mismatch: have %x, want %x", hash, header.UncleHash) } - if hash := types.DeriveSha(block.Transactions(), new(trie.Trie)); hash != header.TxHash { + if hash := types.DeriveSha(block.Transactions(), trie.NewStackTrie(nil)); hash != header.TxHash { return fmt.Errorf("transaction root hash mismatch: have %x, want %x", hash, header.TxHash) } if !v.bc.HasBlockAndState(block.ParentHash(), block.NumberU64()-1) { @@ -90,7 +90,7 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD return fmt.Errorf("invalid bloom (remote: %x local: %x)", header.Bloom, rbloom) } // Tre receipt Trie's root (R = (Tr [[H1, R1], ... [Hn, Rn]])) - receiptSha := types.DeriveSha(receipts, new(trie.Trie)) + receiptSha := types.DeriveSha(receipts, trie.NewStackTrie(nil)) if receiptSha != header.ReceiptHash { return fmt.Errorf("invalid receipt root hash (remote: %x local: %x)", header.ReceiptHash, receiptSha) } diff --git a/core/blockchain.go b/core/blockchain.go index ce1edd9b7fb0..d9505dcf69a1 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -129,6 +129,7 @@ type CacheConfig struct { TrieDirtyDisabled bool // Whether to disable trie write caching and GC altogether (archive node) TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk SnapshotLimit int // Memory allowance (MB) to use for caching snapshot entries in memory + Preimages bool // Whether to store preimage of trie key to the disk SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it } @@ -207,9 +208,10 @@ type BlockChain struct { processor Processor // Block transaction processor interface vmConfig vm.Config - badBlocks *lru.Cache // Bad block cache - shouldPreserve func(*types.Block) bool // Function used to determine whether should preserve the given block. - terminateInsert func(common.Hash, uint64) bool // Testing hook used to terminate ancient receipt chain insertion. + badBlocks *lru.Cache // Bad block cache + shouldPreserve func(*types.Block) bool // Function used to determine whether should preserve the given block. + terminateInsert func(common.Hash, uint64) bool // Testing hook used to terminate ancient receipt chain insertion. + writeLegacyJournal bool // Testing flag used to flush the snapshot journal in legacy format. } // NewBlockChain returns a fully initialised block chain using information @@ -228,11 +230,15 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par badBlocks, _ := lru.New(badBlockLimit) bc := &BlockChain{ - chainConfig: chainConfig, - cacheConfig: cacheConfig, - db: db, - triegc: prque.New(nil), - stateCache: state.NewDatabaseWithCache(db, cacheConfig.TrieCleanLimit, cacheConfig.TrieCleanJournal), + chainConfig: chainConfig, + cacheConfig: cacheConfig, + db: db, + triegc: prque.New(nil), + stateCache: state.NewDatabaseWithConfig(db, &trie.Config{ + Cache: cacheConfig.TrieCleanLimit, + Journal: cacheConfig.TrieCleanJournal, + Preimages: cacheConfig.Preimages, + }), quit: make(chan struct{}), shouldPreserve: shouldPreserve, bodyCache: bodyCache, @@ -281,9 +287,29 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par // Make sure the state associated with the block is available head := bc.CurrentBlock() if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil { - log.Warn("Head state missing, repairing", "number", head.Number(), "hash", head.Hash()) - if err := bc.SetHead(head.NumberU64()); err != nil { - return nil, err + // Head state is missing, before the state recovery, find out the + // disk layer point of snapshot(if it's enabled). Make sure the + // rewound point is lower than disk layer. + var diskRoot common.Hash + if bc.cacheConfig.SnapshotLimit > 0 { + diskRoot = rawdb.ReadSnapshotRoot(bc.db) + } + if diskRoot != (common.Hash{}) { + log.Warn("Head state missing, repairing", "number", head.Number(), "hash", head.Hash(), "snaproot", diskRoot) + + snapDisk, err := bc.SetHeadBeyondRoot(head.NumberU64(), diskRoot) + if err != nil { + return nil, err + } + // Chain rewound, persist old snapshot number to indicate recovery procedure + if snapDisk != 0 { + rawdb.WriteSnapshotRecoveryNumber(bc.db, snapDisk) + } + } else { + log.Warn("Head state missing, repairing", "number", head.Number(), "hash", head.Hash()) + if err := bc.SetHead(head.NumberU64()); err != nil { + return nil, err + } } } // Ensure that a previous crash in SetHead doesn't leave extra ancients @@ -339,12 +365,25 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par } // Load any existing snapshot, regenerating it if loading failed if bc.cacheConfig.SnapshotLimit > 0 { - bc.snaps = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, bc.CurrentBlock().Root(), !bc.cacheConfig.SnapshotWait) + // If the chain was rewound past the snapshot persistent layer (causing + // a recovery block number to be persisted to disk), check if we're still + // in recovery mode and in that case, don't invalidate the snapshot on a + // head mismatch. + var recover bool + + head := bc.CurrentBlock() + if layer := rawdb.ReadSnapshotRecoveryNumber(bc.db); layer != nil && *layer > head.NumberU64() { + log.Warn("Enabling snapshot recovery", "chainhead", head.NumberU64(), "diskbase", *layer) + recover = true + } + bc.snaps = snapshot.New(bc.db, bc.stateCache.TrieDB(), bc.cacheConfig.SnapshotLimit, head.Root(), !bc.cacheConfig.SnapshotWait, recover) } // Take ownership of this particular state go bc.update() if txLookupLimit != nil { bc.txLookupLimit = *txLookupLimit + + bc.wg.Add(1) go bc.maintainTxIndex(txIndexBlock) } // If periodic cache journal is required, spin it up. @@ -442,9 +481,25 @@ func (bc *BlockChain) loadLastState() error { // was fast synced or full synced and in which state, the method will try to // delete minimal data from disk whilst retaining chain consistency. func (bc *BlockChain) SetHead(head uint64) error { + _, err := bc.SetHeadBeyondRoot(head, common.Hash{}) + return err +} + +// SetHeadBeyondRoot rewinds the local chain to a new head with the extra condition +// that the rewind must pass the specified state root. This method is meant to be +// used when rewiding with snapshots enabled to ensure that we go back further than +// persistent disk layer. Depending on whether the node was fast synced or full, and +// in which state, the method will try to delete minimal data from disk whilst +// retaining chain consistency. +// +// The method returns the block number where the requested root cap was found. +func (bc *BlockChain) SetHeadBeyondRoot(head uint64, root common.Hash) (uint64, error) { bc.chainmu.Lock() defer bc.chainmu.Unlock() + // Track the block number of the requested root hash + var rootNumber uint64 // (no root == always 0) + // Retrieve the last pivot block to short circuit rollbacks beyond it and the // current freezer limit to start nuking id underflown pivot := rawdb.ReadLastPivotNumber(bc.db) @@ -460,8 +515,16 @@ func (bc *BlockChain) SetHead(head uint64) error { log.Error("Gap in the chain, rewinding to genesis", "number", header.Number, "hash", header.Hash()) newHeadBlock = bc.genesisBlock } else { - // Block exists, keep rewinding until we find one with state + // Block exists, keep rewinding until we find one with state, + // keeping rewinding until we exceed the optional threshold + // root hash + beyondRoot := (root == common.Hash{}) // Flag whether we're beyond the requested root (no root, always true) + for { + // If a root threshold was requested but not yet crossed, check + if root != (common.Hash{}) && !beyondRoot && newHeadBlock.Root() == root { + beyondRoot, rootNumber = true, newHeadBlock.NumberU64() + } if _, err := state.New(newHeadBlock.Root(), bc.stateCache, bc.snaps); err != nil { log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) if pivot == nil || newHeadBlock.NumberU64() > *pivot { @@ -472,8 +535,12 @@ func (bc *BlockChain) SetHead(head uint64) error { newHeadBlock = bc.genesisBlock } } - log.Debug("Rewound to block with state", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) - break + if beyondRoot || newHeadBlock.NumberU64() == 0 { + log.Debug("Rewound to block with state", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) + break + } + log.Debug("Skipping block with threshold state", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash(), "root", newHeadBlock.Root()) + newHeadBlock = bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) // Keep rewinding } } rawdb.WriteHeadBlockHash(db, newHeadBlock.Hash()) @@ -553,7 +620,7 @@ func (bc *BlockChain) SetHead(head uint64) error { bc.txLookupCache.Purge() bc.futureBlocks.Purge() - return bc.loadLastState() + return rootNumber, bc.loadLastState() } // FastSyncCommitHead sets the current head block to the one defined by the hash @@ -592,12 +659,8 @@ func (bc *BlockChain) CurrentBlock() *types.Block { return bc.currentBlock.Load().(*types.Block) } -// Snapshot returns the blockchain snapshot tree. This method is mainly used for -// testing, to make it possible to verify the snapshot after execution. -// -// Warning: There are no guarantees about the safety of using the returned 'snap' if the -// blockchain is simultaneously importing blocks, so take care. -func (bc *BlockChain) Snapshot() *snapshot.Tree { +// Snapshots returns the blockchain snapshot tree. +func (bc *BlockChain) Snapshots() *snapshot.Tree { return bc.snaps } @@ -938,8 +1001,14 @@ func (bc *BlockChain) Stop() { var snapBase common.Hash if bc.snaps != nil { var err error - if snapBase, err = bc.snaps.Journal(bc.CurrentBlock().Root()); err != nil { - log.Error("Failed to journal state snapshot", "err", err) + if bc.writeLegacyJournal { + if snapBase, err = bc.snaps.LegacyJournal(bc.CurrentBlock().Root()); err != nil { + log.Error("Failed to journal state snapshot", "err", err) + } + } else { + if snapBase, err = bc.snaps.Journal(bc.CurrentBlock().Root()); err != nil { + log.Error("Failed to journal state snapshot", "err", err) + } } } // Ensure the state of a recent block is also stored to disk before exiting. @@ -2230,6 +2299,8 @@ func (bc *BlockChain) update() { // sync, Geth will automatically construct the missing indices and delete // the extra indices. func (bc *BlockChain) maintainTxIndex(ancients uint64) { + defer bc.wg.Done() + // Before starting the actual maintenance, we need to handle a special case, // where user might init Geth with an external ancient database. If so, we // need to reindex all necessary transactions before starting to process any @@ -2239,7 +2310,7 @@ func (bc *BlockChain) maintainTxIndex(ancients uint64) { if bc.txLookupLimit != 0 && ancients > bc.txLookupLimit { from = ancients - bc.txLookupLimit } - rawdb.IndexTransactions(bc.db, from, ancients) + rawdb.IndexTransactions(bc.db, from, ancients, bc.quit) } // indexBlocks reindexes or unindexes transactions depending on user configuration indexBlocks := func(tail *uint64, head uint64, done chan struct{}) { @@ -2253,24 +2324,24 @@ func (bc *BlockChain) maintainTxIndex(ancients uint64) { rawdb.WriteTxIndexTail(bc.db, 0) } else { // Prune all stale tx indices and record the tx index tail - rawdb.UnindexTransactions(bc.db, 0, head-bc.txLookupLimit+1) + rawdb.UnindexTransactions(bc.db, 0, head-bc.txLookupLimit+1, bc.quit) } return } // If a previous indexing existed, make sure that we fill in any missing entries if bc.txLookupLimit == 0 || head < bc.txLookupLimit { if *tail > 0 { - rawdb.IndexTransactions(bc.db, 0, *tail) + rawdb.IndexTransactions(bc.db, 0, *tail, bc.quit) } return } // Update the transaction index to the new chain state if head-bc.txLookupLimit+1 < *tail { // Reindex a part of missing indices and rewind index tail to HEAD-limit - rawdb.IndexTransactions(bc.db, head-bc.txLookupLimit+1, *tail) + rawdb.IndexTransactions(bc.db, head-bc.txLookupLimit+1, *tail, bc.quit) } else { // Unindex a part of stale indices and forward index tail to HEAD-limit - rawdb.UnindexTransactions(bc.db, *tail, head-bc.txLookupLimit+1) + rawdb.UnindexTransactions(bc.db, *tail, head-bc.txLookupLimit+1, bc.quit) } } // Any reindexing done, start listening to chain events and moving the index window @@ -2294,6 +2365,10 @@ func (bc *BlockChain) maintainTxIndex(ancients uint64) { case <-done: done = nil case <-bc.quit: + if done != nil { + log.Info("Waiting background transaction indexer to exit") + <-done + } return } } @@ -2359,12 +2434,8 @@ func (bc *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int) (i bc.wg.Add(1) defer bc.wg.Done() - - whFunc := func(header *types.Header) error { - _, err := bc.hc.WriteHeader(header) - return err - } - return bc.hc.InsertHeaderChain(chain, whFunc, start) + _, err := bc.hc.InsertHeaderChain(chain, start) + return 0, err } // CurrentHeader retrieves the current head header of the canonical chain. The diff --git a/core/blockchain_insert.go b/core/blockchain_insert.go index 5685b0a4bdd9..cb8473c08426 100644 --- a/core/blockchain_insert.go +++ b/core/blockchain_insert.go @@ -43,7 +43,7 @@ func (st *insertStats) report(chain []*types.Block, index int, dirty common.Stor // Fetch the timings for the batch var ( now = mclock.Now() - elapsed = time.Duration(now) - time.Duration(st.startTime) + elapsed = now.Sub(st.startTime) ) // If we're at the last block of the batch or report period reached, log if index == len(chain)-1 || elapsed >= statsReportLimit { diff --git a/core/blockchain_repair_test.go b/core/blockchain_repair_test.go index 27903dd06b35..b5cd232a9c4f 100644 --- a/core/blockchain_repair_test.go +++ b/core/blockchain_repair_test.go @@ -25,6 +25,7 @@ import ( "math/big" "os" "testing" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" @@ -38,7 +39,10 @@ import ( // committed to disk and then the process crashed. In this case we expect the full // chain to be rolled back to the committed block, but the chain data itself left // in the database for replaying. -func TestShortRepair(t *testing.T) { +func TestShortRepair(t *testing.T) { testShortRepair(t, false) } +func TestShortRepairWithSnapshots(t *testing.T) { testShortRepair(t, true) } + +func testShortRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // @@ -68,14 +72,17 @@ func TestShortRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain where the fast sync pivot point was // already committed, after which the process crashed. In this case we expect the full // chain to be rolled back to the committed block, but the chain data itself left in // the database for replaying. -func TestShortFastSyncedRepair(t *testing.T) { +func TestShortFastSyncedRepair(t *testing.T) { testShortFastSyncedRepair(t, false) } +func TestShortFastSyncedRepairWithSnapshots(t *testing.T) { testShortFastSyncedRepair(t, true) } + +func testShortFastSyncedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // @@ -105,14 +112,17 @@ func TestShortFastSyncedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain where the fast sync pivot point was // not yet committed, but the process crashed. In this case we expect the chain to // detect that it was fast syncing and not delete anything, since we can just pick // up directly where we left off. -func TestShortFastSyncingRepair(t *testing.T) { +func TestShortFastSyncingRepair(t *testing.T) { testShortFastSyncingRepair(t, false) } +func TestShortFastSyncingRepairWithSnapshots(t *testing.T) { testShortFastSyncingRepair(t, true) } + +func testShortFastSyncingRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // @@ -142,7 +152,7 @@ func TestShortFastSyncingRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a shorter side chain, where a @@ -150,7 +160,10 @@ func TestShortFastSyncingRepair(t *testing.T) { // test scenario the side chain is below the committed block. In this case we expect // the canonical chain to be rolled back to the committed block, but the chain data // itself left in the database for replaying. -func TestShortOldForkedRepair(t *testing.T) { +func TestShortOldForkedRepair(t *testing.T) { testShortOldForkedRepair(t, false) } +func TestShortOldForkedRepairWithSnapshots(t *testing.T) { testShortOldForkedRepair(t, true) } + +func testShortOldForkedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3 @@ -182,7 +195,7 @@ func TestShortOldForkedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a shorter side chain, where @@ -191,6 +204,13 @@ func TestShortOldForkedRepair(t *testing.T) { // this case we expect the canonical chain to be rolled back to the committed block, // but the chain data itself left in the database for replaying. func TestShortOldForkedFastSyncedRepair(t *testing.T) { + testShortOldForkedFastSyncedRepair(t, false) +} +func TestShortOldForkedFastSyncedRepairWithSnapshots(t *testing.T) { + testShortOldForkedFastSyncedRepair(t, true) +} + +func testShortOldForkedFastSyncedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3 @@ -222,7 +242,7 @@ func TestShortOldForkedFastSyncedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a shorter side chain, where @@ -231,6 +251,13 @@ func TestShortOldForkedFastSyncedRepair(t *testing.T) { // the chain to detect that it was fast syncing and not delete anything, since we // can just pick up directly where we left off. func TestShortOldForkedFastSyncingRepair(t *testing.T) { + testShortOldForkedFastSyncingRepair(t, false) +} +func TestShortOldForkedFastSyncingRepairWithSnapshots(t *testing.T) { + testShortOldForkedFastSyncingRepair(t, true) +} + +func testShortOldForkedFastSyncingRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3 @@ -262,7 +289,7 @@ func TestShortOldForkedFastSyncingRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a shorter side chain, where a @@ -270,7 +297,10 @@ func TestShortOldForkedFastSyncingRepair(t *testing.T) { // test scenario the side chain reaches above the committed block. In this case we // expect the canonical chain to be rolled back to the committed block, but the // chain data itself left in the database for replaying. -func TestShortNewlyForkedRepair(t *testing.T) { +func TestShortNewlyForkedRepair(t *testing.T) { testShortNewlyForkedRepair(t, false) } +func TestShortNewlyForkedRepairWithSnapshots(t *testing.T) { testShortNewlyForkedRepair(t, true) } + +func testShortNewlyForkedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6 @@ -302,7 +332,7 @@ func TestShortNewlyForkedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a shorter side chain, where @@ -311,6 +341,13 @@ func TestShortNewlyForkedRepair(t *testing.T) { // In this case we expect the canonical chain to be rolled back to the committed // block, but the chain data itself left in the database for replaying. func TestShortNewlyForkedFastSyncedRepair(t *testing.T) { + testShortNewlyForkedFastSyncedRepair(t, false) +} +func TestShortNewlyForkedFastSyncedRepairWithSnapshots(t *testing.T) { + testShortNewlyForkedFastSyncedRepair(t, true) +} + +func testShortNewlyForkedFastSyncedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6 @@ -342,7 +379,7 @@ func TestShortNewlyForkedFastSyncedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a shorter side chain, where @@ -351,6 +388,13 @@ func TestShortNewlyForkedFastSyncedRepair(t *testing.T) { // case we expect the chain to detect that it was fast syncing and not delete // anything, since we can just pick up directly where we left off. func TestShortNewlyForkedFastSyncingRepair(t *testing.T) { + testShortNewlyForkedFastSyncingRepair(t, false) +} +func TestShortNewlyForkedFastSyncingRepairWithSnapshots(t *testing.T) { + testShortNewlyForkedFastSyncingRepair(t, true) +} + +func testShortNewlyForkedFastSyncingRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6 @@ -382,14 +426,17 @@ func TestShortNewlyForkedFastSyncingRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a longer side chain, where a // recent block was already committed to disk and then the process crashed. In this // case we expect the canonical chain to be rolled back to the committed block, but // the chain data itself left in the database for replaying. -func TestShortReorgedRepair(t *testing.T) { +func TestShortReorgedRepair(t *testing.T) { testShortReorgedRepair(t, false) } +func TestShortReorgedRepairWithSnapshots(t *testing.T) { testShortReorgedRepair(t, true) } + +func testShortReorgedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10 @@ -421,7 +468,7 @@ func TestShortReorgedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a longer side chain, where @@ -429,6 +476,13 @@ func TestShortReorgedRepair(t *testing.T) { // crashed. In this case we expect the canonical chain to be rolled back to the // committed block, but the chain data itself left in the database for replaying. func TestShortReorgedFastSyncedRepair(t *testing.T) { + testShortReorgedFastSyncedRepair(t, false) +} +func TestShortReorgedFastSyncedRepairWithSnapshots(t *testing.T) { + testShortReorgedFastSyncedRepair(t, true) +} + +func testShortReorgedFastSyncedRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10 @@ -460,7 +514,7 @@ func TestShortReorgedFastSyncedRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a short canonical chain and a longer side chain, where @@ -468,6 +522,13 @@ func TestShortReorgedFastSyncedRepair(t *testing.T) { // this case we expect the chain to detect that it was fast syncing and not delete // anything, since we can just pick up directly where we left off. func TestShortReorgedFastSyncingRepair(t *testing.T) { + testShortReorgedFastSyncingRepair(t, false) +} +func TestShortReorgedFastSyncingRepairWithSnapshots(t *testing.T) { + testShortReorgedFastSyncingRepair(t, true) +} + +func testShortReorgedFastSyncingRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10 @@ -499,14 +560,17 @@ func TestShortReorgedFastSyncingRepair(t *testing.T) { expHeadHeader: 8, expHeadFastBlock: 8, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks where a recent // block - newer than the ancient limit - was already committed to disk and then // the process crashed. In this case we expect the chain to be rolled back to the // committed block, with everything afterwads kept as fast sync data. -func TestLongShallowRepair(t *testing.T) { +func TestLongShallowRepair(t *testing.T) { testLongShallowRepair(t, false) } +func TestLongShallowRepairWithSnapshots(t *testing.T) { testLongShallowRepair(t, true) } + +func testLongShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // @@ -541,14 +605,17 @@ func TestLongShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks where a recent // block - older than the ancient limit - was already committed to disk and then // the process crashed. In this case we expect the chain to be rolled back to the // committed block, with everything afterwads deleted. -func TestLongDeepRepair(t *testing.T) { +func TestLongDeepRepair(t *testing.T) { testLongDeepRepair(t, false) } +func TestLongDeepRepairWithSnapshots(t *testing.T) { testLongDeepRepair(t, true) } + +func testLongDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // @@ -582,7 +649,7 @@ func TestLongDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks where the fast @@ -590,6 +657,13 @@ func TestLongDeepRepair(t *testing.T) { // which the process crashed. In this case we expect the chain to be rolled back // to the committed block, with everything afterwads kept as fast sync data. func TestLongFastSyncedShallowRepair(t *testing.T) { + testLongFastSyncedShallowRepair(t, false) +} +func TestLongFastSyncedShallowRepairWithSnapshots(t *testing.T) { + testLongFastSyncedShallowRepair(t, true) +} + +func testLongFastSyncedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // @@ -624,14 +698,17 @@ func TestLongFastSyncedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks where the fast // sync pivot point - older than the ancient limit - was already committed, after // which the process crashed. In this case we expect the chain to be rolled back // to the committed block, with everything afterwads deleted. -func TestLongFastSyncedDeepRepair(t *testing.T) { +func TestLongFastSyncedDeepRepair(t *testing.T) { testLongFastSyncedDeepRepair(t, false) } +func TestLongFastSyncedDeepRepairWithSnapshots(t *testing.T) { testLongFastSyncedDeepRepair(t, true) } + +func testLongFastSyncedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // @@ -665,7 +742,7 @@ func TestLongFastSyncedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks where the fast @@ -674,6 +751,13 @@ func TestLongFastSyncedDeepRepair(t *testing.T) { // syncing and not delete anything, since we can just pick up directly where we // left off. func TestLongFastSyncingShallowRepair(t *testing.T) { + testLongFastSyncingShallowRepair(t, false) +} +func TestLongFastSyncingShallowRepairWithSnapshots(t *testing.T) { + testLongFastSyncingShallowRepair(t, true) +} + +func testLongFastSyncingShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // @@ -708,7 +792,7 @@ func TestLongFastSyncingShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks where the fast @@ -716,7 +800,10 @@ func TestLongFastSyncingShallowRepair(t *testing.T) { // process crashed. In this case we expect the chain to detect that it was fast // syncing and not delete anything, since we can just pick up directly where we // left off. -func TestLongFastSyncingDeepRepair(t *testing.T) { +func TestLongFastSyncingDeepRepair(t *testing.T) { testLongFastSyncingDeepRepair(t, false) } +func TestLongFastSyncingDeepRepairWithSnapshots(t *testing.T) { testLongFastSyncingDeepRepair(t, true) } + +func testLongFastSyncingDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // @@ -751,7 +838,7 @@ func TestLongFastSyncingDeepRepair(t *testing.T) { expHeadHeader: 24, expHeadFastBlock: 24, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -761,6 +848,13 @@ func TestLongFastSyncingDeepRepair(t *testing.T) { // rolled back to the committed block, with everything afterwads kept as fast // sync data; the side chain completely nuked by the freezer. func TestLongOldForkedShallowRepair(t *testing.T) { + testLongOldForkedShallowRepair(t, false) +} +func TestLongOldForkedShallowRepairWithSnapshots(t *testing.T) { + testLongOldForkedShallowRepair(t, true) +} + +func testLongOldForkedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3 @@ -796,7 +890,7 @@ func TestLongOldForkedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -805,7 +899,10 @@ func TestLongOldForkedShallowRepair(t *testing.T) { // chain is below the committed block. In this case we expect the canonical chain // to be rolled back to the committed block, with everything afterwads deleted; // the side chain completely nuked by the freezer. -func TestLongOldForkedDeepRepair(t *testing.T) { +func TestLongOldForkedDeepRepair(t *testing.T) { testLongOldForkedDeepRepair(t, false) } +func TestLongOldForkedDeepRepairWithSnapshots(t *testing.T) { testLongOldForkedDeepRepair(t, true) } + +func testLongOldForkedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3 @@ -840,7 +937,7 @@ func TestLongOldForkedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -850,6 +947,13 @@ func TestLongOldForkedDeepRepair(t *testing.T) { // to be rolled back to the committed block, with everything afterwads kept as // fast sync data; the side chain completely nuked by the freezer. func TestLongOldForkedFastSyncedShallowRepair(t *testing.T) { + testLongOldForkedFastSyncedShallowRepair(t, false) +} +func TestLongOldForkedFastSyncedShallowRepairWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncedShallowRepair(t, true) +} + +func testLongOldForkedFastSyncedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3 @@ -885,7 +989,7 @@ func TestLongOldForkedFastSyncedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -895,6 +999,13 @@ func TestLongOldForkedFastSyncedShallowRepair(t *testing.T) { // chain to be rolled back to the committed block, with everything afterwads deleted; // the side chain completely nuked by the freezer. func TestLongOldForkedFastSyncedDeepRepair(t *testing.T) { + testLongOldForkedFastSyncedDeepRepair(t, false) +} +func TestLongOldForkedFastSyncedDeepRepairWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncedDeepRepair(t, true) +} + +func testLongOldForkedFastSyncedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3 @@ -929,7 +1040,7 @@ func TestLongOldForkedFastSyncedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -939,6 +1050,13 @@ func TestLongOldForkedFastSyncedDeepRepair(t *testing.T) { // that it was fast syncing and not delete anything. The side chain is completely // nuked by the freezer. func TestLongOldForkedFastSyncingShallowRepair(t *testing.T) { + testLongOldForkedFastSyncingShallowRepair(t, false) +} +func TestLongOldForkedFastSyncingShallowRepairWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncingShallowRepair(t, true) +} + +func testLongOldForkedFastSyncingShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3 @@ -974,7 +1092,7 @@ func TestLongOldForkedFastSyncingShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -984,6 +1102,13 @@ func TestLongOldForkedFastSyncingShallowRepair(t *testing.T) { // that it was fast syncing and not delete anything. The side chain is completely // nuked by the freezer. func TestLongOldForkedFastSyncingDeepRepair(t *testing.T) { + testLongOldForkedFastSyncingDeepRepair(t, false) +} +func TestLongOldForkedFastSyncingDeepRepairWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncingDeepRepair(t, true) +} + +func testLongOldForkedFastSyncingDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3 @@ -1019,7 +1144,7 @@ func TestLongOldForkedFastSyncingDeepRepair(t *testing.T) { expHeadHeader: 24, expHeadFastBlock: 24, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -1029,6 +1154,13 @@ func TestLongOldForkedFastSyncingDeepRepair(t *testing.T) { // rolled back to the committed block, with everything afterwads kept as fast // sync data; the side chain completely nuked by the freezer. func TestLongNewerForkedShallowRepair(t *testing.T) { + testLongNewerForkedShallowRepair(t, false) +} +func TestLongNewerForkedShallowRepairWithSnapshots(t *testing.T) { + testLongNewerForkedShallowRepair(t, true) +} + +func testLongNewerForkedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1064,7 +1196,7 @@ func TestLongNewerForkedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -1073,7 +1205,10 @@ func TestLongNewerForkedShallowRepair(t *testing.T) { // chain is above the committed block. In this case we expect the canonical chain // to be rolled back to the committed block, with everything afterwads deleted; // the side chain completely nuked by the freezer. -func TestLongNewerForkedDeepRepair(t *testing.T) { +func TestLongNewerForkedDeepRepair(t *testing.T) { testLongNewerForkedDeepRepair(t, false) } +func TestLongNewerForkedDeepRepairWithSnapshots(t *testing.T) { testLongNewerForkedDeepRepair(t, true) } + +func testLongNewerForkedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1108,7 +1243,7 @@ func TestLongNewerForkedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -1118,6 +1253,13 @@ func TestLongNewerForkedDeepRepair(t *testing.T) { // to be rolled back to the committed block, with everything afterwads kept as fast // sync data; the side chain completely nuked by the freezer. func TestLongNewerForkedFastSyncedShallowRepair(t *testing.T) { + testLongNewerForkedFastSyncedShallowRepair(t, false) +} +func TestLongNewerForkedFastSyncedShallowRepairWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncedShallowRepair(t, true) +} + +func testLongNewerForkedFastSyncedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1153,7 +1295,7 @@ func TestLongNewerForkedFastSyncedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -1163,6 +1305,13 @@ func TestLongNewerForkedFastSyncedShallowRepair(t *testing.T) { // chain to be rolled back to the committed block, with everything afterwads deleted; // the side chain completely nuked by the freezer. func TestLongNewerForkedFastSyncedDeepRepair(t *testing.T) { + testLongNewerForkedFastSyncedDeepRepair(t, false) +} +func TestLongNewerForkedFastSyncedDeepRepairWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncedDeepRepair(t, true) +} + +func testLongNewerForkedFastSyncedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1197,7 +1346,7 @@ func TestLongNewerForkedFastSyncedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -1207,6 +1356,13 @@ func TestLongNewerForkedFastSyncedDeepRepair(t *testing.T) { // that it was fast syncing and not delete anything. The side chain is completely // nuked by the freezer. func TestLongNewerForkedFastSyncingShallowRepair(t *testing.T) { + testLongNewerForkedFastSyncingShallowRepair(t, false) +} +func TestLongNewerForkedFastSyncingShallowRepairWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncingShallowRepair(t, true) +} + +func testLongNewerForkedFastSyncingShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1242,7 +1398,7 @@ func TestLongNewerForkedFastSyncingShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a shorter @@ -1252,6 +1408,13 @@ func TestLongNewerForkedFastSyncingShallowRepair(t *testing.T) { // that it was fast syncing and not delete anything. The side chain is completely // nuked by the freezer. func TestLongNewerForkedFastSyncingDeepRepair(t *testing.T) { + testLongNewerForkedFastSyncingDeepRepair(t, false) +} +func TestLongNewerForkedFastSyncingDeepRepairWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncingDeepRepair(t, true) +} + +func testLongNewerForkedFastSyncingDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1287,7 +1450,7 @@ func TestLongNewerForkedFastSyncingDeepRepair(t *testing.T) { expHeadHeader: 24, expHeadFastBlock: 24, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a longer side @@ -1295,7 +1458,10 @@ func TestLongNewerForkedFastSyncingDeepRepair(t *testing.T) { // to disk and then the process crashed. In this case we expect the chain to be // rolled back to the committed block, with everything afterwads kept as fast sync // data. The side chain completely nuked by the freezer. -func TestLongReorgedShallowRepair(t *testing.T) { +func TestLongReorgedShallowRepair(t *testing.T) { testLongReorgedShallowRepair(t, false) } +func TestLongReorgedShallowRepairWithSnapshots(t *testing.T) { testLongReorgedShallowRepair(t, true) } + +func testLongReorgedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1331,7 +1497,7 @@ func TestLongReorgedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a longer side @@ -1339,7 +1505,10 @@ func TestLongReorgedShallowRepair(t *testing.T) { // to disk and then the process crashed. In this case we expect the canonical chains // to be rolled back to the committed block, with everything afterwads deleted. The // side chain completely nuked by the freezer. -func TestLongReorgedDeepRepair(t *testing.T) { +func TestLongReorgedDeepRepair(t *testing.T) { testLongReorgedDeepRepair(t, false) } +func TestLongReorgedDeepRepairWithSnapshots(t *testing.T) { testLongReorgedDeepRepair(t, true) } + +func testLongReorgedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1374,7 +1543,7 @@ func TestLongReorgedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a longer @@ -1384,6 +1553,13 @@ func TestLongReorgedDeepRepair(t *testing.T) { // afterwads kept as fast sync data. The side chain completely nuked by the // freezer. func TestLongReorgedFastSyncedShallowRepair(t *testing.T) { + testLongReorgedFastSyncedShallowRepair(t, false) +} +func TestLongReorgedFastSyncedShallowRepairWithSnapshots(t *testing.T) { + testLongReorgedFastSyncedShallowRepair(t, true) +} + +func testLongReorgedFastSyncedShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1419,7 +1595,7 @@ func TestLongReorgedFastSyncedShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a longer @@ -1428,6 +1604,13 @@ func TestLongReorgedFastSyncedShallowRepair(t *testing.T) { // expect the canonical chains to be rolled back to the committed block, with // everything afterwads deleted. The side chain completely nuked by the freezer. func TestLongReorgedFastSyncedDeepRepair(t *testing.T) { + testLongReorgedFastSyncedDeepRepair(t, false) +} +func TestLongReorgedFastSyncedDeepRepairWithSnapshots(t *testing.T) { + testLongReorgedFastSyncedDeepRepair(t, true) +} + +func testLongReorgedFastSyncedDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1462,7 +1645,7 @@ func TestLongReorgedFastSyncedDeepRepair(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a longer @@ -1471,6 +1654,13 @@ func TestLongReorgedFastSyncedDeepRepair(t *testing.T) { // chain to detect that it was fast syncing and not delete anything, since we // can just pick up directly where we left off. func TestLongReorgedFastSyncingShallowRepair(t *testing.T) { + testLongReorgedFastSyncingShallowRepair(t, false) +} +func TestLongReorgedFastSyncingShallowRepairWithSnapshots(t *testing.T) { + testLongReorgedFastSyncingShallowRepair(t, true) +} + +func testLongReorgedFastSyncingShallowRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1506,7 +1696,7 @@ func TestLongReorgedFastSyncingShallowRepair(t *testing.T) { expHeadHeader: 18, expHeadFastBlock: 18, expHeadBlock: 0, - }) + }, snapshots) } // Tests a recovery for a long canonical chain with frozen blocks and a longer @@ -1515,6 +1705,13 @@ func TestLongReorgedFastSyncingShallowRepair(t *testing.T) { // chain to detect that it was fast syncing and not delete anything, since we // can just pick up directly where we left off. func TestLongReorgedFastSyncingDeepRepair(t *testing.T) { + testLongReorgedFastSyncingDeepRepair(t, false) +} +func TestLongReorgedFastSyncingDeepRepairWithSnapshots(t *testing.T) { + testLongReorgedFastSyncingDeepRepair(t, true) +} + +func testLongReorgedFastSyncingDeepRepair(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1550,13 +1747,13 @@ func TestLongReorgedFastSyncingDeepRepair(t *testing.T) { expHeadHeader: 24, expHeadFastBlock: 24, expHeadBlock: 0, - }) + }, snapshots) } -func testRepair(t *testing.T, tt *rewindTest) { +func testRepair(t *testing.T, tt *rewindTest, snapshots bool) { // It's hard to follow the test case, visualize the input //log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) - //fmt.Println(tt.dump(true)) + // fmt.Println(tt.dump(true)) // Create a temporary persistent database datadir, err := ioutil.TempDir("", "") @@ -1575,8 +1772,18 @@ func testRepair(t *testing.T, tt *rewindTest) { var ( genesis = new(Genesis).MustCommit(db) engine = ethash.NewFullFaker() + config = &CacheConfig{ + TrieCleanLimit: 256, + TrieDirtyLimit: 256, + TrieTimeLimit: 5 * time.Minute, + SnapshotLimit: 0, // Disable snapshot by default + } ) - chain, err := NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if snapshots { + config.SnapshotLimit = 256 + config.SnapshotWait = true + } + chain, err := NewBlockChain(db, config, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) if err != nil { t.Fatalf("Failed to create chain: %v", err) } @@ -1599,6 +1806,11 @@ func testRepair(t *testing.T, tt *rewindTest) { } if tt.commitBlock > 0 { chain.stateCache.TrieDB().Commit(canonblocks[tt.commitBlock-1].Root(), true, nil) + if snapshots { + if err := chain.snaps.Cap(canonblocks[tt.commitBlock-1].Root(), 0); err != nil { + t.Fatalf("Failed to flatten snapshots: %v", err) + } + } } if _, err := chain.InsertChain(canonblocks[tt.commitBlock:]); err != nil { t.Fatalf("Failed to import canonical chain tail: %v", err) diff --git a/core/blockchain_sethead_test.go b/core/blockchain_sethead_test.go index dc1368ff4b48..45c4073eb4ce 100644 --- a/core/blockchain_sethead_test.go +++ b/core/blockchain_sethead_test.go @@ -26,6 +26,7 @@ import ( "os" "strings" "testing" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" @@ -150,7 +151,10 @@ func (tt *rewindTest) dump(crash bool) string { // chain to be rolled back to the committed block. Everything above the sethead // point should be deleted. In between the committed block and the requested head // the data can remain as "fast sync" data to avoid redownloading it. -func TestShortSetHead(t *testing.T) { +func TestShortSetHead(t *testing.T) { testShortSetHead(t, false) } +func TestShortSetHeadWithSnapshots(t *testing.T) { testShortSetHead(t, true) } + +func testShortSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // @@ -181,7 +185,7 @@ func TestShortSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain where the fast sync pivot point was @@ -190,7 +194,10 @@ func TestShortSetHead(t *testing.T) { // Everything above the sethead point should be deleted. In between the committed // block and the requested head the data can remain as "fast sync" data to avoid // redownloading it. -func TestShortFastSyncedSetHead(t *testing.T) { +func TestShortFastSyncedSetHead(t *testing.T) { testShortFastSyncedSetHead(t, false) } +func TestShortFastSyncedSetHeadWithSnapshots(t *testing.T) { testShortFastSyncedSetHead(t, true) } + +func testShortFastSyncedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // @@ -221,7 +228,7 @@ func TestShortFastSyncedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain where the fast sync pivot point was @@ -229,7 +236,10 @@ func TestShortFastSyncedSetHead(t *testing.T) { // detect that it was fast syncing and delete everything from the new head, since // we can just pick up fast syncing from there. The head full block should be set // to the genesis. -func TestShortFastSyncingSetHead(t *testing.T) { +func TestShortFastSyncingSetHead(t *testing.T) { testShortFastSyncingSetHead(t, false) } +func TestShortFastSyncingSetHeadWithSnapshots(t *testing.T) { testShortFastSyncingSetHead(t, true) } + +func testShortFastSyncingSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // @@ -260,7 +270,7 @@ func TestShortFastSyncingSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a shorter side chain, where a @@ -270,7 +280,10 @@ func TestShortFastSyncingSetHead(t *testing.T) { // above the sethead point should be deleted. In between the committed block and // the requested head the data can remain as "fast sync" data to avoid redownloading // it. The side chain should be left alone as it was shorter. -func TestShortOldForkedSetHead(t *testing.T) { +func TestShortOldForkedSetHead(t *testing.T) { testShortOldForkedSetHead(t, false) } +func TestShortOldForkedSetHeadWithSnapshots(t *testing.T) { testShortOldForkedSetHead(t, true) } + +func testShortOldForkedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3 @@ -303,7 +316,7 @@ func TestShortOldForkedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a shorter side chain, where @@ -314,6 +327,13 @@ func TestShortOldForkedSetHead(t *testing.T) { // committed block and the requested head the data can remain as "fast sync" data // to avoid redownloading it. The side chain should be left alone as it was shorter. func TestShortOldForkedFastSyncedSetHead(t *testing.T) { + testShortOldForkedFastSyncedSetHead(t, false) +} +func TestShortOldForkedFastSyncedSetHeadWithSnapshots(t *testing.T) { + testShortOldForkedFastSyncedSetHead(t, true) +} + +func testShortOldForkedFastSyncedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3 @@ -346,7 +366,7 @@ func TestShortOldForkedFastSyncedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a shorter side chain, where @@ -356,6 +376,13 @@ func TestShortOldForkedFastSyncedSetHead(t *testing.T) { // head, since we can just pick up fast syncing from there. The head full block // should be set to the genesis. func TestShortOldForkedFastSyncingSetHead(t *testing.T) { + testShortOldForkedFastSyncingSetHead(t, false) +} +func TestShortOldForkedFastSyncingSetHeadWithSnapshots(t *testing.T) { + testShortOldForkedFastSyncingSetHead(t, true) +} + +func testShortOldForkedFastSyncingSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3 @@ -388,7 +415,7 @@ func TestShortOldForkedFastSyncingSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a shorter side chain, where a @@ -402,7 +429,10 @@ func TestShortOldForkedFastSyncingSetHead(t *testing.T) { // The side chain could be left to be if the fork point was before the new head // we are deleting to, but it would be exceedingly hard to detect that case and // properly handle it, so we'll trade extra work in exchange for simpler code. -func TestShortNewlyForkedSetHead(t *testing.T) { +func TestShortNewlyForkedSetHead(t *testing.T) { testShortNewlyForkedSetHead(t, false) } +func TestShortNewlyForkedSetHeadWithSnapshots(t *testing.T) { testShortNewlyForkedSetHead(t, true) } + +func testShortNewlyForkedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8 @@ -435,7 +465,7 @@ func TestShortNewlyForkedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a shorter side chain, where @@ -449,6 +479,13 @@ func TestShortNewlyForkedSetHead(t *testing.T) { // we are deleting to, but it would be exceedingly hard to detect that case and // properly handle it, so we'll trade extra work in exchange for simpler code. func TestShortNewlyForkedFastSyncedSetHead(t *testing.T) { + testShortNewlyForkedFastSyncedSetHead(t, false) +} +func TestShortNewlyForkedFastSyncedSetHeadWithSnapshots(t *testing.T) { + testShortNewlyForkedFastSyncedSetHead(t, true) +} + +func testShortNewlyForkedFastSyncedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8 @@ -481,7 +518,7 @@ func TestShortNewlyForkedFastSyncedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a shorter side chain, where @@ -495,6 +532,13 @@ func TestShortNewlyForkedFastSyncedSetHead(t *testing.T) { // we are deleting to, but it would be exceedingly hard to detect that case and // properly handle it, so we'll trade extra work in exchange for simpler code. func TestShortNewlyForkedFastSyncingSetHead(t *testing.T) { + testShortNewlyForkedFastSyncingSetHead(t, false) +} +func TestShortNewlyForkedFastSyncingSetHeadWithSnapshots(t *testing.T) { + testShortNewlyForkedFastSyncingSetHead(t, true) +} + +func testShortNewlyForkedFastSyncingSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8 @@ -527,7 +571,7 @@ func TestShortNewlyForkedFastSyncingSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a longer side chain, where a @@ -540,7 +584,10 @@ func TestShortNewlyForkedFastSyncingSetHead(t *testing.T) { // The side chain could be left to be if the fork point was before the new head // we are deleting to, but it would be exceedingly hard to detect that case and // properly handle it, so we'll trade extra work in exchange for simpler code. -func TestShortReorgedSetHead(t *testing.T) { +func TestShortReorgedSetHead(t *testing.T) { testShortReorgedSetHead(t, false) } +func TestShortReorgedSetHeadWithSnapshots(t *testing.T) { testShortReorgedSetHead(t, true) } + +func testShortReorgedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10 @@ -573,7 +620,7 @@ func TestShortReorgedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a longer side chain, where @@ -588,6 +635,13 @@ func TestShortReorgedSetHead(t *testing.T) { // we are deleting to, but it would be exceedingly hard to detect that case and // properly handle it, so we'll trade extra work in exchange for simpler code. func TestShortReorgedFastSyncedSetHead(t *testing.T) { + testShortReorgedFastSyncedSetHead(t, false) +} +func TestShortReorgedFastSyncedSetHeadWithSnapshots(t *testing.T) { + testShortReorgedFastSyncedSetHead(t, true) +} + +func testShortReorgedFastSyncedSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10 @@ -620,7 +674,7 @@ func TestShortReorgedFastSyncedSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a short canonical chain and a longer side chain, where @@ -633,6 +687,13 @@ func TestShortReorgedFastSyncedSetHead(t *testing.T) { // we are deleting to, but it would be exceedingly hard to detect that case and // properly handle it, so we'll trade extra work in exchange for simpler code. func TestShortReorgedFastSyncingSetHead(t *testing.T) { + testShortReorgedFastSyncingSetHead(t, false) +} +func TestShortReorgedFastSyncingSetHeadWithSnapshots(t *testing.T) { + testShortReorgedFastSyncingSetHead(t, true) +} + +func testShortReorgedFastSyncingSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10 @@ -665,7 +726,7 @@ func TestShortReorgedFastSyncingSetHead(t *testing.T) { expHeadHeader: 7, expHeadFastBlock: 7, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks where a recent @@ -674,7 +735,10 @@ func TestShortReorgedFastSyncingSetHead(t *testing.T) { // to the committed block. Everything above the sethead point should be deleted. // In between the committed block and the requested head the data can remain as // "fast sync" data to avoid redownloading it. -func TestLongShallowSetHead(t *testing.T) { +func TestLongShallowSetHead(t *testing.T) { testLongShallowSetHead(t, false) } +func TestLongShallowSetHeadWithSnapshots(t *testing.T) { testLongShallowSetHead(t, true) } + +func testLongShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // @@ -710,7 +774,7 @@ func TestLongShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks where a recent @@ -718,7 +782,10 @@ func TestLongShallowSetHead(t *testing.T) { // sethead was called. In this case we expect the full chain to be rolled back // to the committed block. Since the ancient limit was underflown, everything // needs to be deleted onwards to avoid creating a gap. -func TestLongDeepSetHead(t *testing.T) { +func TestLongDeepSetHead(t *testing.T) { testLongDeepSetHead(t, false) } +func TestLongDeepSetHeadWithSnapshots(t *testing.T) { testLongDeepSetHead(t, true) } + +func testLongDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // @@ -753,7 +820,7 @@ func TestLongDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks where the fast @@ -763,6 +830,13 @@ func TestLongDeepSetHead(t *testing.T) { // deleted. In between the committed block and the requested head the data can // remain as "fast sync" data to avoid redownloading it. func TestLongFastSyncedShallowSetHead(t *testing.T) { + testLongFastSyncedShallowSetHead(t, false) +} +func TestLongFastSyncedShallowSetHeadWithSnapshots(t *testing.T) { + testLongFastSyncedShallowSetHead(t, true) +} + +func testLongFastSyncedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // @@ -798,7 +872,7 @@ func TestLongFastSyncedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks where the fast @@ -806,7 +880,10 @@ func TestLongFastSyncedShallowSetHead(t *testing.T) { // which sethead was called. In this case we expect the full chain to be rolled // back to the committed block. Since the ancient limit was underflown, everything // needs to be deleted onwards to avoid creating a gap. -func TestLongFastSyncedDeepSetHead(t *testing.T) { +func TestLongFastSyncedDeepSetHead(t *testing.T) { testLongFastSyncedDeepSetHead(t, false) } +func TestLongFastSyncedDeepSetHeadWithSnapshots(t *testing.T) { testLongFastSyncedDeepSetHead(t, true) } + +func testLongFastSyncedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // @@ -841,7 +918,7 @@ func TestLongFastSyncedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks where the fast @@ -850,6 +927,13 @@ func TestLongFastSyncedDeepSetHead(t *testing.T) { // syncing and delete everything from the new head, since we can just pick up fast // syncing from there. func TestLongFastSyncingShallowSetHead(t *testing.T) { + testLongFastSyncingShallowSetHead(t, false) +} +func TestLongFastSyncingShallowSetHeadWithSnapshots(t *testing.T) { + testLongFastSyncingShallowSetHead(t, true) +} + +func testLongFastSyncingShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // @@ -885,7 +969,7 @@ func TestLongFastSyncingShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks where the fast @@ -894,6 +978,13 @@ func TestLongFastSyncingShallowSetHead(t *testing.T) { // syncing and delete everything from the new head, since we can just pick up fast // syncing from there. func TestLongFastSyncingDeepSetHead(t *testing.T) { + testLongFastSyncingDeepSetHead(t, false) +} +func TestLongFastSyncingDeepSetHeadWithSnapshots(t *testing.T) { + testLongFastSyncingDeepSetHead(t, true) +} + +func testLongFastSyncingDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // @@ -928,7 +1019,7 @@ func TestLongFastSyncingDeepSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter side @@ -939,6 +1030,13 @@ func TestLongFastSyncingDeepSetHead(t *testing.T) { // can remain as "fast sync" data to avoid redownloading it. The side chain is nuked // by the freezer. func TestLongOldForkedShallowSetHead(t *testing.T) { + testLongOldForkedShallowSetHead(t, false) +} +func TestLongOldForkedShallowSetHeadWithSnapshots(t *testing.T) { + testLongOldForkedShallowSetHead(t, true) +} + +func testLongOldForkedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3 @@ -975,7 +1073,7 @@ func TestLongOldForkedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter side @@ -984,7 +1082,10 @@ func TestLongOldForkedShallowSetHead(t *testing.T) { // chain to be rolled back to the committed block. Since the ancient limit was // underflown, everything needs to be deleted onwards to avoid creating a gap. The // side chain is nuked by the freezer. -func TestLongOldForkedDeepSetHead(t *testing.T) { +func TestLongOldForkedDeepSetHead(t *testing.T) { testLongOldForkedDeepSetHead(t, false) } +func TestLongOldForkedDeepSetHeadWithSnapshots(t *testing.T) { testLongOldForkedDeepSetHead(t, true) } + +func testLongOldForkedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3 @@ -1020,7 +1121,7 @@ func TestLongOldForkedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1032,6 +1133,13 @@ func TestLongOldForkedDeepSetHead(t *testing.T) { // requested head the data can remain as "fast sync" data to avoid redownloading // it. The side chain is nuked by the freezer. func TestLongOldForkedFastSyncedShallowSetHead(t *testing.T) { + testLongOldForkedFastSyncedShallowSetHead(t, false) +} +func TestLongOldForkedFastSyncedShallowSetHeadWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncedShallowSetHead(t, true) +} + +func testLongOldForkedFastSyncedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3 @@ -1068,7 +1176,7 @@ func TestLongOldForkedFastSyncedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1079,6 +1187,13 @@ func TestLongOldForkedFastSyncedShallowSetHead(t *testing.T) { // underflown, everything needs to be deleted onwards to avoid creating a gap. The // side chain is nuked by the freezer. func TestLongOldForkedFastSyncedDeepSetHead(t *testing.T) { + testLongOldForkedFastSyncedDeepSetHead(t, false) +} +func TestLongOldForkedFastSyncedDeepSetHeadWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncedDeepSetHead(t, true) +} + +func testLongOldForkedFastSyncedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3 @@ -1114,7 +1229,7 @@ func TestLongOldForkedFastSyncedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1125,6 +1240,13 @@ func TestLongOldForkedFastSyncedDeepSetHead(t *testing.T) { // just pick up fast syncing from there. The side chain is completely nuked by the // freezer. func TestLongOldForkedFastSyncingShallowSetHead(t *testing.T) { + testLongOldForkedFastSyncingShallowSetHead(t, false) +} +func TestLongOldForkedFastSyncingShallowSetHeadWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncingShallowSetHead(t, true) +} + +func testLongOldForkedFastSyncingShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3 @@ -1161,7 +1283,7 @@ func TestLongOldForkedFastSyncingShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1172,6 +1294,13 @@ func TestLongOldForkedFastSyncingShallowSetHead(t *testing.T) { // just pick up fast syncing from there. The side chain is completely nuked by the // freezer. func TestLongOldForkedFastSyncingDeepSetHead(t *testing.T) { + testLongOldForkedFastSyncingDeepSetHead(t, false) +} +func TestLongOldForkedFastSyncingDeepSetHeadWithSnapshots(t *testing.T) { + testLongOldForkedFastSyncingDeepSetHead(t, true) +} + +func testLongOldForkedFastSyncingDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3 @@ -1207,7 +1336,7 @@ func TestLongOldForkedFastSyncingDeepSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1216,6 +1345,13 @@ func TestLongOldForkedFastSyncingDeepSetHead(t *testing.T) { // chain is above the committed block. In this case the freezer will delete the // sidechain since it's dangling, reverting to TestLongShallowSetHead. func TestLongNewerForkedShallowSetHead(t *testing.T) { + testLongNewerForkedShallowSetHead(t, false) +} +func TestLongNewerForkedShallowSetHeadWithSnapshots(t *testing.T) { + testLongNewerForkedShallowSetHead(t, true) +} + +func testLongNewerForkedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1252,7 +1388,7 @@ func TestLongNewerForkedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1261,6 +1397,13 @@ func TestLongNewerForkedShallowSetHead(t *testing.T) { // chain is above the committed block. In this case the freezer will delete the // sidechain since it's dangling, reverting to TestLongDeepSetHead. func TestLongNewerForkedDeepSetHead(t *testing.T) { + testLongNewerForkedDeepSetHead(t, false) +} +func TestLongNewerForkedDeepSetHeadWithSnapshots(t *testing.T) { + testLongNewerForkedDeepSetHead(t, true) +} + +func testLongNewerForkedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1296,7 +1439,7 @@ func TestLongNewerForkedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1305,6 +1448,13 @@ func TestLongNewerForkedDeepSetHead(t *testing.T) { // the side chain is above the committed block. In this case the freezer will delete // the sidechain since it's dangling, reverting to TestLongFastSyncedShallowSetHead. func TestLongNewerForkedFastSyncedShallowSetHead(t *testing.T) { + testLongNewerForkedFastSyncedShallowSetHead(t, false) +} +func TestLongNewerForkedFastSyncedShallowSetHeadWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncedShallowSetHead(t, true) +} + +func testLongNewerForkedFastSyncedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1341,7 +1491,7 @@ func TestLongNewerForkedFastSyncedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1350,6 +1500,13 @@ func TestLongNewerForkedFastSyncedShallowSetHead(t *testing.T) { // the side chain is above the committed block. In this case the freezer will delete // the sidechain since it's dangling, reverting to TestLongFastSyncedDeepSetHead. func TestLongNewerForkedFastSyncedDeepSetHead(t *testing.T) { + testLongNewerForkedFastSyncedDeepSetHead(t, false) +} +func TestLongNewerForkedFastSyncedDeepSetHeadWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncedDeepSetHead(t, true) +} + +func testLongNewerForkedFastSyncedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1385,7 +1542,7 @@ func TestLongNewerForkedFastSyncedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1394,6 +1551,13 @@ func TestLongNewerForkedFastSyncedDeepSetHead(t *testing.T) { // chain is above the committed block. In this case the freezer will delete the // sidechain since it's dangling, reverting to TestLongFastSyncinghallowSetHead. func TestLongNewerForkedFastSyncingShallowSetHead(t *testing.T) { + testLongNewerForkedFastSyncingShallowSetHead(t, false) +} +func TestLongNewerForkedFastSyncingShallowSetHeadWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncingShallowSetHead(t, true) +} + +func testLongNewerForkedFastSyncingShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1430,7 +1594,7 @@ func TestLongNewerForkedFastSyncingShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a shorter @@ -1439,6 +1603,13 @@ func TestLongNewerForkedFastSyncingShallowSetHead(t *testing.T) { // chain is above the committed block. In this case the freezer will delete the // sidechain since it's dangling, reverting to TestLongFastSyncingDeepSetHead. func TestLongNewerForkedFastSyncingDeepSetHead(t *testing.T) { + testLongNewerForkedFastSyncingDeepSetHead(t, false) +} +func TestLongNewerForkedFastSyncingDeepSetHeadWithSnapshots(t *testing.T) { + testLongNewerForkedFastSyncingDeepSetHead(t, true) +} + +func testLongNewerForkedFastSyncingDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12 @@ -1474,14 +1645,17 @@ func TestLongNewerForkedFastSyncingDeepSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a longer side // chain, where a recent block - newer than the ancient limit - was already committed // to disk and then sethead was called. In this case the freezer will delete the // sidechain since it's dangling, reverting to TestLongShallowSetHead. -func TestLongReorgedShallowSetHead(t *testing.T) { +func TestLongReorgedShallowSetHead(t *testing.T) { testLongReorgedShallowSetHead(t, false) } +func TestLongReorgedShallowSetHeadWithSnapshots(t *testing.T) { testLongReorgedShallowSetHead(t, true) } + +func testLongReorgedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1518,14 +1692,17 @@ func TestLongReorgedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a longer side // chain, where a recent block - older than the ancient limit - was already committed // to disk and then sethead was called. In this case the freezer will delete the // sidechain since it's dangling, reverting to TestLongDeepSetHead. -func TestLongReorgedDeepSetHead(t *testing.T) { +func TestLongReorgedDeepSetHead(t *testing.T) { testLongReorgedDeepSetHead(t, false) } +func TestLongReorgedDeepSetHeadWithSnapshots(t *testing.T) { testLongReorgedDeepSetHead(t, true) } + +func testLongReorgedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1561,7 +1738,7 @@ func TestLongReorgedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a longer @@ -1570,6 +1747,13 @@ func TestLongReorgedDeepSetHead(t *testing.T) { // freezer will delete the sidechain since it's dangling, reverting to // TestLongFastSyncedShallowSetHead. func TestLongReorgedFastSyncedShallowSetHead(t *testing.T) { + testLongReorgedFastSyncedShallowSetHead(t, false) +} +func TestLongReorgedFastSyncedShallowSetHeadWithSnapshots(t *testing.T) { + testLongReorgedFastSyncedShallowSetHead(t, true) +} + +func testLongReorgedFastSyncedShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1606,7 +1790,7 @@ func TestLongReorgedFastSyncedShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a longer @@ -1615,6 +1799,13 @@ func TestLongReorgedFastSyncedShallowSetHead(t *testing.T) { // freezer will delete the sidechain since it's dangling, reverting to // TestLongFastSyncedDeepSetHead. func TestLongReorgedFastSyncedDeepSetHead(t *testing.T) { + testLongReorgedFastSyncedDeepSetHead(t, false) +} +func TestLongReorgedFastSyncedDeepSetHeadWithSnapshots(t *testing.T) { + testLongReorgedFastSyncedDeepSetHead(t, true) +} + +func testLongReorgedFastSyncedDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1650,7 +1841,7 @@ func TestLongReorgedFastSyncedDeepSetHead(t *testing.T) { expHeadHeader: 4, expHeadFastBlock: 4, expHeadBlock: 4, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a longer @@ -1660,6 +1851,13 @@ func TestLongReorgedFastSyncedDeepSetHead(t *testing.T) { // head, since we can just pick up fast syncing from there. The side chain is // completely nuked by the freezer. func TestLongReorgedFastSyncingShallowSetHead(t *testing.T) { + testLongReorgedFastSyncingShallowSetHead(t, false) +} +func TestLongReorgedFastSyncingShallowSetHeadWithSnapshots(t *testing.T) { + testLongReorgedFastSyncingShallowSetHead(t, true) +} + +func testLongReorgedFastSyncingShallowSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1696,7 +1894,7 @@ func TestLongReorgedFastSyncingShallowSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } // Tests a sethead for a long canonical chain with frozen blocks and a longer @@ -1706,6 +1904,13 @@ func TestLongReorgedFastSyncingShallowSetHead(t *testing.T) { // head, since we can just pick up fast syncing from there. The side chain is // completely nuked by the freezer. func TestLongReorgedFastSyncingDeepSetHead(t *testing.T) { + testLongReorgedFastSyncingDeepSetHead(t, false) +} +func TestLongReorgedFastSyncingDeepSetHeadWithSnapshots(t *testing.T) { + testLongReorgedFastSyncingDeepSetHead(t, true) +} + +func testLongReorgedFastSyncingDeepSetHead(t *testing.T, snapshots bool) { // Chain: // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10->C11->C12->C13->C14->C15->C16->C17->C18->C19->C20->C21->C22->C23->C24 (HEAD) // └->S1->S2->S3->S4->S5->S6->S7->S8->S9->S10->S11->S12->S13->S14->S15->S16->S17->S18->S19->S20->S21->S22->S23->S24->S25->S26 @@ -1741,13 +1946,13 @@ func TestLongReorgedFastSyncingDeepSetHead(t *testing.T) { expHeadHeader: 6, expHeadFastBlock: 6, expHeadBlock: 0, - }) + }, snapshots) } -func testSetHead(t *testing.T, tt *rewindTest) { +func testSetHead(t *testing.T, tt *rewindTest, snapshots bool) { // It's hard to follow the test case, visualize the input - //log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) - //fmt.Println(tt.dump(false)) + // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) + // fmt.Println(tt.dump(false)) // Create a temporary persistent database datadir, err := ioutil.TempDir("", "") @@ -1766,8 +1971,18 @@ func testSetHead(t *testing.T, tt *rewindTest) { var ( genesis = new(Genesis).MustCommit(db) engine = ethash.NewFullFaker() + config = &CacheConfig{ + TrieCleanLimit: 256, + TrieDirtyLimit: 256, + TrieTimeLimit: 5 * time.Minute, + SnapshotLimit: 0, // Disable snapshot + } ) - chain, err := NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if snapshots { + config.SnapshotLimit = 256 + config.SnapshotWait = true + } + chain, err := NewBlockChain(db, config, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) if err != nil { t.Fatalf("Failed to create chain: %v", err) } @@ -1790,6 +2005,11 @@ func testSetHead(t *testing.T, tt *rewindTest) { } if tt.commitBlock > 0 { chain.stateCache.TrieDB().Commit(canonblocks[tt.commitBlock-1].Root(), true, nil) + if snapshots { + if err := chain.snaps.Cap(canonblocks[tt.commitBlock-1].Root(), 0); err != nil { + t.Fatalf("Failed to flatten snapshots: %v", err) + } + } } if _, err := chain.InsertChain(canonblocks[tt.commitBlock:]); err != nil { t.Fatalf("Failed to import canonical chain tail: %v", err) diff --git a/core/blockchain_snapshot_test.go b/core/blockchain_snapshot_test.go new file mode 100644 index 000000000000..f35dae16785b --- /dev/null +++ b/core/blockchain_snapshot_test.go @@ -0,0 +1,797 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Tests that abnormal program termination (i.e.crash) and restart can recovery +// the snapshot properly if the snapshot is enabled. + +package core + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/params" +) + +// snapshotTest is a test case for snapshot recovery. It can be used for +// simulating these scenarios: +// (i) Geth restarts normally with valid legacy snapshot +// (ii) Geth restarts normally with valid new-format snapshot +// (iii) Geth restarts after the crash, with broken legacy snapshot +// (iv) Geth restarts after the crash, with broken new-format snapshot +// (v) Geth restarts normally, but it's requested to be rewound to a lower point via SetHead +// (vi) Geth restarts normally with a stale snapshot +type snapshotTest struct { + legacy bool // Flag whether the loaded snapshot is in legacy format + crash bool // Flag whether the Geth restarts from the previous crash + restartCrash int // Number of blocks to insert after the normal stop, then the crash happens + gapped int // Number of blocks to insert without enabling snapshot + setHead uint64 // Block number to set head back to + + chainBlocks int // Number of blocks to generate for the canonical chain + snapshotBlock uint64 // Block number of the relevant snapshot disk layer + commitBlock uint64 // Block number for which to commit the state to disk + + expCanonicalBlocks int // Number of canonical blocks expected to remain in the database (excl. genesis) + expHeadHeader uint64 // Block number of the expected head header + expHeadFastBlock uint64 // Block number of the expected head fast sync block + expHeadBlock uint64 // Block number of the expected head full block + expSnapshotBottom uint64 // The block height corresponding to the snapshot disk layer +} + +func (tt *snapshotTest) dump() string { + buffer := new(strings.Builder) + + fmt.Fprint(buffer, "Chain:\n G") + for i := 0; i < tt.chainBlocks; i++ { + fmt.Fprintf(buffer, "->C%d", i+1) + } + fmt.Fprint(buffer, " (HEAD)\n\n") + + fmt.Fprintf(buffer, "Commit: G") + if tt.commitBlock > 0 { + fmt.Fprintf(buffer, ", C%d", tt.commitBlock) + } + fmt.Fprint(buffer, "\n") + + fmt.Fprintf(buffer, "Snapshot: G") + if tt.snapshotBlock > 0 { + fmt.Fprintf(buffer, ", C%d", tt.snapshotBlock) + } + fmt.Fprint(buffer, "\n") + + if tt.crash { + fmt.Fprintf(buffer, "\nCRASH\n\n") + } else { + fmt.Fprintf(buffer, "\nSetHead(%d)\n\n", tt.setHead) + } + fmt.Fprintf(buffer, "------------------------------\n\n") + + fmt.Fprint(buffer, "Expected in leveldb:\n G") + for i := 0; i < tt.expCanonicalBlocks; i++ { + fmt.Fprintf(buffer, "->C%d", i+1) + } + fmt.Fprintf(buffer, "\n\n") + fmt.Fprintf(buffer, "Expected head header : C%d\n", tt.expHeadHeader) + fmt.Fprintf(buffer, "Expected head fast block: C%d\n", tt.expHeadFastBlock) + if tt.expHeadBlock == 0 { + fmt.Fprintf(buffer, "Expected head block : G\n") + } else { + fmt.Fprintf(buffer, "Expected head block : C%d\n", tt.expHeadBlock) + } + if tt.expSnapshotBottom == 0 { + fmt.Fprintf(buffer, "Expected snapshot disk : G\n") + } else { + fmt.Fprintf(buffer, "Expected snapshot disk : C%d\n", tt.expSnapshotBottom) + } + return buffer.String() +} + +// Tests a Geth restart with valid snapshot. Before the shutdown, all snapshot +// journal will be persisted correctly. In this case no snapshot recovery is +// required. +func TestRestartWithNewSnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(0) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : C8 + // Expected snapshot disk : G + testSnapshot(t, &snapshotTest{ + legacy: false, + crash: false, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 8, + expSnapshotBottom: 0, // Initial disk layer built from genesis + }) +} + +// Tests a Geth restart with valid but "legacy" snapshot. Before the shutdown, +// all snapshot journal will be persisted correctly. In this case no snapshot +// recovery is required. +func TestRestartWithLegacySnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(0) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : C8 + // Expected snapshot disk : G + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: false, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 8, + expSnapshotBottom: 0, // Initial disk layer built from genesis + }) +} + +// Tests a Geth was crashed and restarts with a broken snapshot. In this case the +// chain head should be rewound to the point with available state. And also the +// new head should must be lower than disk layer. But there is no committed point +// so the chain should be rewound to genesis and the disk layer should be left +// for recovery. +func TestNoCommitCrashWithNewSnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G, C4 + // + // CRASH + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : G + // Expected snapshot disk : C4 + testSnapshot(t, &snapshotTest{ + legacy: false, + crash: true, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 4, + commitBlock: 0, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 0, + expSnapshotBottom: 4, // Last committed disk layer, wait recovery + }) +} + +// Tests a Geth was crashed and restarts with a broken snapshot. In this case the +// chain head should be rewound to the point with available state. And also the +// new head should must be lower than disk layer. But there is only a low committed +// point so the chain should be rewound to committed point and the disk layer +// should be left for recovery. +func TestLowCommitCrashWithNewSnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G, C2 + // Snapshot: G, C4 + // + // CRASH + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : C2 + // Expected snapshot disk : C4 + testSnapshot(t, &snapshotTest{ + legacy: false, + crash: true, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 4, + commitBlock: 2, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 2, + expSnapshotBottom: 4, // Last committed disk layer, wait recovery + }) +} + +// Tests a Geth was crashed and restarts with a broken snapshot. In this case +// the chain head should be rewound to the point with available state. And also +// the new head should must be lower than disk layer. But there is only a high +// committed point so the chain should be rewound to genesis and the disk layer +// should be left for recovery. +func TestHighCommitCrashWithNewSnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G, C6 + // Snapshot: G, C4 + // + // CRASH + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : G + // Expected snapshot disk : C4 + testSnapshot(t, &snapshotTest{ + legacy: false, + crash: true, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 4, + commitBlock: 6, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 0, + expSnapshotBottom: 4, // Last committed disk layer, wait recovery + }) +} + +// Tests a Geth was crashed and restarts with a broken and "legacy format" +// snapshot. In this case the entire legacy snapshot should be discared +// and rebuild from the new chain head. The new head here refers to the +// genesis because there is no committed point. +func TestNoCommitCrashWithLegacySnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G, C4 + // + // CRASH + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : G + // Expected snapshot disk : G + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: true, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 4, + commitBlock: 0, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 0, + expSnapshotBottom: 0, // Rebuilt snapshot from the latest HEAD(genesis) + }) +} + +// Tests a Geth was crashed and restarts with a broken and "legacy format" +// snapshot. In this case the entire legacy snapshot should be discared +// and rebuild from the new chain head. The new head here refers to the +// block-2 because it's committed into the disk. +func TestLowCommitCrashWithLegacySnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G, C2 + // Snapshot: G, C4 + // + // CRASH + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : C2 + // Expected snapshot disk : C2 + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: true, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 4, + commitBlock: 2, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 2, + expSnapshotBottom: 2, // Rebuilt snapshot from the latest HEAD + }) +} + +// Tests a Geth was crashed and restarts with a broken and "legacy format" +// snapshot. In this case the entire legacy snapshot should be discared +// and rebuild from the new chain head. +// +// The new head here refers to the the genesis, the reason is: +// - the state of block-6 is committed into the disk +// - the legacy disk layer of block-4 is committed into the disk +// - the head is rewound the genesis in order to find an available +// state lower than disk layer +func TestHighCommitCrashWithLegacySnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G, C6 + // Snapshot: G, C4 + // + // CRASH + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8 + // + // Expected head header : C8 + // Expected head fast block: C8 + // Expected head block : G + // Expected snapshot disk : G + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: true, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 4, + commitBlock: 6, + expCanonicalBlocks: 8, + expHeadHeader: 8, + expHeadFastBlock: 8, + expHeadBlock: 0, + expSnapshotBottom: 0, // Rebuilt snapshot from the latest HEAD(genesis) + }) +} + +// Tests a Geth was running with snapshot enabled. Then restarts without +// enabling snapshot and after that re-enable the snapshot again. In this +// case the snapshot should be rebuilt with latest chain head. +func TestGappedNewSnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(0) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10 + // + // Expected head header : C10 + // Expected head fast block: C10 + // Expected head block : C10 + // Expected snapshot disk : C10 + testSnapshot(t, &snapshotTest{ + legacy: false, + crash: false, + gapped: 2, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 10, + expHeadHeader: 10, + expHeadFastBlock: 10, + expHeadBlock: 10, + expSnapshotBottom: 10, // Rebuilt snapshot from the latest HEAD + }) +} + +// Tests a Geth was running with leagcy snapshot enabled. Then restarts +// without enabling snapshot and after that re-enable the snapshot again. +// In this case the snapshot should be rebuilt with latest chain head. +func TestGappedLegacySnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(0) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10 + // + // Expected head header : C10 + // Expected head fast block: C10 + // Expected head block : C10 + // Expected snapshot disk : C10 + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: false, + gapped: 2, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 10, + expHeadHeader: 10, + expHeadFastBlock: 10, + expHeadBlock: 10, + expSnapshotBottom: 10, // Rebuilt snapshot from the latest HEAD + }) +} + +// Tests the Geth was running with snapshot enabled and resetHead is applied. +// In this case the head is rewound to the target(with state available). After +// that the chain is restarted and the original disk layer is kept. +func TestSetHeadWithNewSnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(4) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4 + // + // Expected head header : C4 + // Expected head fast block: C4 + // Expected head block : C4 + // Expected snapshot disk : G + testSnapshot(t, &snapshotTest{ + legacy: false, + crash: false, + gapped: 0, + setHead: 4, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 4, + expHeadHeader: 4, + expHeadFastBlock: 4, + expHeadBlock: 4, + expSnapshotBottom: 0, // The initial disk layer is built from the genesis + }) +} + +// Tests the Geth was running with snapshot(legacy-format) enabled and resetHead +// is applied. In this case the head is rewound to the target(with state available). +// After that the chain is restarted and the original disk layer is kept. +func TestSetHeadWithLegacySnapshot(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(4) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4 + // + // Expected head header : C4 + // Expected head fast block: C4 + // Expected head block : C4 + // Expected snapshot disk : G + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: false, + gapped: 0, + setHead: 4, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 4, + expHeadHeader: 4, + expHeadFastBlock: 4, + expHeadBlock: 4, + expSnapshotBottom: 0, // The initial disk layer is built from the genesis + }) +} + +// Tests the Geth was running with snapshot(legacy-format) enabled and upgrades +// the disk layer journal(journal generator) to latest format. After that the Geth +// is restarted from a crash. In this case Geth will find the new-format disk layer +// journal but with legacy-format diff journal(the new-format is never committed), +// and the invalid diff journal is expected to be dropped. +func TestRecoverSnapshotFromCrashWithLegacyDiffJournal(t *testing.T) { + // Chain: + // G->C1->C2->C3->C4->C5->C6->C7->C8 (HEAD) + // + // Commit: G + // Snapshot: G + // + // SetHead(0) + // + // ------------------------------ + // + // Expected in leveldb: + // G->C1->C2->C3->C4->C5->C6->C7->C8->C9->C10 + // + // Expected head header : C10 + // Expected head fast block: C10 + // Expected head block : C8 + // Expected snapshot disk : C10 + testSnapshot(t, &snapshotTest{ + legacy: true, + crash: false, + restartCrash: 2, + gapped: 0, + setHead: 0, + chainBlocks: 8, + snapshotBlock: 0, + commitBlock: 0, + expCanonicalBlocks: 10, + expHeadHeader: 10, + expHeadFastBlock: 10, + expHeadBlock: 8, // The persisted state in the first running + expSnapshotBottom: 10, // The persisted disk layer in the second running + }) +} + +func testSnapshot(t *testing.T, tt *snapshotTest) { + // It's hard to follow the test case, visualize the input + // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(true)))) + // fmt.Println(tt.dump()) + + // Create a temporary persistent database + datadir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("Failed to create temporary datadir: %v", err) + } + os.RemoveAll(datadir) + + db, err := rawdb.NewLevelDBDatabaseWithFreezer(datadir, 0, 0, datadir, "") + if err != nil { + t.Fatalf("Failed to create persistent database: %v", err) + } + defer db.Close() // Might double close, should be fine + + // Initialize a fresh chain + var ( + genesis = new(Genesis).MustCommit(db) + engine = ethash.NewFullFaker() + gendb = rawdb.NewMemoryDatabase() + + // Snapshot is enabled, the first snapshot is created from the Genesis. + // The snapshot memory allowance is 256MB, it means no snapshot flush + // will happen during the block insertion. + cacheConfig = defaultCacheConfig + ) + chain, err := NewBlockChain(db, cacheConfig, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to create chain: %v", err) + } + blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, gendb, tt.chainBlocks, func(i int, b *BlockGen) {}) + + // Insert the blocks with configured settings. + var breakpoints []uint64 + if tt.commitBlock > tt.snapshotBlock { + breakpoints = append(breakpoints, tt.snapshotBlock, tt.commitBlock) + } else { + breakpoints = append(breakpoints, tt.commitBlock, tt.snapshotBlock) + } + var startPoint uint64 + for _, point := range breakpoints { + if _, err := chain.InsertChain(blocks[startPoint:point]); err != nil { + t.Fatalf("Failed to import canonical chain start: %v", err) + } + startPoint = point + + if tt.commitBlock > 0 && tt.commitBlock == point { + chain.stateCache.TrieDB().Commit(blocks[point-1].Root(), true, nil) + } + if tt.snapshotBlock > 0 && tt.snapshotBlock == point { + if tt.legacy { + // Here we commit the snapshot disk root to simulate + // committing the legacy snapshot. + rawdb.WriteSnapshotRoot(db, blocks[point-1].Root()) + } else { + chain.snaps.Cap(blocks[point-1].Root(), 0) + diskRoot, blockRoot := chain.snaps.DiskRoot(), blocks[point-1].Root() + if !bytes.Equal(diskRoot.Bytes(), blockRoot.Bytes()) { + t.Fatalf("Failed to flush disk layer change, want %x, got %x", blockRoot, diskRoot) + } + } + } + } + if _, err := chain.InsertChain(blocks[startPoint:]); err != nil { + t.Fatalf("Failed to import canonical chain tail: %v", err) + } + // Set the flag for writing legacy journal if necessary + if tt.legacy { + chain.writeLegacyJournal = true + } + // Pull the plug on the database, simulating a hard crash + if tt.crash { + db.Close() + + // Start a new blockchain back up and see where the repair leads us + db, err = rawdb.NewLevelDBDatabaseWithFreezer(datadir, 0, 0, datadir, "") + if err != nil { + t.Fatalf("Failed to reopen persistent database: %v", err) + } + defer db.Close() + + // The interesting thing is: instead of start the blockchain after + // the crash, we do restart twice here: one after the crash and one + // after the normal stop. It's used to ensure the broken snapshot + // can be detected all the time. + chain, err = NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + chain.Stop() + + chain, err = NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + defer chain.Stop() + } else if tt.gapped > 0 { + // Insert blocks without enabling snapshot if gapping is required. + chain.Stop() + gappedBlocks, _ := GenerateChain(params.TestChainConfig, blocks[len(blocks)-1], engine, gendb, tt.gapped, func(i int, b *BlockGen) {}) + + // Insert a few more blocks without enabling snapshot + var cacheConfig = &CacheConfig{ + TrieCleanLimit: 256, + TrieDirtyLimit: 256, + TrieTimeLimit: 5 * time.Minute, + SnapshotLimit: 0, + } + chain, err = NewBlockChain(db, cacheConfig, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + chain.InsertChain(gappedBlocks) + chain.Stop() + + chain, err = NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + defer chain.Stop() + } else if tt.setHead != 0 { + // Rewind the chain if setHead operation is required. + chain.SetHead(tt.setHead) + chain.Stop() + + chain, err = NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + defer chain.Stop() + } else if tt.restartCrash != 0 { + // Firstly, stop the chain properly, with all snapshot journal + // and state committed. + chain.Stop() + + // Restart chain, forcibly flush the disk layer journal with new format + newBlocks, _ := GenerateChain(params.TestChainConfig, blocks[len(blocks)-1], engine, gendb, tt.restartCrash, func(i int, b *BlockGen) {}) + chain, err = NewBlockChain(db, cacheConfig, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + chain.InsertChain(newBlocks) + chain.Snapshots().Cap(newBlocks[len(newBlocks)-1].Root(), 0) + + // Simulate the blockchain crash + // Don't call chain.Stop here, so that no snapshot + // journal and latest state will be committed + + // Restart the chain after the crash + chain, err = NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + defer chain.Stop() + } else { + chain.Stop() + + // Restart the chain normally + chain, err = NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{}, nil, nil) + if err != nil { + t.Fatalf("Failed to recreate chain: %v", err) + } + defer chain.Stop() + } + + // Iterate over all the remaining blocks and ensure there are no gaps + verifyNoGaps(t, chain, true, blocks) + verifyCutoff(t, chain, true, blocks, tt.expCanonicalBlocks) + + if head := chain.CurrentHeader(); head.Number.Uint64() != tt.expHeadHeader { + t.Errorf("Head header mismatch: have %d, want %d", head.Number, tt.expHeadHeader) + } + if head := chain.CurrentFastBlock(); head.NumberU64() != tt.expHeadFastBlock { + t.Errorf("Head fast block mismatch: have %d, want %d", head.NumberU64(), tt.expHeadFastBlock) + } + if head := chain.CurrentBlock(); head.NumberU64() != tt.expHeadBlock { + t.Errorf("Head block mismatch: have %d, want %d", head.NumberU64(), tt.expHeadBlock) + } + // Check the disk layer, ensure they are matched + block := chain.GetBlockByNumber(tt.expSnapshotBottom) + if block == nil { + t.Errorf("The correspnding block[%d] of snapshot disk layer is missing", tt.expSnapshotBottom) + } else if !bytes.Equal(chain.snaps.DiskRoot().Bytes(), block.Root().Bytes()) { + t.Errorf("The snapshot disk layer root is incorrect, want %x, get %x", block.Root(), chain.snaps.DiskRoot()) + } +} diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 7ec62b11ddc3..d60a2359813d 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -17,6 +17,7 @@ package core import ( + "errors" "fmt" "io/ioutil" "math/big" @@ -468,7 +469,7 @@ func testBadHashes(t *testing.T, full bool) { _, err = blockchain.InsertHeaderChain(headers, 1) } - if err != ErrBlacklistedHash { + if !errors.Is(err, ErrBlacklistedHash) { t.Errorf("error mismatch: have: %v, want: %v", err, ErrBlacklistedHash) } } diff --git a/core/bloombits/generator.go b/core/bloombits/generator.go index ae07481ada50..646151db0bfd 100644 --- a/core/bloombits/generator.go +++ b/core/bloombits/generator.go @@ -65,18 +65,23 @@ func (b *Generator) AddBloom(index uint, bloom types.Bloom) error { } // Rotate the bloom and insert into our collection byteIndex := b.nextSec / 8 - bitMask := byte(1) << byte(7-b.nextSec%8) - - for i := 0; i < types.BloomBitLength; i++ { - bloomByteIndex := types.BloomByteLength - 1 - i/8 - bloomBitMask := byte(1) << byte(i%8) - - if (bloom[bloomByteIndex] & bloomBitMask) != 0 { - b.blooms[i][byteIndex] |= bitMask + bitIndex := byte(7 - b.nextSec%8) + for byt := 0; byt < types.BloomByteLength; byt++ { + bloomByte := bloom[types.BloomByteLength-1-byt] + if bloomByte == 0 { + continue } + base := 8 * byt + b.blooms[base+7][byteIndex] |= ((bloomByte >> 7) & 1) << bitIndex + b.blooms[base+6][byteIndex] |= ((bloomByte >> 6) & 1) << bitIndex + b.blooms[base+5][byteIndex] |= ((bloomByte >> 5) & 1) << bitIndex + b.blooms[base+4][byteIndex] |= ((bloomByte >> 4) & 1) << bitIndex + b.blooms[base+3][byteIndex] |= ((bloomByte >> 3) & 1) << bitIndex + b.blooms[base+2][byteIndex] |= ((bloomByte >> 2) & 1) << bitIndex + b.blooms[base+1][byteIndex] |= ((bloomByte >> 1) & 1) << bitIndex + b.blooms[base][byteIndex] |= (bloomByte & 1) << bitIndex } b.nextSec++ - return nil } diff --git a/core/bloombits/generator_test.go b/core/bloombits/generator_test.go index f9bcef96e0b5..88e3ed6b06c0 100644 --- a/core/bloombits/generator_test.go +++ b/core/bloombits/generator_test.go @@ -58,3 +58,42 @@ func TestGenerator(t *testing.T) { } } } + +func BenchmarkGenerator(b *testing.B) { + var input [types.BloomBitLength][types.BloomByteLength]byte + b.Run("empty", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Crunch the input through the generator and verify the result + gen, err := NewGenerator(types.BloomBitLength) + if err != nil { + b.Fatalf("failed to create bloombit generator: %v", err) + } + for j, bloom := range input { + if err := gen.AddBloom(uint(j), bloom); err != nil { + b.Fatalf("bloom %d: failed to add: %v", i, err) + } + } + } + }) + for i := 0; i < types.BloomBitLength; i++ { + rand.Read(input[i][:]) + } + b.Run("random", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Crunch the input through the generator and verify the result + gen, err := NewGenerator(types.BloomBitLength) + if err != nil { + b.Fatalf("failed to create bloombit generator: %v", err) + } + for j, bloom := range input { + if err := gen.AddBloom(uint(j), bloom); err != nil { + b.Fatalf("bloom %d: failed to add: %v", i, err) + } + } + } + }) +} diff --git a/core/bloombits/matcher_test.go b/core/bloombits/matcher_test.go index 91143e525e7f..923579221f51 100644 --- a/core/bloombits/matcher_test.go +++ b/core/bloombits/matcher_test.go @@ -30,6 +30,7 @@ const testSectionSize = 4096 // Tests that wildcard filter rules (nil) can be specified and are handled well. func TestMatcherWildcards(t *testing.T) { + t.Parallel() matcher := NewMatcher(testSectionSize, [][][]byte{ {common.Address{}.Bytes(), common.Address{0x01}.Bytes()}, // Default address is not a wildcard {common.Hash{}.Bytes(), common.Hash{0x01}.Bytes()}, // Default hash is not a wildcard @@ -56,6 +57,7 @@ func TestMatcherWildcards(t *testing.T) { // Tests the matcher pipeline on a single continuous workflow without interrupts. func TestMatcherContinuous(t *testing.T) { + t.Parallel() testMatcherDiffBatches(t, [][]bloomIndexes{{{10, 20, 30}}}, 0, 100000, false, 75) testMatcherDiffBatches(t, [][]bloomIndexes{{{32, 3125, 100}}, {{40, 50, 10}}}, 0, 100000, false, 81) testMatcherDiffBatches(t, [][]bloomIndexes{{{4, 8, 11}, {7, 8, 17}}, {{9, 9, 12}, {15, 20, 13}}, {{18, 15, 15}, {12, 10, 4}}}, 0, 10000, false, 36) @@ -64,6 +66,7 @@ func TestMatcherContinuous(t *testing.T) { // Tests the matcher pipeline on a constantly interrupted and resumed work pattern // with the aim of ensuring data items are requested only once. func TestMatcherIntermittent(t *testing.T) { + t.Parallel() testMatcherDiffBatches(t, [][]bloomIndexes{{{10, 20, 30}}}, 0, 100000, true, 75) testMatcherDiffBatches(t, [][]bloomIndexes{{{32, 3125, 100}}, {{40, 50, 10}}}, 0, 100000, true, 81) testMatcherDiffBatches(t, [][]bloomIndexes{{{4, 8, 11}, {7, 8, 17}}, {{9, 9, 12}, {15, 20, 13}}, {{18, 15, 15}, {12, 10, 4}}}, 0, 10000, true, 36) @@ -71,6 +74,7 @@ func TestMatcherIntermittent(t *testing.T) { // Tests the matcher pipeline on random input to hopefully catch anomalies. func TestMatcherRandom(t *testing.T) { + t.Parallel() for i := 0; i < 10; i++ { testMatcherBothModes(t, makeRandomIndexes([]int{1}, 50), 0, 10000, 0) testMatcherBothModes(t, makeRandomIndexes([]int{3}, 50), 0, 10000, 0) @@ -84,6 +88,7 @@ func TestMatcherRandom(t *testing.T) { // shifter from a multiple of 8. This is needed to cover an optimisation with // bitset matching https://github.com/ethereum/go-ethereum/issues/15309. func TestMatcherShifted(t *testing.T) { + t.Parallel() // Block 0 always matches in the tests, skip ahead of first 8 blocks with the // start to get a potential zero byte in the matcher bitset. @@ -97,6 +102,7 @@ func TestMatcherShifted(t *testing.T) { // Tests that matching on everything doesn't crash (special case internally). func TestWildcardMatcher(t *testing.T) { + t.Parallel() testMatcherBothModes(t, nil, 0, 10000, 0) } diff --git a/core/bloombits/scheduler_test.go b/core/bloombits/scheduler_test.go index 70772e4ab939..707e8ea11d04 100644 --- a/core/bloombits/scheduler_test.go +++ b/core/bloombits/scheduler_test.go @@ -35,6 +35,7 @@ func TestSchedulerMultiClientSingleFetcher(t *testing.T) { testScheduler(t, 10, func TestSchedulerMultiClientMultiFetcher(t *testing.T) { testScheduler(t, 10, 10, 5000) } func testScheduler(t *testing.T, clients int, fetchers int, requests int) { + t.Parallel() f := newScheduler(0) // Create a batch of handler goroutines that respond to bloom bit requests and @@ -88,10 +89,10 @@ func testScheduler(t *testing.T, clients int, fetchers int, requests int) { } close(in) }() - + b := new(big.Int) for j := 0; j < requests; j++ { bits := <-out - if want := new(big.Int).SetUint64(uint64(j)).Bytes(); !bytes.Equal(bits, want) { + if want := b.SetUint64(uint64(j)).Bytes(); !bytes.Equal(bits, want) { t.Errorf("vector %d: delivered content mismatch: have %x, want %x", j, bits, want) } } diff --git a/core/chain_indexer.go b/core/chain_indexer.go index 066bca10009b..4b326c970b50 100644 --- a/core/chain_indexer.go +++ b/core/chain_indexer.go @@ -94,7 +94,7 @@ type ChainIndexer struct { throttling time.Duration // Disk throttling to prevent a heavy upgrade from hogging resources log log.Logger - lock sync.RWMutex + lock sync.Mutex } // NewChainIndexer creates a new chain indexer to do background processing on diff --git a/core/evm.go b/core/evm.go index 8abe5a04771b..8f69d514995e 100644 --- a/core/evm.go +++ b/core/evm.go @@ -35,8 +35,8 @@ type ChainContext interface { GetHeader(common.Hash, uint64) *types.Header } -// NewEVMContext creates a new context for use in the EVM. -func NewEVMContext(msg Message, header *types.Header, chain ChainContext, author *common.Address) vm.Context { +// NewEVMBlockContext creates a new context for use in the EVM. +func NewEVMBlockContext(header *types.Header, chain ChainContext, author *common.Address) vm.BlockContext { // If we don't have an explicit author (i.e. not mining), extract from the header var beneficiary common.Address if author == nil { @@ -44,17 +44,23 @@ func NewEVMContext(msg Message, header *types.Header, chain ChainContext, author } else { beneficiary = *author } - return vm.Context{ + return vm.BlockContext{ CanTransfer: CanTransfer, Transfer: Transfer, GetHash: GetHashFn(header, chain), - Origin: msg.From(), Coinbase: beneficiary, BlockNumber: new(big.Int).Set(header.Number), Time: new(big.Int).SetUint64(header.Time), Difficulty: new(big.Int).Set(header.Difficulty), GasLimit: header.GasLimit, - GasPrice: new(big.Int).Set(msg.GasPrice()), + } +} + +// NewEVMTxContext creates a new transaction context for a single transaction. +func NewEVMTxContext(msg Message) vm.TxContext { + return vm.TxContext{ + Origin: msg.From(), + GasPrice: new(big.Int).Set(msg.GasPrice()), } } diff --git a/core/forkid/forkid.go b/core/forkid/forkid.go index b8d670f3997b..1bf3406828f0 100644 --- a/core/forkid/forkid.go +++ b/core/forkid/forkid.go @@ -65,19 +65,8 @@ type ID struct { // Filter is a fork id filter to validate a remotely advertised ID. type Filter func(id ID) error -// NewID calculates the Ethereum fork ID from the chain config and head. -func NewID(chain Blockchain) ID { - return newID( - chain.Config(), - chain.Genesis().Hash(), - chain.CurrentHeader().Number.Uint64(), - ) -} - -// newID is the internal version of NewID, which takes extracted values as its -// arguments instead of a chain. The reason is to allow testing the IDs without -// having to simulate an entire blockchain. -func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { +// NewID calculates the Ethereum fork ID from the chain config, genesis hash, and head. +func NewID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { // Calculate the starting checksum from the genesis hash hash := crc32.ChecksumIEEE(genesis[:]) @@ -95,6 +84,15 @@ func newID(config *params.ChainConfig, genesis common.Hash, head uint64) ID { return ID{Hash: checksumToBytes(hash), Next: next} } +// NewIDWithChain calculates the Ethereum fork ID from an existing chain instance. +func NewIDWithChain(chain Blockchain) ID { + return NewID( + chain.Config(), + chain.Genesis().Hash(), + chain.CurrentHeader().Number.Uint64(), + ) +} + // NewFilter creates a filter that returns if a fork ID should be rejected or not // based on the local chain's status. func NewFilter(chain Blockchain) Filter { diff --git a/core/forkid/forkid_test.go b/core/forkid/forkid_test.go index da5653021429..888b55347524 100644 --- a/core/forkid/forkid_test.go +++ b/core/forkid/forkid_test.go @@ -118,7 +118,7 @@ func TestCreation(t *testing.T) { } for i, tt := range tests { for j, ttt := range tt.cases { - if have := newID(tt.config, tt.genesis, ttt.head); have != ttt.want { + if have := NewID(tt.config, tt.genesis, ttt.head); have != ttt.want { t.Errorf("test %d, case %d: fork ID mismatch: have %x, want %x", i, j, have, ttt.want) } } diff --git a/core/genesis.go b/core/genesis.go index 4525b9c17440..908a969afd12 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -175,7 +175,7 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig // We have the genesis block in database(perhaps in ancient database) // but the corresponding state is missing. header := rawdb.ReadHeader(db, stored, 0) - if _, err := state.New(header.Root, state.NewDatabaseWithCache(db, 0, ""), nil); err != nil { + if _, err := state.New(header.Root, state.NewDatabaseWithConfig(db, nil), nil); err != nil { if genesis == nil { genesis = DefaultGenesisBlock() } @@ -243,8 +243,8 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig { return params.RinkebyChainConfig case ghash == params.GoerliGenesisHash: return params.GoerliChainConfig - case ghash == params.YoloV1GenesisHash: - return params.YoloV1ChainConfig + case ghash == params.YoloV2GenesisHash: + return params.YoloV2ChainConfig default: return params.AllEthashProtocolChanges } @@ -380,10 +380,11 @@ func DefaultGoerliGenesisBlock() *Genesis { } } -func DefaultYoloV1GenesisBlock() *Genesis { +func DefaultYoloV2GenesisBlock() *Genesis { + // TODO: Update with yolov2 values + regenerate alloc data return &Genesis{ - Config: params.YoloV1ChainConfig, - Timestamp: 0x5ed754f1, + Config: params.YoloV2ChainConfig, + Timestamp: 0x5f91b932, ExtraData: hexutil.MustDecode("0x00000000000000000000000000000000000000000000000000000000000000008a37866fd3627c9205a37c8685666f32ec07bb1b0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), GasLimit: 0x47b760, Difficulty: big.NewInt(1), diff --git a/core/headerchain.go b/core/headerchain.go index f5a8e21cfc6c..dc354549c083 100644 --- a/core/headerchain.go +++ b/core/headerchain.go @@ -129,118 +129,192 @@ func (hc *HeaderChain) GetBlockNumber(hash common.Hash) *uint64 { return number } -// WriteHeader writes a header into the local chain, given that its parent is -// already known. If the total difficulty of the newly inserted header becomes -// greater than the current known TD, the canonical chain is re-routed. +type headerWriteResult struct { + status WriteStatus + ignored int + imported int + lastHash common.Hash + lastHeader *types.Header +} + +// WriteHeaders writes a chain of headers into the local chain, given that the parents +// are already known. If the total difficulty of the newly inserted chain becomes +// greater than the current known TD, the canonical chain is reorged. // // Note: This method is not concurrent-safe with inserting blocks simultaneously // into the chain, as side effects caused by reorganisations cannot be emulated // without the real blocks. Hence, writing headers directly should only be done // in two scenarios: pure-header mode of operation (light clients), or properly // separated header/block phases (non-archive clients). -func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, err error) { - // Cache some values to prevent constant recalculation +func (hc *HeaderChain) writeHeaders(headers []*types.Header) (result *headerWriteResult, err error) { + if len(headers) == 0 { + return &headerWriteResult{}, nil + } + ptd := hc.GetTd(headers[0].ParentHash, headers[0].Number.Uint64()-1) + if ptd == nil { + return &headerWriteResult{}, consensus.ErrUnknownAncestor + } var ( - hash = header.Hash() - number = header.Number.Uint64() + lastNumber = headers[0].Number.Uint64() - 1 // Last successfully imported number + lastHash = headers[0].ParentHash // Last imported header hash + newTD = new(big.Int).Set(ptd) // Total difficulty of inserted chain + + lastHeader *types.Header + inserted []numberHash // Ephemeral lookup of number/hash for the chain + firstInserted = -1 // Index of the first non-ignored header ) - // Calculate the total difficulty of the header - ptd := hc.GetTd(header.ParentHash, number-1) - if ptd == nil { - return NonStatTy, consensus.ErrUnknownAncestor - } - head := hc.CurrentHeader().Number.Uint64() - localTd := hc.GetTd(hc.currentHeaderHash, head) - externTd := new(big.Int).Add(header.Difficulty, ptd) - - // Irrelevant of the canonical status, write the td and header to the database - // - // Note all the components of header(td, hash->number index and header) should - // be written atomically. - headerBatch := hc.chainDb.NewBatch() - rawdb.WriteTd(headerBatch, hash, number, externTd) - rawdb.WriteHeader(headerBatch, header) - if err := headerBatch.Write(); err != nil { - log.Crit("Failed to write header into disk", "err", err) + + batch := hc.chainDb.NewBatch() + for i, header := range headers { + var hash common.Hash + // The headers have already been validated at this point, so we already + // know that it's a contiguous chain, where + // headers[i].Hash() == headers[i+1].ParentHash + if i < len(headers)-1 { + hash = headers[i+1].ParentHash + } else { + hash = header.Hash() + } + number := header.Number.Uint64() + newTD.Add(newTD, header.Difficulty) + + // If the header is already known, skip it, otherwise store + if !hc.HasHeader(hash, number) { + // Irrelevant of the canonical status, write the TD and header to the database. + rawdb.WriteTd(batch, hash, number, newTD) + hc.tdCache.Add(hash, new(big.Int).Set(newTD)) + + rawdb.WriteHeader(batch, header) + inserted = append(inserted, numberHash{number, hash}) + hc.headerCache.Add(hash, header) + hc.numberCache.Add(hash, number) + if firstInserted < 0 { + firstInserted = i + } + } + lastHeader, lastHash, lastNumber = header, hash, number + } + + // Skip the slow disk write of all headers if interrupted. + if hc.procInterrupt() { + log.Debug("Premature abort during headers import") + return &headerWriteResult{}, errors.New("aborted") } + // Commit to disk! + if err := batch.Write(); err != nil { + log.Crit("Failed to write headers", "error", err) + } + batch.Reset() + + var ( + head = hc.CurrentHeader().Number.Uint64() + localTD = hc.GetTd(hc.currentHeaderHash, head) + status = SideStatTy + ) // If the total difficulty is higher than our known, add it to the canonical chain // Second clause in the if statement reduces the vulnerability to selfish mining. // Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf - reorg := externTd.Cmp(localTd) > 0 - if !reorg && externTd.Cmp(localTd) == 0 { - if header.Number.Uint64() < head { + reorg := newTD.Cmp(localTD) > 0 + if !reorg && newTD.Cmp(localTD) == 0 { + if lastNumber < head { reorg = true - } else if header.Number.Uint64() == head { + } else if lastNumber == head { reorg = mrand.Float64() < 0.5 } } + // If the parent of the (first) block is already the canon header, + // we don't have to go backwards to delete canon blocks, but + // simply pile them onto the existing chain + chainAlreadyCanon := headers[0].ParentHash == hc.currentHeaderHash if reorg { // If the header can be added into canonical chain, adjust the // header chain markers(canonical indexes and head header flag). // // Note all markers should be written atomically. - - // Delete any canonical number assignments above the new head - markerBatch := hc.chainDb.NewBatch() - for i := number + 1; ; i++ { - hash := rawdb.ReadCanonicalHash(hc.chainDb, i) - if hash == (common.Hash{}) { - break + markerBatch := batch // we can reuse the batch to keep allocs down + if !chainAlreadyCanon { + // Delete any canonical number assignments above the new head + for i := lastNumber + 1; ; i++ { + hash := rawdb.ReadCanonicalHash(hc.chainDb, i) + if hash == (common.Hash{}) { + break + } + rawdb.DeleteCanonicalHash(markerBatch, i) + } + // Overwrite any stale canonical number assignments, going + // backwards from the first header in this import + var ( + headHash = headers[0].ParentHash // inserted[0].parent? + headNumber = headers[0].Number.Uint64() - 1 // inserted[0].num-1 ? + headHeader = hc.GetHeader(headHash, headNumber) + ) + for rawdb.ReadCanonicalHash(hc.chainDb, headNumber) != headHash { + rawdb.WriteCanonicalHash(markerBatch, headHash, headNumber) + headHash = headHeader.ParentHash + headNumber = headHeader.Number.Uint64() - 1 + headHeader = hc.GetHeader(headHash, headNumber) + } + // If some of the older headers were already known, but obtained canon-status + // during this import batch, then we need to write that now + // Further down, we continue writing the staus for the ones that + // were not already known + for i := 0; i < firstInserted; i++ { + hash := headers[i].Hash() + num := headers[i].Number.Uint64() + rawdb.WriteCanonicalHash(markerBatch, hash, num) + rawdb.WriteHeadHeaderHash(markerBatch, hash) } - rawdb.DeleteCanonicalHash(markerBatch, i) } - - // Overwrite any stale canonical number assignments - var ( - headHash = header.ParentHash - headNumber = header.Number.Uint64() - 1 - headHeader = hc.GetHeader(headHash, headNumber) - ) - for rawdb.ReadCanonicalHash(hc.chainDb, headNumber) != headHash { - rawdb.WriteCanonicalHash(markerBatch, headHash, headNumber) - - headHash = headHeader.ParentHash - headNumber = headHeader.Number.Uint64() - 1 - headHeader = hc.GetHeader(headHash, headNumber) + // Extend the canonical chain with the new headers + for _, hn := range inserted { + rawdb.WriteCanonicalHash(markerBatch, hn.hash, hn.number) + rawdb.WriteHeadHeaderHash(markerBatch, hn.hash) } - // Extend the canonical chain with the new header - rawdb.WriteCanonicalHash(markerBatch, hash, number) - rawdb.WriteHeadHeaderHash(markerBatch, hash) if err := markerBatch.Write(); err != nil { log.Crit("Failed to write header markers into disk", "err", err) } + markerBatch.Reset() // Last step update all in-memory head header markers - hc.currentHeaderHash = hash - hc.currentHeader.Store(types.CopyHeader(header)) - headHeaderGauge.Update(header.Number.Int64()) + hc.currentHeaderHash = lastHash + hc.currentHeader.Store(types.CopyHeader(lastHeader)) + headHeaderGauge.Update(lastHeader.Number.Int64()) + // Chain status is canonical since this insert was a reorg. + // Note that all inserts which have higher TD than existing are 'reorg'. status = CanonStatTy - } else { - status = SideStatTy } - hc.tdCache.Add(hash, externTd) - hc.headerCache.Add(hash, header) - hc.numberCache.Add(hash, number) - return -} -// WhCallback is a callback function for inserting individual headers. -// A callback is used for two reasons: first, in a LightChain, status should be -// processed and light chain events sent, while in a BlockChain this is not -// necessary since chain events are sent after inserting blocks. Second, the -// header writes should be protected by the parent chain mutex individually. -type WhCallback func(*types.Header) error + if len(inserted) == 0 { + status = NonStatTy + } + return &headerWriteResult{ + status: status, + ignored: len(headers) - len(inserted), + imported: len(inserted), + lastHash: lastHash, + lastHeader: lastHeader, + }, nil +} func (hc *HeaderChain) ValidateHeaderChain(chain []*types.Header, checkFreq int) (int, error) { // Do a sanity check that the provided chain is actually ordered and linked for i := 1; i < len(chain); i++ { - if chain[i].Number.Uint64() != chain[i-1].Number.Uint64()+1 || chain[i].ParentHash != chain[i-1].Hash() { + parentHash := chain[i-1].Hash() + if chain[i].Number.Uint64() != chain[i-1].Number.Uint64()+1 || chain[i].ParentHash != parentHash { // Chain broke ancestry, log a message (programming error) and skip insertion log.Error("Non contiguous header insert", "number", chain[i].Number, "hash", chain[i].Hash(), - "parent", chain[i].ParentHash, "prevnumber", chain[i-1].Number, "prevhash", chain[i-1].Hash()) + "parent", chain[i].ParentHash, "prevnumber", chain[i-1].Number, "prevhash", parentHash) return 0, fmt.Errorf("non contiguous insert: item %d is #%d [%x…], item %d is #%d [%x…] (parent [%x…])", i-1, chain[i-1].Number, - chain[i-1].Hash().Bytes()[:4], i, chain[i].Number, chain[i].Hash().Bytes()[:4], chain[i].ParentHash[:4]) + parentHash.Bytes()[:4], i, chain[i].Number, chain[i].Hash().Bytes()[:4], chain[i].ParentHash[:4]) + } + // If the header is a banned one, straight out abort + if BadHashes[parentHash] { + return i - 1, ErrBlacklistedHash + } + // If it's the last header in the cunk, we need to check it too + if i == len(chain)-1 && BadHashes[chain[i].Hash()] { + return i, ErrBlacklistedHash } } @@ -263,16 +337,12 @@ func (hc *HeaderChain) ValidateHeaderChain(chain []*types.Header, checkFreq int) defer close(abort) // Iterate over the headers and ensure they all check out - for i, header := range chain { + for i := range chain { // If the chain is terminating, stop processing blocks if hc.procInterrupt() { log.Debug("Premature abort during headers verification") return 0, errors.New("aborted") } - // If the header is a banned one, straight out abort - if BadHashes[header.Hash()] { - return i, ErrBlacklistedHash - } // Otherwise wait for headers checks and ensure they pass if err := <-results; err != nil { return i, err @@ -282,55 +352,41 @@ func (hc *HeaderChain) ValidateHeaderChain(chain []*types.Header, checkFreq int) return 0, nil } -// InsertHeaderChain attempts to insert the given header chain in to the local -// chain, possibly creating a reorg. If an error is returned, it will return the -// index number of the failing header as well an error describing what went wrong. +// InsertHeaderChain inserts the given headers. // -// The verify parameter can be used to fine tune whether nonce verification -// should be done or not. The reason behind the optional check is because some -// of the header retrieval mechanisms already need to verfy nonces, as well as -// because nonces can be verified sparsely, not needing to check each. -func (hc *HeaderChain) InsertHeaderChain(chain []*types.Header, writeHeader WhCallback, start time.Time) (int, error) { - // Collect some import statistics to report on - stats := struct{ processed, ignored int }{} - // All headers passed verification, import them into the database - for i, header := range chain { - // Short circuit insertion if shutting down - if hc.procInterrupt() { - log.Debug("Premature abort during headers import") - return i, errors.New("aborted") - } - // If the header's already known, skip it, otherwise store - hash := header.Hash() - if hc.HasHeader(hash, header.Number.Uint64()) { - externTd := hc.GetTd(hash, header.Number.Uint64()) - localTd := hc.GetTd(hc.currentHeaderHash, hc.CurrentHeader().Number.Uint64()) - if externTd == nil || externTd.Cmp(localTd) <= 0 { - stats.ignored++ - continue - } - } - if err := writeHeader(header); err != nil { - return i, err - } - stats.processed++ +// The validity of the headers is NOT CHECKED by this method, i.e. they need to be +// validated by ValidateHeaderChain before calling InsertHeaderChain. +// +// This insert is all-or-nothing. If this returns an error, no headers were written, +// otherwise they were all processed successfully. +// +// The returned 'write status' says if the inserted headers are part of the canonical chain +// or a side chain. +func (hc *HeaderChain) InsertHeaderChain(chain []*types.Header, start time.Time) (WriteStatus, error) { + if hc.procInterrupt() { + return 0, errors.New("aborted") } - // Report some public statistics so the user has a clue what's going on - last := chain[len(chain)-1] + res, err := hc.writeHeaders(chain) + // Report some public statistics so the user has a clue what's going on context := []interface{}{ - "count", stats.processed, "elapsed", common.PrettyDuration(time.Since(start)), - "number", last.Number, "hash", last.Hash(), + "count", res.imported, + "elapsed", common.PrettyDuration(time.Since(start)), } - if timestamp := time.Unix(int64(last.Time), 0); time.Since(timestamp) > time.Minute { - context = append(context, []interface{}{"age", common.PrettyAge(timestamp)}...) + if err != nil { + context = append(context, "err", err) } - if stats.ignored > 0 { - context = append(context, []interface{}{"ignored", stats.ignored}...) + if last := res.lastHeader; last != nil { + context = append(context, "number", last.Number, "hash", res.lastHash) + if timestamp := time.Unix(int64(last.Time), 0); time.Since(timestamp) > time.Minute { + context = append(context, []interface{}{"age", common.PrettyAge(timestamp)}...) + } + } + if res.ignored > 0 { + context = append(context, []interface{}{"ignored", res.ignored}...) } log.Info("Imported new block headers", context...) - - return 0, nil + return res.status, err } // GetBlockHashesFromHash retrieves a number of block hashes starting at a given @@ -369,9 +425,8 @@ func (hc *HeaderChain) GetAncestor(hash common.Hash, number, ancestor uint64, ma // in this case it is cheaper to just read the header if header := hc.GetHeader(hash, number); header != nil { return header.ParentHash, number - 1 - } else { - return common.Hash{}, 0 } + return common.Hash{}, 0 } for ancestor != 0 { if rawdb.ReadCanonicalHash(hc.chainDb, number) == hash { diff --git a/core/headerchain_test.go b/core/headerchain_test.go new file mode 100644 index 000000000000..0aa25efd1f23 --- /dev/null +++ b/core/headerchain_test.go @@ -0,0 +1,115 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package core + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/ethereum/go-ethereum/consensus" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/params" +) + +func verifyUnbrokenCanonchain(hc *HeaderChain) error { + h := hc.CurrentHeader() + for { + canonHash := rawdb.ReadCanonicalHash(hc.chainDb, h.Number.Uint64()) + if exp := h.Hash(); canonHash != exp { + return fmt.Errorf("Canon hash chain broken, block %d got %x, expected %x", + h.Number, canonHash[:8], exp[:8]) + } + // Verify that we have the TD + if td := rawdb.ReadTd(hc.chainDb, canonHash, h.Number.Uint64()); td == nil { + return fmt.Errorf("Canon TD missing at block %d", h.Number) + } + if h.Number.Uint64() == 0 { + break + } + h = hc.GetHeader(h.ParentHash, h.Number.Uint64()-1) + } + return nil +} + +func testInsert(t *testing.T, hc *HeaderChain, chain []*types.Header, wantStatus WriteStatus, wantErr error) { + t.Helper() + + status, err := hc.InsertHeaderChain(chain, time.Now()) + if status != wantStatus { + t.Errorf("wrong write status from InsertHeaderChain: got %v, want %v", status, wantStatus) + } + // Always verify that the header chain is unbroken + if err := verifyUnbrokenCanonchain(hc); err != nil { + t.Fatal(err) + } + if !errors.Is(err, wantErr) { + t.Fatalf("unexpected error from InsertHeaderChain: %v", err) + } +} + +// This test checks status reporting of InsertHeaderChain. +func TestHeaderInsertion(t *testing.T) { + var ( + db = rawdb.NewMemoryDatabase() + genesis = new(Genesis).MustCommit(db) + ) + + hc, err := NewHeaderChain(db, params.AllEthashProtocolChanges, ethash.NewFaker(), func() bool { return false }) + if err != nil { + t.Fatal(err) + } + // chain A: G->A1->A2...A128 + chainA := makeHeaderChain(genesis.Header(), 128, ethash.NewFaker(), db, 10) + // chain B: G->A1->B2...B128 + chainB := makeHeaderChain(chainA[0], 128, ethash.NewFaker(), db, 10) + log.Root().SetHandler(log.StdoutHandler) + + // Inserting 64 headers on an empty chain, expecting + // 1 callbacks, 1 canon-status, 0 sidestatus, + testInsert(t, hc, chainA[:64], CanonStatTy, nil) + + // Inserting 64 identical headers, expecting + // 0 callbacks, 0 canon-status, 0 sidestatus, + testInsert(t, hc, chainA[:64], NonStatTy, nil) + + // Inserting the same some old, some new headers + // 1 callbacks, 1 canon, 0 side + testInsert(t, hc, chainA[32:96], CanonStatTy, nil) + + // Inserting side blocks, but not overtaking the canon chain + testInsert(t, hc, chainB[0:32], SideStatTy, nil) + + // Inserting more side blocks, but we don't have the parent + testInsert(t, hc, chainB[34:36], NonStatTy, consensus.ErrUnknownAncestor) + + // Inserting more sideblocks, overtaking the canon chain + testInsert(t, hc, chainB[32:97], CanonStatTy, nil) + + // Inserting more A-headers, taking back the canonicality + testInsert(t, hc, chainA[90:100], CanonStatTy, nil) + + // And B becomes canon again + testInsert(t, hc, chainB[97:107], CanonStatTy, nil) + + // And B becomes even longer + testInsert(t, hc, chainB[107:128], CanonStatTy, nil) +} diff --git a/core/rawdb/accessors_metadata.go b/core/rawdb/accessors_metadata.go index 14a302a1270f..079e335fa6fb 100644 --- a/core/rawdb/accessors_metadata.go +++ b/core/rawdb/accessors_metadata.go @@ -18,6 +18,7 @@ package rawdb import ( "encoding/json" + "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" @@ -30,7 +31,7 @@ import ( func ReadDatabaseVersion(db ethdb.KeyValueReader) *uint64 { var version uint64 - enc, _ := db.Get(databaseVerisionKey) + enc, _ := db.Get(databaseVersionKey) if len(enc) == 0 { return nil } @@ -47,7 +48,7 @@ func WriteDatabaseVersion(db ethdb.KeyValueWriter, version uint64) { if err != nil { log.Crit("Failed to encode database version", "err", err) } - if err = db.Put(databaseVerisionKey, enc); err != nil { + if err = db.Put(databaseVersionKey, enc); err != nil { log.Crit("Failed to store the database version", "err", err) } } @@ -79,3 +80,61 @@ func WriteChainConfig(db ethdb.KeyValueWriter, hash common.Hash, cfg *params.Cha log.Crit("Failed to store chain config", "err", err) } } + +// crashList is a list of unclean-shutdown-markers, for rlp-encoding to the +// database +type crashList struct { + Discarded uint64 // how many ucs have we deleted + Recent []uint64 // unix timestamps of 10 latest unclean shutdowns +} + +const crashesToKeep = 10 + +// PushUncleanShutdownMarker appends a new unclean shutdown marker and returns +// the previous data +// - a list of timestamps +// - a count of how many old unclean-shutdowns have been discarded +func PushUncleanShutdownMarker(db ethdb.KeyValueStore) ([]uint64, uint64, error) { + var uncleanShutdowns crashList + // Read old data + if data, err := db.Get(uncleanShutdownKey); err != nil { + log.Warn("Error reading unclean shutdown markers", "error", err) + } else if err := rlp.DecodeBytes(data, &uncleanShutdowns); err != nil { + return nil, 0, err + } + var discarded = uncleanShutdowns.Discarded + var previous = make([]uint64, len(uncleanShutdowns.Recent)) + copy(previous, uncleanShutdowns.Recent) + // Add a new (but cap it) + uncleanShutdowns.Recent = append(uncleanShutdowns.Recent, uint64(time.Now().Unix())) + if count := len(uncleanShutdowns.Recent); count > crashesToKeep+1 { + numDel := count - (crashesToKeep + 1) + uncleanShutdowns.Recent = uncleanShutdowns.Recent[numDel:] + uncleanShutdowns.Discarded += uint64(numDel) + } + // And save it again + data, _ := rlp.EncodeToBytes(uncleanShutdowns) + if err := db.Put(uncleanShutdownKey, data); err != nil { + log.Warn("Failed to write unclean-shutdown marker", "err", err) + return nil, 0, err + } + return previous, discarded, nil +} + +// PopUncleanShutdownMarker removes the last unclean shutdown marker +func PopUncleanShutdownMarker(db ethdb.KeyValueStore) { + var uncleanShutdowns crashList + // Read old data + if data, err := db.Get(uncleanShutdownKey); err != nil { + log.Warn("Error reading unclean shutdown markers", "error", err) + } else if err := rlp.DecodeBytes(data, &uncleanShutdowns); err != nil { + log.Error("Error decoding unclean shutdown markers", "error", err) // Should mos def _not_ happen + } + if l := len(uncleanShutdowns.Recent); l > 0 { + uncleanShutdowns.Recent = uncleanShutdowns.Recent[:l-1] + } + data, _ := rlp.EncodeToBytes(uncleanShutdowns) + if err := db.Put(uncleanShutdownKey, data); err != nil { + log.Warn("Failed to clear unclean-shutdown marker", "err", err) + } +} diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go index ecd4e65978ee..0a91d9353b22 100644 --- a/core/rawdb/accessors_snapshot.go +++ b/core/rawdb/accessors_snapshot.go @@ -17,6 +17,8 @@ package rawdb import ( + "encoding/binary" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" @@ -118,3 +120,79 @@ func DeleteSnapshotJournal(db ethdb.KeyValueWriter) { log.Crit("Failed to remove snapshot journal", "err", err) } } + +// ReadSnapshotGenerator retrieves the serialized snapshot generator saved at +// the last shutdown. +func ReadSnapshotGenerator(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(snapshotGeneratorKey) + return data +} + +// WriteSnapshotGenerator stores the serialized snapshot generator to save at +// shutdown. +func WriteSnapshotGenerator(db ethdb.KeyValueWriter, generator []byte) { + if err := db.Put(snapshotGeneratorKey, generator); err != nil { + log.Crit("Failed to store snapshot generator", "err", err) + } +} + +// DeleteSnapshotGenerator deletes the serialized snapshot generator saved at +// the last shutdown +func DeleteSnapshotGenerator(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotGeneratorKey); err != nil { + log.Crit("Failed to remove snapshot generator", "err", err) + } +} + +// ReadSnapshotRecoveryNumber retrieves the block number of the last persisted +// snapshot layer. +func ReadSnapshotRecoveryNumber(db ethdb.KeyValueReader) *uint64 { + data, _ := db.Get(snapshotRecoveryKey) + if len(data) == 0 { + return nil + } + if len(data) != 8 { + return nil + } + number := binary.BigEndian.Uint64(data) + return &number +} + +// WriteSnapshotRecoveryNumber stores the block number of the last persisted +// snapshot layer. +func WriteSnapshotRecoveryNumber(db ethdb.KeyValueWriter, number uint64) { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], number) + if err := db.Put(snapshotRecoveryKey, buf[:]); err != nil { + log.Crit("Failed to store snapshot recovery number", "err", err) + } +} + +// DeleteSnapshotRecoveryNumber deletes the block number of the last persisted +// snapshot layer. +func DeleteSnapshotRecoveryNumber(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotRecoveryKey); err != nil { + log.Crit("Failed to remove snapshot recovery number", "err", err) + } +} + +// ReadSanpshotSyncStatus retrieves the serialized sync status saved at shutdown. +func ReadSanpshotSyncStatus(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(snapshotSyncStatusKey) + return data +} + +// WriteSnapshotSyncStatus stores the serialized sync status to save at shutdown. +func WriteSnapshotSyncStatus(db ethdb.KeyValueWriter, status []byte) { + if err := db.Put(snapshotSyncStatusKey, status); err != nil { + log.Crit("Failed to store snapshot sync status", "err", err) + } +} + +// DeleteSnapshotSyncStatus deletes the serialized sync status saved at the last +// shutdown +func DeleteSnapshotSyncStatus(db ethdb.KeyValueWriter) { + if err := db.Delete(snapshotSyncStatusKey); err != nil { + log.Crit("Failed to remove snapshot sync status", "err", err) + } +} diff --git a/core/rawdb/chain_iterator.go b/core/rawdb/chain_iterator.go index 3130e922e85a..393b72c26c16 100644 --- a/core/rawdb/chain_iterator.go +++ b/core/rawdb/chain_iterator.go @@ -84,15 +84,17 @@ type blockTxHashes struct { } // iterateTransactions iterates over all transactions in the (canon) block -// number(s) given, and yields the hashes on a channel -func iterateTransactions(db ethdb.Database, from uint64, to uint64, reverse bool) (chan *blockTxHashes, chan struct{}) { +// number(s) given, and yields the hashes on a channel. If there is a signal +// received from interrupt channel, the iteration will be aborted and result +// channel will be closed. +func iterateTransactions(db ethdb.Database, from uint64, to uint64, reverse bool, interrupt chan struct{}) chan *blockTxHashes { // One thread sequentially reads data from db type numberRlp struct { number uint64 rlp rlp.RawValue } if to == from { - return nil, nil + return nil } threads := to - from if cpus := runtime.NumCPU(); threads > uint64(cpus) { @@ -101,7 +103,6 @@ func iterateTransactions(db ethdb.Database, from uint64, to uint64, reverse bool var ( rlpCh = make(chan *numberRlp, threads*2) // we send raw rlp over this channel hashesCh = make(chan *blockTxHashes, threads*2) // send hashes over hashesCh - abortCh = make(chan struct{}) ) // lookup runs in one instance lookup := func() { @@ -115,7 +116,7 @@ func iterateTransactions(db ethdb.Database, from uint64, to uint64, reverse bool // Feed the block to the aggregator, or abort on interrupt select { case rlpCh <- &numberRlp{n, data}: - case <-abortCh: + case <-interrupt: return } if reverse { @@ -168,7 +169,7 @@ func iterateTransactions(db ethdb.Database, from uint64, to uint64, reverse bool // Feed the block to the aggregator, or abort on interrupt select { case hashesCh <- result: - case <-abortCh: + case <-interrupt: return } } @@ -177,25 +178,28 @@ func iterateTransactions(db ethdb.Database, from uint64, to uint64, reverse bool for i := 0; i < int(threads); i++ { go process() } - return hashesCh, abortCh + return hashesCh } -// IndexTransactions creates txlookup indices of the specified block range. +// indexTransactions creates txlookup indices of the specified block range. // // This function iterates canonical chain in reverse order, it has one main advantage: // We can write tx index tail flag periodically even without the whole indexing // procedure is finished. So that we can resume indexing procedure next time quickly. -func IndexTransactions(db ethdb.Database, from uint64, to uint64) { +// +// There is a passed channel, the whole procedure will be interrupted if any +// signal received. +func indexTransactions(db ethdb.Database, from uint64, to uint64, interrupt chan struct{}, hook func(uint64) bool) { // short circuit for invalid range if from >= to { return } var ( - hashesCh, abortCh = iterateTransactions(db, from, to, true) - batch = db.NewBatch() - start = time.Now() - logged = start.Add(-7 * time.Second) - // Since we iterate in reverse, we expect the first number to come + hashesCh = iterateTransactions(db, from, to, true, interrupt) + batch = db.NewBatch() + start = time.Now() + logged = start.Add(-7 * time.Second) + // Since we iterate in reverse, we expect the first number to come // in to be [to-1]. Therefore, setting lastNum to means that the // prqueue gap-evaluation will work correctly lastNum = to @@ -203,8 +207,6 @@ func IndexTransactions(db ethdb.Database, from uint64, to uint64) { // for stats reporting blocks, txs = 0, 0 ) - defer close(abortCh) - for chanDelivery := range hashesCh { // Push the delivery into the queue and process contiguous ranges. // Since we iterate in reverse, so lower numbers have lower prio, and @@ -215,6 +217,10 @@ func IndexTransactions(db ethdb.Database, from uint64, to uint64) { if _, priority := queue.Peek(); priority != int64(lastNum-1) { break } + // For testing + if hook != nil && !hook(lastNum-1) { + break + } // Next block available, pop it off and index it delivery := queue.PopItem().(*blockTxHashes) lastNum = delivery.number @@ -223,8 +229,7 @@ func IndexTransactions(db ethdb.Database, from uint64, to uint64) { txs += len(delivery.hashes) // If enough data was accumulated in memory or we're at the last block, dump to disk if batch.ValueSize() > ethdb.IdealBatchSize { - // Also write the tail there - WriteTxIndexTail(batch, lastNum) + WriteTxIndexTail(batch, lastNum) // Also write the tail here if err := batch.Write(); err != nil { log.Crit("Failed writing batch to db", "error", err) return @@ -238,67 +243,122 @@ func IndexTransactions(db ethdb.Database, from uint64, to uint64) { } } } - if lastNum < to { - WriteTxIndexTail(batch, lastNum) - // No need to write the batch if we never entered the loop above... + // If there exists uncommitted data, flush them. + if batch.ValueSize() > 0 { + WriteTxIndexTail(batch, lastNum) // Also write the tail there if err := batch.Write(); err != nil { log.Crit("Failed writing batch to db", "error", err) return } } - log.Info("Indexed transactions", "blocks", blocks, "txs", txs, "tail", lastNum, "elapsed", common.PrettyDuration(time.Since(start))) + select { + case <-interrupt: + log.Debug("Transaction indexing interrupted", "blocks", blocks, "txs", txs, "tail", lastNum, "elapsed", common.PrettyDuration(time.Since(start))) + default: + log.Info("Indexed transactions", "blocks", blocks, "txs", txs, "tail", lastNum, "elapsed", common.PrettyDuration(time.Since(start))) + } } -// UnindexTransactions removes txlookup indices of the specified block range. -func UnindexTransactions(db ethdb.Database, from uint64, to uint64) { +// IndexTransactions creates txlookup indices of the specified block range. +// +// This function iterates canonical chain in reverse order, it has one main advantage: +// We can write tx index tail flag periodically even without the whole indexing +// procedure is finished. So that we can resume indexing procedure next time quickly. +// +// There is a passed channel, the whole procedure will be interrupted if any +// signal received. +func IndexTransactions(db ethdb.Database, from uint64, to uint64, interrupt chan struct{}) { + indexTransactions(db, from, to, interrupt, nil) +} + +// indexTransactionsForTesting is the internal debug version with an additional hook. +func indexTransactionsForTesting(db ethdb.Database, from uint64, to uint64, interrupt chan struct{}, hook func(uint64) bool) { + indexTransactions(db, from, to, interrupt, hook) +} + +// unindexTransactions removes txlookup indices of the specified block range. +// +// There is a passed channel, the whole procedure will be interrupted if any +// signal received. +func unindexTransactions(db ethdb.Database, from uint64, to uint64, interrupt chan struct{}, hook func(uint64) bool) { // short circuit for invalid range if from >= to { return } - // Write flag first and then unindex the transaction indices. Some indices - // will be left in the database if crash happens but it's fine. - WriteTxIndexTail(db, to) - // If only one block is unindexed, do it directly - //if from+1 == to { - // data := ReadCanonicalBodyRLP(db, uint64(from)) - // DeleteTxLookupEntries(db, ReadBlock(db, ReadCanonicalHash(db, from), from)) - // log.Info("Unindexed transactions", "blocks", 1, "tail", to) - // return - //} - // TODO @holiman, add this back (if we want it) var ( - hashesCh, abortCh = iterateTransactions(db, from, to, false) - batch = db.NewBatch() - start = time.Now() - logged = start.Add(-7 * time.Second) + hashesCh = iterateTransactions(db, from, to, false, interrupt) + batch = db.NewBatch() + start = time.Now() + logged = start.Add(-7 * time.Second) + // we expect the first number to come in to be [from]. Therefore, setting + // nextNum to from means that the prqueue gap-evaluation will work correctly + nextNum = from + queue = prque.New(nil) + // for stats reporting + blocks, txs = 0, 0 ) - defer close(abortCh) // Otherwise spin up the concurrent iterator and unindexer - blocks, txs := 0, 0 for delivery := range hashesCh { - DeleteTxLookupEntries(batch, delivery.hashes) - txs += len(delivery.hashes) - blocks++ + // Push the delivery into the queue and process contiguous ranges. + queue.Push(delivery, -int64(delivery.number)) + for !queue.Empty() { + // If the next available item is gapped, return + if _, priority := queue.Peek(); -priority != int64(nextNum) { + break + } + // For testing + if hook != nil && !hook(nextNum) { + break + } + delivery := queue.PopItem().(*blockTxHashes) + nextNum = delivery.number + 1 + DeleteTxLookupEntries(batch, delivery.hashes) + txs += len(delivery.hashes) + blocks++ - // If enough data was accumulated in memory or we're at the last block, dump to disk - // A batch counts the size of deletion as '1', so we need to flush more - // often than that. - if blocks%1000 == 0 { - if err := batch.Write(); err != nil { - log.Crit("Failed writing batch to db", "error", err) - return + // If enough data was accumulated in memory or we're at the last block, dump to disk + // A batch counts the size of deletion as '1', so we need to flush more + // often than that. + if blocks%1000 == 0 { + WriteTxIndexTail(batch, nextNum) + if err := batch.Write(); err != nil { + log.Crit("Failed writing batch to db", "error", err) + return + } + batch.Reset() + } + // If we've spent too much time already, notify the user of what we're doing + if time.Since(logged) > 8*time.Second { + log.Info("Unindexing transactions", "blocks", blocks, "txs", txs, "total", to-from, "elapsed", common.PrettyDuration(time.Since(start))) + logged = time.Now() } - batch.Reset() } - // If we've spent too much time already, notify the user of what we're doing - if time.Since(logged) > 8*time.Second { - log.Info("Unindexing transactions", "blocks", blocks, "txs", txs, "total", to-from, "elapsed", common.PrettyDuration(time.Since(start))) - logged = time.Now() + } + // Commit the last batch if there exists uncommitted data + if batch.ValueSize() > 0 { + WriteTxIndexTail(batch, nextNum) + if err := batch.Write(); err != nil { + log.Crit("Failed writing batch to db", "error", err) + return } } - if err := batch.Write(); err != nil { - log.Crit("Failed writing batch to db", "error", err) - return + select { + case <-interrupt: + log.Debug("Transaction unindexing interrupted", "blocks", blocks, "txs", txs, "tail", to, "elapsed", common.PrettyDuration(time.Since(start))) + default: + log.Info("Unindexed transactions", "blocks", blocks, "txs", txs, "tail", to, "elapsed", common.PrettyDuration(time.Since(start))) } - log.Info("Unindexed transactions", "blocks", blocks, "txs", txs, "tail", to, "elapsed", common.PrettyDuration(time.Since(start))) +} + +// UnindexTransactions removes txlookup indices of the specified block range. +// +// There is a passed channel, the whole procedure will be interrupted if any +// signal received. +func UnindexTransactions(db ethdb.Database, from uint64, to uint64, interrupt chan struct{}) { + unindexTransactions(db, from, to, interrupt, nil) +} + +// unindexTransactionsForTesting is the internal debug version with an additional hook. +func unindexTransactionsForTesting(db ethdb.Database, from uint64, to uint64, interrupt chan struct{}, hook func(uint64) bool) { + unindexTransactions(db, from, to, interrupt, hook) } diff --git a/core/rawdb/chain_iterator_test.go b/core/rawdb/chain_iterator_test.go index c635cd2f1273..90b2639d38cf 100644 --- a/core/rawdb/chain_iterator_test.go +++ b/core/rawdb/chain_iterator_test.go @@ -20,6 +20,7 @@ import ( "math/big" "reflect" "sort" + "sync" "testing" "github.com/ethereum/go-ethereum/common" @@ -59,7 +60,7 @@ func TestChainIterator(t *testing.T) { } for i, c := range cases { var numbers []int - hashCh, _ := iterateTransactions(chainDb, c.from, c.to, c.reverse) + hashCh := iterateTransactions(chainDb, c.from, c.to, c.reverse, nil) if hashCh != nil { for h := range hashCh { numbers = append(numbers, int(h.number)) @@ -80,3 +81,85 @@ func TestChainIterator(t *testing.T) { } } } + +func TestIndexTransactions(t *testing.T) { + // Construct test chain db + chainDb := NewMemoryDatabase() + + var block *types.Block + var txs []*types.Transaction + for i := uint64(0); i <= 10; i++ { + if i == 0 { + block = types.NewBlock(&types.Header{Number: big.NewInt(int64(i))}, nil, nil, nil, newHasher()) // Empty genesis block + } else { + tx := types.NewTransaction(i, common.BytesToAddress([]byte{0x11}), big.NewInt(111), 1111, big.NewInt(11111), []byte{0x11, 0x11, 0x11}) + txs = append(txs, tx) + block = types.NewBlock(&types.Header{Number: big.NewInt(int64(i))}, []*types.Transaction{tx}, nil, nil, newHasher()) + } + WriteBlock(chainDb, block) + WriteCanonicalHash(chainDb, block.Hash(), block.NumberU64()) + } + // verify checks whether the tx indices in the range [from, to) + // is expected. + verify := func(from, to int, exist bool, tail uint64) { + for i := from; i < to; i++ { + if i == 0 { + continue + } + number := ReadTxLookupEntry(chainDb, txs[i-1].Hash()) + if exist && number == nil { + t.Fatalf("Transaction indice missing") + } + if !exist && number != nil { + t.Fatalf("Transaction indice is not deleted") + } + } + number := ReadTxIndexTail(chainDb) + if number == nil || *number != tail { + t.Fatalf("Transaction tail mismatch") + } + } + IndexTransactions(chainDb, 5, 11, nil) + verify(5, 11, true, 5) + verify(0, 5, false, 5) + + IndexTransactions(chainDb, 0, 5, nil) + verify(0, 11, true, 0) + + UnindexTransactions(chainDb, 0, 5, nil) + verify(5, 11, true, 5) + verify(0, 5, false, 5) + + UnindexTransactions(chainDb, 5, 11, nil) + verify(0, 11, false, 11) + + // Testing corner cases + signal := make(chan struct{}) + var once sync.Once + indexTransactionsForTesting(chainDb, 5, 11, signal, func(n uint64) bool { + if n <= 8 { + once.Do(func() { + close(signal) + }) + return false + } + return true + }) + verify(9, 11, true, 9) + verify(0, 9, false, 9) + IndexTransactions(chainDb, 0, 9, nil) + + signal = make(chan struct{}) + var once2 sync.Once + unindexTransactionsForTesting(chainDb, 0, 11, signal, func(n uint64) bool { + if n >= 8 { + once2.Do(func() { + close(signal) + }) + return false + } + return true + }) + verify(8, 11, true, 8) + verify(0, 8, false, 8) +} diff --git a/core/rawdb/database.go b/core/rawdb/database.go index b1ac3e95878f..b01a31ebcdd5 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -355,7 +355,7 @@ func InspectDatabase(db ethdb.Database) error { bloomTrieNodes.Add(size) default: var accounted bool - for _, meta := range [][]byte{databaseVerisionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey} { + for _, meta := range [][]byte{databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, fastTrieProgressKey, uncleanShutdownKey} { if bytes.Equal(key, meta) { metadata.Add(size) accounted = true diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index e2b093a34a6c..2aabfd3baa53 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -27,8 +27,8 @@ import ( // The fields below define the low level database schema prefixing. var ( - // databaseVerisionKey tracks the current database version. - databaseVerisionKey = []byte("DatabaseVersion") + // databaseVersionKey tracks the current database version. + databaseVersionKey = []byte("DatabaseVersion") // headHeaderKey tracks the latest known header's hash. headHeaderKey = []byte("LastHeader") @@ -51,6 +51,15 @@ var ( // snapshotJournalKey tracks the in-memory diff layers across restarts. snapshotJournalKey = []byte("SnapshotJournal") + // snapshotGeneratorKey tracks the snapshot generation marker across restarts. + snapshotGeneratorKey = []byte("SnapshotGenerator") + + // snapshotRecoveryKey tracks the snapshot recovery marker across restarts. + snapshotRecoveryKey = []byte("SnapshotRecovery") + + // snapshotSyncStatusKey tracks the snapshot sync status across restarts. + snapshotSyncStatusKey = []byte("SnapshotSyncStatus") + // txIndexTailKey tracks the oldest block whose transactions have been indexed. txIndexTailKey = []byte("TransactionIndexTail") @@ -75,6 +84,8 @@ var ( preimagePrefix = []byte("secure-key-") // preimagePrefix + hash -> preimage configPrefix = []byte("ethereum-config-") // config prefix for the db + uncleanShutdownKey = []byte("unclean-shutdown") // config prefix for the db + // Chain index prefixes (use `i` + single byte to avoid mixing data types). BloomBitsIndexPrefix = []byte("iB") // BloomBitsIndexPrefix is the data table of a chain indexer to track its progress diff --git a/core/state/access_list.go b/core/state/access_list.go new file mode 100644 index 000000000000..419469134595 --- /dev/null +++ b/core/state/access_list.go @@ -0,0 +1,136 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "github.com/ethereum/go-ethereum/common" +) + +type accessList struct { + addresses map[common.Address]int + slots []map[common.Hash]struct{} +} + +// ContainsAddress returns true if the address is in the access list. +func (al *accessList) ContainsAddress(address common.Address) bool { + _, ok := al.addresses[address] + return ok +} + +// Contains checks if a slot within an account is present in the access list, returning +// separate flags for the presence of the account and the slot respectively. +func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { + idx, ok := al.addresses[address] + if !ok { + // no such address (and hence zero slots) + return false, false + } + if idx == -1 { + // address yes, but no slots + return true, false + } + _, slotPresent = al.slots[idx][slot] + return true, slotPresent +} + +// newAccessList creates a new accessList. +func newAccessList() *accessList { + return &accessList{ + addresses: make(map[common.Address]int), + } +} + +// Copy creates an independent copy of an accessList. +func (a *accessList) Copy() *accessList { + cp := newAccessList() + for k, v := range a.addresses { + cp.addresses[k] = v + } + cp.slots = make([]map[common.Hash]struct{}, len(a.slots)) + for i, slotMap := range a.slots { + newSlotmap := make(map[common.Hash]struct{}, len(slotMap)) + for k := range slotMap { + newSlotmap[k] = struct{}{} + } + cp.slots[i] = newSlotmap + } + return cp +} + +// AddAddress adds an address to the access list, and returns 'true' if the operation +// caused a change (addr was not previously in the list). +func (al *accessList) AddAddress(address common.Address) bool { + if _, present := al.addresses[address]; present { + return false + } + al.addresses[address] = -1 + return true +} + +// AddSlot adds the specified (addr, slot) combo to the access list. +// Return values are: +// - address added +// - slot added +// For any 'true' value returned, a corresponding journal entry must be made. +func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) { + idx, addrPresent := al.addresses[address] + if !addrPresent || idx == -1 { + // Address not present, or addr present but no slots there + al.addresses[address] = len(al.slots) + slotmap := map[common.Hash]struct{}{slot: {}} + al.slots = append(al.slots, slotmap) + return !addrPresent, true + } + // There is already an (address,slot) mapping + slotmap := al.slots[idx] + if _, ok := slotmap[slot]; !ok { + slotmap[slot] = struct{}{} + // Journal add slot change + return false, true + } + // No changes required + return false, false +} + +// DeleteSlot removes an (address, slot)-tuple from the access list. +// This operation needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) { + idx, addrOk := al.addresses[address] + // There are two ways this can fail + if !addrOk { + panic("reverting slot change, address not present in list") + } + slotmap := al.slots[idx] + delete(slotmap, slot) + // If that was the last (first) slot, remove it + // Since additions and rollbacks are always performed in order, + // we can delete the item without worrying about screwing up later indices + if len(slotmap) == 0 { + al.slots = al.slots[:idx] + al.addresses[address] = -1 + } +} + +// DeleteAddress removes an address from the access list. This operation +// needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteAddress(address common.Address) { + delete(al.addresses, address) +} diff --git a/core/state/database.go b/core/state/database.go index a9342f5179ca..83f7b2a839c5 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -104,18 +104,18 @@ type Trie interface { // NewDatabase creates a backing store for state. The returned database is safe for // concurrent use, but does not retain any recent trie nodes in memory. To keep some -// historical state in memory, use the NewDatabaseWithCache constructor. +// historical state in memory, use the NewDatabaseWithConfig constructor. func NewDatabase(db ethdb.Database) Database { - return NewDatabaseWithCache(db, 0, "") + return NewDatabaseWithConfig(db, nil) } -// NewDatabaseWithCache creates a backing store for state. The returned database +// NewDatabaseWithConfig creates a backing store for state. The returned database // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // large memory cache. -func NewDatabaseWithCache(db ethdb.Database, cache int, journal string) Database { +func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithCache(db, cache, journal), + db: trie.NewDatabaseWithConfig(db, config), codeSizeCache: csc, codeCache: fastcache.New(codeCacheSize), } diff --git a/core/state/journal.go b/core/state/journal.go index f242dac5afb0..2070f3087554 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -130,6 +130,14 @@ type ( touchChange struct { account *common.Address } + // Changes to the access list + accessListAddAccountChange struct { + address *common.Address + } + accessListAddSlotChange struct { + address *common.Address + slot *common.Hash + } ) func (ch createObjectChange) revert(s *StateDB) { @@ -234,3 +242,28 @@ func (ch addPreimageChange) revert(s *StateDB) { func (ch addPreimageChange) dirtied() *common.Address { return nil } + +func (ch accessListAddAccountChange) revert(s *StateDB) { + /* + One important invariant here, is that whenever a (addr, slot) is added, if the + addr is not already present, the add causes two journal entries: + - one for the address, + - one for the (address,slot) + Therefore, when unrolling the change, we can always blindly delete the + (addr) at this point, since no storage adds can remain when come upon + a single (addr) change. + */ + s.accessList.DeleteAddress(*ch.address) +} + +func (ch accessListAddAccountChange) dirtied() *common.Address { + return nil +} + +func (ch accessListAddSlotChange) revert(s *StateDB) { + s.accessList.DeleteSlot(*ch.address, *ch.slot) +} + +func (ch accessListAddSlotChange) dirtied() *common.Address { + return nil +} diff --git a/core/state/snapshot/disklayer_test.go b/core/state/snapshot/disklayer_test.go index 8460cd332f99..40ff5ade4c37 100644 --- a/core/state/snapshot/disklayer_test.go +++ b/core/state/snapshot/disklayer_test.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/leveldb" "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/ethereum/go-ethereum/rlp" ) // reverse reverses the contents of a byte slice. It's used to update random accs @@ -429,6 +430,81 @@ func TestDiskPartialMerge(t *testing.T) { } } +// Tests that when the bottom-most diff layer is merged into the disk +// layer whether the corresponding generator is persisted correctly. +func TestDiskGeneratorPersistence(t *testing.T) { + var ( + accOne = randomHash() + accTwo = randomHash() + accOneSlotOne = randomHash() + accOneSlotTwo = randomHash() + + accThree = randomHash() + accThreeSlot = randomHash() + baseRoot = randomHash() + diffRoot = randomHash() + diffTwoRoot = randomHash() + genMarker = append(randomHash().Bytes(), randomHash().Bytes()...) + ) + // Testing scenario 1, the disk layer is still under the construction. + db := rawdb.NewMemoryDatabase() + + rawdb.WriteAccountSnapshot(db, accOne, accOne[:]) + rawdb.WriteStorageSnapshot(db, accOne, accOneSlotOne, accOneSlotOne[:]) + rawdb.WriteStorageSnapshot(db, accOne, accOneSlotTwo, accOneSlotTwo[:]) + rawdb.WriteSnapshotRoot(db, baseRoot) + + // Create a disk layer based on all above updates + snaps := &Tree{ + layers: map[common.Hash]snapshot{ + baseRoot: &diskLayer{ + diskdb: db, + cache: fastcache.New(500 * 1024), + root: baseRoot, + genMarker: genMarker, + }, + }, + } + // Modify or delete some accounts, flatten everything onto disk + if err := snaps.Update(diffRoot, baseRoot, nil, map[common.Hash][]byte{ + accTwo: accTwo[:], + }, nil); err != nil { + t.Fatalf("failed to update snapshot tree: %v", err) + } + if err := snaps.Cap(diffRoot, 0); err != nil { + t.Fatalf("failed to flatten snapshot tree: %v", err) + } + blob := rawdb.ReadSnapshotGenerator(db) + var generator journalGenerator + if err := rlp.DecodeBytes(blob, &generator); err != nil { + t.Fatalf("Failed to decode snapshot generator %v", err) + } + if !bytes.Equal(generator.Marker, genMarker) { + t.Fatalf("Generator marker is not matched") + } + // Test senario 2, the disk layer is fully generated + // Modify or delete some accounts, flatten everything onto disk + if err := snaps.Update(diffTwoRoot, diffRoot, nil, map[common.Hash][]byte{ + accThree: accThree.Bytes(), + }, map[common.Hash]map[common.Hash][]byte{ + accThree: {accThreeSlot: accThreeSlot.Bytes()}, + }); err != nil { + t.Fatalf("failed to update snapshot tree: %v", err) + } + diskLayer := snaps.layers[snaps.diskRoot()].(*diskLayer) + diskLayer.genMarker = nil // Construction finished + if err := snaps.Cap(diffTwoRoot, 0); err != nil { + t.Fatalf("failed to flatten snapshot tree: %v", err) + } + blob = rawdb.ReadSnapshotGenerator(db) + if err := rlp.DecodeBytes(blob, &generator); err != nil { + t.Fatalf("Failed to decode snapshot generator %v", err) + } + if len(generator.Marker) != 0 { + t.Fatalf("Failed to update snapshot generator") + } +} + // Tests that merging something into a disk layer persists it into the database // and invalidates any previously written and cached values, discarding anything // after the in-progress generation marker. diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go index cf9b2b0393a0..4a2fa78d3ac2 100644 --- a/core/state/snapshot/generate.go +++ b/core/state/snapshot/generate.go @@ -19,6 +19,7 @@ package snapshot import ( "bytes" "encoding/binary" + "fmt" "math/big" "time" @@ -112,9 +113,42 @@ func generateSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache i genAbort: make(chan chan *generatorStats), } go base.generate(&generatorStats{wiping: wiper, start: time.Now()}) + log.Debug("Start snapshot generation", "root", root) return base } +// journalProgress persists the generator stats into the database to resume later. +func journalProgress(db ethdb.KeyValueWriter, marker []byte, stats *generatorStats) { + // Write out the generator marker. Note it's a standalone disk layer generator + // which is not mixed with journal. It's ok if the generator is persisted while + // journal is not. + entry := journalGenerator{ + Done: marker == nil, + Marker: marker, + } + if stats != nil { + entry.Wiping = (stats.wiping != nil) + entry.Accounts = stats.accounts + entry.Slots = stats.slots + entry.Storage = uint64(stats.storage) + } + blob, err := rlp.EncodeToBytes(entry) + if err != nil { + panic(err) // Cannot happen, here to catch dev errors + } + var logstr string + switch len(marker) { + case 0: + logstr = "done" + case common.HashLength: + logstr = fmt.Sprintf("%#x", marker) + default: + logstr = fmt.Sprintf("%#x:%#x", marker[:common.HashLength], marker[common.HashLength:]) + } + log.Debug("Journalled generator progress", "progress", logstr) + rawdb.WriteSnapshotGenerator(db, blob) +} + // generate is a background thread that iterates over the state and storage tries, // constructing the state snapshot. All the arguments are purely for statistics // gethering and logging, since the method surfs the blocks as they arrive, often @@ -129,7 +163,7 @@ func (dl *diskLayer) generate(stats *generatorStats) { stats.wiping = nil stats.start = time.Now() - // If generator was aboted during wipe, return + // If generator was aborted during wipe, return case abort := <-dl.genAbort: abort <- stats return @@ -186,11 +220,15 @@ func (dl *diskLayer) generate(stats *generatorStats) { if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil { // Only write and set the marker if we actually did something useful if batch.ValueSize() > 0 { + // Ensure the generator entry is in sync with the data + marker := accountHash[:] + journalProgress(batch, marker, stats) + batch.Write() batch.Reset() dl.lock.Lock() - dl.genMarker = accountHash[:] + dl.genMarker = marker dl.lock.Unlock() } if abort != nil { @@ -203,7 +241,10 @@ func (dl *diskLayer) generate(stats *generatorStats) { if acc.Root != emptyRoot { storeTrie, err := trie.NewSecure(acc.Root, dl.triedb) if err != nil { - log.Crit("Storage trie inaccessible for snapshot generation", "err", err) + log.Error("Generator failed to access storage trie", "root", dl.root, "account", accountHash, "stroot", acc.Root, "err", err) + abort := <-dl.genAbort + abort <- stats + return } var storeMarker []byte if accMarker != nil && bytes.Equal(accountHash[:], accMarker) && len(dl.genMarker) > common.HashLength { @@ -224,11 +265,15 @@ func (dl *diskLayer) generate(stats *generatorStats) { if batch.ValueSize() > ethdb.IdealBatchSize || abort != nil { // Only write and set the marker if we actually did something useful if batch.ValueSize() > 0 { + // Ensure the generator entry is in sync with the data + marker := append(accountHash[:], storeIt.Key...) + journalProgress(batch, marker, stats) + batch.Write() batch.Reset() dl.lock.Lock() - dl.genMarker = append(accountHash[:], storeIt.Key...) + dl.genMarker = marker dl.lock.Unlock() } if abort != nil { @@ -238,6 +283,12 @@ func (dl *diskLayer) generate(stats *generatorStats) { } } } + if err := storeIt.Err; err != nil { + log.Error("Generator failed to iterate storage trie", "accroot", dl.root, "acchash", common.BytesToHash(accIt.Key), "stroot", acc.Root, "err", err) + abort := <-dl.genAbort + abort <- stats + return + } } if time.Since(logged) > 8*time.Second { stats.Log("Generating state snapshot", dl.root, accIt.Key) @@ -246,8 +297,17 @@ func (dl *diskLayer) generate(stats *generatorStats) { // Some account processed, unmark the marker accMarker = nil } + if err := accIt.Err; err != nil { + log.Error("Generator failed to iterate account trie", "root", dl.root, "err", err) + abort := <-dl.genAbort + abort <- stats + return + } // Snapshot fully generated, set the marker to nil if batch.ValueSize() > 0 { + // Ensure the generator entry is in sync with the data + journalProgress(batch, nil, stats) + batch.Write() } log.Info("Generated state snapshot", "accounts", stats.accounts, "slots", stats.slots, diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go new file mode 100644 index 000000000000..03263f3976a6 --- /dev/null +++ b/core/state/snapshot/generate_test.go @@ -0,0 +1,190 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snapshot + +import ( + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +// Tests that snapshot generation errors out correctly in case of a missing trie +// node in the account trie. +func TestGenerateCorruptAccountTrie(t *testing.T) { + // We can't use statedb to make a test trie (circular dependency), so make + // a fake one manually. We're going with a small account trie of 3 accounts, + // without any storage slots to keep the test smaller. + var ( + diskdb = memorydb.New() + triedb = trie.NewDatabase(diskdb) + ) + tr, _ := trie.NewSecure(common.Hash{}, triedb) + acc := &Account{Balance: big.NewInt(1), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} + val, _ := rlp.EncodeToBytes(acc) + tr.Update([]byte("acc-1"), val) // 0xc7a30f39aff471c95d8a837497ad0e49b65be475cc0953540f80cfcdbdcd9074 + + acc = &Account{Balance: big.NewInt(2), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} + val, _ = rlp.EncodeToBytes(acc) + tr.Update([]byte("acc-2"), val) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7 + + acc = &Account{Balance: big.NewInt(3), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} + val, _ = rlp.EncodeToBytes(acc) + tr.Update([]byte("acc-3"), val) // 0x19ead688e907b0fab07176120dceec244a72aff2f0aa51e8b827584e378772f4 + tr.Commit(nil) // Root: 0xa04693ea110a31037fb5ee814308a6f1d76bdab0b11676bdf4541d2de55ba978 + + // Delete an account trie leaf and ensure the generator chokes + triedb.Commit(common.HexToHash("0xa04693ea110a31037fb5ee814308a6f1d76bdab0b11676bdf4541d2de55ba978"), false, nil) + diskdb.Delete(common.HexToHash("0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7").Bytes()) + + snap := generateSnapshot(diskdb, triedb, 16, common.HexToHash("0xa04693ea110a31037fb5ee814308a6f1d76bdab0b11676bdf4541d2de55ba978"), nil) + select { + case <-snap.genPending: + // Snapshot generation succeeded + t.Errorf("Snapshot generated against corrupt account trie") + + case <-time.After(250 * time.Millisecond): + // Not generated fast enough, hopefully blocked inside on missing trie node fail + } + // Signal abortion to the generator and wait for it to tear down + stop := make(chan *generatorStats) + snap.genAbort <- stop + <-stop +} + +// Tests that snapshot generation errors out correctly in case of a missing root +// trie node for a storage trie. It's similar to internal corruption but it is +// handled differently inside the generator. +func TestGenerateMissingStorageTrie(t *testing.T) { + // We can't use statedb to make a test trie (circular dependency), so make + // a fake one manually. We're going with a small account trie of 3 accounts, + // two of which also has the same 3-slot storage trie attached. + var ( + diskdb = memorydb.New() + triedb = trie.NewDatabase(diskdb) + ) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 + stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 + stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 + stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 + + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} + val, _ := rlp.EncodeToBytes(acc) + accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e + + acc = &Account{Balance: big.NewInt(2), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} + val, _ = rlp.EncodeToBytes(acc) + accTrie.Update([]byte("acc-2"), val) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7 + + acc = &Account{Balance: big.NewInt(3), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} + val, _ = rlp.EncodeToBytes(acc) + accTrie.Update([]byte("acc-3"), val) // 0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2 + accTrie.Commit(nil) // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd + + // We can only corrupt the disk database, so flush the tries out + triedb.Reference( + common.HexToHash("0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67"), + common.HexToHash("0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e"), + ) + triedb.Reference( + common.HexToHash("0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67"), + common.HexToHash("0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2"), + ) + triedb.Commit(common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"), false, nil) + + // Delete a storage trie root and ensure the generator chokes + diskdb.Delete(common.HexToHash("0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67").Bytes()) + + snap := generateSnapshot(diskdb, triedb, 16, common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"), nil) + select { + case <-snap.genPending: + // Snapshot generation succeeded + t.Errorf("Snapshot generated against corrupt storage trie") + + case <-time.After(250 * time.Millisecond): + // Not generated fast enough, hopefully blocked inside on missing trie node fail + } + // Signal abortion to the generator and wait for it to tear down + stop := make(chan *generatorStats) + snap.genAbort <- stop + <-stop +} + +// Tests that snapshot generation errors out correctly in case of a missing trie +// node in a storage trie. +func TestGenerateCorruptStorageTrie(t *testing.T) { + // We can't use statedb to make a test trie (circular dependency), so make + // a fake one manually. We're going with a small account trie of 3 accounts, + // two of which also has the same 3-slot storage trie attached. + var ( + diskdb = memorydb.New() + triedb = trie.NewDatabase(diskdb) + ) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 + stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 + stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 + stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 + + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} + val, _ := rlp.EncodeToBytes(acc) + accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e + + acc = &Account{Balance: big.NewInt(2), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} + val, _ = rlp.EncodeToBytes(acc) + accTrie.Update([]byte("acc-2"), val) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7 + + acc = &Account{Balance: big.NewInt(3), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} + val, _ = rlp.EncodeToBytes(acc) + accTrie.Update([]byte("acc-3"), val) // 0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2 + accTrie.Commit(nil) // Root: 0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd + + // We can only corrupt the disk database, so flush the tries out + triedb.Reference( + common.HexToHash("0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67"), + common.HexToHash("0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e"), + ) + triedb.Reference( + common.HexToHash("0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67"), + common.HexToHash("0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2"), + ) + triedb.Commit(common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"), false, nil) + + // Delete a storage trie leaf and ensure the generator chokes + diskdb.Delete(common.HexToHash("0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371").Bytes()) + + snap := generateSnapshot(diskdb, triedb, 16, common.HexToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"), nil) + select { + case <-snap.genPending: + // Snapshot generation succeeded + t.Errorf("Snapshot generated against corrupt storage trie") + + case <-time.After(250 * time.Millisecond): + // Not generated fast enough, hopefully blocked inside on missing trie node fail + } + // Signal abortion to the generator and wait for it to tear down + stop := make(chan *generatorStats) + snap.genAbort <- stop + <-stop +} diff --git a/core/state/snapshot/journal.go b/core/state/snapshot/journal.go index fc1053f818d6..178ba0890276 100644 --- a/core/state/snapshot/journal.go +++ b/core/state/snapshot/journal.go @@ -33,6 +33,8 @@ import ( "github.com/ethereum/go-ethereum/trie" ) +const journalVersion uint64 = 0 + // journalGenerator is a disk layer entry containing the generator progress marker. type journalGenerator struct { Wiping bool // Whether the database was in progress of being wiped @@ -61,8 +63,91 @@ type journalStorage struct { Vals [][]byte } +// loadAndParseLegacyJournal tries to parse the snapshot journal in legacy format. +func loadAndParseLegacyJournal(db ethdb.KeyValueStore, base *diskLayer) (snapshot, journalGenerator, error) { + // Retrieve the journal, for legacy journal it must exist since even for + // 0 layer it stores whether we've already generated the snapshot or are + // in progress only. + journal := rawdb.ReadSnapshotJournal(db) + if len(journal) == 0 { + return nil, journalGenerator{}, errors.New("missing or corrupted snapshot journal") + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + + // Read the snapshot generation progress for the disk layer + var generator journalGenerator + if err := r.Decode(&generator); err != nil { + return nil, journalGenerator{}, fmt.Errorf("failed to load snapshot progress marker: %v", err) + } + // Load all the snapshot diffs from the journal + snapshot, err := loadDiffLayer(base, r) + if err != nil { + return nil, generator, err + } + return snapshot, generator, nil +} + +// loadAndParseJournal tries to parse the snapshot journal in latest format. +func loadAndParseJournal(db ethdb.KeyValueStore, base *diskLayer) (snapshot, journalGenerator, error) { + // Retrieve the disk layer generator. It must exist, no matter the + // snapshot is fully generated or not. Otherwise the entire disk + // layer is invalid. + generatorBlob := rawdb.ReadSnapshotGenerator(db) + if len(generatorBlob) == 0 { + return nil, journalGenerator{}, errors.New("missing snapshot generator") + } + var generator journalGenerator + if err := rlp.DecodeBytes(generatorBlob, &generator); err != nil { + return nil, journalGenerator{}, fmt.Errorf("failed to decode snapshot generator: %v", err) + } + // Retrieve the diff layer journal. It's possible that the journal is + // not existent, e.g. the disk layer is generating while that the Geth + // crashes without persisting the diff journal. + // So if there is no journal, or the journal is invalid(e.g. the journal + // is not matched with disk layer; or the it's the legacy-format journal, + // etc.), we just discard all diffs and try to recover them later. + journal := rawdb.ReadSnapshotJournal(db) + if len(journal) == 0 { + log.Warn("Loaded snapshot journal", "diskroot", base.root, "diffs", "missing") + return base, generator, nil + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + + // Firstly, resolve the first element as the journal version + version, err := r.Uint() + if err != nil { + log.Warn("Failed to resolve the journal version", "error", err) + return base, generator, nil + } + if version != journalVersion { + log.Warn("Discarded the snapshot journal with wrong version", "required", journalVersion, "got", version) + return base, generator, nil + } + // Secondly, resolve the disk layer root, ensure it's continuous + // with disk layer. Note now we can ensure it's the snapshot journal + // correct version, so we expect everything can be resolved properly. + var root common.Hash + if err := r.Decode(&root); err != nil { + return nil, journalGenerator{}, errors.New("missing disk layer root") + } + // The diff journal is not matched with disk, discard them. + // It can happen that Geth crashes without persisting the latest + // diff journal. + if !bytes.Equal(root.Bytes(), base.root.Bytes()) { + log.Warn("Loaded snapshot journal", "diskroot", base.root, "diffs", "unmatched") + return base, generator, nil + } + // Load all the snapshot diffs from the journal + snapshot, err := loadDiffLayer(base, r) + if err != nil { + return nil, journalGenerator{}, err + } + log.Debug("Loaded snapshot journal", "diskroot", base.root, "diffhead", snapshot.Root()) + return snapshot, generator, nil +} + // loadSnapshot loads a pre-existing state snapshot backed by a key-value store. -func loadSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash) (snapshot, error) { +func loadSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash, recovery bool) (snapshot, error) { // Retrieve the block number and hash of the snapshot, failing if no snapshot // is present in the database (or crashed mid-update). baseRoot := rawdb.ReadSnapshotRoot(diskdb) @@ -75,28 +160,36 @@ func loadSnapshot(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, cache: fastcache.New(cache * 1024 * 1024), root: baseRoot, } - // Retrieve the journal, it must exist since even for 0 layer it stores whether - // we've already generated the snapshot or are in progress only - journal := rawdb.ReadSnapshotJournal(diskdb) - if len(journal) == 0 { - return nil, errors.New("missing or corrupted snapshot journal") - } - r := rlp.NewStream(bytes.NewReader(journal), 0) - - // Read the snapshot generation progress for the disk layer - var generator journalGenerator - if err := r.Decode(&generator); err != nil { - return nil, fmt.Errorf("failed to load snapshot progress marker: %v", err) + var legacy bool + snapshot, generator, err := loadAndParseJournal(diskdb, base) + if err != nil { + log.Warn("Failed to load new-format journal", "error", err) + snapshot, generator, err = loadAndParseLegacyJournal(diskdb, base) + legacy = true } - // Load all the snapshot diffs from the journal - snapshot, err := loadDiffLayer(base, r) if err != nil { return nil, err } - // Entire snapshot journal loaded, sanity check the head and return - // Journal doesn't exist, don't worry if it's not supposed to + // Entire snapshot journal loaded, sanity check the head. If the loaded + // snapshot is not matched with current state root, print a warning log + // or discard the entire snapshot it's legacy snapshot. + // + // Possible scenario: Geth was crashed without persisting journal and then + // restart, the head is rewound to the point with available state(trie) + // which is below the snapshot. In this case the snapshot can be recovered + // by re-executing blocks but right now it's unavailable. if head := snapshot.Root(); head != root { - return nil, fmt.Errorf("head doesn't match snapshot: have %#x, want %#x", head, root) + // If it's legacy snapshot, or it's new-format snapshot but + // it's not in recovery mode, returns the error here for + // rebuilding the entire snapshot forcibly. + if legacy || !recovery { + return nil, fmt.Errorf("head doesn't match snapshot: have %#x, want %#x", head, root) + } + // It's in snapshot recovery, the assumption is held that + // the disk layer is always higher than chain head. It can + // be eventually recovered when the chain head beyonds the + // disk layer. + log.Warn("Snapshot is not continuous with chain", "snaproot", head, "chainroot", root) } // Everything loaded correctly, resume any suspended operations if !generator.Done { @@ -183,8 +276,8 @@ func loadDiffLayer(parent snapshot, r *rlp.Stream) (snapshot, error) { return loadDiffLayer(newDiffLayer(parent, root, destructSet, accountData, storageData), r) } -// Journal writes the persistent layer generator stats into a buffer to be stored -// in the database as the snapshot journal. +// Journal terminates any in-progress snapshot generation, also implicitly pushing +// the progress into the database. func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { // If the snapshot is currently being generated, abort it var stats *generatorStats @@ -200,6 +293,85 @@ func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { dl.lock.RLock() defer dl.lock.RUnlock() + if dl.stale { + return common.Hash{}, ErrSnapshotStale + } + // Ensure the generator stats is written even if none was ran this cycle + journalProgress(dl.diskdb, dl.genMarker, stats) + + log.Debug("Journalled disk layer", "root", dl.root) + return dl.root, nil +} + +// Journal writes the memory layer contents into a buffer to be stored in the +// database as the snapshot journal. +func (dl *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + // Journal the parent first + base, err := dl.parent.Journal(buffer) + if err != nil { + return common.Hash{}, err + } + // Ensure the layer didn't get stale + dl.lock.RLock() + defer dl.lock.RUnlock() + + if dl.Stale() { + return common.Hash{}, ErrSnapshotStale + } + // Everything below was journalled, persist this layer too + if err := rlp.Encode(buffer, dl.root); err != nil { + return common.Hash{}, err + } + destructs := make([]journalDestruct, 0, len(dl.destructSet)) + for hash := range dl.destructSet { + destructs = append(destructs, journalDestruct{Hash: hash}) + } + if err := rlp.Encode(buffer, destructs); err != nil { + return common.Hash{}, err + } + accounts := make([]journalAccount, 0, len(dl.accountData)) + for hash, blob := range dl.accountData { + accounts = append(accounts, journalAccount{Hash: hash, Blob: blob}) + } + if err := rlp.Encode(buffer, accounts); err != nil { + return common.Hash{}, err + } + storage := make([]journalStorage, 0, len(dl.storageData)) + for hash, slots := range dl.storageData { + keys := make([]common.Hash, 0, len(slots)) + vals := make([][]byte, 0, len(slots)) + for key, val := range slots { + keys = append(keys, key) + vals = append(vals, val) + } + storage = append(storage, journalStorage{Hash: hash, Keys: keys, Vals: vals}) + } + if err := rlp.Encode(buffer, storage); err != nil { + return common.Hash{}, err + } + log.Debug("Journalled diff layer", "root", dl.root, "parent", dl.parent.Root()) + return base, nil +} + +// LegacyJournal writes the persistent layer generator stats into a buffer +// to be stored in the database as the snapshot journal. +// +// Note it's the legacy version which is only used in testing right now. +func (dl *diskLayer) LegacyJournal(buffer *bytes.Buffer) (common.Hash, error) { + // If the snapshot is currently being generated, abort it + var stats *generatorStats + if dl.genAbort != nil { + abort := make(chan *generatorStats) + dl.genAbort <- abort + + if stats = <-abort; stats != nil { + stats.Log("Journalling in-progress snapshot", dl.root, dl.genMarker) + } + } + // Ensure the layer didn't get stale + dl.lock.RLock() + defer dl.lock.RUnlock() + if dl.stale { return common.Hash{}, ErrSnapshotStale } @@ -214,6 +386,7 @@ func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { entry.Slots = stats.slots entry.Storage = uint64(stats.storage) } + log.Debug("Legacy journalled disk layer", "root", dl.root) if err := rlp.Encode(buffer, entry); err != nil { return common.Hash{}, err } @@ -222,9 +395,11 @@ func (dl *diskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { // Journal writes the memory layer contents into a buffer to be stored in the // database as the snapshot journal. -func (dl *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { +// +// Note it's the legacy version which is only used in testing right now. +func (dl *diffLayer) LegacyJournal(buffer *bytes.Buffer) (common.Hash, error) { // Journal the parent first - base, err := dl.parent.Journal(buffer) + base, err := dl.parent.LegacyJournal(buffer) if err != nil { return common.Hash{}, err } @@ -266,5 +441,6 @@ func (dl *diffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { if err := rlp.Encode(buffer, storage); err != nil { return common.Hash{}, err } + log.Debug("Legacy journalled disk layer", "root", dl.root, "parent", dl.parent.Root()) return base, nil } diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index f6c5a6a9a8ca..60b4158b5642 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -86,6 +87,10 @@ var ( // range of accounts covered. ErrNotCoveredYet = errors.New("not covered yet") + // ErrNotConstructed is returned if the callers want to iterate the snapshot + // while the generation is not finished yet. + ErrNotConstructed = errors.New("snapshot is not constructed") + // errSnapshotCycle is returned if a snapshot is attempted to be inserted // that forms a cycle in the snapshot tree. errSnapshotCycle = errors.New("snapshot cycle") @@ -132,6 +137,10 @@ type snapshot interface { // flattening everything down (bad for reorgs). Journal(buffer *bytes.Buffer) (common.Hash, error) + // LegacyJournal is basically identical to Journal. it's the legacy version for + // flushing legacy journal. Now the only purpose of this function is for testing. + LegacyJournal(buffer *bytes.Buffer) (common.Hash, error) + // Stale return whether this layer has become stale (was flattened across) or // if it's still live. Stale() bool @@ -164,10 +173,12 @@ type Tree struct { // store (with a number of memory layers from a journal), ensuring that the head // of the snapshot matches the expected one. // -// If the snapshot is missing or inconsistent, the entirety is deleted and will -// be reconstructed from scratch based on the tries in the key-value store, on a -// background thread. -func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash, async bool) *Tree { +// If the snapshot is missing or the disk layer is broken, the entire is deleted +// and will be reconstructed from scratch based on the tries in the key-value +// store, on a background thread. If the memory layers from the journal is not +// continuous with disk layer or the journal is missing, all diffs will be discarded +// iff it's in "recovery" mode, otherwise rebuild is mandatory. +func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root common.Hash, async bool, recovery bool) *Tree { // Create a new, empty snapshot tree snap := &Tree{ diskdb: diskdb, @@ -179,7 +190,7 @@ func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root comm defer snap.waitBuild() } // Attempt to load a previously persisted snapshot and rebuild one if failed - head, err := loadSnapshot(diskdb, triedb, cache, root) + head, err := loadSnapshot(diskdb, triedb, cache, root, recovery) if err != nil { log.Warn("Failed to load snapshot, regenerating", "err", err) snap.Rebuild(root) @@ -194,7 +205,7 @@ func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root comm } // waitBuild blocks until the snapshot finishes rebuilding. This method is meant -// to be used by tests to ensure we're testing what we believe we are. +// to be used by tests to ensure we're testing what we believe we are. func (t *Tree) waitBuild() { // Find the rebuild termination channel var done chan struct{} @@ -411,6 +422,9 @@ func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer { // diffToDisk merges a bottom-most diff into the persistent disk layer underneath // it. The method will panic if called onto a non-bottom-most diff layer. +// +// The disk layer persistence should be operated in an atomic way. All updates should +// be discarded if the whole transition if not finished. func diffToDisk(bottom *diffLayer) *diskLayer { var ( base = bottom.parent.(*diskLayer) @@ -423,8 +437,7 @@ func diffToDisk(bottom *diffLayer) *diskLayer { base.genAbort <- abort stats = <-abort } - // Start by temporarily deleting the current snapshot block marker. This - // ensures that in the case of a crash, the entire snapshot is invalidated. + // Put the deletion in the batch writer, flush all updates in the final step. rawdb.DeleteSnapshotRoot(batch) // Mark the original base as stale as we're going to create a new wrapper @@ -467,12 +480,6 @@ func diffToDisk(bottom *diffLayer) *diskLayer { base.cache.Set(hash[:], data) snapshotCleanAccountWriteMeter.Mark(int64(len(data))) - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { - log.Crit("Failed to write account snapshot", "err", err) - } - batch.Reset() - } snapshotFlushAccountItemMeter.Mark(1) snapshotFlushAccountSizeMeter.Mark(int64(len(data))) } @@ -501,18 +508,19 @@ func diffToDisk(bottom *diffLayer) *diskLayer { snapshotFlushStorageItemMeter.Mark(1) snapshotFlushStorageSizeMeter.Mark(int64(len(data))) } - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { - log.Crit("Failed to write storage snapshot", "err", err) - } - batch.Reset() - } } // Update the snapshot block marker and write any remainder data rawdb.WriteSnapshotRoot(batch, bottom.root) + + // Write out the generator progress marker and report + journalProgress(batch, base.genMarker, stats) + + // Flush all the updates in the single db operation. Ensure the + // disk layer transition is atomic. if err := batch.Write(); err != nil { log.Crit("Failed to write leftover snapshot", "err", err) } + log.Debug("Journalled disk layer", "root", bottom.root, "complete", base.genMarker == nil) res := &diskLayer{ root: bottom.root, cache: base.cache, @@ -550,7 +558,21 @@ func (t *Tree) Journal(root common.Hash) (common.Hash, error) { t.lock.Lock() defer t.lock.Unlock() + // Firstly write out the metadata of journal journal := new(bytes.Buffer) + if err := rlp.Encode(journal, journalVersion); err != nil { + return common.Hash{}, err + } + diskroot := t.diskRoot() + if diskroot == (common.Hash{}) { + return common.Hash{}, errors.New("invalid disk root") + } + // Secondly write out the disk layer root, ensure the + // diff journal is continuous with disk. + if err := rlp.Encode(journal, diskroot); err != nil { + return common.Hash{}, err + } + // Finally write out the journal of each layer in reverse order. base, err := snap.(snapshot).Journal(journal) if err != nil { return common.Hash{}, err @@ -560,6 +582,29 @@ func (t *Tree) Journal(root common.Hash) (common.Hash, error) { return base, nil } +// LegacyJournal is basically identical to Journal. it's the legacy +// version for flushing legacy journal. Now the only purpose of this +// function is for testing. +func (t *Tree) LegacyJournal(root common.Hash) (common.Hash, error) { + // Retrieve the head snapshot to journal from var snap snapshot + snap := t.Snapshot(root) + if snap == nil { + return common.Hash{}, fmt.Errorf("snapshot [%#x] missing", root) + } + // Run the journaling + t.lock.Lock() + defer t.lock.Unlock() + + journal := new(bytes.Buffer) + base, err := snap.(snapshot).LegacyJournal(journal) + if err != nil { + return common.Hash{}, err + } + // Store the journal into the database and return + rawdb.WriteSnapshotJournal(t.diskdb, journal.Bytes()) + return base, nil +} + // Rebuild wipes all available snapshot data from the persistent database and // discard all caches and diff layers. Afterwards, it starts a new snapshot // generator with the given root hash. @@ -567,6 +612,10 @@ func (t *Tree) Rebuild(root common.Hash) { t.lock.Lock() defer t.lock.Unlock() + // Firstly delete any recovery flag in the database. Because now we are + // building a brand new snapshot. + rawdb.DeleteSnapshotRecoveryNumber(t.diskdb) + // Track whether there's a wipe currently running and keep it alive if so var wiper chan struct{} @@ -609,11 +658,79 @@ func (t *Tree) Rebuild(root common.Hash) { // AccountIterator creates a new account iterator for the specified root hash and // seeks to a starting account hash. func (t *Tree) AccountIterator(root common.Hash, seek common.Hash) (AccountIterator, error) { + ok, err := t.generating() + if err != nil { + return nil, err + } + if ok { + return nil, ErrNotConstructed + } return newFastAccountIterator(t, root, seek) } // StorageIterator creates a new storage iterator for the specified root hash and // account. The iterator will be move to the specific start position. func (t *Tree) StorageIterator(root common.Hash, account common.Hash, seek common.Hash) (StorageIterator, error) { + ok, err := t.generating() + if err != nil { + return nil, err + } + if ok { + return nil, ErrNotConstructed + } return newFastStorageIterator(t, root, account, seek) } + +// disklayer is an internal helper function to return the disk layer. +// The lock of snapTree is assumed to be held already. +func (t *Tree) disklayer() *diskLayer { + var snap snapshot + for _, s := range t.layers { + snap = s + break + } + if snap == nil { + return nil + } + switch layer := snap.(type) { + case *diskLayer: + return layer + case *diffLayer: + return layer.origin + default: + panic(fmt.Sprintf("%T: undefined layer", snap)) + } +} + +// diskRoot is a internal helper function to return the disk layer root. +// The lock of snapTree is assumed to be held already. +func (t *Tree) diskRoot() common.Hash { + disklayer := t.disklayer() + if disklayer == nil { + return common.Hash{} + } + return disklayer.Root() +} + +// generating is an internal helper function which reports whether the snapshot +// is still under the construction. +func (t *Tree) generating() (bool, error) { + t.lock.Lock() + defer t.lock.Unlock() + + layer := t.disklayer() + if layer == nil { + return false, errors.New("disk layer is missing") + } + layer.lock.RLock() + defer layer.lock.RUnlock() + return layer.genMarker != nil, nil +} + +// diskRoot is a external helper function to return the disk layer root. +func (t *Tree) DiskRoot() common.Hash { + t.lock.Lock() + defer t.lock.Unlock() + + return t.diskRoot() +} diff --git a/core/state/state_object.go b/core/state/state_object.go index 26ab67e1adb7..d0d3b4513eee 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -299,7 +299,7 @@ func (s *stateObject) updateTrie(db Database) Trie { if len(s.pendingStorage) == 0 { return s.trie } - // Track the amount of time wasted on updating the storge trie + // Track the amount of time wasted on updating the storage trie if metrics.EnabledExpensive { defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) } @@ -347,7 +347,7 @@ func (s *stateObject) updateRoot(db Database) { if s.updateTrie(db) == nil { return } - // Track the amount of time wasted on hashing the storge trie + // Track the amount of time wasted on hashing the storage trie if metrics.EnabledExpensive { defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) } @@ -364,7 +364,7 @@ func (s *stateObject) CommitTrie(db Database) error { if s.dbErr != nil { return s.dbErr } - // Track the amount of time wasted on committing the storge trie + // Track the amount of time wasted on committing the storage trie if metrics.EnabledExpensive { defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } diff --git a/core/state/state_test.go b/core/state/state_test.go index 0dc4c0ad63c3..526d7f8177cd 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -41,7 +41,9 @@ func newStateTest() *stateTest { } func TestDump(t *testing.T) { - s := newStateTest() + db := rawdb.NewMemoryDatabase() + sdb, _ := New(common.Hash{}, NewDatabaseWithConfig(db, nil), nil) + s := &stateTest{db: db, state: sdb} // generate a few entries obj1 := s.state.GetOrNewStateObject(toAddr([]byte{0x01})) diff --git a/core/state/statedb.go b/core/state/statedb.go index 36f7d863af9b..a9d1de2e0643 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -93,6 +93,9 @@ type StateDB struct { preimages map[common.Hash][]byte + // Per-transaction access list + accessList *accessList + // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. journal *journal @@ -129,6 +132,7 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), journal: newJournal(), + accessList: newAccessList(), } if sdb.snaps != nil { if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { @@ -178,6 +182,7 @@ func (s *StateDB) Reset(root common.Hash) error { s.snapStorage = make(map[common.Hash]map[common.Hash][]byte) } } + s.accessList = newAccessList() return nil } @@ -309,14 +314,19 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { return common.Hash{} } -// GetProof returns the MerkleProof for a given Account -func (s *StateDB) GetProof(a common.Address) ([][]byte, error) { +// GetProof returns the Merkle proof for a given account. +func (s *StateDB) GetProof(addr common.Address) ([][]byte, error) { + return s.GetProofByHash(crypto.Keccak256Hash(addr.Bytes())) +} + +// GetProofByHash returns the Merkle proof for a given account. +func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { var proof proofList - err := s.trie.Prove(crypto.Keccak256(a.Bytes()), 0, &proof) - return [][]byte(proof), err + err := s.trie.Prove(addrHash[:], 0, &proof) + return proof, err } -// GetStorageProof returns the StorageProof for given key +// GetStorageProof returns the Merkle proof for given storage slot. func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { var proof proofList trie := s.StorageTrie(a) @@ -324,7 +334,18 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, return proof, errors.New("storage trie for requested address does not exist") } err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) - return [][]byte(proof), err + return proof, err +} + +// GetStorageProofByHash returns the Merkle proof for given storage slot. +func (s *StateDB) GetStorageProofByHash(a common.Address, key common.Hash) ([][]byte, error) { + var proof proofList + trie := s.StorageTrie(a) + if trie == nil { + return proof, errors.New("storage trie for requested address does not exist") + } + err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) + return proof, err } // GetCommittedState retrieves a value from the given account's committed storage trie. @@ -697,6 +718,12 @@ func (s *StateDB) Copy() *StateDB { for hash, preimage := range s.preimages { state.preimages[hash] = preimage } + // Do we need to copy the access list? In practice: No. At the start of a + // transaction, the access list is empty. In practice, we only ever copy state + // _between_ transactions/blocks, never in the middle of a transaction. + // However, it doesn't cost us much to copy an empty list, so we do it anyway + // to not blow up if we ever decide copy it in the middle of a transaction + state.accessList = s.accessList.Copy() return state } @@ -798,6 +825,7 @@ func (s *StateDB) Prepare(thash, bhash common.Hash, ti int) { s.thash = thash s.bhash = bhash s.txIndex = ti + s.accessList = newAccessList() } func (s *StateDB) clearJournalAndRefund() { @@ -869,11 +897,50 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (common.Hash, error) { if err := s.snaps.Update(root, parent, s.snapDestructs, s.snapAccounts, s.snapStorage); err != nil { log.Warn("Failed to update snapshot tree", "from", parent, "to", root, "err", err) } - if err := s.snaps.Cap(root, 127); err != nil { // Persistent layer is 128th, the last available trie - log.Warn("Failed to cap snapshot tree", "root", root, "layers", 127, "err", err) + // Keep 128 diff layers in the memory, persistent layer is 129th. + // - head layer is paired with HEAD state + // - head-1 layer is paired with HEAD-1 state + // - head-127 layer(bottom-most diff layer) is paired with HEAD-127 state + if err := s.snaps.Cap(root, 128); err != nil { + log.Warn("Failed to cap snapshot tree", "root", root, "layers", 128, "err", err) } } s.snap, s.snapDestructs, s.snapAccounts, s.snapStorage = nil, nil, nil, nil } return root, err } + +// AddAddressToAccessList adds the given address to the access list +func (s *StateDB) AddAddressToAccessList(addr common.Address) { + if s.accessList.AddAddress(addr) { + s.journal.append(accessListAddAccountChange{&addr}) + } +} + +// AddSlotToAccessList adds the given (address, slot)-tuple to the access list +func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { + addrMod, slotMod := s.accessList.AddSlot(addr, slot) + if addrMod { + // In practice, this should not happen, since there is no way to enter the + // scope of 'address' without having the 'address' become already added + // to the access list (via call-variant, create, etc). + // Better safe than sorry, though + s.journal.append(accessListAddAccountChange{&addr}) + } + if slotMod { + s.journal.append(accessListAddSlotChange{ + address: &addr, + slot: &slot, + }) + } +} + +// AddressInAccessList returns true if the given address is in the access list. +func (s *StateDB) AddressInAccessList(addr common.Address) bool { + return s.accessList.ContainsAddress(addr) +} + +// SlotInAccessList returns true if the given (address, slot)-tuple is in the access list. +func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { + return s.accessList.Contains(addr, slot) +} diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 36ff27133188..70d01ff3dd9c 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -328,6 +328,20 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { }, args: make([]int64, 1), }, + { + name: "AddAddressToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddAddressToAccessList(addr) + }, + }, + { + name: "AddSlotToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddSlotToAccessList(addr, + common.Hash{byte(a.args[0])}) + }, + args: make([]int64, 1), + }, } action := actions[r.Intn(len(actions))] var nameargs []string @@ -727,3 +741,177 @@ func TestMissingTrieNodes(t *testing.T) { t.Fatalf("expected error, got root :%x", root) } } + +func TestStateDBAccessList(t *testing.T) { + // Some helpers + addr := func(a string) common.Address { + return common.HexToAddress(a) + } + slot := func(a string) common.Hash { + return common.HexToHash(a) + } + + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + state, _ := New(common.Hash{}, db, nil) + state.accessList = newAccessList() + + verifyAddrs := func(astrings ...string) { + t.Helper() + // convert to common.Address form + var addresses []common.Address + var addressMap = make(map[common.Address]struct{}) + for _, astring := range astrings { + address := addr(astring) + addresses = append(addresses, address) + addressMap[address] = struct{}{} + } + // Check that the given addresses are in the access list + for _, address := range addresses { + if !state.AddressInAccessList(address) { + t.Fatalf("expected %x to be in access list", address) + } + } + // Check that only the expected addresses are present in the acesslist + for address := range state.accessList.addresses { + if _, exist := addressMap[address]; !exist { + t.Fatalf("extra address %x in access list", address) + } + } + } + verifySlots := func(addrString string, slotStrings ...string) { + if !state.AddressInAccessList(addr(addrString)) { + t.Fatalf("scope missing address/slots %v", addrString) + } + var address = addr(addrString) + // convert to common.Hash form + var slots []common.Hash + var slotMap = make(map[common.Hash]struct{}) + for _, slotString := range slotStrings { + s := slot(slotString) + slots = append(slots, s) + slotMap[s] = struct{}{} + } + // Check that the expected items are in the access list + for i, s := range slots { + if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent { + t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString) + } + } + // Check that no extra elements are in the access list + index := state.accessList.addresses[address] + if index >= 0 { + stateSlots := state.accessList.slots[index] + for s := range stateSlots { + if _, slotPresent := slotMap[s]; !slotPresent { + t.Fatalf("scope has extra slot %v (address %v)", s, addrString) + } + } + } + } + + state.AddAddressToAccessList(addr("aa")) // 1 + state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3 + state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + // Make a copy + stateCopy1 := state.Copy() + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + // same again, should cause no journal entries + state.AddSlotToAccessList(addr("bb"), slot("01")) + state.AddSlotToAccessList(addr("bb"), slot("02")) + state.AddAddressToAccessList(addr("aa")) + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + // some new ones + state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 + state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 + state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8 + state.AddAddressToAccessList(addr("cc")) + if exp, got := 8, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + verifySlots("cc", "01") + + // now start rolling back changes + state.journal.revert(state, 7) + if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 6) + if state.AddressInAccessList(addr("cc")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 5) + if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 4) + if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + state.journal.revert(state, 3) + if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01") + + state.journal.revert(state, 2) + if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + + state.journal.revert(state, 1) + if state.AddressInAccessList(addr("bb")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa") + + state.journal.revert(state, 0) + if state.AddressInAccessList(addr("aa")) { + t.Fatalf("addr present, expected missing") + } + if got, exp := len(state.accessList.addresses), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + // Check the copy + // Make a copy + state = stateCopy1 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + if got, exp := len(state.accessList.addresses), 2; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 1; got != exp { + t.Fatalf("expected empty, got %d", got) + } +} diff --git a/core/state/sync_test.go b/core/state/sync_test.go index deb4b52b4c10..9c4867093d1e 100644 --- a/core/state/sync_test.go +++ b/core/state/sync_test.go @@ -62,7 +62,8 @@ func makeTestState() (Database, common.Hash, []*testAccount) { } if i%5 == 0 { for j := byte(0); j < 5; j++ { - obj.SetState(db, crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j}), crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j})) + hash := crypto.Keccak256Hash([]byte{i, i, i, i, i, j, j}) + obj.SetState(db, hash, hash) } } state.updateStateObject(obj) @@ -401,15 +402,14 @@ func TestIncompleteStateSync(t *testing.T) { // Create a random state to copy srcDb, srcRoot, srcAccounts := makeTestState() - // isCode reports whether the hash is contract code hash. - isCode := func(hash common.Hash) bool { - for _, acc := range srcAccounts { - if hash == crypto.Keccak256Hash(acc.code) { - return true - } + // isCodeLookup to save some hashing + var isCode = make(map[common.Hash]struct{}) + for _, acc := range srcAccounts { + if len(acc.code) > 0 { + isCode[crypto.Keccak256Hash(acc.code)] = struct{}{} } - return false } + isCode[common.BytesToHash(emptyCodeHash)] = struct{}{} checkTrieConsistency(srcDb.TrieDB().DiskDB().(ethdb.Database), srcRoot) // Create a destination state and sync with the scheduler @@ -447,15 +447,13 @@ func TestIncompleteStateSync(t *testing.T) { batch.Write() for _, result := range results { added = append(added, result.Hash) - } - // Check that all known sub-tries added so far are complete or missing entirely. - for _, hash := range added { - if isCode(hash) { + // Check that all known sub-tries added so far are complete or missing entirely. + if _, ok := isCode[result.Hash]; ok { continue } // Can't use checkStateConsistency here because subtrie keys may have odd // length and crash in LeafKey. - if err := checkTrieConsistency(dstDb, hash); err != nil { + if err := checkTrieConsistency(dstDb, result.Hash); err != nil { t.Fatalf("state inconsistent: %v", err) } } @@ -466,9 +464,9 @@ func TestIncompleteStateSync(t *testing.T) { // Sanity check that removing any node from the database is detected for _, node := range added[1:] { var ( - key = node.Bytes() - code = isCode(node) - val []byte + key = node.Bytes() + _, code = isCode[node] + val []byte ) if code { val = rawdb.ReadCode(dstDb, node) diff --git a/core/state_prefetcher.go b/core/state_prefetcher.go index 1c550fa8bc6d..6fa52c2d9b8f 100644 --- a/core/state_prefetcher.go +++ b/core/state_prefetcher.go @@ -86,8 +86,9 @@ func precacheTransaction(config *params.ChainConfig, bc ChainContext, author *co return err } // Create the EVM and execute the transaction - context := NewEVMContext(msg, header, bc, author) - vm := vm.NewEVM(context, statedb, config, cfg) + context := NewEVMBlockContext(header, bc, author) + txContext := NewEVMTxContext(msg) + vm := vm.NewEVM(context, txContext, statedb, config, cfg) _, err = ApplyMessage(vm, msg, gaspool) return err diff --git a/core/state_processor.go b/core/state_processor.go index e655d8f3bfba..cdc86a694ea0 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -17,6 +17,8 @@ package core import ( + "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/misc" @@ -65,13 +67,19 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg if p.config.DAOForkSupport && p.config.DAOForkBlock != nil && p.config.DAOForkBlock.Cmp(block.Number()) == 0 { misc.ApplyDAOHardFork(statedb) } + blockContext := NewEVMBlockContext(header, p.bc, nil) + vmenv := vm.NewEVM(blockContext, vm.TxContext{}, statedb, p.config, cfg) // Iterate over and process the individual transactions for i, tx := range block.Transactions() { - statedb.Prepare(tx.Hash(), block.Hash(), i) - receipt, err := ApplyTransaction(p.config, p.bc, nil, gp, statedb, header, tx, usedGas, cfg) + msg, err := tx.AsMessage(types.MakeSigner(p.config, header.Number)) if err != nil { return nil, nil, 0, err } + statedb.Prepare(tx.Hash(), block.Hash(), i) + receipt, err := applyTransaction(msg, p.config, p.bc, nil, gp, statedb, header, tx, usedGas, vmenv) + if err != nil { + return nil, nil, 0, fmt.Errorf("could not apply tx %d [%v]: %w", i, tx.Hash().Hex(), err) + } receipts = append(receipts, receipt) allLogs = append(allLogs, receipt.Logs...) } @@ -81,22 +89,25 @@ func (p *StateProcessor) Process(block *types.Block, statedb *state.StateDB, cfg return receipts, allLogs, *usedGas, nil } -// ApplyTransaction attempts to apply a transaction to the given state database -// and uses the input parameters for its environment. It returns the receipt -// for the transaction, gas used and an error if the transaction failed, -// indicating the block was invalid. -func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config) (*types.Receipt, error) { - msg, err := tx.AsMessage(types.MakeSigner(config, header.Number)) - if err != nil { - return nil, err - } +func applyTransaction(msg types.Message, config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (*types.Receipt, error) { // Create a new context to be used in the EVM environment - context := NewEVMContext(msg, header, bc, author) - // Create a new environment which holds all relevant information - // about the transaction and calling mechanisms. - vmenv := vm.NewEVM(context, statedb, config, cfg) + txContext := NewEVMTxContext(msg) + // Add addresses to access list if applicable + if config.IsYoloV2(header.Number) { + statedb.AddAddressToAccessList(msg.From()) + if dst := msg.To(); dst != nil { + statedb.AddAddressToAccessList(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range evm.ActivePrecompiles() { + statedb.AddAddressToAccessList(addr) + } + } + + // Update the evm with the new transaction context. + evm.Reset(txContext, statedb) // Apply the transaction to the current state (included in the env) - result, err := ApplyMessage(vmenv, msg, gp) + result, err := ApplyMessage(evm, msg, gp) if err != nil { return nil, err } @@ -116,7 +127,7 @@ func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *commo receipt.GasUsed = result.UsedGas // if the transaction created a contract, store the creation address in the receipt. if msg.To() == nil { - receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce()) + receipt.ContractAddress = crypto.CreateAddress(evm.TxContext.Origin, tx.Nonce()) } // Set the receipt logs and create a bloom for filtering receipt.Logs = statedb.GetLogs(tx.Hash()) @@ -127,3 +138,18 @@ func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *commo return receipt, err } + +// ApplyTransaction attempts to apply a transaction to the given state database +// and uses the input parameters for its environment. It returns the receipt +// for the transaction, gas used and an error if the transaction failed, +// indicating the block was invalid. +func ApplyTransaction(config *params.ChainConfig, bc ChainContext, author *common.Address, gp *GasPool, statedb *state.StateDB, header *types.Header, tx *types.Transaction, usedGas *uint64, cfg vm.Config) (*types.Receipt, error) { + msg, err := tx.AsMessage(types.MakeSigner(config, header.Number)) + if err != nil { + return nil, err + } + // Create a new context to be used in the EVM environment + blockContext := NewEVMBlockContext(header, bc, author) + vmenv := vm.NewEVM(blockContext, vm.TxContext{}, statedb, config, cfg) + return applyTransaction(msg, config, bc, author, gp, statedb, header, tx, usedGas, vmenv) +} diff --git a/core/state_processor_test.go b/core/state_processor_test.go new file mode 100644 index 000000000000..6e63975ac13b --- /dev/null +++ b/core/state_processor_test.go @@ -0,0 +1,152 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package core + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" + "golang.org/x/crypto/sha3" +) + +// TestStateProcessorErrors tests the output from the 'core' errors +// as defined in core/error.go. These errors are generated when the +// blockchain imports bad blocks, meaning blocks which have valid headers but +// contain invalid transactions +func TestStateProcessorErrors(t *testing.T) { + var ( + signer = types.HomesteadSigner{} + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + db = rawdb.NewMemoryDatabase() + gspec = &Genesis{ + Config: params.TestChainConfig, + } + genesis = gspec.MustCommit(db) + blockchain, _ = NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}, nil, nil) + ) + defer blockchain.Stop() + var makeTx = func(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *types.Transaction { + tx, _ := types.SignTx(types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data), signer, testKey) + return tx + } + for i, tt := range []struct { + txs []*types.Transaction + want string + }{ + { + txs: []*types.Transaction{ + makeTx(0, common.Address{}, big.NewInt(0), params.TxGas, nil, nil), + makeTx(0, common.Address{}, big.NewInt(0), params.TxGas, nil, nil), + }, + want: "could not apply tx 1 [0x36bfa6d14f1cd35a1be8cc2322982a595fabc0e799f09c1de3bad7bd5b1f7626]: nonce too low: address 0x71562b71999873DB5b286dF957af199Ec94617F7, tx: 0 state: 1", + }, + { + txs: []*types.Transaction{ + makeTx(100, common.Address{}, big.NewInt(0), params.TxGas, nil, nil), + }, + want: "could not apply tx 0 [0x51cd272d41ef6011d8138e18bf4043797aca9b713c7d39a97563f9bbe6bdbe6f]: nonce too high: address 0x71562b71999873DB5b286dF957af199Ec94617F7, tx: 100 state: 0", + }, + { + txs: []*types.Transaction{ + makeTx(0, common.Address{}, big.NewInt(0), 21000000, nil, nil), + }, + want: "could not apply tx 0 [0x54c58b530824b0bb84b7a98183f08913b5d74e1cebc368515ef3c65edf8eb56a]: gas limit reached", + }, + { + txs: []*types.Transaction{ + makeTx(0, common.Address{}, big.NewInt(1), params.TxGas, nil, nil), + }, + want: "could not apply tx 0 [0x3094b17498940d92b13baccf356ce8bfd6f221e926abc903d642fa1466c5b50e]: insufficient funds for transfer: address 0x71562b71999873DB5b286dF957af199Ec94617F7", + }, + { + txs: []*types.Transaction{ + makeTx(0, common.Address{}, big.NewInt(0), params.TxGas, big.NewInt(0xffffff), nil), + }, + want: "could not apply tx 0 [0xaa3f7d86802b1f364576d9071bf231e31d61b392d306831ac9cf706ff5371ce0]: insufficient funds for gas * price + value: address 0x71562b71999873DB5b286dF957af199Ec94617F7 have 0 want 352321515000", + }, + { + txs: []*types.Transaction{ + makeTx(0, common.Address{}, big.NewInt(0), params.TxGas, nil, nil), + makeTx(1, common.Address{}, big.NewInt(0), params.TxGas, nil, nil), + makeTx(2, common.Address{}, big.NewInt(0), params.TxGas, nil, nil), + makeTx(3, common.Address{}, big.NewInt(0), params.TxGas-1000, big.NewInt(0), nil), + }, + want: "could not apply tx 3 [0x836fab5882205362680e49b311a20646de03b630920f18ec6ee3b111a2cf6835]: intrinsic gas too low: have 20000, want 21000", + }, + // The last 'core' error is ErrGasUintOverflow: "gas uint64 overflow", but in order to + // trigger that one, we'd have to allocate a _huge_ chunk of data, such that the + // multiplication len(data) +gas_per_byte overflows uint64. Not testable at the moment + } { + block := GenerateBadBlock(genesis, ethash.NewFaker(), tt.txs) + _, err := blockchain.InsertChain(types.Blocks{block}) + if err == nil { + t.Fatal("block imported without errors") + } + if have, want := err.Error(), tt.want; have != want { + t.Errorf("test %d:\nhave \"%v\"\nwant \"%v\"\n", i, have, want) + } + } +} + +// GenerateBadBlock constructs a "block" which contains the transactions. The transactions are not expected to be +// valid, and no proper post-state can be made. But from the perspective of the blockchain, the block is sufficiently +// valid to be considered for import: +// - valid pow (fake), ancestry, difficulty, gaslimit etc +func GenerateBadBlock(parent *types.Block, engine consensus.Engine, txs types.Transactions) *types.Block { + header := &types.Header{ + ParentHash: parent.Hash(), + Coinbase: parent.Coinbase(), + Difficulty: engine.CalcDifficulty(&fakeChainReader{params.TestChainConfig}, parent.Time()+10, &types.Header{ + Number: parent.Number(), + Time: parent.Time(), + Difficulty: parent.Difficulty(), + UncleHash: parent.UncleHash(), + }), + GasLimit: CalcGasLimit(parent, parent.GasLimit(), parent.GasLimit()), + Number: new(big.Int).Add(parent.Number(), common.Big1), + Time: parent.Time() + 10, + UncleHash: types.EmptyUncleHash, + } + var receipts []*types.Receipt + + // The post-state result doesn't need to be correct (this is a bad block), but we do need something there + // Preferably something unique. So let's use a combo of blocknum + txhash + hasher := sha3.NewLegacyKeccak256() + hasher.Write(header.Number.Bytes()) + var cumulativeGas uint64 + for _, tx := range txs { + txh := tx.Hash() + hasher.Write(txh[:]) + receipt := types.NewReceipt(nil, false, cumulativeGas+tx.Gas()) + receipt.TxHash = tx.Hash() + receipt.GasUsed = tx.Gas() + receipts = append(receipts, receipt) + cumulativeGas += tx.Gas() + } + header.Root = common.BytesToHash(hasher.Sum(nil)) + // Assemble and return the final block for sealing + return types.NewBlock(header, txs, nil, receipts, new(trie.Trie)) +} diff --git a/core/state_transition.go b/core/state_transition.go index 9a9bf475e9aa..a8d1936c1f76 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -17,6 +17,7 @@ package core import ( + "fmt" "math" "math/big" @@ -174,8 +175,8 @@ func (st *StateTransition) to() common.Address { func (st *StateTransition) buyGas() error { mgval := new(big.Int).Mul(new(big.Int).SetUint64(st.msg.Gas()), st.gasPrice) - if st.state.GetBalance(st.msg.From()).Cmp(mgval) < 0 { - return ErrInsufficientFunds + if have, want := st.state.GetBalance(st.msg.From()), mgval; have.Cmp(want) < 0 { + return fmt.Errorf("%w: address %v have %v want %v", ErrInsufficientFunds, st.msg.From().Hex(), have, want) } if err := st.gp.SubGas(st.msg.Gas()); err != nil { return err @@ -190,11 +191,13 @@ func (st *StateTransition) buyGas() error { func (st *StateTransition) preCheck() error { // Make sure this transaction's nonce is correct. if st.msg.CheckNonce() { - nonce := st.state.GetNonce(st.msg.From()) - if nonce < st.msg.Nonce() { - return ErrNonceTooHigh - } else if nonce > st.msg.Nonce() { - return ErrNonceTooLow + stNonce := st.state.GetNonce(st.msg.From()) + if msgNonce := st.msg.Nonce(); stNonce < msgNonce { + return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooHigh, + st.msg.From().Hex(), msgNonce, stNonce) + } else if stNonce > msgNonce { + return fmt.Errorf("%w: address %v, tx: %d state: %d", ErrNonceTooLow, + st.msg.From().Hex(), msgNonce, stNonce) } } return st.buyGas() @@ -230,8 +233,8 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { } msg := st.msg sender := vm.AccountRef(msg.From()) - homestead := st.evm.ChainConfig().IsHomestead(st.evm.BlockNumber) - istanbul := st.evm.ChainConfig().IsIstanbul(st.evm.BlockNumber) + homestead := st.evm.ChainConfig().IsHomestead(st.evm.Context.BlockNumber) + istanbul := st.evm.ChainConfig().IsIstanbul(st.evm.Context.BlockNumber) contractCreation := msg.To() == nil // Check clauses 4-5, subtract intrinsic gas if everything is correct @@ -240,13 +243,13 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { return nil, err } if st.gas < gas { - return nil, ErrIntrinsicGas + return nil, fmt.Errorf("%w: have %d, want %d", ErrIntrinsicGas, st.gas, gas) } st.gas -= gas // Check clause 6 - if msg.Value().Sign() > 0 && !st.evm.CanTransfer(st.state, msg.From(), msg.Value()) { - return nil, ErrInsufficientFundsForTransfer + if msg.Value().Sign() > 0 && !st.evm.Context.CanTransfer(st.state, msg.From(), msg.Value()) { + return nil, fmt.Errorf("%w: address %v", ErrInsufficientFundsForTransfer, msg.From().Hex()) } var ( ret []byte @@ -260,7 +263,7 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { ret, st.gas, vmerr = st.evm.Call(sender, st.to(), st.data, st.gas, st.value) } st.refundGas() - st.state.AddBalance(st.evm.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice)) + st.state.AddBalance(st.evm.Context.Coinbase, new(big.Int).Mul(new(big.Int).SetUint64(st.gasUsed()), st.gasPrice)) return &ExecutionResult{ UsedGas: st.gasUsed(), diff --git a/core/tx_list.go b/core/tx_list.go index bf304eedcf43..894640d570c7 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -24,7 +24,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/log" ) // nonceHeap is a heap.Interface implementation over 64bit unsigned integers for @@ -433,29 +432,35 @@ func (h *priceHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = nil *h = old[0 : n-1] return x } // txPricedList is a price-sorted heap to allow operating on transactions pool -// contents in a price-incrementing way. +// contents in a price-incrementing way. It's built opon the all transactions +// in txpool but only interested in the remote part. It means only remote transactions +// will be considered for tracking, sorting, eviction, etc. type txPricedList struct { - all *txLookup // Pointer to the map of all transactions - items *priceHeap // Heap of prices of all the stored transactions - stales int // Number of stale price points to (re-heap trigger) + all *txLookup // Pointer to the map of all transactions + remotes *priceHeap // Heap of prices of all the stored **remote** transactions + stales int // Number of stale price points to (re-heap trigger) } // newTxPricedList creates a new price-sorted transaction heap. func newTxPricedList(all *txLookup) *txPricedList { return &txPricedList{ - all: all, - items: new(priceHeap), + all: all, + remotes: new(priceHeap), } } // Put inserts a new transaction into the heap. -func (l *txPricedList) Put(tx *types.Transaction) { - heap.Push(l.items, tx) +func (l *txPricedList) Put(tx *types.Transaction, local bool) { + if local { + return + } + heap.Push(l.remotes, tx) } // Removed notifies the prices transaction list that an old transaction dropped @@ -464,121 +469,95 @@ func (l *txPricedList) Put(tx *types.Transaction) { func (l *txPricedList) Removed(count int) { // Bump the stale counter, but exit if still too low (< 25%) l.stales += count - if l.stales <= len(*l.items)/4 { + if l.stales <= len(*l.remotes)/4 { return } // Seems we've reached a critical number of stale transactions, reheap - reheap := make(priceHeap, 0, l.all.Count()) - - l.stales, l.items = 0, &reheap - l.all.Range(func(hash common.Hash, tx *types.Transaction) bool { - *l.items = append(*l.items, tx) - return true - }) - heap.Init(l.items) + l.Reheap() } // Cap finds all the transactions below the given price threshold, drops them // from the priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transactions { +// +// Note: only remote transactions will be considered for eviction. +func (l *txPricedList) Cap(threshold *big.Int) types.Transactions { drop := make(types.Transactions, 0, 128) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 { + for len(*l.remotes) > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if l.all.Get(tx.Hash()) == nil { + cheapest := (*l.remotes)[0] + if l.all.GetRemote(cheapest.Hash()) == nil { // Removed or migrated + heap.Pop(l.remotes) l.stales-- continue } // Stop the discards if we've reached the threshold - if tx.GasPriceIntCmp(threshold) >= 0 { - save = append(save, tx) + if cheapest.GasPriceIntCmp(threshold) >= 0 { break } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - } - } - for _, tx := range save { - heap.Push(l.items, tx) + heap.Pop(l.remotes) + drop = append(drop, cheapest) } return drop } // Underpriced checks whether a transaction is cheaper than (or as cheap as) the -// lowest priced transaction currently being tracked. -func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) bool { - // Local transactions cannot be underpriced - if local.containsTx(tx) { - return false - } +// lowest priced (remote) transaction currently being tracked. +func (l *txPricedList) Underpriced(tx *types.Transaction) bool { // Discard stale price points if found at the heap start - for len(*l.items) > 0 { - head := []*types.Transaction(*l.items)[0] - if l.all.Get(head.Hash()) == nil { + for len(*l.remotes) > 0 { + head := []*types.Transaction(*l.remotes)[0] + if l.all.GetRemote(head.Hash()) == nil { // Removed or migrated l.stales-- - heap.Pop(l.items) + heap.Pop(l.remotes) continue } break } // Check if the transaction is underpriced or not - if len(*l.items) == 0 { - log.Error("Pricing query for empty pool") // This cannot happen, print to catch programming errors - return false + if len(*l.remotes) == 0 { + return false // There is no remote transaction at all. } - cheapest := []*types.Transaction(*l.items)[0] + // If the remote transaction is even cheaper than the + // cheapest one tracked locally, reject it. + cheapest := []*types.Transaction(*l.remotes)[0] return cheapest.GasPriceCmp(tx) >= 0 } // Discard finds a number of most underpriced transactions, removes them from the // priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Discard(slots int, local *accountSet) types.Transactions { - // If we have some local accountset, those will not be discarded - if !local.empty() { - // In case the list is filled to the brim with 'local' txs, we do this - // little check to avoid unpacking / repacking the heap later on, which - // is very expensive - discardable := 0 - for _, tx := range *l.items { - if !local.containsTx(tx) { - discardable++ - } - if discardable >= slots { - break - } - } - if slots > discardable { - slots = discardable - } - } - if slots == 0 { - return nil - } - drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, len(*l.items)-slots) // Local underpriced transactions to keep - - for len(*l.items) > 0 && slots > 0 { +// +// Note local transaction won't be considered for eviction. +func (l *txPricedList) Discard(slots int, force bool) (types.Transactions, bool) { + drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop + for len(*l.remotes) > 0 && slots > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if l.all.Get(tx.Hash()) == nil { + tx := heap.Pop(l.remotes).(*types.Transaction) + if l.all.GetRemote(tx.Hash()) == nil { // Removed or migrated l.stales-- continue } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - slots -= numSlots(tx) + // Non stale transaction found, discard it + drop = append(drop, tx) + slots -= numSlots(tx) + } + // If we still can't make enough room for the new transaction + if slots > 0 && !force { + for _, tx := range drop { + heap.Push(l.remotes, tx) } + return nil, false } - for _, tx := range save { - heap.Push(l.items, tx) - } - return drop + return drop, true +} + +// Reheap forcibly rebuilds the heap based on the current remote transaction set. +func (l *txPricedList) Reheap() { + reheap := make(priceHeap, 0, l.all.RemoteCount()) + + l.stales, l.remotes = 0, &reheap + l.all.Range(func(hash common.Hash, tx *types.Transaction, local bool) bool { + *l.remotes = append(*l.remotes, tx) + return true + }, false, true) // Only iterate remotes + heap.Init(l.remotes) } diff --git a/core/tx_pool.go b/core/tx_pool.go index 0fe1d3db5b68..4a17c31ca881 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -63,6 +63,10 @@ var ( // configured for the transaction pool. ErrUnderpriced = errors.New("transaction underpriced") + // ErrTxPoolOverflow is returned if the transaction pool is full and can't accpet + // another remote transaction. + ErrTxPoolOverflow = errors.New("txpool is full") + // ErrReplaceUnderpriced is returned if a transaction is attempted to be replaced // with a different one without the required price bump. ErrReplaceUnderpriced = errors.New("replacement transaction underpriced") @@ -105,6 +109,7 @@ var ( validTxMeter = metrics.NewRegisteredMeter("txpool/valid", nil) invalidTxMeter = metrics.NewRegisteredMeter("txpool/invalid", nil) underpricedTxMeter = metrics.NewRegisteredMeter("txpool/underpriced", nil) + overflowedTxMeter = metrics.NewRegisteredMeter("txpool/overflowed", nil) pendingGauge = metrics.NewRegisteredGauge("txpool/pending", nil) queuedGauge = metrics.NewRegisteredGauge("txpool/queued", nil) @@ -421,7 +426,7 @@ func (pool *TxPool) SetGasPrice(price *big.Int) { defer pool.mu.Unlock() pool.gasPrice = price - for _, tx := range pool.priced.Cap(price, pool.locals) { + for _, tx := range pool.priced.Cap(price) { pool.removeTx(tx.Hash(), false) } log.Info("Transaction pool price threshold updated", "price", price) @@ -536,7 +541,6 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { return ErrInvalidSender } // Drop non-local transactions under our own minimal accepted gas price - local = local || pool.locals.contains(from) // account may be local even if the transaction arrived from the network if !local && tx.GasPriceIntCmp(pool.gasPrice) < 0 { return ErrUnderpriced } @@ -575,22 +579,36 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e knownTxMeter.Mark(1) return false, ErrAlreadyKnown } + // Make the local flag. If it's from local source or it's from the network but + // the sender is marked as local previously, treat it as the local transaction. + isLocal := local || pool.locals.containsTx(tx) + // If the transaction fails basic validation, discard it - if err := pool.validateTx(tx, local); err != nil { + if err := pool.validateTx(tx, isLocal); err != nil { log.Trace("Discarding invalid transaction", "hash", hash, "err", err) invalidTxMeter.Mark(1) return false, err } // If the transaction pool is full, discard underpriced transactions - if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue { + if uint64(pool.all.Count()+numSlots(tx)) > pool.config.GlobalSlots+pool.config.GlobalQueue { // If the new transaction is underpriced, don't accept it - if !local && pool.priced.Underpriced(tx, pool.locals) { + if !isLocal && pool.priced.Underpriced(tx) { log.Trace("Discarding underpriced transaction", "hash", hash, "price", tx.GasPrice()) underpricedTxMeter.Mark(1) return false, ErrUnderpriced } - // New transaction is better than our worse ones, make room for it - drop := pool.priced.Discard(pool.all.Slots()-int(pool.config.GlobalSlots+pool.config.GlobalQueue)+numSlots(tx), pool.locals) + // New transaction is better than our worse ones, make room for it. + // If it's a local transaction, forcibly discard all available transactions. + // Otherwise if we can't make enough room for new one, abort the operation. + drop, success := pool.priced.Discard(pool.all.Slots()-int(pool.config.GlobalSlots+pool.config.GlobalQueue)+numSlots(tx), isLocal) + + // Special case, we still can't make the room for the new remote one. + if !isLocal && !success { + log.Trace("Discarding overflown transaction", "hash", hash) + overflowedTxMeter.Mark(1) + return false, ErrTxPoolOverflow + } + // Kick out the underpriced remote transactions. for _, tx := range drop { log.Trace("Discarding freshly underpriced transaction", "hash", tx.Hash(), "price", tx.GasPrice()) underpricedTxMeter.Mark(1) @@ -612,8 +630,8 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e pool.priced.Removed(1) pendingReplaceMeter.Mark(1) } - pool.all.Add(tx) - pool.priced.Put(tx) + pool.all.Add(tx, isLocal) + pool.priced.Put(tx, isLocal) pool.journalTx(from, tx) pool.queueTxEvent(tx) log.Trace("Pooled new executable transaction", "hash", hash, "from", from, "to", tx.To()) @@ -623,18 +641,17 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e return old != nil, nil } // New transaction isn't replacing a pending one, push into queue - replaced, err = pool.enqueueTx(hash, tx) + replaced, err = pool.enqueueTx(hash, tx, isLocal, true) if err != nil { return false, err } // Mark local addresses and journal local transactions - if local { - if !pool.locals.contains(from) { - log.Info("Setting new local account", "address", from) - pool.locals.add(from) - } + if local && !pool.locals.contains(from) { + log.Info("Setting new local account", "address", from) + pool.locals.add(from) + pool.priced.Removed(pool.all.RemoteToLocals(pool.locals)) // Migrate the remotes if it's marked as local first time. } - if local || pool.locals.contains(from) { + if isLocal { localGauge.Inc(1) } pool.journalTx(from, tx) @@ -646,7 +663,7 @@ func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err e // enqueueTx inserts a new transaction into the non-executable transaction queue. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, error) { +func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction, local bool, addAll bool) (bool, error) { // Try to insert the transaction into the future queue from, _ := types.Sender(pool.signer, tx) // already validated if pool.queue[from] == nil { @@ -667,9 +684,14 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er // Nothing was replaced, bump the queued counter queuedGauge.Inc(1) } - if pool.all.Get(hash) == nil { - pool.all.Add(tx) - pool.priced.Put(tx) + // If the transaction isn't in lookup set but it's expected to be there, + // show the error log. + if pool.all.Get(hash) == nil && !addAll { + log.Error("Missing transaction in lookup set, please report the issue", "hash", hash) + } + if addAll { + pool.all.Add(tx, local) + pool.priced.Put(tx, local) } // If we never record the heartbeat, do it right now. if _, exist := pool.beats[from]; !exist { @@ -718,11 +740,6 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T // Nothing was replaced, bump the pending counter pendingGauge.Inc(1) } - // Failsafe to work around direct pending inserts (tests) - if pool.all.Get(hash) == nil { - pool.all.Add(tx) - pool.priced.Put(tx) - } // Set the potentially new pending nonce and notify any subsystems of the new tx pool.pendingNonces.set(addr, tx.Nonce()+1) @@ -817,6 +834,7 @@ func (pool *TxPool) addTxs(txs []*types.Transaction, local, sync bool) []error { nilSlot++ } errs[nilSlot] = err + nilSlot++ } // Reorg the pool internals if needed and return done := pool.requestPromoteExecutables(dirtyAddrs) @@ -903,7 +921,8 @@ func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { } // Postpone any invalidated transactions for _, tx := range invalids { - pool.enqueueTx(tx.Hash(), tx) + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(tx.Hash(), tx, false, false) } // Update the account nonce if needed pool.pendingNonces.setIfLower(addr, tx.Nonce()) @@ -1407,7 +1426,9 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range invalids { hash := tx.Hash() log.Trace("Demoting pending transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } pendingGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) if pool.locals.contains(addr) { @@ -1419,7 +1440,9 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range gapped { hash := tx.Hash() log.Error("Demoting invalidated transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } pendingGauge.Dec(int64(len(gapped))) // This might happen in a reorg, so log it to the metering @@ -1518,8 +1541,8 @@ func (as *accountSet) merge(other *accountSet) { as.cache = nil } -// txLookup is used internally by TxPool to track transactions while allowing lookup without -// mutex contention. +// txLookup is used internally by TxPool to track transactions while allowing +// lookup without mutex contention. // // Note, although this type is properly protected against concurrent access, it // is **not** a type that should ever be mutated or even exposed outside of the @@ -1527,27 +1550,43 @@ func (as *accountSet) merge(other *accountSet) { // internal mechanisms. The sole purpose of the type is to permit out-of-bound // peeking into the pool in TxPool.Get without having to acquire the widely scoped // TxPool.mu mutex. +// +// This lookup set combines the notion of "local transactions", which is useful +// to build upper-level structure. type txLookup struct { - all map[common.Hash]*types.Transaction - slots int - lock sync.RWMutex + slots int + lock sync.RWMutex + locals map[common.Hash]*types.Transaction + remotes map[common.Hash]*types.Transaction } // newTxLookup returns a new txLookup structure. func newTxLookup() *txLookup { return &txLookup{ - all: make(map[common.Hash]*types.Transaction), + locals: make(map[common.Hash]*types.Transaction), + remotes: make(map[common.Hash]*types.Transaction), } } -// Range calls f on each key and value present in the map. -func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction) bool) { +// Range calls f on each key and value present in the map. The callback passed +// should return the indicator whether the iteration needs to be continued. +// Callers need to specify which set (or both) to be iterated. +func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction, local bool) bool, local bool, remote bool) { t.lock.RLock() defer t.lock.RUnlock() - for key, value := range t.all { - if !f(key, value) { - break + if local { + for key, value := range t.locals { + if !f(key, value, true) { + return + } + } + } + if remote { + for key, value := range t.remotes { + if !f(key, value, false) { + return + } } } } @@ -1557,15 +1596,50 @@ func (t *txLookup) Get(hash common.Hash) *types.Transaction { t.lock.RLock() defer t.lock.RUnlock() - return t.all[hash] + if tx := t.locals[hash]; tx != nil { + return tx + } + return t.remotes[hash] } -// Count returns the current number of items in the lookup. +// GetLocal returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetLocal(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.locals[hash] +} + +// GetRemote returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetRemote(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.remotes[hash] +} + +// Count returns the current number of transactions in the lookup. func (t *txLookup) Count() int { t.lock.RLock() defer t.lock.RUnlock() - return len(t.all) + return len(t.locals) + len(t.remotes) +} + +// LocalCount returns the current number of local transactions in the lookup. +func (t *txLookup) LocalCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.locals) +} + +// RemoteCount returns the current number of remote transactions in the lookup. +func (t *txLookup) RemoteCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.remotes) } // Slots returns the current number of slots used in the lookup. @@ -1577,14 +1651,18 @@ func (t *txLookup) Slots() int { } // Add adds a transaction to the lookup. -func (t *txLookup) Add(tx *types.Transaction) { +func (t *txLookup) Add(tx *types.Transaction, local bool) { t.lock.Lock() defer t.lock.Unlock() t.slots += numSlots(tx) slotsGauge.Update(int64(t.slots)) - t.all[tx.Hash()] = tx + if local { + t.locals[tx.Hash()] = tx + } else { + t.remotes[tx.Hash()] = tx + } } // Remove removes a transaction from the lookup. @@ -1592,10 +1670,36 @@ func (t *txLookup) Remove(hash common.Hash) { t.lock.Lock() defer t.lock.Unlock() - t.slots -= numSlots(t.all[hash]) + tx, ok := t.locals[hash] + if !ok { + tx, ok = t.remotes[hash] + } + if !ok { + log.Error("No transaction found to be deleted", "hash", hash) + return + } + t.slots -= numSlots(tx) slotsGauge.Update(int64(t.slots)) - delete(t.all, hash) + delete(t.locals, hash) + delete(t.remotes, hash) +} + +// RemoteToLocals migrates the transactions belongs to the given locals to locals +// set. The assumption is held the locals set is thread-safe to be used. +func (t *txLookup) RemoteToLocals(locals *accountSet) int { + t.lock.Lock() + defer t.lock.Unlock() + + var migrated int + for hash, tx := range t.remotes { + if locals.containsTx(tx) { + t.locals[hash] = tx + delete(t.remotes, hash) + migrated += 1 + } + } + return migrated } // numSlots calculates the number of slots needed for a single transaction. diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 4fca734e656e..47d3830b0616 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -107,10 +107,11 @@ func validateTxPoolInternals(pool *TxPool) error { if total := pool.all.Count(); total != pending+queued { return fmt.Errorf("total transaction count %d != %d pending + %d queued", total, pending, queued) } - if priced := pool.priced.items.Len() - pool.priced.stales; priced != pending+queued { - return fmt.Errorf("total priced transaction count %d != %d pending + %d queued", priced, pending, queued) + pool.priced.Reheap() + priced, remote := pool.priced.remotes.Len(), pool.all.RemoteCount() + if priced != remote { + return fmt.Errorf("total priced transaction count %d != %d", priced, remote) } - // Ensure the next nonce to assign is the correct one for addr, txs := range pool.pending { // Find the last transaction @@ -242,7 +243,7 @@ func TestInvalidTransactions(t *testing.T) { from, _ := deriveSender(tx) pool.currentState.AddBalance(from, big.NewInt(1)) - if err := pool.AddRemote(tx); err != ErrInsufficientFunds { + if err := pool.AddRemote(tx); !errors.Is(err, ErrInsufficientFunds) { t.Error("expected", ErrInsufficientFunds) } @@ -255,7 +256,7 @@ func TestInvalidTransactions(t *testing.T) { pool.currentState.SetNonce(from, 1) pool.currentState.AddBalance(from, big.NewInt(0xffffffffffffff)) tx = transaction(0, 100000, key) - if err := pool.AddRemote(tx); err != ErrNonceTooLow { + if err := pool.AddRemote(tx); !errors.Is(err, ErrNonceTooLow) { t.Error("expected", ErrNonceTooLow) } @@ -280,7 +281,7 @@ func TestTransactionQueue(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) <-pool.requestReset(nil, nil) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if len(pool.pending) != 1 { t.Error("expected valid txs to be 1 is", len(pool.pending)) @@ -289,7 +290,7 @@ func TestTransactionQueue(t *testing.T) { tx = transaction(1, 100, key) from, _ = deriveSender(tx) pool.currentState.SetNonce(from, 2) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if _, ok := pool.pending[from].txs.items[tx.Nonce()]; ok { @@ -313,9 +314,9 @@ func TestTransactionQueue2(t *testing.T) { pool.currentState.AddBalance(from, big.NewInt(1000)) pool.reset(nil, nil) - pool.enqueueTx(tx1.Hash(), tx1) - pool.enqueueTx(tx2.Hash(), tx2) - pool.enqueueTx(tx3.Hash(), tx3) + pool.enqueueTx(tx1.Hash(), tx1, false, true) + pool.enqueueTx(tx2.Hash(), tx2, false, true) + pool.enqueueTx(tx3.Hash(), tx3, false, true) pool.promoteExecutables([]common.Address{from}) if len(pool.pending) != 1 { @@ -488,12 +489,21 @@ func TestTransactionDropping(t *testing.T) { tx11 = transaction(11, 200, key) tx12 = transaction(12, 300, key) ) + pool.all.Add(tx0, false) + pool.priced.Put(tx0, false) pool.promoteTx(account, tx0.Hash(), tx0) + + pool.all.Add(tx1, false) + pool.priced.Put(tx1, false) pool.promoteTx(account, tx1.Hash(), tx1) + + pool.all.Add(tx2, false) + pool.priced.Put(tx2, false) pool.promoteTx(account, tx2.Hash(), tx2) - pool.enqueueTx(tx10.Hash(), tx10) - pool.enqueueTx(tx11.Hash(), tx11) - pool.enqueueTx(tx12.Hash(), tx12) + + pool.enqueueTx(tx10.Hash(), tx10, false, true) + pool.enqueueTx(tx11.Hash(), tx11, false, true) + pool.enqueueTx(tx12.Hash(), tx12, false, true) // Check that pre and post validations leave the pool as is if pool.pending[account].Len() != 3 { @@ -1139,7 +1149,7 @@ func TestTransactionAllowedTxSize(t *testing.T) { t.Fatalf("expected rejection on slightly oversize transaction") } // Try adding a transaction of random not allowed size - if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, dataSize+1+uint64(rand.Intn(int(10*txMaxSize))))); err == nil { + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, dataSize+1+uint64(rand.Intn(10*txMaxSize)))); err == nil { t.Fatalf("expected rejection on oversize transaction") } // Run some sanity checks on the pool internals @@ -1964,7 +1974,7 @@ func benchmarkFuturePromotion(b *testing.B, size int) { for i := 0; i < size; i++ { tx := transaction(uint64(1+i), 100000, key) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) } // Benchmark the speed of pool validation b.ResetTimer() @@ -2007,3 +2017,38 @@ func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) { } } } + +func BenchmarkInsertRemoteWithAllLocals(b *testing.B) { + // Allocate keys for testing + key, _ := crypto.GenerateKey() + account := crypto.PubkeyToAddress(key.PublicKey) + + remoteKey, _ := crypto.GenerateKey() + remoteAddr := crypto.PubkeyToAddress(remoteKey.PublicKey) + + locals := make([]*types.Transaction, 4096+1024) // Occupy all slots + for i := 0; i < len(locals); i++ { + locals[i] = transaction(uint64(i), 100000, key) + } + remotes := make([]*types.Transaction, 1000) + for i := 0; i < len(remotes); i++ { + remotes[i] = pricedTransaction(uint64(i), 100000, big.NewInt(2), remoteKey) // Higher gasprice + } + // Benchmark importing the transactions into the queue + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + pool, _ := setupTxPool() + pool.currentState.AddBalance(account, big.NewInt(100000000)) + for _, local := range locals { + pool.AddLocal(local) + } + b.StartTimer() + // Assign a high enough balance for testing + pool.currentState.AddBalance(remoteAddr, big.NewInt(100000000)) + for i := 0; i < len(remotes); i++ { + pool.AddRemotes([]*types.Transaction{remotes[i]}) + } + pool.Stop() + } +} diff --git a/core/types/bloom9.go b/core/types/bloom9.go index d045c9e6678e..1793c2adc73c 100644 --- a/core/types/bloom9.go +++ b/core/types/bloom9.go @@ -17,6 +17,7 @@ package types import ( + "encoding/binary" "fmt" "math/big" @@ -57,28 +58,36 @@ func (b *Bloom) SetBytes(d []byte) { } // Add adds d to the filter. Future calls of Test(d) will return true. -func (b *Bloom) Add(d *big.Int) { - bin := new(big.Int).SetBytes(b[:]) - bin.Or(bin, bloom9(d.Bytes())) - b.SetBytes(bin.Bytes()) +func (b *Bloom) Add(d []byte) { + b.add(d, make([]byte, 6)) +} + +// add is internal version of Add, which takes a scratch buffer for reuse (needs to be at least 6 bytes) +func (b *Bloom) add(d []byte, buf []byte) { + i1, v1, i2, v2, i3, v3 := bloomValues(d, buf) + b[i1] |= v1 + b[i2] |= v2 + b[i3] |= v3 } // Big converts b to a big integer. +// Note: Converting a bloom filter to a big.Int and then calling GetBytes +// does not return the same bytes, since big.Int will trim leading zeroes func (b Bloom) Big() *big.Int { return new(big.Int).SetBytes(b[:]) } +// Bytes returns the backing byte slice of the bloom func (b Bloom) Bytes() []byte { return b[:] } -func (b Bloom) Test(test *big.Int) bool { - return BloomLookup(b, test) -} - -func (b Bloom) TestBytes(test []byte) bool { - return b.Test(new(big.Int).SetBytes(test)) - +// Test checks if the given topic is present in the bloom filter +func (b Bloom) Test(topic []byte) bool { + i1, v1, i2, v2, i3, v3 := bloomValues(topic, make([]byte, 6)) + return v1 == v1&b[i1] && + v2 == v2&b[i2] && + v3 == v3&b[i3] } // MarshalText encodes b as a hex string with 0x prefix. @@ -91,46 +100,61 @@ func (b *Bloom) UnmarshalText(input []byte) error { return hexutil.UnmarshalFixedText("Bloom", input, b[:]) } +// CreateBloom creates a bloom filter out of the give Receipts (+Logs) func CreateBloom(receipts Receipts) Bloom { - bin := new(big.Int) + buf := make([]byte, 6) + var bin Bloom for _, receipt := range receipts { - bin.Or(bin, LogsBloom(receipt.Logs)) + for _, log := range receipt.Logs { + bin.add(log.Address.Bytes(), buf) + for _, b := range log.Topics { + bin.add(b[:], buf) + } + } } - - return BytesToBloom(bin.Bytes()) + return bin } -func LogsBloom(logs []*Log) *big.Int { - bin := new(big.Int) +// LogsBloom returns the bloom bytes for the given logs +func LogsBloom(logs []*Log) []byte { + buf := make([]byte, 6) + var bin Bloom for _, log := range logs { - bin.Or(bin, bloom9(log.Address.Bytes())) + bin.add(log.Address.Bytes(), buf) for _, b := range log.Topics { - bin.Or(bin, bloom9(b[:])) + bin.add(b[:], buf) } } - - return bin + return bin[:] } -func bloom9(b []byte) *big.Int { - b = crypto.Keccak256(b) - - r := new(big.Int) - - for i := 0; i < 6; i += 2 { - t := big.NewInt(1) - b := (uint(b[i+1]) + (uint(b[i]) << 8)) & 2047 - r.Or(r, t.Lsh(t, b)) - } - - return r +// Bloom9 returns the bloom filter for the given data +func Bloom9(data []byte) []byte { + var b Bloom + b.SetBytes(data) + return b.Bytes() } -var Bloom9 = bloom9 +// bloomValues returns the bytes (index-value pairs) to set for the given data +func bloomValues(data []byte, hashbuf []byte) (uint, byte, uint, byte, uint, byte) { + sha := hasherPool.Get().(crypto.KeccakState) + sha.Reset() + sha.Write(data) + sha.Read(hashbuf) + hasherPool.Put(sha) + // The actual bits to flip + v1 := byte(1 << (hashbuf[1] & 0x7)) + v2 := byte(1 << (hashbuf[3] & 0x7)) + v3 := byte(1 << (hashbuf[5] & 0x7)) + // The indices for the bytes to OR in + i1 := BloomByteLength - uint((binary.BigEndian.Uint16(hashbuf)&0x7ff)>>3) - 1 + i2 := BloomByteLength - uint((binary.BigEndian.Uint16(hashbuf[2:])&0x7ff)>>3) - 1 + i3 := BloomByteLength - uint((binary.BigEndian.Uint16(hashbuf[4:])&0x7ff)>>3) - 1 + + return i1, v1, i2, v2, i3, v3 +} +// BloomLookup is a convenience-method to check presence int he bloom filter func BloomLookup(bin Bloom, topic bytesBacked) bool { - bloom := bin.Big() - cmp := bloom9(topic.Bytes()) - - return bloom.And(bloom, cmp).Cmp(cmp) == 0 + return bin.Test(topic.Bytes()) } diff --git a/core/types/bloom9_test.go b/core/types/bloom9_test.go index a28ac0e7afba..893df486dd1b 100644 --- a/core/types/bloom9_test.go +++ b/core/types/bloom9_test.go @@ -17,8 +17,12 @@ package types import ( + "fmt" "math/big" "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" ) func TestBloom(t *testing.T) { @@ -35,47 +39,118 @@ func TestBloom(t *testing.T) { var bloom Bloom for _, data := range positive { - bloom.Add(new(big.Int).SetBytes([]byte(data))) + bloom.Add([]byte(data)) } for _, data := range positive { - if !bloom.TestBytes([]byte(data)) { + if !bloom.Test([]byte(data)) { t.Error("expected", data, "to test true") } } for _, data := range negative { - if bloom.TestBytes([]byte(data)) { + if bloom.Test([]byte(data)) { t.Error("did not expect", data, "to test true") } } } -/* -import ( - "testing" - - "github.com/ethereum/go-ethereum/core/state" -) +// TestBloomExtensively does some more thorough tests +func TestBloomExtensively(t *testing.T) { + var exp = common.HexToHash("c8d3ca65cdb4874300a9e39475508f23ed6da09fdbc487f89a2dcf50b09eb263") + var b Bloom + // Add 100 "random" things + for i := 0; i < 100; i++ { + data := fmt.Sprintf("xxxxxxxxxx data %d yyyyyyyyyyyyyy", i) + b.Add([]byte(data)) + //b.Add(new(big.Int).SetBytes([]byte(data))) + } + got := crypto.Keccak256Hash(b.Bytes()) + if got != exp { + t.Errorf("Got %x, exp %x", got, exp) + } + var b2 Bloom + b2.SetBytes(b.Bytes()) + got2 := crypto.Keccak256Hash(b2.Bytes()) + if got != got2 { + t.Errorf("Got %x, exp %x", got, got2) + } +} -func TestBloom9(t *testing.T) { - testCase := []byte("testtest") - bin := LogsBloom([]state.Log{ - {testCase, [][]byte{[]byte("hellohello")}, nil}, - }).Bytes() - res := BloomLookup(bin, testCase) +func BenchmarkBloom9(b *testing.B) { + test := []byte("testestestest") + for i := 0; i < b.N; i++ { + Bloom9(test) + } +} - if !res { - t.Errorf("Bloom lookup failed") +func BenchmarkBloom9Lookup(b *testing.B) { + toTest := []byte("testtest") + bloom := new(Bloom) + for i := 0; i < b.N; i++ { + bloom.Test(toTest) } } +func BenchmarkCreateBloom(b *testing.B) { -func TestAddress(t *testing.T) { - block := &Block{} - block.Coinbase = common.Hex2Bytes("22341ae42d6dd7384bc8584e50419ea3ac75b83f") - fmt.Printf("%x\n", crypto.Keccak256(block.Coinbase)) + var txs = Transactions{ + NewContractCreation(1, big.NewInt(1), 1, big.NewInt(1), nil), + NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil), + } + var rSmall = Receipts{ + &Receipt{ + Status: ReceiptStatusFailed, + CumulativeGasUsed: 1, + Logs: []*Log{ + {Address: common.BytesToAddress([]byte{0x11})}, + {Address: common.BytesToAddress([]byte{0x01, 0x11})}, + }, + TxHash: txs[0].Hash(), + ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), + GasUsed: 1, + }, + &Receipt{ + PostState: common.Hash{2}.Bytes(), + CumulativeGasUsed: 3, + Logs: []*Log{ + {Address: common.BytesToAddress([]byte{0x22})}, + {Address: common.BytesToAddress([]byte{0x02, 0x22})}, + }, + TxHash: txs[1].Hash(), + ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}), + GasUsed: 2, + }, + } - bin := CreateBloom(block) - fmt.Printf("bin = %x\n", common.LeftPadBytes(bin, 64)) + var rLarge = make(Receipts, 200) + // Fill it with 200 receipts x 2 logs + for i := 0; i < 200; i += 2 { + copy(rLarge[i:], rSmall) + } + b.Run("small", func(b *testing.B) { + b.ReportAllocs() + var bl Bloom + for i := 0; i < b.N; i++ { + bl = CreateBloom(rSmall) + } + b.StopTimer() + var exp = common.HexToHash("c384c56ece49458a427c67b90fefe979ebf7104795be65dc398b280f24104949") + got := crypto.Keccak256Hash(bl.Bytes()) + if got != exp { + b.Errorf("Got %x, exp %x", got, exp) + } + }) + b.Run("large", func(b *testing.B) { + b.ReportAllocs() + var bl Bloom + for i := 0; i < b.N; i++ { + bl = CreateBloom(rLarge) + } + b.StopTimer() + var exp = common.HexToHash("c384c56ece49458a427c67b90fefe979ebf7104795be65dc398b280f24104949") + got := crypto.Keccak256Hash(bl.Bytes()) + if got != exp { + b.Errorf("Got %x, exp %x", got, exp) + } + }) } -*/ diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 7d40c7f66085..51a10f3f3da7 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -17,13 +17,10 @@ package types import ( - "bytes" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" ) -// DerivableList is the interface which can derive the hash. type DerivableList interface { Len() int GetRlp(i int) []byte @@ -38,11 +35,24 @@ type Hasher interface { func DeriveSha(list DerivableList, hasher Hasher) common.Hash { hasher.Reset() - keybuf := new(bytes.Buffer) - for i := 0; i < list.Len(); i++ { - keybuf.Reset() - rlp.Encode(keybuf, uint(i)) - hasher.Update(keybuf.Bytes(), list.GetRlp(i)) + + // StackTrie requires values to be inserted in increasing + // hash order, which is not the order that `list` provides + // hashes in. This insertion sequence ensures that the + // order is correct. + + var buf []byte + for i := 1; i < list.Len() && i <= 0x7f; i++ { + buf = rlp.AppendUint64(buf[:0], uint64(i)) + hasher.Update(buf, list.GetRlp(i)) + } + if list.Len() > 0 { + buf = rlp.AppendUint64(buf[:0], 0) + hasher.Update(buf, list.GetRlp(0)) + } + for i := 0x80; i < list.Len(); i++ { + buf = rlp.AppendUint64(buf[:0], uint64(i)) + hasher.Update(buf, list.GetRlp(i)) } return hasher.Hash() } diff --git a/core/types/transaction.go b/core/types/transaction.go index ec19744881cc..177bfbb70b55 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -213,7 +213,7 @@ func (tx *Transaction) Hash() common.Hash { } // Size returns the true RLP encoded storage size of the transaction, either by -// encoding and returning it, or returning a previsouly cached value. +// encoding and returning it, or returning a previously cached value. func (tx *Transaction) Size() common.StorageSize { if size := tx.size.Load(); size != nil { return size.(common.StorageSize) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 8930a06266b8..0c828b1cd9c2 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -58,7 +58,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{2}): &sha256hash{}, common.BytesToAddress([]byte{3}): &ripemd160hash{}, common.BytesToAddress([]byte{4}): &dataCopy{}, - common.BytesToAddress([]byte{5}): &bigModExp{}, + common.BytesToAddress([]byte{5}): &bigModExp{eip2565: false}, common.BytesToAddress([]byte{6}): &bn256AddByzantium{}, common.BytesToAddress([]byte{7}): &bn256ScalarMulByzantium{}, common.BytesToAddress([]byte{8}): &bn256PairingByzantium{}, @@ -71,21 +71,21 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{2}): &sha256hash{}, common.BytesToAddress([]byte{3}): &ripemd160hash{}, common.BytesToAddress([]byte{4}): &dataCopy{}, - common.BytesToAddress([]byte{5}): &bigModExp{}, + common.BytesToAddress([]byte{5}): &bigModExp{eip2565: false}, common.BytesToAddress([]byte{6}): &bn256AddIstanbul{}, common.BytesToAddress([]byte{7}): &bn256ScalarMulIstanbul{}, common.BytesToAddress([]byte{8}): &bn256PairingIstanbul{}, common.BytesToAddress([]byte{9}): &blake2F{}, } -// PrecompiledContractsYoloV1 contains the default set of pre-compiled Ethereum -// contracts used in the Yolo v1 test release. -var PrecompiledContractsYoloV1 = map[common.Address]PrecompiledContract{ +// PrecompiledContractsYoloV2 contains the default set of pre-compiled Ethereum +// contracts used in the Yolo v2 test release. +var PrecompiledContractsYoloV2 = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{1}): &ecrecover{}, common.BytesToAddress([]byte{2}): &sha256hash{}, common.BytesToAddress([]byte{3}): &ripemd160hash{}, common.BytesToAddress([]byte{4}): &dataCopy{}, - common.BytesToAddress([]byte{5}): &bigModExp{}, + common.BytesToAddress([]byte{5}): &bigModExp{eip2565: false}, common.BytesToAddress([]byte{6}): &bn256AddIstanbul{}, common.BytesToAddress([]byte{7}): &bn256ScalarMulIstanbul{}, common.BytesToAddress([]byte{8}): &bn256PairingIstanbul{}, @@ -101,6 +101,28 @@ var PrecompiledContractsYoloV1 = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{18}): &bls12381MapG2{}, } +var ( + PrecompiledAddressesYoloV2 []common.Address + PrecompiledAddressesIstanbul []common.Address + PrecompiledAddressesByzantium []common.Address + PrecompiledAddressesHomestead []common.Address +) + +func init() { + for k := range PrecompiledContractsHomestead { + PrecompiledAddressesHomestead = append(PrecompiledAddressesHomestead, k) + } + for k := range PrecompiledContractsByzantium { + PrecompiledAddressesHomestead = append(PrecompiledAddressesByzantium, k) + } + for k := range PrecompiledContractsIstanbul { + PrecompiledAddressesIstanbul = append(PrecompiledAddressesIstanbul, k) + } + for k := range PrecompiledContractsYoloV2 { + PrecompiledAddressesYoloV2 = append(PrecompiledAddressesYoloV2, k) + } +} + // RunPrecompiledContract runs and evaluates the output of a precompiled contract. // It returns // - the returned bytes, @@ -200,14 +222,19 @@ func (c *dataCopy) Run(in []byte) ([]byte, error) { } // bigModExp implements a native big integer exponential modular operation. -type bigModExp struct{} +type bigModExp struct { + eip2565 bool +} var ( big0 = big.NewInt(0) big1 = big.NewInt(1) + big3 = big.NewInt(3) big4 = big.NewInt(4) + big7 = big.NewInt(7) big8 = big.NewInt(8) big16 = big.NewInt(16) + big20 = big.NewInt(20) big32 = big.NewInt(32) big64 = big.NewInt(64) big96 = big.NewInt(96) @@ -217,6 +244,34 @@ var ( big199680 = big.NewInt(199680) ) +// modexpMultComplexity implements bigModexp multComplexity formula, as defined in EIP-198 +// +// def mult_complexity(x): +// if x <= 64: return x ** 2 +// elif x <= 1024: return x ** 2 // 4 + 96 * x - 3072 +// else: return x ** 2 // 16 + 480 * x - 199680 +// +// where is x is max(length_of_MODULUS, length_of_BASE) +func modexpMultComplexity(x *big.Int) *big.Int { + switch { + case x.Cmp(big64) <= 0: + x.Mul(x, x) // x ** 2 + case x.Cmp(big1024) <= 0: + // (x ** 2 // 4 ) + ( 96 * x - 3072) + x = new(big.Int).Add( + new(big.Int).Div(new(big.Int).Mul(x, x), big4), + new(big.Int).Sub(new(big.Int).Mul(big96, x), big3072), + ) + default: + // (x ** 2 // 16) + (480 * x - 199680) + x = new(big.Int).Add( + new(big.Int).Div(new(big.Int).Mul(x, x), big16), + new(big.Int).Sub(new(big.Int).Mul(big480, x), big199680), + ) + } + return x +} + // RequiredGas returns the gas required to execute the pre-compiled contract. func (c *bigModExp) RequiredGas(input []byte) uint64 { var ( @@ -251,25 +306,36 @@ func (c *bigModExp) RequiredGas(input []byte) uint64 { adjExpLen.Mul(big8, adjExpLen) } adjExpLen.Add(adjExpLen, big.NewInt(int64(msb))) - // Calculate the gas cost of the operation gas := new(big.Int).Set(math.BigMax(modLen, baseLen)) - switch { - case gas.Cmp(big64) <= 0: + if c.eip2565 { + // EIP-2565 has three changes + // 1. Different multComplexity (inlined here) + // in EIP-2565 (https://eips.ethereum.org/EIPS/eip-2565): + // + // def mult_complexity(x): + // ceiling(x/8)^2 + // + //where is x is max(length_of_MODULUS, length_of_BASE) + gas = gas.Add(gas, big7) + gas = gas.Div(gas, big8) gas.Mul(gas, gas) - case gas.Cmp(big1024) <= 0: - gas = new(big.Int).Add( - new(big.Int).Div(new(big.Int).Mul(gas, gas), big4), - new(big.Int).Sub(new(big.Int).Mul(big96, gas), big3072), - ) - default: - gas = new(big.Int).Add( - new(big.Int).Div(new(big.Int).Mul(gas, gas), big16), - new(big.Int).Sub(new(big.Int).Mul(big480, gas), big199680), - ) + + gas.Mul(gas, math.BigMax(adjExpLen, big1)) + // 2. Different divisor (`GQUADDIVISOR`) (3) + gas.Div(gas, big3) + if gas.BitLen() > 64 { + return math.MaxUint64 + } + // 3. Minimum price of 200 gas + if gas.Uint64() < 200 { + return 200 + } + return gas.Uint64() } + gas = modexpMultComplexity(gas) gas.Mul(gas, math.BigMax(adjExpLen, big1)) - gas.Div(gas, new(big.Int).SetUint64(params.ModExpQuadCoeffDiv)) + gas.Div(gas, big20) if gas.BitLen() > 64 { return math.MaxUint64 diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 6320875e1aeb..30d9b49f719b 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -43,7 +43,29 @@ type precompiledFailureTest struct { Name string } -var allPrecompiles = PrecompiledContractsYoloV1 +// allPrecompiles does not map to the actual set of precompiles, as it also contains +// repriced versions of precompiles at certain slots +var allPrecompiles = map[common.Address]PrecompiledContract{ + common.BytesToAddress([]byte{1}): &ecrecover{}, + common.BytesToAddress([]byte{2}): &sha256hash{}, + common.BytesToAddress([]byte{3}): &ripemd160hash{}, + common.BytesToAddress([]byte{4}): &dataCopy{}, + common.BytesToAddress([]byte{5}): &bigModExp{eip2565: false}, + common.BytesToAddress([]byte{0xf5}): &bigModExp{eip2565: true}, + common.BytesToAddress([]byte{6}): &bn256AddIstanbul{}, + common.BytesToAddress([]byte{7}): &bn256ScalarMulIstanbul{}, + common.BytesToAddress([]byte{8}): &bn256PairingIstanbul{}, + common.BytesToAddress([]byte{9}): &blake2F{}, + common.BytesToAddress([]byte{10}): &bls12381G1Add{}, + common.BytesToAddress([]byte{11}): &bls12381G1Mul{}, + common.BytesToAddress([]byte{12}): &bls12381G1MultiExp{}, + common.BytesToAddress([]byte{13}): &bls12381G2Add{}, + common.BytesToAddress([]byte{14}): &bls12381G2Mul{}, + common.BytesToAddress([]byte{15}): &bls12381G2MultiExp{}, + common.BytesToAddress([]byte{16}): &bls12381Pairing{}, + common.BytesToAddress([]byte{17}): &bls12381MapG1{}, + common.BytesToAddress([]byte{18}): &bls12381MapG2{}, +} // EIP-152 test vectors var blake2FMalformedInputTests = []precompiledFailureTest{ @@ -213,6 +235,9 @@ func BenchmarkPrecompiledIdentity(bench *testing.B) { func TestPrecompiledModExp(t *testing.T) { testJson("modexp", "05", t) } func BenchmarkPrecompiledModExp(b *testing.B) { benchJson("modexp", "05", b) } +func TestPrecompiledModExpEip2565(t *testing.T) { testJson("modexp_eip2565", "f5", t) } +func BenchmarkPrecompiledModExpEip2565(b *testing.B) { benchJson("modexp_eip2565", "f5", b) } + // Tests the sample inputs from the elliptic curve addition EIP 213. func TestPrecompiledBn256Add(t *testing.T) { testJson("bn256Add", "06", t) } func BenchmarkPrecompiledBn256Add(b *testing.B) { benchJson("bn256Add", "06", b) } diff --git a/core/vm/eips.go b/core/vm/eips.go index 6b5ba62aade1..962c0f14b162 100644 --- a/core/vm/eips.go +++ b/core/vm/eips.go @@ -25,6 +25,7 @@ import ( ) var activators = map[int]func(*JumpTable){ + 2929: enable2929, 2200: enable2200, 1884: enable1884, 1344: enable1344, @@ -134,3 +135,41 @@ func enable2315(jt *JumpTable) { jumps: true, } } + +// enable2929 enables "EIP-2929: Gas cost increases for state access opcodes" +// https://eips.ethereum.org/EIPS/eip-2929 +func enable2929(jt *JumpTable) { + jt[SSTORE].dynamicGas = gasSStoreEIP2929 + + jt[SLOAD].constantGas = 0 + jt[SLOAD].dynamicGas = gasSLoadEIP2929 + + jt[EXTCODECOPY].constantGas = WarmStorageReadCostEIP2929 + jt[EXTCODECOPY].dynamicGas = gasExtCodeCopyEIP2929 + + jt[EXTCODESIZE].constantGas = WarmStorageReadCostEIP2929 + jt[EXTCODESIZE].dynamicGas = gasEip2929AccountCheck + + jt[EXTCODEHASH].constantGas = WarmStorageReadCostEIP2929 + jt[EXTCODEHASH].dynamicGas = gasEip2929AccountCheck + + jt[BALANCE].constantGas = WarmStorageReadCostEIP2929 + jt[BALANCE].dynamicGas = gasEip2929AccountCheck + + jt[CALL].constantGas = WarmStorageReadCostEIP2929 + jt[CALL].dynamicGas = gasCallEIP2929 + + jt[CALLCODE].constantGas = WarmStorageReadCostEIP2929 + jt[CALLCODE].dynamicGas = gasCallCodeEIP2929 + + jt[STATICCALL].constantGas = WarmStorageReadCostEIP2929 + jt[STATICCALL].dynamicGas = gasStaticCallEIP2929 + + jt[DELEGATECALL].constantGas = WarmStorageReadCostEIP2929 + jt[DELEGATECALL].dynamicGas = gasDelegateCallEIP2929 + + // This was previously part of the dynamic cost, but we're using it as a constantGas + // factor here + jt[SELFDESTRUCT].constantGas = params.SelfdestructGasEIP150 + jt[SELFDESTRUCT].dynamicGas = gasSelfdestructEIP2929 +} diff --git a/core/vm/evm.go b/core/vm/evm.go index f5469c500c8a..9afef7664366 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -42,11 +42,26 @@ type ( GetHashFunc func(uint64) common.Hash ) +// ActivePrecompiles returns the addresses of the precompiles enabled with the current +// configuration +func (evm *EVM) ActivePrecompiles() []common.Address { + switch { + case evm.chainRules.IsYoloV2: + return PrecompiledAddressesYoloV2 + case evm.chainRules.IsIstanbul: + return PrecompiledAddressesIstanbul + case evm.chainRules.IsByzantium: + return PrecompiledAddressesByzantium + default: + return PrecompiledAddressesHomestead + } +} + func (evm *EVM) precompile(addr common.Address) (PrecompiledContract, bool) { var precompiles map[common.Address]PrecompiledContract switch { - case evm.chainRules.IsYoloV1: - precompiles = PrecompiledContractsYoloV1 + case evm.chainRules.IsYoloV2: + precompiles = PrecompiledContractsYoloV2 case evm.chainRules.IsIstanbul: precompiles = PrecompiledContractsIstanbul case evm.chainRules.IsByzantium: @@ -76,9 +91,9 @@ func run(evm *EVM, contract *Contract, input []byte, readOnly bool) ([]byte, err return nil, errors.New("no compatible interpreter") } -// Context provides the EVM with auxiliary information. Once provided +// BlockContext provides the EVM with auxiliary information. Once provided // it shouldn't be modified. -type Context struct { +type BlockContext struct { // CanTransfer returns whether the account contains // sufficient ether to transfer the value CanTransfer CanTransferFunc @@ -87,10 +102,6 @@ type Context struct { // GetHash returns the hash corresponding to n GetHash GetHashFunc - // Message information - Origin common.Address // Provides information for ORIGIN - GasPrice *big.Int // Provides information for GASPRICE - // Block information Coinbase common.Address // Provides information for COINBASE GasLimit uint64 // Provides information for GASLIMIT @@ -99,6 +110,14 @@ type Context struct { Difficulty *big.Int // Provides information for DIFFICULTY } +// TxContext provides the EVM with information about a transaction. +// All fields can change between transactions. +type TxContext struct { + // Message information + Origin common.Address // Provides information for ORIGIN + GasPrice *big.Int // Provides information for GASPRICE +} + // EVM is the Ethereum Virtual Machine base object and provides // the necessary tools to run a contract on the given state with // the provided context. It should be noted that any error @@ -110,7 +129,8 @@ type Context struct { // The EVM should never be reused and is not thread safe. type EVM struct { // Context provides auxiliary blockchain related information - Context + Context BlockContext + TxContext // StateDB gives access to the underlying state StateDB StateDB // Depth is the current call stack @@ -138,17 +158,18 @@ type EVM struct { // NewEVM returns a new EVM. The returned EVM is not thread safe and should // only ever be used *once*. -func NewEVM(ctx Context, statedb StateDB, chainConfig *params.ChainConfig, vmConfig Config) *EVM { +func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig *params.ChainConfig, vmConfig Config) *EVM { evm := &EVM{ - Context: ctx, + Context: blockCtx, + TxContext: txCtx, StateDB: statedb, vmConfig: vmConfig, chainConfig: chainConfig, - chainRules: chainConfig.Rules(ctx.BlockNumber), + chainRules: chainConfig.Rules(blockCtx.BlockNumber), interpreters: make([]Interpreter, 0, 1), } - if chainConfig.IsEWASM(ctx.BlockNumber) { + if chainConfig.IsEWASM(blockCtx.BlockNumber) { // to be implemented by EVM-C and Wagon PRs. // if vmConfig.EWASMInterpreter != "" { // extIntOpts := strings.Split(vmConfig.EWASMInterpreter, ":") @@ -172,6 +193,13 @@ func NewEVM(ctx Context, statedb StateDB, chainConfig *params.ChainConfig, vmCon return evm } +// Reset resets the EVM with a new transaction context.Reset +// This is not threadsafe and should only be done very cautiously. +func (evm *EVM) Reset(txCtx TxContext, statedb StateDB) { + evm.TxContext = txCtx + evm.StateDB = statedb +} + // Cancel cancels any running EVM operation. This may be called concurrently and // it's safe to be called multiple times. func (evm *EVM) Cancel() { @@ -218,7 +246,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas } evm.StateDB.CreateAccount(addr) } - evm.Transfer(evm.StateDB, caller.Address(), addr, value) + evm.Context.Transfer(evm.StateDB, caller.Address(), addr, value) // Capture the tracer start/end events in debug mode if evm.vmConfig.Debug && evm.depth == 0 { @@ -411,12 +439,16 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, if evm.depth > int(params.CallCreateDepth) { return nil, common.Address{}, gas, ErrDepth } - if !evm.CanTransfer(evm.StateDB, caller.Address(), value) { + if !evm.Context.CanTransfer(evm.StateDB, caller.Address(), value) { return nil, common.Address{}, gas, ErrInsufficientBalance } nonce := evm.StateDB.GetNonce(caller.Address()) evm.StateDB.SetNonce(caller.Address(), nonce+1) - + // We add this to the access list _before_ taking a snapshot. Even if the creation fails, + // the access-list change should not be rolled back + if evm.chainRules.IsYoloV2 { + evm.StateDB.AddAddressToAccessList(address) + } // Ensure there's no existing contract already at the designated address contractHash := evm.StateDB.GetCodeHash(address) if evm.StateDB.GetNonce(address) != 0 || (contractHash != (common.Hash{}) && contractHash != emptyCodeHash) { @@ -428,7 +460,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, if evm.chainRules.IsEIP158 { evm.StateDB.SetNonce(address, 1) } - evm.Transfer(evm.StateDB, caller.Address(), address, value) + evm.Context.Transfer(evm.StateDB, caller.Address(), address, value) // Initialise a new contract and set the code that is to be used by the EVM. // The contract is a scoped environment for this execution context only. @@ -493,7 +525,7 @@ func (evm *EVM) Create(caller ContractRef, code []byte, gas uint64, value *big.I // instead of the usual sender-and-nonce-hash as the address where the contract is initialized at. func (evm *EVM) Create2(caller ContractRef, code []byte, gas uint64, endowment *big.Int, salt *uint256.Int) (ret []byte, contractAddr common.Address, leftOverGas uint64, err error) { codeAndHash := &codeAndHash{code: code} - contractAddr = crypto.CreateAddress2(caller.Address(), common.Hash(salt.Bytes32()), codeAndHash.Hash().Bytes()) + contractAddr = crypto.CreateAddress2(caller.Address(), salt.Bytes32(), codeAndHash.Hash().Bytes()) return evm.create(caller, codeAndHash, gas, endowment, contractAddr) } diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 6655f9bf42d8..944b6cf0a5df 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -96,7 +96,7 @@ var ( func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { var ( y, x = stack.Back(1), stack.Back(0) - current = evm.StateDB.GetState(contract.Address(), common.Hash(x.Bytes32())) + current = evm.StateDB.GetState(contract.Address(), x.Bytes32()) ) // The legacy gas metering only takes into consideration the current state // Legacy rules should be applied if we are in Petersburg (removal of EIP-1283) @@ -135,7 +135,7 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi if current == value { // noop (1) return params.NetSstoreNoopGas, nil } - original := evm.StateDB.GetCommittedState(contract.Address(), common.Hash(x.Bytes32())) + original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return params.NetSstoreInitGas, nil @@ -163,18 +163,18 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi } // 0. If *gasleft* is less than or equal to 2300, fail the current call. -// 1. If current value equals new value (this is a no-op), SSTORE_NOOP_GAS gas is deducted. +// 1. If current value equals new value (this is a no-op), SLOAD_GAS is deducted. // 2. If current value does not equal new value: // 2.1. If original value equals current value (this storage slot has not been changed by the current execution context): -// 2.1.1. If original value is 0, SSTORE_INIT_GAS gas is deducted. -// 2.1.2. Otherwise, SSTORE_CLEAN_GAS gas is deducted. If new value is 0, add SSTORE_CLEAR_REFUND to refund counter. -// 2.2. If original value does not equal current value (this storage slot is dirty), SSTORE_DIRTY_GAS gas is deducted. Apply both of the following clauses: +// 2.1.1. If original value is 0, SSTORE_SET_GAS (20K) gas is deducted. +// 2.1.2. Otherwise, SSTORE_RESET_GAS gas is deducted. If new value is 0, add SSTORE_CLEARS_SCHEDULE to refund counter. +// 2.2. If original value does not equal current value (this storage slot is dirty), SLOAD_GAS gas is deducted. Apply both of the following clauses: // 2.2.1. If original value is not 0: -// 2.2.1.1. If current value is 0 (also means that new value is not 0), subtract SSTORE_CLEAR_REFUND gas from refund counter. We can prove that refund counter will never go below 0. -// 2.2.1.2. If new value is 0 (also means that current value is not 0), add SSTORE_CLEAR_REFUND gas to refund counter. +// 2.2.1.1. If current value is 0 (also means that new value is not 0), subtract SSTORE_CLEARS_SCHEDULE gas from refund counter. +// 2.2.1.2. If new value is 0 (also means that current value is not 0), add SSTORE_CLEARS_SCHEDULE gas to refund counter. // 2.2.2. If original value equals new value (this storage slot is reset): -// 2.2.2.1. If original value is 0, add SSTORE_INIT_REFUND to refund counter. -// 2.2.2.2. Otherwise, add SSTORE_CLEAN_REFUND gas to refund counter. +// 2.2.2.1. If original value is 0, add SSTORE_SET_GAS - SLOAD_GAS to refund counter. +// 2.2.2.2. Otherwise, add SSTORE_RESET_GAS - SLOAD_GAS gas to refund counter. func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { // If we fail the minimum gas availability invariant, fail (0) if contract.Gas <= params.SstoreSentryGasEIP2200 { @@ -183,38 +183,38 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m // Gas sentry honoured, do the actual gas calculation based on the stored value var ( y, x = stack.Back(1), stack.Back(0) - current = evm.StateDB.GetState(contract.Address(), common.Hash(x.Bytes32())) + current = evm.StateDB.GetState(contract.Address(), x.Bytes32()) ) value := common.Hash(y.Bytes32()) if current == value { // noop (1) - return params.SstoreNoopGasEIP2200, nil + return params.SloadGasEIP2200, nil } - original := evm.StateDB.GetCommittedState(contract.Address(), common.Hash(x.Bytes32())) + original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return params.SstoreInitGasEIP2200, nil + return params.SstoreSetGasEIP2200, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) - evm.StateDB.AddRefund(params.SstoreClearRefundEIP2200) + evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) } - return params.SstoreCleanGasEIP2200, nil // write existing slot (2.1.2) + return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) } if original != (common.Hash{}) { if current == (common.Hash{}) { // recreate slot (2.2.1.1) - evm.StateDB.SubRefund(params.SstoreClearRefundEIP2200) + evm.StateDB.SubRefund(params.SstoreClearsScheduleRefundEIP2200) } else if value == (common.Hash{}) { // delete slot (2.2.1.2) - evm.StateDB.AddRefund(params.SstoreClearRefundEIP2200) + evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) } } if original == value { if original == (common.Hash{}) { // reset to original inexistent slot (2.2.2.1) - evm.StateDB.AddRefund(params.SstoreInitRefundEIP2200) + evm.StateDB.AddRefund(params.SstoreSetGasEIP2200 - params.SloadGasEIP2200) } else { // reset to original existing slot (2.2.2.2) - evm.StateDB.AddRefund(params.SstoreCleanRefundEIP2200) + evm.StateDB.AddRefund(params.SstoreResetGasEIP2200 - params.SloadGasEIP2200) } } - return params.SstoreDirtyGasEIP2200, nil // dirty update (2.2) + return params.SloadGasEIP2200, nil // dirty update (2.2) } func makeGasLog(n uint64) gasFunc { diff --git a/core/vm/gas_table_test.go b/core/vm/gas_table_test.go index 419c9030625d..6cd126c9b4ca 100644 --- a/core/vm/gas_table_test.go +++ b/core/vm/gas_table_test.go @@ -87,11 +87,11 @@ func TestEIP2200(t *testing.T) { statedb.SetState(address, common.Hash{}, common.BytesToHash([]byte{tt.original})) statedb.Finalise(true) // Push the state into the "original" slot - vmctx := Context{ + vmctx := BlockContext{ CanTransfer: func(StateDB, common.Address, *big.Int) bool { return true }, Transfer: func(StateDB, common.Address, common.Address, *big.Int) {}, } - vmenv := NewEVM(vmctx, statedb, params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}}) + vmenv := NewEVM(vmctx, TxContext{}, statedb, params.AllEthashProtocolChanges, Config{ExtraEips: []int{2200}}) _, gas, err := vmenv.Call(AccountRef(common.Address{}), address, nil, tt.gaspool, new(big.Int)) if err != tt.failure { diff --git a/core/vm/gen_structlog.go b/core/vm/gen_structlog.go index ac1a9070c897..44da014de919 100644 --- a/core/vm/gen_structlog.go +++ b/core/vm/gen_structlog.go @@ -24,7 +24,7 @@ func (s StructLog) MarshalJSON() ([]byte, error) { MemorySize int `json:"memSize"` Stack []*math.HexOrDecimal256 `json:"stack"` ReturnStack []math.HexOrDecimal64 `json:"returnStack"` - ReturnData []byte `json:"returnData"` + ReturnData hexutil.Bytes `json:"returnData"` Storage map[common.Hash]common.Hash `json:"-"` Depth int `json:"depth"` RefundCounter uint64 `json:"refund"` @@ -72,7 +72,7 @@ func (s *StructLog) UnmarshalJSON(input []byte) error { MemorySize *int `json:"memSize"` Stack []*math.HexOrDecimal256 `json:"stack"` ReturnStack []math.HexOrDecimal64 `json:"returnStack"` - ReturnData []byte `json:"returnData"` + ReturnData *hexutil.Bytes `json:"returnData"` Storage map[common.Hash]common.Hash `json:"-"` Depth *int `json:"depth"` RefundCounter *uint64 `json:"refund"` @@ -113,7 +113,7 @@ func (s *StructLog) UnmarshalJSON(input []byte) error { } } if dec.ReturnData != nil { - s.ReturnData = dec.ReturnData + s.ReturnData = *dec.ReturnData } if dec.Storage != nil { s.Storage = dec.Storage diff --git a/core/vm/instructions.go b/core/vm/instructions.go index adf44b7f481f..1137505292f1 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -341,7 +341,7 @@ func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *call func opExtCodeSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { slot := callContext.stack.peek() - slot.SetUint64(uint64(interpreter.evm.StateDB.GetCodeSize(common.Address(slot.Bytes20())))) + slot.SetUint64(uint64(interpreter.evm.StateDB.GetCodeSize(slot.Bytes20()))) return nil, nil } @@ -438,14 +438,14 @@ func opBlockhash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) return nil, nil } var upper, lower uint64 - upper = interpreter.evm.BlockNumber.Uint64() + upper = interpreter.evm.Context.BlockNumber.Uint64() if upper < 257 { lower = 0 } else { lower = upper - 256 } if num64 >= lower && num64 < upper { - num.SetBytes(interpreter.evm.GetHash(num64).Bytes()) + num.SetBytes(interpreter.evm.Context.GetHash(num64).Bytes()) } else { num.Clear() } @@ -453,30 +453,30 @@ func opBlockhash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) } func opCoinbase(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(interpreter.evm.Coinbase.Bytes())) + callContext.stack.push(new(uint256.Int).SetBytes(interpreter.evm.Context.Coinbase.Bytes())) return nil, nil } func opTimestamp(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v, _ := uint256.FromBig(interpreter.evm.Time) + v, _ := uint256.FromBig(interpreter.evm.Context.Time) callContext.stack.push(v) return nil, nil } func opNumber(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v, _ := uint256.FromBig(interpreter.evm.BlockNumber) + v, _ := uint256.FromBig(interpreter.evm.Context.BlockNumber) callContext.stack.push(v) return nil, nil } func opDifficulty(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v, _ := uint256.FromBig(interpreter.evm.Difficulty) + v, _ := uint256.FromBig(interpreter.evm.Context.Difficulty) callContext.stack.push(v) return nil, nil } func opGasLimit(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(interpreter.evm.GasLimit)) + callContext.stack.push(new(uint256.Int).SetUint64(interpreter.evm.Context.GasLimit)) return nil, nil } @@ -517,7 +517,7 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] loc := callContext.stack.pop() val := callContext.stack.pop() interpreter.evm.StateDB.SetState(callContext.contract.Address(), - common.Hash(loc.Bytes32()), common.Hash(val.Bytes32())) + loc.Bytes32(), val.Bytes32()) return nil, nil } @@ -817,7 +817,7 @@ func opStop(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by func opSuicide(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { beneficiary := callContext.stack.pop() balance := interpreter.evm.StateDB.GetBalance(callContext.contract.Address()) - interpreter.evm.StateDB.AddBalance(common.Address(beneficiary.Bytes20()), balance) + interpreter.evm.StateDB.AddBalance(beneficiary.Bytes20(), balance) interpreter.evm.StateDB.Suicide(callContext.contract.Address()) return nil, nil } @@ -832,7 +832,7 @@ func makeLog(size int) executionFunc { mStart, mSize := stack.pop(), stack.pop() for i := 0; i < size; i++ { addr := stack.pop() - topics[i] = common.Hash(addr.Bytes32()) + topics[i] = addr.Bytes32() } d := callContext.memory.GetCopy(int64(mStart.Uint64()), int64(mSize.Uint64())) @@ -842,7 +842,7 @@ func makeLog(size int) executionFunc { Data: d, // This is a non-consensus field, but assigned here because // core/state doesn't know the current block number. - BlockNumber: interpreter.evm.BlockNumber.Uint64(), + BlockNumber: interpreter.evm.Context.BlockNumber.Uint64(), }) return nil, nil diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 0b6fb1f486ad..985d5a5156c3 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -92,7 +92,7 @@ func init() { func testTwoOperandOp(t *testing.T, tests []TwoOperandTestcase, opFn executionFunc, name string) { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack = newstack() rstack = newReturnStack() pc = uint64(0) @@ -192,7 +192,7 @@ func TestSAR(t *testing.T) { func TestAddMod(t *testing.T) { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack = newstack() evmInterpreter = NewEVMInterpreter(env, env.vmConfig) pc = uint64(0) @@ -231,7 +231,7 @@ func TestAddMod(t *testing.T) { // getResult is a convenience function to generate the expected values func getResult(args []*twoOperandParams, opFn executionFunc) []TwoOperandTestcase { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack, rstack = newstack(), newReturnStack() pc = uint64(0) interpreter = env.interpreter.(*EVMInterpreter) @@ -281,7 +281,7 @@ func TestJsonTestcases(t *testing.T) { func opBenchmark(bench *testing.B, op executionFunc, args ...string) { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack, rstack = newstack(), newReturnStack() evmInterpreter = NewEVMInterpreter(env, env.vmConfig) ) @@ -515,7 +515,7 @@ func BenchmarkOpIsZero(b *testing.B) { func TestOpMstore(t *testing.T) { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack, rstack = newstack(), newReturnStack() mem = NewMemory() evmInterpreter = NewEVMInterpreter(env, env.vmConfig) @@ -539,7 +539,7 @@ func TestOpMstore(t *testing.T) { func BenchmarkOpMstore(bench *testing.B) { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack, rstack = newstack(), newReturnStack() mem = NewMemory() evmInterpreter = NewEVMInterpreter(env, env.vmConfig) @@ -560,7 +560,7 @@ func BenchmarkOpMstore(bench *testing.B) { func BenchmarkOpSHA3(bench *testing.B) { var ( - env = NewEVM(Context{}, nil, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, nil, params.TestChainConfig, Config{}) stack, rstack = newstack(), newReturnStack() mem = NewMemory() evmInterpreter = NewEVMInterpreter(env, env.vmConfig) diff --git a/core/vm/interface.go b/core/vm/interface.go index dd401466adfa..fb5bbca48f65 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -57,6 +57,15 @@ type StateDB interface { // is defined according to EIP161 (balance = nonce = code = 0). Empty(common.Address) bool + AddressInAccessList(addr common.Address) bool + SlotInAccessList(addr common.Address, slot common.Hash) (addressOk bool, slotOk bool) + // AddAddressToAccessList adds the given address to the access list. This operation is safe to perform + // even if the feature/fork is not active yet + AddAddressToAccessList(addr common.Address) + // AddSlotToAccessList adds the given (address,slot) to the access list. This operation is safe to perform + // even if the feature/fork is not active yet + AddSlotToAccessList(addr common.Address, slot common.Hash) + RevertToSnapshot(int) Snapshot() int diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 1e2a661debd1..bffc5013a65a 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -99,8 +99,8 @@ func NewEVMInterpreter(evm *EVM, cfg Config) *EVMInterpreter { if cfg.JumpTable[STOP] == nil { var jt JumpTable switch { - case evm.chainRules.IsYoloV1: - jt = yoloV1InstructionSet + case evm.chainRules.IsYoloV2: + jt = yoloV2InstructionSet case evm.chainRules.IsIstanbul: jt = istanbulInstructionSet case evm.chainRules.IsConstantinople: diff --git a/core/vm/jump_table.go b/core/vm/jump_table.go index 9d9bc12b62e6..83fb2c1ed628 100644 --- a/core/vm/jump_table.go +++ b/core/vm/jump_table.go @@ -56,17 +56,19 @@ var ( byzantiumInstructionSet = newByzantiumInstructionSet() constantinopleInstructionSet = newConstantinopleInstructionSet() istanbulInstructionSet = newIstanbulInstructionSet() - yoloV1InstructionSet = newYoloV1InstructionSet() + yoloV2InstructionSet = newYoloV2InstructionSet() ) // JumpTable contains the EVM opcodes supported at a given fork. type JumpTable [256]*operation -func newYoloV1InstructionSet() JumpTable { +// newYoloV2InstructionSet creates an instructionset containing +// - "EIP-2315: Simple Subroutines" +// - "EIP-2929: Gas cost increases for state access opcodes" +func newYoloV2InstructionSet() JumpTable { instructionSet := newIstanbulInstructionSet() - enable2315(&instructionSet) // Subroutines - https://eips.ethereum.org/EIPS/eip-2315 - + enable2929(&instructionSet) // Access lists for trie accesses https://eips.ethereum.org/EIPS/eip-2929 return instructionSet } diff --git a/core/vm/logger.go b/core/vm/logger.go index e1d7c67ef177..962be6ec8e0e 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/params" ) var errTraceLimitReached = errors.New("the number of logs reached the specified limit") @@ -53,6 +54,8 @@ type LogConfig struct { DisableReturnData bool // disable return data capture Debug bool // print output during capture end Limit int // maximum length of output, but zero means unlimited + // Chain overrides, can be used to execute a trace using future fork rules + Overrides *params.ChainConfig `json:"overrides,omitempty"` } //go:generate gencodec -type StructLog -field-override structLogMarshaling -out gen_structlog.go @@ -82,6 +85,7 @@ type structLogMarshaling struct { Gas math.HexOrDecimal64 GasCost math.HexOrDecimal64 Memory hexutil.Bytes + ReturnData hexutil.Bytes OpName string `json:"opName"` // adds call to OpName() in MarshalJSON ErrorString string `json:"error"` // adds call to ErrorString() in MarshalJSON } @@ -313,8 +317,8 @@ func (t *mdLogger) CaptureStart(from common.Address, to common.Address, create b } fmt.Fprintf(t.out, ` -| Pc | Op | Cost | Stack | RStack | -|-------|-------------|------|-----------|-----------| +| Pc | Op | Cost | Stack | RStack | Refund | +|-------|-------------|------|-----------|-----------|---------| `) return nil } @@ -322,22 +326,24 @@ func (t *mdLogger) CaptureStart(from common.Address, to common.Address, create b func (t *mdLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, rStack *ReturnStack, rData []byte, contract *Contract, depth int, err error) error { fmt.Fprintf(t.out, "| %4d | %10v | %3d |", pc, op, cost) - if !t.cfg.DisableStack { // format stack + if !t.cfg.DisableStack { + // format stack var a []string for _, elem := range stack.data { - a = append(a, fmt.Sprintf("%d", elem)) + a = append(a, fmt.Sprintf("%v", elem.String())) } b := fmt.Sprintf("[%v]", strings.Join(a, ",")) fmt.Fprintf(t.out, "%10v |", b) - } - if !t.cfg.DisableStack { // format return stack - var a []string + + // format return stack + a = a[:0] for _, elem := range rStack.data { a = append(a, fmt.Sprintf("%2d", elem)) } - b := fmt.Sprintf("[%v]", strings.Join(a, ",")) + b = fmt.Sprintf("[%v]", strings.Join(a, ",")) fmt.Fprintf(t.out, "%10v |", b) } + fmt.Fprintf(t.out, "%10v |", env.StateDB.GetRefund()) fmt.Fprintln(t.out, "") if err != nil { fmt.Fprintf(t.out, "Error: %v\n", err) @@ -353,11 +359,7 @@ func (t *mdLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64 } func (t *mdLogger) CaptureEnd(output []byte, gasUsed uint64, tm time.Duration, err error) error { - fmt.Fprintf(t.out, ` -Output: 0x%x -Consumed gas: %d -Error: %v -`, + fmt.Fprintf(t.out, "\nOutput: `0x%x`\nConsumed gas: `%d`\nError: `%v`\n", output, gasUsed, err) return nil } diff --git a/core/vm/logger_test.go b/core/vm/logger_test.go index e287f0c7aaed..bf7d5358f896 100644 --- a/core/vm/logger_test.go +++ b/core/vm/logger_test.go @@ -51,7 +51,7 @@ func (*dummyStatedb) GetRefund() uint64 { return 1337 } func TestStoreCapture(t *testing.T) { var ( - env = NewEVM(Context{}, &dummyStatedb{}, params.TestChainConfig, Config{}) + env = NewEVM(BlockContext{}, TxContext{}, &dummyStatedb{}, params.TestChainConfig, Config{}) logger = NewStructLogger(nil) mem = NewMemory() stack = newstack() diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go new file mode 100644 index 000000000000..191953ce5e4a --- /dev/null +++ b/core/vm/operations_acl.go @@ -0,0 +1,222 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package vm + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/params" +) + +const ( + ColdAccountAccessCostEIP2929 = uint64(2600) // COLD_ACCOUNT_ACCESS_COST + ColdSloadCostEIP2929 = uint64(2100) // COLD_SLOAD_COST + WarmStorageReadCostEIP2929 = uint64(100) // WARM_STORAGE_READ_COST +) + +// gasSStoreEIP2929 implements gas cost for SSTORE according to EIP-2929" +// +// When calling SSTORE, check if the (address, storage_key) pair is in accessed_storage_keys. +// If it is not, charge an additional COLD_SLOAD_COST gas, and add the pair to accessed_storage_keys. +// Additionally, modify the parameters defined in EIP 2200 as follows: +// +// Parameter Old value New value +// SLOAD_GAS 800 = WARM_STORAGE_READ_COST +// SSTORE_RESET_GAS 5000 5000 - COLD_SLOAD_COST +// +//The other parameters defined in EIP 2200 are unchanged. +// see gasSStoreEIP2200(...) in core/vm/gas_table.go for more info about how EIP 2200 is specified +func gasSStoreEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + // If we fail the minimum gas availability invariant, fail (0) + if contract.Gas <= params.SstoreSentryGasEIP2200 { + return 0, errors.New("not enough gas for reentrancy sentry") + } + // Gas sentry honoured, do the actual gas calculation based on the stored value + var ( + y, x = stack.Back(1), stack.peek() + slot = common.Hash(x.Bytes32()) + current = evm.StateDB.GetState(contract.Address(), slot) + cost = uint64(0) + ) + // Check slot presence in the access list + if addrPresent, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { + cost = ColdSloadCostEIP2929 + // If the caller cannot afford the cost, this change will be rolled back + evm.StateDB.AddSlotToAccessList(contract.Address(), slot) + if !addrPresent { + // Once we're done with YOLOv2 and schedule this for mainnet, might + // be good to remove this panic here, which is just really a + // canary to have during testing + panic("impossible case: address was not present in access list during sstore op") + } + } + value := common.Hash(y.Bytes32()) + + if current == value { // noop (1) + // EIP 2200 original clause: + // return params.SloadGasEIP2200, nil + return cost + WarmStorageReadCostEIP2929, nil // SLOAD_GAS + } + original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if original == current { + if original == (common.Hash{}) { // create slot (2.1.1) + return cost + params.SstoreSetGasEIP2200, nil + } + if value == (common.Hash{}) { // delete slot (2.1.2b) + evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) + } + // EIP-2200 original clause: + // return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) + return cost + (params.SstoreResetGasEIP2200 - ColdSloadCostEIP2929), nil // write existing slot (2.1.2) + } + if original != (common.Hash{}) { + if current == (common.Hash{}) { // recreate slot (2.2.1.1) + evm.StateDB.SubRefund(params.SstoreClearsScheduleRefundEIP2200) + } else if value == (common.Hash{}) { // delete slot (2.2.1.2) + evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) + } + } + if original == value { + if original == (common.Hash{}) { // reset to original inexistent slot (2.2.2.1) + // EIP 2200 Original clause: + //evm.StateDB.AddRefund(params.SstoreSetGasEIP2200 - params.SloadGasEIP2200) + evm.StateDB.AddRefund(params.SstoreSetGasEIP2200 - WarmStorageReadCostEIP2929) + } else { // reset to original existing slot (2.2.2.2) + // EIP 2200 Original clause: + // evm.StateDB.AddRefund(params.SstoreResetGasEIP2200 - params.SloadGasEIP2200) + // - SSTORE_RESET_GAS redefined as (5000 - COLD_SLOAD_COST) + // - SLOAD_GAS redefined as WARM_STORAGE_READ_COST + // Final: (5000 - COLD_SLOAD_COST) - WARM_STORAGE_READ_COST + evm.StateDB.AddRefund((params.SstoreResetGasEIP2200 - ColdSloadCostEIP2929) - WarmStorageReadCostEIP2929) + } + } + // EIP-2200 original clause: + //return params.SloadGasEIP2200, nil // dirty update (2.2) + return cost + WarmStorageReadCostEIP2929, nil // dirty update (2.2) +} + +// gasSLoadEIP2929 calculates dynamic gas for SLOAD according to EIP-2929 +// For SLOAD, if the (address, storage_key) pair (where address is the address of the contract +// whose storage is being read) is not yet in accessed_storage_keys, +// charge 2100 gas and add the pair to accessed_storage_keys. +// If the pair is already in accessed_storage_keys, charge 100 gas. +func gasSLoadEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + loc := stack.peek() + slot := common.Hash(loc.Bytes32()) + // Check slot presence in the access list + if _, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { + // If the caller cannot afford the cost, this change will be rolled back + // If he does afford it, we can skip checking the same thing later on, during execution + evm.StateDB.AddSlotToAccessList(contract.Address(), slot) + return ColdSloadCostEIP2929, nil + } + return WarmStorageReadCostEIP2929, nil +} + +// gasExtCodeCopyEIP2929 implements extcodecopy according to EIP-2929 +// EIP spec: +// > If the target is not in accessed_addresses, +// > charge COLD_ACCOUNT_ACCESS_COST gas, and add the address to accessed_addresses. +// > Otherwise, charge WARM_STORAGE_READ_COST gas. +func gasExtCodeCopyEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + // memory expansion first (dynamic part of pre-2929 implementation) + gas, err := gasExtCodeCopy(evm, contract, stack, mem, memorySize) + if err != nil { + return 0, err + } + addr := common.Address(stack.peek().Bytes20()) + // Check slot presence in the access list + if !evm.StateDB.AddressInAccessList(addr) { + evm.StateDB.AddAddressToAccessList(addr) + var overflow bool + // We charge (cold-warm), since 'warm' is already charged as constantGas + if gas, overflow = math.SafeAdd(gas, ColdAccountAccessCostEIP2929-WarmStorageReadCostEIP2929); overflow { + return 0, ErrGasUintOverflow + } + return gas, nil + } + return gas, nil +} + +// gasEip2929AccountCheck checks whether the first stack item (as address) is present in the access list. +// If it is, this method returns '0', otherwise 'cold-warm' gas, presuming that the opcode using it +// is also using 'warm' as constant factor. +// This method is used by: +// - extcodehash, +// - extcodesize, +// - (ext) balance +func gasEip2929AccountCheck(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + addr := common.Address(stack.peek().Bytes20()) + // Check slot presence in the access list + if !evm.StateDB.AddressInAccessList(addr) { + // If the caller cannot afford the cost, this change will be rolled back + evm.StateDB.AddAddressToAccessList(addr) + // The warm storage read cost is already charged as constantGas + return ColdAccountAccessCostEIP2929 - WarmStorageReadCostEIP2929, nil + } + return 0, nil +} + +func makeCallVariantGasCallEIP2929(oldCalculator gasFunc) gasFunc { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + addr := common.Address(stack.Back(1).Bytes20()) + // Check slot presence in the access list + if !evm.StateDB.AddressInAccessList(addr) { + evm.StateDB.AddAddressToAccessList(addr) + // The WarmStorageReadCostEIP2929 (100) is already deducted in the form of a constant cost + if !contract.UseGas(ColdAccountAccessCostEIP2929 - WarmStorageReadCostEIP2929) { + return 0, ErrOutOfGas + } + } + // Now call the old calculator, which takes into account + // - create new account + // - transfer value + // - memory expansion + // - 63/64ths rule + return oldCalculator(evm, contract, stack, mem, memorySize) + } +} + +var ( + gasCallEIP2929 = makeCallVariantGasCallEIP2929(gasCall) + gasDelegateCallEIP2929 = makeCallVariantGasCallEIP2929(gasDelegateCall) + gasStaticCallEIP2929 = makeCallVariantGasCallEIP2929(gasStaticCall) + gasCallCodeEIP2929 = makeCallVariantGasCallEIP2929(gasCallCode) +) + +func gasSelfdestructEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + var ( + gas uint64 + address = common.Address(stack.peek().Bytes20()) + ) + if !evm.StateDB.AddressInAccessList(address) { + // If the caller cannot afford the cost, this change will be rolled back + evm.StateDB.AddAddressToAccessList(address) + gas = ColdAccountAccessCostEIP2929 + } + // if empty and transfers value + if evm.StateDB.Empty(address) && evm.StateDB.GetBalance(contract.Address()).Sign() != 0 { + gas += params.CreateBySelfdestructGas + } + if !evm.StateDB.HasSuicided(contract.Address()) { + evm.StateDB.AddRefund(params.SelfdestructRefundGas) + } + return gas, nil + +} diff --git a/core/vm/runtime/env.go b/core/vm/runtime/env.go index 38ee44890425..6c4c72eeaca5 100644 --- a/core/vm/runtime/env.go +++ b/core/vm/runtime/env.go @@ -22,18 +22,20 @@ import ( ) func NewEnv(cfg *Config) *vm.EVM { - context := vm.Context{ + txContext := vm.TxContext{ + Origin: cfg.Origin, + GasPrice: cfg.GasPrice, + } + blockContext := vm.BlockContext{ CanTransfer: core.CanTransfer, Transfer: core.Transfer, GetHash: cfg.GetHashFn, - Origin: cfg.Origin, Coinbase: cfg.Coinbase, BlockNumber: cfg.BlockNumber, Time: cfg.Time, Difficulty: cfg.Difficulty, GasLimit: cfg.GasLimit, - GasPrice: cfg.GasPrice, } - return vm.NewEVM(context, cfg.State, cfg.ChainConfig, cfg.EVMConfig) + return vm.NewEVM(blockContext, txContext, cfg.State, cfg.ChainConfig, cfg.EVMConfig) } diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index 7ebaa9a7e386..2586535f9db8 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -65,7 +65,7 @@ func setDefaults(cfg *Config) { PetersburgBlock: new(big.Int), IstanbulBlock: new(big.Int), MuirGlacierBlock: new(big.Int), - YoloV1Block: nil, + YoloV2Block: nil, } } @@ -113,6 +113,13 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { vmenv = NewEnv(cfg) sender = vm.AccountRef(cfg.Origin) ) + if cfg.ChainConfig.IsYoloV2(vmenv.Context.BlockNumber) { + cfg.State.AddAddressToAccessList(cfg.Origin) + cfg.State.AddAddressToAccessList(address) + for _, addr := range vmenv.ActivePrecompiles() { + cfg.State.AddAddressToAccessList(addr) + } + } cfg.State.CreateAccount(address) // set the receiver's (the executing contract) code for execution. cfg.State.SetCode(address, code) @@ -142,6 +149,12 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { vmenv = NewEnv(cfg) sender = vm.AccountRef(cfg.Origin) ) + if cfg.ChainConfig.IsYoloV2(vmenv.Context.BlockNumber) { + cfg.State.AddAddressToAccessList(cfg.Origin) + for _, addr := range vmenv.ActivePrecompiles() { + cfg.State.AddAddressToAccessList(addr) + } + } // Call the code with the given configuration. code, address, leftOverGas, err := vmenv.Create( @@ -164,6 +177,14 @@ func Call(address common.Address, input []byte, cfg *Config) ([]byte, uint64, er vmenv := NewEnv(cfg) sender := cfg.State.GetOrNewStateObject(cfg.Origin) + if cfg.ChainConfig.IsYoloV2(vmenv.Context.BlockNumber) { + cfg.State.AddAddressToAccessList(cfg.Origin) + cfg.State.AddAddressToAccessList(address) + for _, addr := range vmenv.ActivePrecompiles() { + cfg.State.AddAddressToAccessList(addr) + } + } + // Call the code with the given configuration. ret, leftOverGas, err := vmenv.Call( sender, diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 108ee80e41e6..b185258dadb8 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -722,3 +722,115 @@ func BenchmarkSimpleLoop(b *testing.B) { //benchmarkNonModifyingCode(10000000, staticCallIdentity, "staticcall-identity-10M", b) //benchmarkNonModifyingCode(10000000, loopingCode, "loop-10M", b) } + +// TestEip2929Cases contains various testcases that are used for +// EIP-2929 about gas repricings +func TestEip2929Cases(t *testing.T) { + + id := 1 + prettyPrint := func(comment string, code []byte) { + + instrs := make([]string, 0) + it := asm.NewInstructionIterator(code) + for it.Next() { + if it.Arg() != nil && 0 < len(it.Arg()) { + instrs = append(instrs, fmt.Sprintf("%v 0x%x", it.Op(), it.Arg())) + } else { + instrs = append(instrs, fmt.Sprintf("%v", it.Op())) + } + } + ops := strings.Join(instrs, ", ") + fmt.Printf("### Case %d\n\n", id) + id++ + fmt.Printf("%v\n\nBytecode: \n```\n0x%x\n```\nOperations: \n```\n%v\n```\n\n", + comment, + code, ops) + Execute(code, nil, &Config{ + EVMConfig: vm.Config{ + Debug: true, + Tracer: vm.NewMarkdownLogger(nil, os.Stdout), + ExtraEips: []int{2929}, + }, + }) + } + + { // First eip testcase + code := []byte{ + // Three checks against a precompile + byte(vm.PUSH1), 1, byte(vm.EXTCODEHASH), byte(vm.POP), + byte(vm.PUSH1), 2, byte(vm.EXTCODESIZE), byte(vm.POP), + byte(vm.PUSH1), 3, byte(vm.BALANCE), byte(vm.POP), + // Three checks against a non-precompile + byte(vm.PUSH1), 0xf1, byte(vm.EXTCODEHASH), byte(vm.POP), + byte(vm.PUSH1), 0xf2, byte(vm.EXTCODESIZE), byte(vm.POP), + byte(vm.PUSH1), 0xf3, byte(vm.BALANCE), byte(vm.POP), + // Same three checks (should be cheaper) + byte(vm.PUSH1), 0xf2, byte(vm.EXTCODEHASH), byte(vm.POP), + byte(vm.PUSH1), 0xf3, byte(vm.EXTCODESIZE), byte(vm.POP), + byte(vm.PUSH1), 0xf1, byte(vm.BALANCE), byte(vm.POP), + // Check the origin, and the 'this' + byte(vm.ORIGIN), byte(vm.BALANCE), byte(vm.POP), + byte(vm.ADDRESS), byte(vm.BALANCE), byte(vm.POP), + + byte(vm.STOP), + } + prettyPrint("This checks `EXT`(codehash,codesize,balance) of precompiles, which should be `100`, "+ + "and later checks the same operations twice against some non-precompiles. "+ + "Those are cheaper second time they are accessed. Lastly, it checks the `BALANCE` of `origin` and `this`.", code) + } + + { // EXTCODECOPY + code := []byte{ + // extcodecopy( 0xff,0,0,0,0) + byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, //length, codeoffset, memoffset + byte(vm.PUSH1), 0xff, byte(vm.EXTCODECOPY), + // extcodecopy( 0xff,0,0,0,0) + byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, //length, codeoffset, memoffset + byte(vm.PUSH1), 0xff, byte(vm.EXTCODECOPY), + // extcodecopy( this,0,0,0,0) + byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, //length, codeoffset, memoffset + byte(vm.ADDRESS), byte(vm.EXTCODECOPY), + + byte(vm.STOP), + } + prettyPrint("This checks `extcodecopy( 0xff,0,0,0,0)` twice, (should be expensive first time), "+ + "and then does `extcodecopy( this,0,0,0,0)`.", code) + } + + { // SLOAD + SSTORE + code := []byte{ + + // Add slot `0x1` to access list + byte(vm.PUSH1), 0x01, byte(vm.SLOAD), byte(vm.POP), // SLOAD( 0x1) (add to access list) + // Write to `0x1` which is already in access list + byte(vm.PUSH1), 0x11, byte(vm.PUSH1), 0x01, byte(vm.SSTORE), // SSTORE( loc: 0x01, val: 0x11) + // Write to `0x2` which is not in access list + byte(vm.PUSH1), 0x11, byte(vm.PUSH1), 0x02, byte(vm.SSTORE), // SSTORE( loc: 0x02, val: 0x11) + // Write again to `0x2` + byte(vm.PUSH1), 0x11, byte(vm.PUSH1), 0x02, byte(vm.SSTORE), // SSTORE( loc: 0x02, val: 0x11) + // Read slot in access list (0x2) + byte(vm.PUSH1), 0x02, byte(vm.SLOAD), // SLOAD( 0x2) + // Read slot in access list (0x1) + byte(vm.PUSH1), 0x01, byte(vm.SLOAD), // SLOAD( 0x1) + } + prettyPrint("This checks `sload( 0x1)` followed by `sstore(loc: 0x01, val:0x11)`, then 'naked' sstore:"+ + "`sstore(loc: 0x02, val:0x11)` twice, and `sload(0x2)`, `sload(0x1)`. ", code) + } + { // Call variants + code := []byte{ + // identity precompile + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), + byte(vm.PUSH1), 0x04, byte(vm.PUSH1), 0x0, byte(vm.CALL), byte(vm.POP), + + // random account - call 1 + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), + byte(vm.PUSH1), 0xff, byte(vm.PUSH1), 0x0, byte(vm.CALL), byte(vm.POP), + + // random account - call 2 + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), + byte(vm.PUSH1), 0xff, byte(vm.PUSH1), 0x0, byte(vm.STATICCALL), byte(vm.POP), + } + prettyPrint("This calls the `identity`-precompile (cheap), then calls an account (expensive) and `staticcall`s the same"+ + "account (cheap)", code) + } +} diff --git a/core/vm/testdata/precompiles/modexp_eip2565.json b/core/vm/testdata/precompiles/modexp_eip2565.json new file mode 100644 index 000000000000..c55441439ebb --- /dev/null +++ b/core/vm/testdata/precompiles/modexp_eip2565.json @@ -0,0 +1,121 @@ +[ + { + "Input": "00000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000002003fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2efffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", + "Expected": "0000000000000000000000000000000000000000000000000000000000000001", + "Name": "eip_example1", + "Gas": 1360, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000020fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2efffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", + "Expected": "0000000000000000000000000000000000000000000000000000000000000000", + "Name": "eip_example2", + "Gas": 1360, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000040e09ad9675465c53a109fac66a445c91b292d2bb2c5268addb30cd82f80fcb0033ff97c80a5fc6f39193ae969c6ede6710a6b7ac27078a06d90ef1c72e5c85fb502fc9e1f6beb81516545975218075ec2af118cd8798df6e08a147c60fd6095ac2bb02c2908cf4dd7c81f11c289e4bce98f3553768f392a80ce22bf5c4f4a248c6b", + "Expected": "60008f1614cc01dcfb6bfb09c625cf90b47d4468db81b5f8b7a39d42f332eab9b2da8f2d95311648a8f243f4bb13cfb3d8f7f2a3c014122ebb3ed41b02783adc", + "Name": "nagydani-1-square", + "Gas": 200, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000040e09ad9675465c53a109fac66a445c91b292d2bb2c5268addb30cd82f80fcb0033ff97c80a5fc6f39193ae969c6ede6710a6b7ac27078a06d90ef1c72e5c85fb503fc9e1f6beb81516545975218075ec2af118cd8798df6e08a147c60fd6095ac2bb02c2908cf4dd7c81f11c289e4bce98f3553768f392a80ce22bf5c4f4a248c6b", + "Expected": "4834a46ba565db27903b1c720c9d593e84e4cbd6ad2e64b31885d944f68cd801f92225a8961c952ddf2797fa4701b330c85c4b363798100b921a1a22a46a7fec", + "Name": "nagydani-1-qube", + "Gas": 200, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000040e09ad9675465c53a109fac66a445c91b292d2bb2c5268addb30cd82f80fcb0033ff97c80a5fc6f39193ae969c6ede6710a6b7ac27078a06d90ef1c72e5c85fb5010001fc9e1f6beb81516545975218075ec2af118cd8798df6e08a147c60fd6095ac2bb02c2908cf4dd7c81f11c289e4bce98f3553768f392a80ce22bf5c4f4a248c6b", + "Expected": "c36d804180c35d4426b57b50c5bfcca5c01856d104564cd513b461d3c8b8409128a5573e416d0ebe38f5f736766d9dc27143e4da981dfa4d67f7dc474cbee6d2", + "Name": "nagydani-1-pow0x10001", + "Gas": 341, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000080cad7d991a00047dd54d3399b6b0b937c718abddef7917c75b6681f40cc15e2be0003657d8d4c34167b2f0bbbca0ccaa407c2a6a07d50f1517a8f22979ce12a81dcaf707cc0cebfc0ce2ee84ee7f77c38b9281b9822a8d3de62784c089c9b18dcb9a2a5eecbede90ea788a862a9ddd9d609c2c52972d63e289e28f6a590ffbf5102e6d893b80aeed5e6e9ce9afa8a5d5675c93a32ac05554cb20e9951b2c140e3ef4e433068cf0fb73bc9f33af1853f64aa27a0028cbf570d7ac9048eae5dc7b28c87c31e5810f1e7fa2cda6adf9f1076dbc1ec1238560071e7efc4e9565c49be9e7656951985860a558a754594115830bcdb421f741408346dd5997bb01c287087", + "Expected": "981dd99c3b113fae3e3eaa9435c0dc96779a23c12a53d1084b4f67b0b053a27560f627b873e3f16ad78f28c94f14b6392def26e4d8896c5e3c984e50fa0b3aa44f1da78b913187c6128baa9340b1e9c9a0fd02cb78885e72576da4a8f7e5a113e173a7a2889fde9d407bd9f06eb05bc8fc7b4229377a32941a02bf4edcc06d70", + "Name": "nagydani-2-square", + "Gas": 200, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000080cad7d991a00047dd54d3399b6b0b937c718abddef7917c75b6681f40cc15e2be0003657d8d4c34167b2f0bbbca0ccaa407c2a6a07d50f1517a8f22979ce12a81dcaf707cc0cebfc0ce2ee84ee7f77c38b9281b9822a8d3de62784c089c9b18dcb9a2a5eecbede90ea788a862a9ddd9d609c2c52972d63e289e28f6a590ffbf5103e6d893b80aeed5e6e9ce9afa8a5d5675c93a32ac05554cb20e9951b2c140e3ef4e433068cf0fb73bc9f33af1853f64aa27a0028cbf570d7ac9048eae5dc7b28c87c31e5810f1e7fa2cda6adf9f1076dbc1ec1238560071e7efc4e9565c49be9e7656951985860a558a754594115830bcdb421f741408346dd5997bb01c287087", + "Expected": "d89ceb68c32da4f6364978d62aaa40d7b09b59ec61eb3c0159c87ec3a91037f7dc6967594e530a69d049b64adfa39c8fa208ea970cfe4b7bcd359d345744405afe1cbf761647e32b3184c7fbe87cee8c6c7ff3b378faba6c68b83b6889cb40f1603ee68c56b4c03d48c595c826c041112dc941878f8c5be828154afd4a16311f", + "Name": "nagydani-2-qube", + "Gas": 200, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000008000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000080cad7d991a00047dd54d3399b6b0b937c718abddef7917c75b6681f40cc15e2be0003657d8d4c34167b2f0bbbca0ccaa407c2a6a07d50f1517a8f22979ce12a81dcaf707cc0cebfc0ce2ee84ee7f77c38b9281b9822a8d3de62784c089c9b18dcb9a2a5eecbede90ea788a862a9ddd9d609c2c52972d63e289e28f6a590ffbf51010001e6d893b80aeed5e6e9ce9afa8a5d5675c93a32ac05554cb20e9951b2c140e3ef4e433068cf0fb73bc9f33af1853f64aa27a0028cbf570d7ac9048eae5dc7b28c87c31e5810f1e7fa2cda6adf9f1076dbc1ec1238560071e7efc4e9565c49be9e7656951985860a558a754594115830bcdb421f741408346dd5997bb01c287087", + "Expected": "ad85e8ef13fd1dd46eae44af8b91ad1ccae5b7a1c92944f92a19f21b0b658139e0cabe9c1f679507c2de354bf2c91ebd965d1e633978a830d517d2f6f8dd5fd58065d58559de7e2334a878f8ec6992d9b9e77430d4764e863d77c0f87beede8f2f7f2ab2e7222f85cc9d98b8467f4bb72e87ef2882423ebdb6daf02dddac6db2", + "Name": "nagydani-2-pow0x10001", + "Gas": 1365, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000100c9130579f243e12451760976261416413742bd7c91d39ae087f46794062b8c239f2a74abf3918605a0e046a7890e049475ba7fbb78f5de6490bd22a710cc04d30088179a919d86c2da62cf37f59d8f258d2310d94c24891be2d7eeafaa32a8cb4b0cfe5f475ed778f45907dc8916a73f03635f233f7a77a00a3ec9ca6761a5bbd558a2318ecd0caa1c5016691523e7e1fa267dd35e70c66e84380bdcf7c0582f540174e572c41f81e93da0b757dff0b0fe23eb03aa19af0bdec3afb474216febaacb8d0381e631802683182b0fe72c28392539850650b70509f54980241dc175191a35d967288b532a7a8223ce2440d010615f70df269501944d4ec16fe4a3cb02d7a85909174757835187cb52e71934e6c07ef43b4c46fc30bbcd0bc72913068267c54a4aabebb493922492820babdeb7dc9b1558fcf7bd82c37c82d3147e455b623ab0efa752fe0b3a67ca6e4d126639e645a0bf417568adbb2a6a4eef62fa1fa29b2a5a43bebea1f82193a7dd98eb483d09bb595af1fa9c97c7f41f5649d976aee3e5e59e2329b43b13bea228d4a93f16ba139ccb511de521ffe747aa2eca664f7c9e33da59075cc335afcd2bf3ae09765f01ab5a7c3e3938ec168b74724b5074247d200d9970382f683d6059b94dbc336603d1dfee714e4b447ac2fa1d99ecb4961da2854e03795ed758220312d101e1e3d87d5313a6d052aebde75110363d", + "Expected": "affc7507ea6d84751ec6b3f0d7b99dbcc263f33330e450d1b3ff0bc3d0874320bf4edd57debd587306988157958cb3cfd369cc0c9c198706f635c9e0f15d047df5cb44d03e2727f26b083c4ad8485080e1293f171c1ed52aef5993a5815c35108e848c951cf1e334490b4a539a139e57b68f44fee583306f5b85ffa57206b3ee5660458858534e5386b9584af3c7f67806e84c189d695e5eb96e1272d06ec2df5dc5fabc6e94b793718c60c36be0a4d031fc84cd658aa72294b2e16fc240aef70cb9e591248e38bd49c5a554d1afa01f38dab72733092f7555334bbef6c8c430119840492380aa95fa025dcf699f0a39669d812b0c6946b6091e6e235337b6f8", + "Name": "nagydani-3-square", + "Gas": 341, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000100c9130579f243e12451760976261416413742bd7c91d39ae087f46794062b8c239f2a74abf3918605a0e046a7890e049475ba7fbb78f5de6490bd22a710cc04d30088179a919d86c2da62cf37f59d8f258d2310d94c24891be2d7eeafaa32a8cb4b0cfe5f475ed778f45907dc8916a73f03635f233f7a77a00a3ec9ca6761a5bbd558a2318ecd0caa1c5016691523e7e1fa267dd35e70c66e84380bdcf7c0582f540174e572c41f81e93da0b757dff0b0fe23eb03aa19af0bdec3afb474216febaacb8d0381e631802683182b0fe72c28392539850650b70509f54980241dc175191a35d967288b532a7a8223ce2440d010615f70df269501944d4ec16fe4a3cb03d7a85909174757835187cb52e71934e6c07ef43b4c46fc30bbcd0bc72913068267c54a4aabebb493922492820babdeb7dc9b1558fcf7bd82c37c82d3147e455b623ab0efa752fe0b3a67ca6e4d126639e645a0bf417568adbb2a6a4eef62fa1fa29b2a5a43bebea1f82193a7dd98eb483d09bb595af1fa9c97c7f41f5649d976aee3e5e59e2329b43b13bea228d4a93f16ba139ccb511de521ffe747aa2eca664f7c9e33da59075cc335afcd2bf3ae09765f01ab5a7c3e3938ec168b74724b5074247d200d9970382f683d6059b94dbc336603d1dfee714e4b447ac2fa1d99ecb4961da2854e03795ed758220312d101e1e3d87d5313a6d052aebde75110363d", + "Expected": "1b280ecd6a6bf906b806d527c2a831e23b238f89da48449003a88ac3ac7150d6a5e9e6b3be4054c7da11dd1e470ec29a606f5115801b5bf53bc1900271d7c3ff3cd5ed790d1c219a9800437a689f2388ba1a11d68f6a8e5b74e9a3b1fac6ee85fc6afbac599f93c391f5dc82a759e3c6c0ab45ce3f5d25d9b0c1bf94cf701ea6466fc9a478dacc5754e593172b5111eeba88557048bceae401337cd4c1182ad9f700852bc8c99933a193f0b94cf1aedbefc48be3bc93ef5cb276d7c2d5462ac8bb0c8fe8923a1db2afe1c6b90d59c534994a6a633f0ead1d638fdc293486bb634ff2c8ec9e7297c04241a61c37e3ae95b11d53343d4ba2b4cc33d2cfa7eb705e", + "Name": "nagydani-3-qube", + "Gas": 341, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000100c9130579f243e12451760976261416413742bd7c91d39ae087f46794062b8c239f2a74abf3918605a0e046a7890e049475ba7fbb78f5de6490bd22a710cc04d30088179a919d86c2da62cf37f59d8f258d2310d94c24891be2d7eeafaa32a8cb4b0cfe5f475ed778f45907dc8916a73f03635f233f7a77a00a3ec9ca6761a5bbd558a2318ecd0caa1c5016691523e7e1fa267dd35e70c66e84380bdcf7c0582f540174e572c41f81e93da0b757dff0b0fe23eb03aa19af0bdec3afb474216febaacb8d0381e631802683182b0fe72c28392539850650b70509f54980241dc175191a35d967288b532a7a8223ce2440d010615f70df269501944d4ec16fe4a3cb010001d7a85909174757835187cb52e71934e6c07ef43b4c46fc30bbcd0bc72913068267c54a4aabebb493922492820babdeb7dc9b1558fcf7bd82c37c82d3147e455b623ab0efa752fe0b3a67ca6e4d126639e645a0bf417568adbb2a6a4eef62fa1fa29b2a5a43bebea1f82193a7dd98eb483d09bb595af1fa9c97c7f41f5649d976aee3e5e59e2329b43b13bea228d4a93f16ba139ccb511de521ffe747aa2eca664f7c9e33da59075cc335afcd2bf3ae09765f01ab5a7c3e3938ec168b74724b5074247d200d9970382f683d6059b94dbc336603d1dfee714e4b447ac2fa1d99ecb4961da2854e03795ed758220312d101e1e3d87d5313a6d052aebde75110363d", + "Expected": "37843d7c67920b5f177372fa56e2a09117df585f81df8b300fba245b1175f488c99476019857198ed459ed8d9799c377330e49f4180c4bf8e8f66240c64f65ede93d601f957b95b83efdee1e1bfde74169ff77002eaf078c71815a9220c80b2e3b3ff22c2f358111d816ebf83c2999026b6de50bfc711ff68705d2f40b753424aefc9f70f08d908b5a20276ad613b4ab4309a3ea72f0c17ea9df6b3367d44fb3acab11c333909e02e81ea2ed404a712d3ea96bba87461720e2d98723e7acd0520ac1a5212dbedcd8dc0c1abf61d4719e319ff4758a774790b8d463cdfe131d1b2dcfee52d002694e98e720cb6ae7ccea353bc503269ba35f0f63bf8d7b672a76", + "Name": "nagydani-3-pow0x10001", + "Gas": 5461, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000200db34d0e438249c0ed685c949cc28776a05094e1c48691dc3f2dca5fc3356d2a0663bd376e4712839917eb9a19c670407e2c377a2de385a3ff3b52104f7f1f4e0c7bf7717fb913896693dc5edbb65b760ef1b00e42e9d8f9af17352385e1cd742c9b006c0f669995cb0bb21d28c0aced2892267637b6470d8cee0ab27fc5d42658f6e88240c31d6774aa60a7ebd25cd48b56d0da11209f1928e61005c6eb709f3e8e0aaf8d9b10f7d7e296d772264dc76897ccdddadc91efa91c1903b7232a9e4c3b941917b99a3bc0c26497dedc897c25750af60237aa67934a26a2bc491db3dcc677491944bc1f51d3e5d76b8d846a62db03dedd61ff508f91a56d71028125035c3a44cbb041497c83bf3e4ae2a9613a401cc721c547a2afa3b16a2969933d3626ed6d8a7428648f74122fd3f2a02a20758f7f693892c8fd798b39abac01d18506c45e71432639e9f9505719ee822f62ccbf47f6850f096ff77b5afaf4be7d772025791717dbe5abf9b3f40cff7d7aab6f67e38f62faf510747276e20a42127e7500c444f9ed92baf65ade9e836845e39c4316d9dce5f8e2c8083e2c0acbb95296e05e51aab13b6b8f53f06c9c4276e12b0671133218cc3ea907da3bd9a367096d9202128d14846cc2e20d56fc8473ecb07cecbfb8086919f3971926e7045b853d85a69d026195c70f9f7a823536e2a8f4b3e12e94d9b53a934353451094b8102df3143a0057457d75e8c708b6337a6f5a4fd1a06727acf9fb93e2993c62f3378b37d56c85e7b1e00f0145ebf8e4095bd723166293c60b6ac1252291ef65823c9e040ddad14969b3b340a4ef714db093a587c37766d68b8d6b5016e741587e7e6bf7e763b44f0247e64bae30f994d248bfd20541a333e5b225ef6a61199e301738b1e688f70ec1d7fb892c183c95dc543c3e12adf8a5e8b9ca9d04f9445cced3ab256f29e998e69efaa633a7b60e1db5a867924ccab0a171d9d6e1098dfa15acde9553de599eaa56490c8f411e4985111f3d40bddfc5e301edb01547b01a886550a61158f7e2033c59707789bf7c854181d0c2e2a42a93cf09209747d7082e147eb8544de25c3eb14f2e35559ea0c0f5877f2f3fc92132c0ae9da4e45b2f6c866a224ea6d1f28c05320e287750fbc647368d41116e528014cc1852e5531d53e4af938374daba6cee4baa821ed07117253bb3601ddd00d59a3d7fb2ef1f5a2fbba7c429f0cf9a5b3462410fd833a69118f8be9c559b1000cc608fd877fb43f8e65c2d1302622b944462579056874b387208d90623fcdaf93920ca7a9e4ba64ea208758222ad868501cc2c345e2d3a5ea2a17e5069248138c8a79c0251185d29ee73e5afab5354769142d2bf0cb6712727aa6bf84a6245fcdae66e4938d84d1b9dd09a884818622080ff5f98942fb20acd7e0c916c2d5ea7ce6f7e173315384518f", + "Expected": "8a5aea5f50dcc03dc7a7a272b5aeebc040554dbc1ffe36753c4fc75f7ed5f6c2cc0de3a922bf96c78bf0643a73025ad21f45a4a5cadd717612c511ab2bff1190fe5f1ae05ba9f8fe3624de1de2a817da6072ddcdb933b50216811dbe6a9ca79d3a3c6b3a476b079fd0d05f04fb154e2dd3e5cb83b148a006f2bcbf0042efb2ae7b916ea81b27aac25c3bf9a8b6d35440062ad8eae34a83f3ffa2cc7b40346b62174a4422584f72f95316f6b2bee9ff232ba9739301c97c99a9ded26c45d72676eb856ad6ecc81d36a6de36d7f9dafafee11baa43a4b0d5e4ecffa7b9b7dcefd58c397dd373e6db4acd2b2c02717712e6289bed7c813b670c4a0c6735aa7f3b0f1ce556eae9fcc94b501b2c8781ba50a8c6220e8246371c3c7359fe4ef9da786ca7d98256754ca4e496be0a9174bedbecb384bdf470779186d6a833f068d2838a88d90ef3ad48ff963b67c39cc5a3ee123baf7bf3125f64e77af7f30e105d72c4b9b5b237ed251e4c122c6d8c1405e736299c3afd6db16a28c6a9cfa68241e53de4cd388271fe534a6a9b0dbea6171d170db1b89858468885d08fecbd54c8e471c3e25d48e97ba450b96d0d87e00ac732aaa0d3ce4309c1064bd8a4c0808a97e0143e43a24cfa847635125cd41c13e0574487963e9d725c01375db99c31da67b4cf65eff555f0c0ac416c727ff8d438ad7c42030551d68c2e7adda0abb1ca7c10", + "Name": "nagydani-4-square", + "Gas": 1365, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000200db34d0e438249c0ed685c949cc28776a05094e1c48691dc3f2dca5fc3356d2a0663bd376e4712839917eb9a19c670407e2c377a2de385a3ff3b52104f7f1f4e0c7bf7717fb913896693dc5edbb65b760ef1b00e42e9d8f9af17352385e1cd742c9b006c0f669995cb0bb21d28c0aced2892267637b6470d8cee0ab27fc5d42658f6e88240c31d6774aa60a7ebd25cd48b56d0da11209f1928e61005c6eb709f3e8e0aaf8d9b10f7d7e296d772264dc76897ccdddadc91efa91c1903b7232a9e4c3b941917b99a3bc0c26497dedc897c25750af60237aa67934a26a2bc491db3dcc677491944bc1f51d3e5d76b8d846a62db03dedd61ff508f91a56d71028125035c3a44cbb041497c83bf3e4ae2a9613a401cc721c547a2afa3b16a2969933d3626ed6d8a7428648f74122fd3f2a02a20758f7f693892c8fd798b39abac01d18506c45e71432639e9f9505719ee822f62ccbf47f6850f096ff77b5afaf4be7d772025791717dbe5abf9b3f40cff7d7aab6f67e38f62faf510747276e20a42127e7500c444f9ed92baf65ade9e836845e39c4316d9dce5f8e2c8083e2c0acbb95296e05e51aab13b6b8f53f06c9c4276e12b0671133218cc3ea907da3bd9a367096d9202128d14846cc2e20d56fc8473ecb07cecbfb8086919f3971926e7045b853d85a69d026195c70f9f7a823536e2a8f4b3e12e94d9b53a934353451094b8103df3143a0057457d75e8c708b6337a6f5a4fd1a06727acf9fb93e2993c62f3378b37d56c85e7b1e00f0145ebf8e4095bd723166293c60b6ac1252291ef65823c9e040ddad14969b3b340a4ef714db093a587c37766d68b8d6b5016e741587e7e6bf7e763b44f0247e64bae30f994d248bfd20541a333e5b225ef6a61199e301738b1e688f70ec1d7fb892c183c95dc543c3e12adf8a5e8b9ca9d04f9445cced3ab256f29e998e69efaa633a7b60e1db5a867924ccab0a171d9d6e1098dfa15acde9553de599eaa56490c8f411e4985111f3d40bddfc5e301edb01547b01a886550a61158f7e2033c59707789bf7c854181d0c2e2a42a93cf09209747d7082e147eb8544de25c3eb14f2e35559ea0c0f5877f2f3fc92132c0ae9da4e45b2f6c866a224ea6d1f28c05320e287750fbc647368d41116e528014cc1852e5531d53e4af938374daba6cee4baa821ed07117253bb3601ddd00d59a3d7fb2ef1f5a2fbba7c429f0cf9a5b3462410fd833a69118f8be9c559b1000cc608fd877fb43f8e65c2d1302622b944462579056874b387208d90623fcdaf93920ca7a9e4ba64ea208758222ad868501cc2c345e2d3a5ea2a17e5069248138c8a79c0251185d29ee73e5afab5354769142d2bf0cb6712727aa6bf84a6245fcdae66e4938d84d1b9dd09a884818622080ff5f98942fb20acd7e0c916c2d5ea7ce6f7e173315384518f", + "Expected": "5a2664252aba2d6e19d9600da582cdd1f09d7a890ac48e6b8da15ae7c6ff1856fc67a841ac2314d283ffa3ca81a0ecf7c27d89ef91a5a893297928f5da0245c99645676b481b7e20a566ee6a4f2481942bee191deec5544600bb2441fd0fb19e2ee7d801ad8911c6b7750affec367a4b29a22942c0f5f4744a4e77a8b654da2a82571037099e9c6d930794efe5cdca73c7b6c0844e386bdca8ea01b3d7807146bb81365e2cdc6475f8c23e0ff84463126189dc9789f72bbce2e3d2d114d728a272f1345122de23df54c922ec7a16e5c2a8f84da8871482bd258c20a7c09bbcd64c7a96a51029bbfe848736a6ba7bf9d931a9b7de0bcaf3635034d4958b20ae9ab3a95a147b0421dd5f7ebff46c971010ebfc4adbbe0ad94d5498c853e7142c450d8c71de4b2f84edbf8acd2e16d00c8115b150b1c30e553dbb82635e781379fe2a56360420ff7e9f70cc64c00aba7e26ed13c7c19622865ae07248daced36416080f35f8cc157a857ed70ea4f347f17d1bee80fa038abd6e39b1ba06b97264388b21364f7c56e192d4b62d9b161405f32ab1e2594e86243e56fcf2cb30d21adef15b9940f91af681da24328c883d892670c6aa47940867a81830a82b82716895db810df1b834640abefb7db2092dd92912cb9a735175bc447be40a503cf22dfe565b4ed7a3293ca0dfd63a507430b323ee248ec82e843b673c97ad730728cebc", + "Name": "nagydani-4-qube", + "Gas": 1365, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000200db34d0e438249c0ed685c949cc28776a05094e1c48691dc3f2dca5fc3356d2a0663bd376e4712839917eb9a19c670407e2c377a2de385a3ff3b52104f7f1f4e0c7bf7717fb913896693dc5edbb65b760ef1b00e42e9d8f9af17352385e1cd742c9b006c0f669995cb0bb21d28c0aced2892267637b6470d8cee0ab27fc5d42658f6e88240c31d6774aa60a7ebd25cd48b56d0da11209f1928e61005c6eb709f3e8e0aaf8d9b10f7d7e296d772264dc76897ccdddadc91efa91c1903b7232a9e4c3b941917b99a3bc0c26497dedc897c25750af60237aa67934a26a2bc491db3dcc677491944bc1f51d3e5d76b8d846a62db03dedd61ff508f91a56d71028125035c3a44cbb041497c83bf3e4ae2a9613a401cc721c547a2afa3b16a2969933d3626ed6d8a7428648f74122fd3f2a02a20758f7f693892c8fd798b39abac01d18506c45e71432639e9f9505719ee822f62ccbf47f6850f096ff77b5afaf4be7d772025791717dbe5abf9b3f40cff7d7aab6f67e38f62faf510747276e20a42127e7500c444f9ed92baf65ade9e836845e39c4316d9dce5f8e2c8083e2c0acbb95296e05e51aab13b6b8f53f06c9c4276e12b0671133218cc3ea907da3bd9a367096d9202128d14846cc2e20d56fc8473ecb07cecbfb8086919f3971926e7045b853d85a69d026195c70f9f7a823536e2a8f4b3e12e94d9b53a934353451094b81010001df3143a0057457d75e8c708b6337a6f5a4fd1a06727acf9fb93e2993c62f3378b37d56c85e7b1e00f0145ebf8e4095bd723166293c60b6ac1252291ef65823c9e040ddad14969b3b340a4ef714db093a587c37766d68b8d6b5016e741587e7e6bf7e763b44f0247e64bae30f994d248bfd20541a333e5b225ef6a61199e301738b1e688f70ec1d7fb892c183c95dc543c3e12adf8a5e8b9ca9d04f9445cced3ab256f29e998e69efaa633a7b60e1db5a867924ccab0a171d9d6e1098dfa15acde9553de599eaa56490c8f411e4985111f3d40bddfc5e301edb01547b01a886550a61158f7e2033c59707789bf7c854181d0c2e2a42a93cf09209747d7082e147eb8544de25c3eb14f2e35559ea0c0f5877f2f3fc92132c0ae9da4e45b2f6c866a224ea6d1f28c05320e287750fbc647368d41116e528014cc1852e5531d53e4af938374daba6cee4baa821ed07117253bb3601ddd00d59a3d7fb2ef1f5a2fbba7c429f0cf9a5b3462410fd833a69118f8be9c559b1000cc608fd877fb43f8e65c2d1302622b944462579056874b387208d90623fcdaf93920ca7a9e4ba64ea208758222ad868501cc2c345e2d3a5ea2a17e5069248138c8a79c0251185d29ee73e5afab5354769142d2bf0cb6712727aa6bf84a6245fcdae66e4938d84d1b9dd09a884818622080ff5f98942fb20acd7e0c916c2d5ea7ce6f7e173315384518f", + "Expected": "bed8b970c4a34849fc6926b08e40e20b21c15ed68d18f228904878d4370b56322d0da5789da0318768a374758e6375bfe4641fca5285ec7171828922160f48f5ca7efbfee4d5148612c38ad683ae4e3c3a053d2b7c098cf2b34f2cb19146eadd53c86b2d7ccf3d83b2c370bfb840913ee3879b1057a6b4e07e110b6bcd5e958bc71a14798c91d518cc70abee264b0d25a4110962a764b364ac0b0dd1ee8abc8426d775ec0f22b7e47b32576afaf1b5a48f64573ed1c5c29f50ab412188d9685307323d990802b81dacc06c6e05a1e901830ba9fcc67688dc29c5e27bde0a6e845ca925f5454b6fb3747edfaa2a5820838fb759eadf57f7cb5cec57fc213ddd8a4298fa079c3c0f472b07fb15aa6a7f0a3780bd296ff6a62e58ef443870b02260bd4fd2bbc98255674b8e1f1f9f8d33c7170b0ebbea4523b695911abbf26e41885344823bd0587115fdd83b721a4e8457a31c9a84b3d3520a07e0e35df7f48e5a9d534d0ec7feef1ff74de6a11e7f93eab95175b6ce22c68d78a642ad642837897ec11349205d8593ac19300207572c38d29ca5dfa03bc14cdbc32153c80e5cc3e739403d34c75915e49beb43094cc6dcafb3665b305ddec9286934ae66ec6b777ca528728c851318eb0f207b39f1caaf96db6eeead6b55ed08f451939314577d42bcc9f97c0b52d0234f88fd07e4c1d7780fdebc025cfffcb572cb27a8c33963", + "Name": "nagydani-4-pow0x10001", + "Gas": 21845, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000400c5a1611f8be90071a43db23cc2fe01871cc4c0e8ab5743f6378e4fef77f7f6db0095c0727e20225beb665645403453e325ad5f9aeb9ba99bf3c148f63f9c07cf4fe8847ad5242d6b7d4499f93bd47056ddab8f7dee878fc2314f344dbee2a7c41a5d3db91eff372c730c2fdd3a141a4b61999e36d549b9870cf2f4e632c4d5df5f024f81c028000073a0ed8847cfb0593d36a47142f578f05ccbe28c0c06aeb1b1da027794c48db880278f79ba78ae64eedfea3c07d10e0562668d839749dc95f40467d15cf65b9cfc52c7c4bcef1cda3596dd52631aac942f146c7cebd46065131699ce8385b0db1874336747ee020a5698a3d1a1082665721e769567f579830f9d259cec1a836845109c21cf6b25da572512bf3c42fd4b96e43895589042ab60dd41f497db96aec102087fe784165bb45f942859268fd2ff6c012d9d00c02ba83eace047cc5f7b2c392c2955c58a49f0338d6fc58749c9db2155522ac17914ec216ad87f12e0ee95574613942fa615898c4d9e8a3be68cd6afa4e7a003dedbdf8edfee31162b174f965b20ae752ad89c967b3068b6f722c16b354456ba8e280f987c08e0a52d40a2e8f3a59b94d590aeef01879eb7a90b3ee7d772c839c85519cbeaddc0c193ec4874a463b53fcaea3271d80ebfb39b33489365fc039ae549a17a9ff898eea2f4cb27b8dbee4c17b998438575b2b8d107e4a0d66ba7fca85b41a58a8d51f191a35c856dfbe8aef2b00048a694bbccff832d23c8ca7a7ff0b6c0b3011d00b97c86c0628444d267c951d9e4fb8f83e154b8f74fb51aa16535e498235c5597dac9606ed0be3173a3836baa4e7d756ffe1e2879b415d3846bccd538c05b847785699aefde3e305decb600cd8fb0e7d8de5efc26971a6ad4e6d7a2d91474f1023a0ac4b78dc937da0ce607a45974d2cac1c33a2631ff7fe6144a3b2e5cf98b531a9627dea92c1dc82204d09db0439b6a11dd64b484e1263aa45fd9539b6020b55e3baece3986a8bffc1003406348f5c61265099ed43a766ee4f93f5f9c5abbc32a0fd3ac2b35b87f9ec26037d88275bd7dd0a54474995ee34ed3727f3f97c48db544b1980193a4b76a8a3ddab3591ce527f16d91882e67f0103b5cda53f7da54d489fc4ac08b6ab358a5a04aa9daa16219d50bd672a7cb804ed769d218807544e5993f1c27427104b349906a0b654df0bf69328afd3013fbe430155339c39f236df5557bf92f1ded7ff609a8502f49064ec3d1dbfb6c15d3a4c11a4f8acd12278cbf68acd5709463d12e3338a6eddb8c112f199645e23154a8e60879d2a654e3ed9296aa28f134168619691cd2c6b9e2eba4438381676173fc63c2588a3c5910dc149cf3760f0aa9fa9c3f5faa9162b0bf1aac9dd32b706a60ef53cbdb394b6b40222b5bc80eea82ba8958386672564cae3794f977871ab62337cf02e30049201ec12937e7ce79d0f55d9c810e20acf52212aca1d3888949e0e4830aad88d804161230eb89d4d329cc83570fe257217d2119134048dd2ed167646975fc7d77136919a049ea74cf08ddd2b896890bb24a0ba18094a22baa351bf29ad96c66bbb1a598f2ca391749620e62d61c3561a7d3653ccc8892c7b99baaf76bf836e2991cb06d6bc0514568ff0d1ec8bb4b3d6984f5eaefb17d3ea2893722375d3ddb8e389a8eef7d7d198f8e687d6a513983df906099f9a2d23f4f9dec6f8ef2f11fc0a21fac45353b94e00486f5e17d386af42502d09db33cf0cf28310e049c07e88682aeeb00cb833c5174266e62407a57583f1f88b304b7c6e0c84bbe1c0fd423072d37a5bd0aacf764229e5c7cd02473460ba3645cd8e8ae144065bf02d0dd238593d8e230354f67e0b2f23012c23274f80e3ee31e35e2606a4a3f31d94ab755e6d163cff52cbb36b6d0cc67ffc512aeed1dce4d7a0d70ce82f2baba12e8d514dc92a056f994adfb17b5b9712bd5186f27a2fda1f7039c5df2c8587fdc62f5627580c13234b55be4df3056050e2d1ef3218f0dd66cb05265fe1acfb0989d8213f2c19d1735a7cf3fa65d88dad5af52dc2bba22b7abf46c3bc77b5091baab9e8f0ddc4d5e581037de91a9f8dcbc69309be29cc815cf19a20a7585b8b3073edf51fc9baeb3e509b97fa4ecfd621e0fd57bd61cac1b895c03248ff12bdbc57509250df3517e8a3fe1d776836b34ab352b973d932ef708b14f7418f9eceb1d87667e61e3e758649cb083f01b133d37ab2f5afa96d6c84bcacf4efc3851ad308c1e7d9113624fce29fab460ab9d2a48d92cdb281103a5250ad44cb2ff6e67ac670c02fdafb3e0f1353953d6d7d5646ca1568dea55275a050ec501b7c6250444f7219f1ba7521ba3b93d089727ca5f3bbe0d6c1300b423377004954c5628fdb65770b18ced5c9b23a4a5a6d6ef25fe01b4ce278de0bcc4ed86e28a0a68818ffa40970128cf2c38740e80037984428c1bd5113f40ff47512ee6f4e4d8f9b8e8e1b3040d2928d003bd1c1329dc885302fbce9fa81c23b4dc49c7c82d29b52957847898676c89aa5d32b5b0e1c0d5a2b79a19d67562f407f19425687971a957375879d90c5f57c857136c17106c9ab1b99d80e69c8c954ed386493368884b55c939b8d64d26f643e800c56f90c01079d7c534e3b2b7ae352cefd3016da55f6a85eb803b85e2304915fd2001f77c74e28746293c46e4f5f0fd49cf988aafd0026b8e7a3bab2da5cdce1ea26c2e29ec03f4807fac432662b2d6c060be1c7be0e5489de69d0a6e03a4b9117f9244b34a0f1ecba89884f781c6320412413a00c4980287409a2a78c2cd7e65cecebbe4ec1c28cac4dd95f6998e78fc6f1392384331c9436aa10e10e2bf8ad2c4eafbcf276aa7bae64b74428911b3269c749338b0fc5075ad", + "Expected": "d61fe4e3f32ac260915b5b03b78a86d11bfc41d973fce5b0cc59035cf8289a8a2e3878ea15fa46565b0d806e2f85b53873ea20ed653869b688adf83f3ef444535bf91598ff7e80f334fb782539b92f39f55310cc4b35349ab7b278346eda9bc37c0d8acd3557fae38197f412f8d9e57ce6a76b7205c23564cab06e5615be7c6f05c3d05ec690cba91da5e89d55b152ff8dd2157dc5458190025cf94b1ad98f7cbe64e9482faba95e6b33844afc640892872b44a9932096508f4a782a4805323808f23e54b6ff9b841dbfa87db3505ae4f687972c18ea0f0d0af89d36c1c2a5b14560c153c3fee406f5cf15cfd1c0bb45d767426d465f2f14c158495069d0c5955a00150707862ecaae30624ebacdd8ac33e4e6aab3ff90b6ba445a84689386b9e945d01823a65874444316e83767290fcff630d2477f49d5d8ffdd200e08ee1274270f86ed14c687895f6caf5ce528bd970c20d2408a9ba66216324c6a011ac4999098362dbd98a038129a2d40c8da6ab88318aa3046cb660327cc44236d9e5d2163bd0959062195c51ed93d0088b6f92051fc99050ece2538749165976233697ab4b610385366e5ce0b02ad6b61c168ecfbedcdf74278a38de340fd7a5fead8e588e294795f9b011e2e60377a89e25c90e145397cdeabc60fd32444a6b7642a611a83c464d8b8976666351b4865c37b02e6dc21dbcdf5f930341707b618cc0f03c3122646b3385c9df9f2ec730eec9d49e7dfc9153b6e6289da8c4f0ebea9ccc1b751948e3bb7171c9e4d57423b0eeeb79095c030cb52677b3f7e0b45c30f645391f3f9c957afa549c4e0b2465b03c67993cd200b1af01035962edbc4c9e89b31c82ac121987d6529dafdeef67a132dc04b6dc68e77f22862040b75e2ceb9ff16da0fca534e6db7bd12fa7b7f51b6c08c1e23dfcdb7acbd2da0b51c87ffbced065a612e9b1c8bba9b7e2d8d7a2f04fcc4aaf355b60d764879a76b5e16762d5f2f55d585d0c8e82df6940960cddfb72c91dfa71f6b4e1c6ca25dfc39a878e998a663c04fe29d5e83b9586d047b4d7ff70a9f0d44f127e7d741685ca75f11629128d916a0ffef4be586a30c4b70389cc746e84ebf177c01ee8a4511cfbb9d1ecf7f7b33c7dd8177896e10bbc82f838dcd6db7ac67de62bf46b6a640fb580c5d1d2708f3862e3d2b645d0d18e49ef088053e3a220adc0e033c2afcfe61c90e32151152eb3caaf746c5e377d541cafc6cbb0cc0fa48b5caf1728f2e1957f5addfc234f1a9d89e40d49356c9172d0561a695fce6dab1d412321bbf407f63766ffd7b6b3d79bcfa07991c5a9709849c1008689e3b47c50d613980bec239fb64185249d055b30375ccb4354d71fe4d05648fbf6c80634dfc3575f2f24abb714c1e4c95e8896763bf4316e954c7ad19e5780ab7a040ca6fb9271f90a8b22ae738daf6cb", + "Name": "nagydani-5-square", + "Gas": 5461, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000400c5a1611f8be90071a43db23cc2fe01871cc4c0e8ab5743f6378e4fef77f7f6db0095c0727e20225beb665645403453e325ad5f9aeb9ba99bf3c148f63f9c07cf4fe8847ad5242d6b7d4499f93bd47056ddab8f7dee878fc2314f344dbee2a7c41a5d3db91eff372c730c2fdd3a141a4b61999e36d549b9870cf2f4e632c4d5df5f024f81c028000073a0ed8847cfb0593d36a47142f578f05ccbe28c0c06aeb1b1da027794c48db880278f79ba78ae64eedfea3c07d10e0562668d839749dc95f40467d15cf65b9cfc52c7c4bcef1cda3596dd52631aac942f146c7cebd46065131699ce8385b0db1874336747ee020a5698a3d1a1082665721e769567f579830f9d259cec1a836845109c21cf6b25da572512bf3c42fd4b96e43895589042ab60dd41f497db96aec102087fe784165bb45f942859268fd2ff6c012d9d00c02ba83eace047cc5f7b2c392c2955c58a49f0338d6fc58749c9db2155522ac17914ec216ad87f12e0ee95574613942fa615898c4d9e8a3be68cd6afa4e7a003dedbdf8edfee31162b174f965b20ae752ad89c967b3068b6f722c16b354456ba8e280f987c08e0a52d40a2e8f3a59b94d590aeef01879eb7a90b3ee7d772c839c85519cbeaddc0c193ec4874a463b53fcaea3271d80ebfb39b33489365fc039ae549a17a9ff898eea2f4cb27b8dbee4c17b998438575b2b8d107e4a0d66ba7fca85b41a58a8d51f191a35c856dfbe8aef2b00048a694bbccff832d23c8ca7a7ff0b6c0b3011d00b97c86c0628444d267c951d9e4fb8f83e154b8f74fb51aa16535e498235c5597dac9606ed0be3173a3836baa4e7d756ffe1e2879b415d3846bccd538c05b847785699aefde3e305decb600cd8fb0e7d8de5efc26971a6ad4e6d7a2d91474f1023a0ac4b78dc937da0ce607a45974d2cac1c33a2631ff7fe6144a3b2e5cf98b531a9627dea92c1dc82204d09db0439b6a11dd64b484e1263aa45fd9539b6020b55e3baece3986a8bffc1003406348f5c61265099ed43a766ee4f93f5f9c5abbc32a0fd3ac2b35b87f9ec26037d88275bd7dd0a54474995ee34ed3727f3f97c48db544b1980193a4b76a8a3ddab3591ce527f16d91882e67f0103b5cda53f7da54d489fc4ac08b6ab358a5a04aa9daa16219d50bd672a7cb804ed769d218807544e5993f1c27427104b349906a0b654df0bf69328afd3013fbe430155339c39f236df5557bf92f1ded7ff609a8502f49064ec3d1dbfb6c15d3a4c11a4f8acd12278cbf68acd5709463d12e3338a6eddb8c112f199645e23154a8e60879d2a654e3ed9296aa28f134168619691cd2c6b9e2eba4438381676173fc63c2588a3c5910dc149cf3760f0aa9fa9c3f5faa9162b0bf1aac9dd32b706a60ef53cbdb394b6b40222b5bc80eea82ba8958386672564cae3794f977871ab62337cf03e30049201ec12937e7ce79d0f55d9c810e20acf52212aca1d3888949e0e4830aad88d804161230eb89d4d329cc83570fe257217d2119134048dd2ed167646975fc7d77136919a049ea74cf08ddd2b896890bb24a0ba18094a22baa351bf29ad96c66bbb1a598f2ca391749620e62d61c3561a7d3653ccc8892c7b99baaf76bf836e2991cb06d6bc0514568ff0d1ec8bb4b3d6984f5eaefb17d3ea2893722375d3ddb8e389a8eef7d7d198f8e687d6a513983df906099f9a2d23f4f9dec6f8ef2f11fc0a21fac45353b94e00486f5e17d386af42502d09db33cf0cf28310e049c07e88682aeeb00cb833c5174266e62407a57583f1f88b304b7c6e0c84bbe1c0fd423072d37a5bd0aacf764229e5c7cd02473460ba3645cd8e8ae144065bf02d0dd238593d8e230354f67e0b2f23012c23274f80e3ee31e35e2606a4a3f31d94ab755e6d163cff52cbb36b6d0cc67ffc512aeed1dce4d7a0d70ce82f2baba12e8d514dc92a056f994adfb17b5b9712bd5186f27a2fda1f7039c5df2c8587fdc62f5627580c13234b55be4df3056050e2d1ef3218f0dd66cb05265fe1acfb0989d8213f2c19d1735a7cf3fa65d88dad5af52dc2bba22b7abf46c3bc77b5091baab9e8f0ddc4d5e581037de91a9f8dcbc69309be29cc815cf19a20a7585b8b3073edf51fc9baeb3e509b97fa4ecfd621e0fd57bd61cac1b895c03248ff12bdbc57509250df3517e8a3fe1d776836b34ab352b973d932ef708b14f7418f9eceb1d87667e61e3e758649cb083f01b133d37ab2f5afa96d6c84bcacf4efc3851ad308c1e7d9113624fce29fab460ab9d2a48d92cdb281103a5250ad44cb2ff6e67ac670c02fdafb3e0f1353953d6d7d5646ca1568dea55275a050ec501b7c6250444f7219f1ba7521ba3b93d089727ca5f3bbe0d6c1300b423377004954c5628fdb65770b18ced5c9b23a4a5a6d6ef25fe01b4ce278de0bcc4ed86e28a0a68818ffa40970128cf2c38740e80037984428c1bd5113f40ff47512ee6f4e4d8f9b8e8e1b3040d2928d003bd1c1329dc885302fbce9fa81c23b4dc49c7c82d29b52957847898676c89aa5d32b5b0e1c0d5a2b79a19d67562f407f19425687971a957375879d90c5f57c857136c17106c9ab1b99d80e69c8c954ed386493368884b55c939b8d64d26f643e800c56f90c01079d7c534e3b2b7ae352cefd3016da55f6a85eb803b85e2304915fd2001f77c74e28746293c46e4f5f0fd49cf988aafd0026b8e7a3bab2da5cdce1ea26c2e29ec03f4807fac432662b2d6c060be1c7be0e5489de69d0a6e03a4b9117f9244b34a0f1ecba89884f781c6320412413a00c4980287409a2a78c2cd7e65cecebbe4ec1c28cac4dd95f6998e78fc6f1392384331c9436aa10e10e2bf8ad2c4eafbcf276aa7bae64b74428911b3269c749338b0fc5075ad", + "Expected": "5f9c70ec884926a89461056ad20ac4c30155e817f807e4d3f5bb743d789c83386762435c3627773fa77da5144451f2a8aad8adba88e0b669f5377c5e9bad70e45c86fe952b613f015a9953b8a5de5eaee4566acf98d41e327d93a35bd5cef4607d025e58951167957df4ff9b1627649d3943805472e5e293d3efb687cfd1e503faafeb2840a3e3b3f85d016051a58e1c9498aab72e63b748d834b31eb05d85dcde65e27834e266b85c75cc4ec0135135e0601cb93eeeb6e0010c8ceb65c4c319623c5e573a2c8c9fbbf7df68a930beb412d3f4dfd146175484f45d7afaa0d2e60684af9b34730f7c8438465ad3e1d0c3237336722f2aa51095bd5759f4b8ab4dda111b684aa3dac62a761722e7ae43495b7709933512c81c4e3c9133a51f7ce9f2b51fcec064f65779666960b4e45df3900f54311f5613e8012dd1b8efd359eda31a778264c72aa8bb419d862734d769076bce2810011989a45374e5c5d8729fec21427f0bf397eacbb4220f603cf463a4b0c94efd858ffd9768cd60d6ce68d755e0fbad007ce5c2223d70c7018345a102e4ab3c60a13a9e7794303156d4c2063e919f2153c13961fb324c80b240742f47773a7a8e25b3e3fb19b00ce839346c6eb3c732fbc6b888df0b1fe0a3d07b053a2e9402c267b2d62f794d8a2840526e3ade15ce2264496ccd7519571dfde47f7a4bb16292241c20b2be59f3f8fb4f6383f232d838c5a22d8c95b6834d9d2ca493f5a505ebe8899503b0e8f9b19e6e2dd81c1628b80016d02097e0134de51054c4e7674824d4d758760fc52377d2cad145e259aa2ffaf54139e1a66b1e0c1c191e32ac59474c6b526f5b3ba07d3e5ec286eddf531fcd5292869be58c9f22ef91026159f7cf9d05ef66b4299f4da48cc1635bf2243051d342d378a22c83390553e873713c0454ce5f3234397111ac3fe3207b86f0ed9fc025c81903e1748103692074f83824fda6341be4f95ff00b0a9a208c267e12fa01825054cc0513629bf3dbb56dc5b90d4316f87654a8be18227978ea0a8a522760cad620d0d14fd38920fb7321314062914275a5f99f677145a6979b156bd82ecd36f23f8e1273cc2759ecc0b2c69d94dad5211d1bed939dd87ed9e07b91d49713a6e16ade0a98aea789f04994e318e4ff2c8a188cd8d43aeb52c6daa3bc29b4af50ea82a247c5cd67b573b34cbadcc0a376d3bbd530d50367b42705d870f2e27a8197ef46070528bfe408360faa2ebb8bf76e9f388572842bcb119f4d84ee34ae31f5cc594f23705a49197b181fb78ed1ec99499c690f843a4d0cf2e226d118e9372271054fbabdcc5c92ae9fefaef0589cd0e722eaf30c1703ec4289c7fd81beaa8a455ccee5298e31e2080c10c366a6fcf56f7d13582ad0bcad037c612b710fc595b70fbefaaca23623b60c6c39b11beb8e5843b6b3dac60f", + "Name": "nagydani-5-qube", + "Gas": 5461, + "NoBenchmark": false + }, + { + "Input": "000000000000000000000000000000000000000000000000000000000000040000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000400c5a1611f8be90071a43db23cc2fe01871cc4c0e8ab5743f6378e4fef77f7f6db0095c0727e20225beb665645403453e325ad5f9aeb9ba99bf3c148f63f9c07cf4fe8847ad5242d6b7d4499f93bd47056ddab8f7dee878fc2314f344dbee2a7c41a5d3db91eff372c730c2fdd3a141a4b61999e36d549b9870cf2f4e632c4d5df5f024f81c028000073a0ed8847cfb0593d36a47142f578f05ccbe28c0c06aeb1b1da027794c48db880278f79ba78ae64eedfea3c07d10e0562668d839749dc95f40467d15cf65b9cfc52c7c4bcef1cda3596dd52631aac942f146c7cebd46065131699ce8385b0db1874336747ee020a5698a3d1a1082665721e769567f579830f9d259cec1a836845109c21cf6b25da572512bf3c42fd4b96e43895589042ab60dd41f497db96aec102087fe784165bb45f942859268fd2ff6c012d9d00c02ba83eace047cc5f7b2c392c2955c58a49f0338d6fc58749c9db2155522ac17914ec216ad87f12e0ee95574613942fa615898c4d9e8a3be68cd6afa4e7a003dedbdf8edfee31162b174f965b20ae752ad89c967b3068b6f722c16b354456ba8e280f987c08e0a52d40a2e8f3a59b94d590aeef01879eb7a90b3ee7d772c839c85519cbeaddc0c193ec4874a463b53fcaea3271d80ebfb39b33489365fc039ae549a17a9ff898eea2f4cb27b8dbee4c17b998438575b2b8d107e4a0d66ba7fca85b41a58a8d51f191a35c856dfbe8aef2b00048a694bbccff832d23c8ca7a7ff0b6c0b3011d00b97c86c0628444d267c951d9e4fb8f83e154b8f74fb51aa16535e498235c5597dac9606ed0be3173a3836baa4e7d756ffe1e2879b415d3846bccd538c05b847785699aefde3e305decb600cd8fb0e7d8de5efc26971a6ad4e6d7a2d91474f1023a0ac4b78dc937da0ce607a45974d2cac1c33a2631ff7fe6144a3b2e5cf98b531a9627dea92c1dc82204d09db0439b6a11dd64b484e1263aa45fd9539b6020b55e3baece3986a8bffc1003406348f5c61265099ed43a766ee4f93f5f9c5abbc32a0fd3ac2b35b87f9ec26037d88275bd7dd0a54474995ee34ed3727f3f97c48db544b1980193a4b76a8a3ddab3591ce527f16d91882e67f0103b5cda53f7da54d489fc4ac08b6ab358a5a04aa9daa16219d50bd672a7cb804ed769d218807544e5993f1c27427104b349906a0b654df0bf69328afd3013fbe430155339c39f236df5557bf92f1ded7ff609a8502f49064ec3d1dbfb6c15d3a4c11a4f8acd12278cbf68acd5709463d12e3338a6eddb8c112f199645e23154a8e60879d2a654e3ed9296aa28f134168619691cd2c6b9e2eba4438381676173fc63c2588a3c5910dc149cf3760f0aa9fa9c3f5faa9162b0bf1aac9dd32b706a60ef53cbdb394b6b40222b5bc80eea82ba8958386672564cae3794f977871ab62337cf010001e30049201ec12937e7ce79d0f55d9c810e20acf52212aca1d3888949e0e4830aad88d804161230eb89d4d329cc83570fe257217d2119134048dd2ed167646975fc7d77136919a049ea74cf08ddd2b896890bb24a0ba18094a22baa351bf29ad96c66bbb1a598f2ca391749620e62d61c3561a7d3653ccc8892c7b99baaf76bf836e2991cb06d6bc0514568ff0d1ec8bb4b3d6984f5eaefb17d3ea2893722375d3ddb8e389a8eef7d7d198f8e687d6a513983df906099f9a2d23f4f9dec6f8ef2f11fc0a21fac45353b94e00486f5e17d386af42502d09db33cf0cf28310e049c07e88682aeeb00cb833c5174266e62407a57583f1f88b304b7c6e0c84bbe1c0fd423072d37a5bd0aacf764229e5c7cd02473460ba3645cd8e8ae144065bf02d0dd238593d8e230354f67e0b2f23012c23274f80e3ee31e35e2606a4a3f31d94ab755e6d163cff52cbb36b6d0cc67ffc512aeed1dce4d7a0d70ce82f2baba12e8d514dc92a056f994adfb17b5b9712bd5186f27a2fda1f7039c5df2c8587fdc62f5627580c13234b55be4df3056050e2d1ef3218f0dd66cb05265fe1acfb0989d8213f2c19d1735a7cf3fa65d88dad5af52dc2bba22b7abf46c3bc77b5091baab9e8f0ddc4d5e581037de91a9f8dcbc69309be29cc815cf19a20a7585b8b3073edf51fc9baeb3e509b97fa4ecfd621e0fd57bd61cac1b895c03248ff12bdbc57509250df3517e8a3fe1d776836b34ab352b973d932ef708b14f7418f9eceb1d87667e61e3e758649cb083f01b133d37ab2f5afa96d6c84bcacf4efc3851ad308c1e7d9113624fce29fab460ab9d2a48d92cdb281103a5250ad44cb2ff6e67ac670c02fdafb3e0f1353953d6d7d5646ca1568dea55275a050ec501b7c6250444f7219f1ba7521ba3b93d089727ca5f3bbe0d6c1300b423377004954c5628fdb65770b18ced5c9b23a4a5a6d6ef25fe01b4ce278de0bcc4ed86e28a0a68818ffa40970128cf2c38740e80037984428c1bd5113f40ff47512ee6f4e4d8f9b8e8e1b3040d2928d003bd1c1329dc885302fbce9fa81c23b4dc49c7c82d29b52957847898676c89aa5d32b5b0e1c0d5a2b79a19d67562f407f19425687971a957375879d90c5f57c857136c17106c9ab1b99d80e69c8c954ed386493368884b55c939b8d64d26f643e800c56f90c01079d7c534e3b2b7ae352cefd3016da55f6a85eb803b85e2304915fd2001f77c74e28746293c46e4f5f0fd49cf988aafd0026b8e7a3bab2da5cdce1ea26c2e29ec03f4807fac432662b2d6c060be1c7be0e5489de69d0a6e03a4b9117f9244b34a0f1ecba89884f781c6320412413a00c4980287409a2a78c2cd7e65cecebbe4ec1c28cac4dd95f6998e78fc6f1392384331c9436aa10e10e2bf8ad2c4eafbcf276aa7bae64b74428911b3269c749338b0fc5075ad", + "Expected": "5a0eb2bdf0ac1cae8e586689fa16cd4b07dfdedaec8a110ea1fdb059dd5253231b6132987598dfc6e11f86780428982d50cf68f67ae452622c3b336b537ef3298ca645e8f89ee39a26758206a5a3f6409afc709582f95274b57b71fae5c6b74619ae6f089a5393c5b79235d9caf699d23d88fb873f78379690ad8405e34c19f5257d596580c7a6a7206a3712825afe630c76b31cdb4a23e7f0632e10f14f4e282c81a66451a26f8df2a352b5b9f607a7198449d1b926e27036810368e691a74b91c61afa73d9d3b99453e7c8b50fd4f09c039a2f2feb5c419206694c31b92df1d9586140cb3417b38d0c503c7b508cc2ed12e813a1c795e9829eb39ee78eeaf360a169b491a1d4e419574e712402de9d48d54c1ae5e03739b7156615e8267e1fb0a897f067afd11fb33f6e24182d7aaaaa18fe5bc1982f20d6b871e5a398f0f6f718181d31ec225cfa9a0a70124ed9a70031bdf0c1c7829f708b6e17d50419ef361cf77d99c85f44607186c8d683106b8bd38a49b5d0fb503b397a83388c5678dcfcc737499d84512690701ed621a6f0172aecf037184ddf0f2453e4053024018e5ab2e30d6d5363b56e8b41509317c99042f517247474ab3abc848e00a07f69c254f46f2a05cf6ed84e5cc906a518fdcfdf2c61ce731f24c5264f1a25fc04934dc28aec112134dd523f70115074ca34e3807aa4cb925147f3a0ce152d323bd8c675ace446d0fd1ae30c4b57f0eb2c23884bc18f0964c0114796c5b6d080c3d89175665fbf63a6381a6a9da39ad070b645c8bb1779506da14439a9f5b5d481954764ea114fac688930bc68534d403cff4210673b6a6ff7ae416b7cd41404c3d3f282fcd193b86d0f54d0006c2a503b40d5c3930da980565b8f9630e9493a79d1c03e74e5f93ac8e4dc1a901ec5e3b3e57049124c7b72ea345aa359e782285d9e6a5c144a378111dd02c40855ff9c2be9b48425cb0b2fd62dc8678fd151121cf26a65e917d65d8e0dacfae108eb5508b601fb8ffa370be1f9a8b749a2d12eeab81f41079de87e2d777994fa4d28188c579ad327f9957fb7bdecec5c680844dd43cb57cf87aeb763c003e65011f73f8c63442df39a92b946a6bd968a1c1e4d5fa7d88476a68bd8e20e5b70a99259c7d3f85fb1b65cd2e93972e6264e74ebf289b8b6979b9b68a85cd5b360c1987f87235c3c845d62489e33acf85d53fa3561fe3a3aee18924588d9c6eba4edb7a4d106b31173e42929f6f0c48c80ce6a72d54eca7c0fe870068b7a7c89c63cdda593f5b32d3cb4ea8a32c39f00ab449155757172d66763ed9527019d6de6c9f2416aa6203f4d11c9ebee1e1d3845099e55504446448027212616167eb36035726daa7698b075286f5379cd3e93cb3e0cf4f9cb8d017facbb5550ed32d5ec5400ae57e47e2bf78d1eaeff9480cc765ceff39db500", + "Name": "nagydani-5-pow0x10001", + "Gas": 87381, + "NoBenchmark": false + } +] \ No newline at end of file diff --git a/crypto/bls12381/fp_test.go b/crypto/bls12381/fp_test.go index 14bb4d7d6554..97528d9db32e 100644 --- a/crypto/bls12381/fp_test.go +++ b/crypto/bls12381/fp_test.go @@ -1393,6 +1393,15 @@ func BenchmarkMultiplication(t *testing.B) { } } +func BenchmarkInverse(t *testing.B) { + a, _ := new(fe).rand(rand.Reader) + b, _ := new(fe).rand(rand.Reader) + t.ResetTimer() + for i := 0; i < t.N; i++ { + inverse(a, b) + } +} + func padBytes(in []byte, size int) []byte { out := make([]byte, size) if len(in) > size { diff --git a/crypto/bls12381/g1.go b/crypto/bls12381/g1.go index 63d38f90c2f7..d853823cd298 100644 --- a/crypto/bls12381/g1.go +++ b/crypto/bls12381/g1.go @@ -266,9 +266,8 @@ func (g *G1) Add(r, p1, p2 *PointG1) *PointG1 { if t[1].equal(t[3]) { if t[0].equal(t[2]) { return g.Double(r, p1) - } else { - return r.Zero() } + return r.Zero() } sub(t[1], t[1], t[3]) double(t[4], t[1]) diff --git a/crypto/bls12381/g2.go b/crypto/bls12381/g2.go index 1d1a3258f7c7..fa110e3edfc5 100644 --- a/crypto/bls12381/g2.go +++ b/crypto/bls12381/g2.go @@ -287,9 +287,8 @@ func (g *G2) Add(r, p1, p2 *PointG2) *PointG2 { if t[1].equal(t[3]) { if t[0].equal(t[2]) { return g.Double(r, p1) - } else { - return r.Zero() } + return r.Zero() } g.f.sub(t[1], t[1], t[3]) g.f.double(t[4], t[1]) diff --git a/crypto/bn256/bn256_fuzz.go b/crypto/bn256/bn256_fuzz.go index 6aa142117045..b34043487f20 100644 --- a/crypto/bn256/bn256_fuzz.go +++ b/crypto/bn256/bn256_fuzz.go @@ -8,42 +8,52 @@ package bn256 import ( "bytes" + "fmt" + "io" "math/big" cloudflare "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare" google "github.com/ethereum/go-ethereum/crypto/bn256/google" ) -// FuzzAdd fuzzez bn256 addition between the Google and Cloudflare libraries. -func FuzzAdd(data []byte) int { - // Ensure we have enough data in the first place - if len(data) != 128 { - return 0 +func getG1Points(input io.Reader) (*cloudflare.G1, *google.G1) { + _, xc, err := cloudflare.RandomG1(input) + if err != nil { + // insufficient input + return nil, nil } - // Ensure both libs can parse the first curve point - xc := new(cloudflare.G1) - _, errc := xc.Unmarshal(data[:64]) - xg := new(google.G1) - _, errg := xg.Unmarshal(data[:64]) - - if (errc == nil) != (errg == nil) { - panic("parse mismatch") - } else if errc != nil { - return 0 + if _, err := xg.Unmarshal(xc.Marshal()); err != nil { + panic(fmt.Sprintf("Could not marshal cloudflare -> google:", err)) } - // Ensure both libs can parse the second curve point - yc := new(cloudflare.G1) - _, errc = yc.Unmarshal(data[64:]) + return xc, xg +} - yg := new(google.G1) - _, errg = yg.Unmarshal(data[64:]) +func getG2Points(input io.Reader) (*cloudflare.G2, *google.G2) { + _, xc, err := cloudflare.RandomG2(input) + if err != nil { + // insufficient input + return nil, nil + } + xg := new(google.G2) + if _, err := xg.Unmarshal(xc.Marshal()); err != nil { + panic(fmt.Sprintf("Could not marshal cloudflare -> google:", err)) + } + return xc, xg +} - if (errc == nil) != (errg == nil) { - panic("parse mismatch") - } else if errc != nil { +// FuzzAdd fuzzez bn256 addition between the Google and Cloudflare libraries. +func FuzzAdd(data []byte) int { + input := bytes.NewReader(data) + xc, xg := getG1Points(input) + if xc == nil { + return 0 + } + yc, yg := getG1Points(input) + if yc == nil { return 0 } + // Ensure both libs can parse the second curve point // Add the two points and ensure they result in the same output rc := new(cloudflare.G1) rc.Add(xc, yc) @@ -54,73 +64,56 @@ func FuzzAdd(data []byte) int { if !bytes.Equal(rc.Marshal(), rg.Marshal()) { panic("add mismatch") } - return 0 + return 1 } // FuzzMul fuzzez bn256 scalar multiplication between the Google and Cloudflare // libraries. func FuzzMul(data []byte) int { - // Ensure we have enough data in the first place - if len(data) != 96 { + input := bytes.NewReader(data) + pc, pg := getG1Points(input) + if pc == nil { return 0 } - // Ensure both libs can parse the curve point - pc := new(cloudflare.G1) - _, errc := pc.Unmarshal(data[:64]) - - pg := new(google.G1) - _, errg := pg.Unmarshal(data[:64]) - - if (errc == nil) != (errg == nil) { - panic("parse mismatch") - } else if errc != nil { + // Add the two points and ensure they result in the same output + remaining := input.Len() + if remaining == 0 { return 0 } - // Add the two points and ensure they result in the same output + if remaining > 128 { + // The evm only ever uses 32 byte integers, we need to cap this otherwise + // we run into slow exec. A 236Kb byte integer cause oss-fuzz to report it as slow. + // 128 bytes should be fine though + return 0 + } + buf := make([]byte, remaining) + input.Read(buf) + rc := new(cloudflare.G1) - rc.ScalarMult(pc, new(big.Int).SetBytes(data[64:])) + rc.ScalarMult(pc, new(big.Int).SetBytes(buf)) rg := new(google.G1) - rg.ScalarMult(pg, new(big.Int).SetBytes(data[64:])) + rg.ScalarMult(pg, new(big.Int).SetBytes(buf)) if !bytes.Equal(rc.Marshal(), rg.Marshal()) { panic("scalar mul mismatch") } - return 0 + return 1 } func FuzzPair(data []byte) int { - // Ensure we have enough data in the first place - if len(data) != 192 { + input := bytes.NewReader(data) + pc, pg := getG1Points(input) + if pc == nil { return 0 } - // Ensure both libs can parse the curve point - pc := new(cloudflare.G1) - _, errc := pc.Unmarshal(data[:64]) - - pg := new(google.G1) - _, errg := pg.Unmarshal(data[:64]) - - if (errc == nil) != (errg == nil) { - panic("parse mismatch") - } else if errc != nil { - return 0 - } - // Ensure both libs can parse the twist point - tc := new(cloudflare.G2) - _, errc = tc.Unmarshal(data[64:]) - - tg := new(google.G2) - _, errg = tg.Unmarshal(data[64:]) - - if (errc == nil) != (errg == nil) { - panic("parse mismatch") - } else if errc != nil { + tc, tg := getG2Points(input) + if tc == nil { return 0 } // Pair the two points and ensure thet result in the same output if cloudflare.PairingCheck([]*cloudflare.G1{pc}, []*cloudflare.G2{tc}) != google.PairingCheck([]*google.G1{pg}, []*google.G2{tg}) { panic("pair mismatch") } - return 0 + return 1 } diff --git a/crypto/bn256/cloudflare/bn256.go b/crypto/bn256/cloudflare/bn256.go index 38822a76bfec..4f607af2ad35 100644 --- a/crypto/bn256/cloudflare/bn256.go +++ b/crypto/bn256/cloudflare/bn256.go @@ -9,8 +9,13 @@ // // This package specifically implements the Optimal Ate pairing over a 256-bit // Barreto-Naehrig curve as described in -// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible -// with the implementation described in that paper. +// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is not +// compatible with the implementation described in that paper, as different +// parameters are chosen. +// +// (This package previously claimed to operate at a 128-bit security level. +// However, recent improvements in attacks mean that is no longer true. See +// https://moderncrypto.org/mail-archive/curves/2016/000740.html.) package bn256 import ( @@ -23,7 +28,7 @@ import ( func randomK(r io.Reader) (k *big.Int, err error) { for { k, err = rand.Int(r, Order) - if k.Sign() > 0 || err != nil { + if err != nil || k.Sign() > 0 { return } } diff --git a/crypto/bn256/google/bn256.go b/crypto/bn256/google/bn256.go index e0402e51f0e2..0a9d5cd35dce 100644 --- a/crypto/bn256/google/bn256.go +++ b/crypto/bn256/google/bn256.go @@ -12,8 +12,9 @@ // // This package specifically implements the Optimal Ate pairing over a 256-bit // Barreto-Naehrig curve as described in -// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is compatible -// with the implementation described in that paper. +// http://cryptojedi.org/papers/dclxvi-20100714.pdf. Its output is not +// compatible with the implementation described in that paper, as different +// parameters are chosen. // // (This package previously claimed to operate at a 128-bit security level. // However, recent improvements in attacks mean that is no longer true. See diff --git a/crypto/bn256/google/constants.go b/crypto/bn256/google/constants.go index ab649d7f3f7c..2990bd951282 100644 --- a/crypto/bn256/google/constants.go +++ b/crypto/bn256/google/constants.go @@ -13,13 +13,16 @@ func bigFromBase10(s string) *big.Int { return n } -// u is the BN parameter that determines the prime: 1868033³. +// u is the BN parameter that determines the prime. var u = bigFromBase10("4965661367192848881") -// p is a prime over which we form a basic field: 36u⁴+36u³+24u²+6u+1. +// P is a prime over which we form a basic field: 36u⁴+36u³+24u²+6u+1. var P = bigFromBase10("21888242871839275222246405745257275088696311157297823662689037894645226208583") // Order is the number of elements in both G₁ and G₂: 36u⁴+36u³+18u²+6u+1. +// Needs to be highly 2-adic for efficient SNARK key and proof generation. +// Order - 1 = 2^28 * 3^2 * 13 * 29 * 983 * 11003 * 237073 * 405928799 * 1670836401704629 * 13818364434197438864469338081. +// Refer to https://eprint.iacr.org/2013/879.pdf and https://eprint.iacr.org/2013/507.pdf for more information on these parameters. var Order = bigFromBase10("21888242871839275222246405745257275088548364400416034343698204186575808495617") // xiToPMinus1Over6 is ξ^((p-1)/6) where ξ = i+9. diff --git a/crypto/secp256k1/curve.go b/crypto/secp256k1/curve.go index 5409ee1d2cdb..8f83cccad9d1 100644 --- a/crypto/secp256k1/curve.go +++ b/crypto/secp256k1/curve.go @@ -116,6 +116,10 @@ func (BitCurve *BitCurve) IsOnCurve(x, y *big.Int) bool { // affineFromJacobian reverses the Jacobian transform. See the comment at the // top of the file. func (BitCurve *BitCurve) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) { + if z.Sign() == 0 { + return new(big.Int), new(big.Int) + } + zinv := new(big.Int).ModInverse(z, BitCurve.P) zinvsq := new(big.Int).Mul(zinv, zinv) diff --git a/crypto/secp256k1/dummy.go b/crypto/secp256k1/dummy.go new file mode 100644 index 000000000000..c0f2ee52c0c9 --- /dev/null +++ b/crypto/secp256k1/dummy.go @@ -0,0 +1,20 @@ +// +build dummy + +// This file is part of a workaround for `go mod vendor` which won't vendor +// C files if there's no Go file in the same directory. +// This would prevent the crypto/secp256k1/libsecp256k1/include/secp256k1.h file to be vendored. +// +// This Go file imports the c directory where there is another dummy.go file which +// is the second part of this workaround. +// +// These two files combined make it so `go mod vendor` behaves correctly. +// +// See this issue for reference: https://github.com/golang/go/issues/26366 + +package secp256k1 + +import ( + _ "github.com/ethereum/go-ethereum/crypto/secp256k1/libsecp256k1/include" + _ "github.com/ethereum/go-ethereum/crypto/secp256k1/libsecp256k1/src" + _ "github.com/ethereum/go-ethereum/crypto/secp256k1/libsecp256k1/src/modules/recovery" +) diff --git a/crypto/secp256k1/libsecp256k1/contrib/dummy.go b/crypto/secp256k1/libsecp256k1/contrib/dummy.go new file mode 100644 index 000000000000..fda594be9914 --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/contrib/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package contrib diff --git a/crypto/secp256k1/libsecp256k1/dummy.go b/crypto/secp256k1/libsecp256k1/dummy.go new file mode 100644 index 000000000000..379b16992f47 --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package libsecp256k1 diff --git a/crypto/secp256k1/libsecp256k1/include/dummy.go b/crypto/secp256k1/libsecp256k1/include/dummy.go new file mode 100644 index 000000000000..5af540c73c4a --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/include/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package include diff --git a/crypto/secp256k1/libsecp256k1/src/dummy.go b/crypto/secp256k1/libsecp256k1/src/dummy.go new file mode 100644 index 000000000000..65868f38a8ea --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/src/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package src diff --git a/crypto/secp256k1/libsecp256k1/src/modules/dummy.go b/crypto/secp256k1/libsecp256k1/src/modules/dummy.go new file mode 100644 index 000000000000..3c7a696439f0 --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/src/modules/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package module diff --git a/crypto/secp256k1/libsecp256k1/src/modules/ecdh/dummy.go b/crypto/secp256k1/libsecp256k1/src/modules/ecdh/dummy.go new file mode 100644 index 000000000000..b6fc38327ec8 --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/src/modules/ecdh/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package ecdh diff --git a/crypto/secp256k1/libsecp256k1/src/modules/recovery/dummy.go b/crypto/secp256k1/libsecp256k1/src/modules/recovery/dummy.go new file mode 100644 index 000000000000..b9491f0cb9f4 --- /dev/null +++ b/crypto/secp256k1/libsecp256k1/src/modules/recovery/dummy.go @@ -0,0 +1,7 @@ +// +build dummy + +// Package c contains only a C file. +// +// This Go file is part of a workaround for `go mod vendor`. +// Please see the file crypto/secp256k1/dummy.go for more information. +package recovery diff --git a/crypto/signify/signify.go b/crypto/signify/signify.go new file mode 100644 index 000000000000..7ba9705491cf --- /dev/null +++ b/crypto/signify/signify.go @@ -0,0 +1,100 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// signFile reads the contents of an input file and signs it (in armored format) +// with the key provided, placing the signature into the output file. + +package signify + +import ( + "bytes" + "crypto/ed25519" + "encoding/base64" + "errors" + "fmt" + "io/ioutil" + "strings" + "time" +) + +var ( + errInvalidKeyHeader = errors.New("incorrect key header") + errInvalidKeyLength = errors.New("invalid, key length != 104") +) + +func parsePrivateKey(key string) (k ed25519.PrivateKey, header []byte, keyNum []byte, err error) { + keydata, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, nil, nil, err + } + if len(keydata) != 104 { + return nil, nil, nil, errInvalidKeyLength + } + if string(keydata[:2]) != "Ed" { + return nil, nil, nil, errInvalidKeyHeader + } + return ed25519.PrivateKey(keydata[40:]), keydata[:2], keydata[32:40], nil +} + +// SignFile creates a signature of the input file. +// +// This accepts base64 keys in the format created by the 'signify' tool. +// The signature is written to the 'output' file. +func SignFile(input string, output string, key string, untrustedComment string, trustedComment string) error { + // Pre-check comments and ensure they're set to something. + if strings.IndexByte(untrustedComment, '\n') >= 0 { + return errors.New("untrusted comment must not contain newline") + } + if strings.IndexByte(trustedComment, '\n') >= 0 { + return errors.New("trusted comment must not contain newline") + } + if untrustedComment == "" { + untrustedComment = "verify with " + input + ".pub" + } + if trustedComment == "" { + trustedComment = fmt.Sprintf("timestamp:%d", time.Now().Unix()) + } + + filedata, err := ioutil.ReadFile(input) + if err != nil { + return err + } + skey, header, keyNum, err := parsePrivateKey(key) + if err != nil { + return err + } + + // Create the main data signature. + rawSig := ed25519.Sign(skey, filedata) + var dataSig []byte + dataSig = append(dataSig, header...) + dataSig = append(dataSig, keyNum...) + dataSig = append(dataSig, rawSig...) + + // Create the comment signature. + var commentSigInput []byte + commentSigInput = append(commentSigInput, rawSig...) + commentSigInput = append(commentSigInput, []byte(trustedComment)...) + commentSig := ed25519.Sign(skey, commentSigInput) + + // Create the output file. + var out = new(bytes.Buffer) + fmt.Fprintln(out, "untrusted comment:", untrustedComment) + fmt.Fprintln(out, base64.StdEncoding.EncodeToString(dataSig)) + fmt.Fprintln(out, "trusted comment:", trustedComment) + fmt.Fprintln(out, base64.StdEncoding.EncodeToString(commentSig)) + return ioutil.WriteFile(output, out.Bytes(), 0644) +} diff --git a/crypto/signify/signify_fuzz.go b/crypto/signify/signify_fuzz.go new file mode 100644 index 000000000000..d1bcf356a4db --- /dev/null +++ b/crypto/signify/signify_fuzz.go @@ -0,0 +1,148 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// +build gofuzz + +package signify + +import ( + "bufio" + "fmt" + "io/ioutil" + "log" + "os" + "os/exec" + "runtime" + + fuzz "github.com/google/gofuzz" + "github.com/jedisct1/go-minisign" +) + +func Fuzz(data []byte) int { + if len(data) < 32 { + return -1 + } + tmpFile, err := ioutil.TempFile("", "") + if err != nil { + panic(err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + testSecKey, testPubKey := createKeyPair() + // Create message + tmpFile.Write(data) + if err = tmpFile.Close(); err != nil { + panic(err) + } + // Fuzz comments + var untrustedComment string + var trustedComment string + f := fuzz.NewFromGoFuzz(data) + f.Fuzz(&untrustedComment) + f.Fuzz(&trustedComment) + fmt.Printf("untrusted: %v\n", untrustedComment) + fmt.Printf("trusted: %v\n", trustedComment) + + err = SignifySignFile(tmpFile.Name(), tmpFile.Name()+".sig", testSecKey, untrustedComment, trustedComment) + if err != nil { + panic(err) + } + defer os.Remove(tmpFile.Name() + ".sig") + + signify := "signify" + path := os.Getenv("SIGNIFY") + if path != "" { + signify = path + } + + _, err := exec.LookPath(signify) + if err != nil { + panic(err) + } + + // Write the public key into the file to pass it as + // an argument to signify-openbsd + pubKeyFile, err := ioutil.TempFile("", "") + if err != nil { + panic(err) + } + defer os.Remove(pubKeyFile.Name()) + defer pubKeyFile.Close() + pubKeyFile.WriteString("untrusted comment: signify public key\n") + pubKeyFile.WriteString(testPubKey) + pubKeyFile.WriteString("\n") + + cmd := exec.Command(signify, "-V", "-p", pubKeyFile.Name(), "-x", tmpFile.Name()+".sig", "-m", tmpFile.Name()) + if output, err := cmd.CombinedOutput(); err != nil { + panic(fmt.Sprintf("could not verify the file: %v, output: \n%s", err, output)) + } + + // Verify the signature using a golang library + sig, err := minisign.NewSignatureFromFile(tmpFile.Name() + ".sig") + if err != nil { + panic(err) + } + + pKey, err := minisign.NewPublicKey(testPubKey) + if err != nil { + panic(err) + } + + valid, err := pKey.VerifyFromFile(tmpFile.Name(), sig) + if err != nil { + panic(err) + } + if !valid { + panic("invalid signature") + } + return 1 +} + +func getKey(fileS string) (string, error) { + file, err := os.Open(fileS) + if err != nil { + log.Fatal(err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + // Discard the first line + scanner.Scan() + scanner.Scan() + return scanner.Text(), scanner.Err() +} + +func createKeyPair() (string, string) { + // Create key and put it in correct format + tmpKey, err := ioutil.TempFile("", "") + defer os.Remove(tmpKey.Name()) + defer os.Remove(tmpKey.Name() + ".pub") + defer os.Remove(tmpKey.Name() + ".sec") + cmd := exec.Command("signify", "-G", "-n", "-p", tmpKey.Name()+".pub", "-s", tmpKey.Name()+".sec") + if output, err := cmd.CombinedOutput(); err != nil { + panic(fmt.Sprintf("could not verify the file: %v, output: \n%s", err, output)) + } + secKey, err := getKey(tmpKey.Name() + ".sec") + if err != nil { + panic(err) + } + pubKey, err := getKey(tmpKey.Name() + ".pub") + if err != nil { + panic(err) + } + return secKey, pubKey +} diff --git a/crypto/signify/signify_test.go b/crypto/signify/signify_test.go new file mode 100644 index 000000000000..615d4e652792 --- /dev/null +++ b/crypto/signify/signify_test.go @@ -0,0 +1,154 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// signFile reads the contents of an input file and signs it (in armored format) +// with the key provided, placing the signature into the output file. + +package signify + +import ( + "io/ioutil" + "math/rand" + "os" + "testing" + "time" + + "github.com/jedisct1/go-minisign" +) + +var ( + testSecKey = "RWRCSwAAAABVN5lr2JViGBN8DhX3/Qb/0g0wBdsNAR/APRW2qy9Fjsfr12sK2cd3URUFis1jgzQzaoayK8x4syT4G3Gvlt9RwGIwUYIQW/0mTeI+ECHu1lv5U4Wa2YHEPIesVPyRm5M=" + testPubKey = "RWTAPRW2qy9FjsBiMFGCEFv9Jk3iPhAh7tZb+VOFmtmBxDyHrFT8kZuT" +) + +func TestSignify(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + rand.Seed(time.Now().UnixNano()) + + data := make([]byte, 1024) + rand.Read(data) + tmpFile.Write(data) + + if err = tmpFile.Close(); err != nil { + t.Fatal(err) + } + + err = SignFile(tmpFile.Name(), tmpFile.Name()+".sig", testSecKey, "clé", "croissants") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name() + ".sig") + + // Verify the signature using a golang library + sig, err := minisign.NewSignatureFromFile(tmpFile.Name() + ".sig") + if err != nil { + t.Fatal(err) + } + + pKey, err := minisign.NewPublicKey(testPubKey) + if err != nil { + t.Fatal(err) + } + + valid, err := pKey.VerifyFromFile(tmpFile.Name(), sig) + if err != nil { + t.Fatal(err) + } + if !valid { + t.Fatal("invalid signature") + } +} + +func TestSignifyTrustedCommentTooManyLines(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + rand.Seed(time.Now().UnixNano()) + + data := make([]byte, 1024) + rand.Read(data) + tmpFile.Write(data) + + if err = tmpFile.Close(); err != nil { + t.Fatal(err) + } + + err = SignFile(tmpFile.Name(), tmpFile.Name()+".sig", testSecKey, "", "crois\nsants") + if err == nil || err.Error() == "" { + t.Fatalf("should have errored on a multi-line trusted comment, got %v", err) + } + defer os.Remove(tmpFile.Name() + ".sig") +} + +func TestSignifyTrustedCommentTooManyLinesLF(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + rand.Seed(time.Now().UnixNano()) + + data := make([]byte, 1024) + rand.Read(data) + tmpFile.Write(data) + + if err = tmpFile.Close(); err != nil { + t.Fatal(err) + } + + err = SignFile(tmpFile.Name(), tmpFile.Name()+".sig", testSecKey, "crois\rsants", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name() + ".sig") +} + +func TestSignifyTrustedCommentEmpty(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + rand.Seed(time.Now().UnixNano()) + + data := make([]byte, 1024) + rand.Read(data) + tmpFile.Write(data) + + if err = tmpFile.Close(); err != nil { + t.Fatal(err) + } + + err = SignFile(tmpFile.Name(), tmpFile.Name()+".sig", testSecKey, "", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name() + ".sig") +} diff --git a/eth/api.go b/eth/api.go index 76118e2d7fc6..fd3565647647 100644 --- a/eth/api.go +++ b/eth/api.go @@ -389,6 +389,8 @@ func (api *PublicDebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, sta if err != nil { return state.IteratorDump{}, err } + } else { + return state.IteratorDump{}, errors.New("either block number or block hash must be specified") } if maxResults > AccountRangeMaxResults || maxResults <= 0 { diff --git a/eth/api_backend.go b/eth/api_backend.go index 0e91691d8f51..2f7020475f12 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -56,7 +56,7 @@ func (b *EthAPIBackend) CurrentBlock() *types.Block { } func (b *EthAPIBackend) SetHead(number uint64) { - b.eth.protocolManager.downloader.Cancel() + b.eth.handler.downloader.Cancel() b.eth.blockchain.SetHead(number) } @@ -194,8 +194,9 @@ func (b *EthAPIBackend) GetTd(ctx context.Context, hash common.Hash) *big.Int { func (b *EthAPIBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header) (*vm.EVM, func() error, error) { vmError := func() error { return nil } - context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) - return vm.NewEVM(context, state, b.eth.blockchain.Config(), *b.eth.blockchain.GetVMConfig()), vmError, nil + txContext := core.NewEVMTxContext(msg) + context := core.NewEVMBlockContext(header, b.eth.BlockChain(), nil) + return vm.NewEVM(context, txContext, state, b.eth.blockchain.Config(), *b.eth.blockchain.GetVMConfig()), vmError, nil } func (b *EthAPIBackend) SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription { @@ -271,10 +272,6 @@ func (b *EthAPIBackend) Downloader() *downloader.Downloader { return b.eth.Downloader() } -func (b *EthAPIBackend) ProtocolVersion() int { - return b.eth.EthVersion() -} - func (b *EthAPIBackend) SuggestPrice(ctx context.Context) (*big.Int, error) { return b.gpo.SuggestPrice(ctx) } diff --git a/eth/api_test.go b/eth/api_test.go index 42f71e261e60..b44eed40bcd2 100644 --- a/eth/api_test.go +++ b/eth/api_test.go @@ -57,8 +57,10 @@ func (h resultHash) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h resultHash) Less(i, j int) bool { return bytes.Compare(h[i].Bytes(), h[j].Bytes()) < 0 } func TestAccountRange(t *testing.T) { + t.Parallel() + var ( - statedb = state.NewDatabase(rawdb.NewMemoryDatabase()) + statedb = state.NewDatabaseWithConfig(rawdb.NewMemoryDatabase(), nil) state, _ = state.New(common.Hash{}, statedb, nil) addrs = [AccountRangeMaxResults * 2]common.Address{} m = map[common.Address]bool{} @@ -126,6 +128,8 @@ func TestAccountRange(t *testing.T) { } func TestEmptyAccountRange(t *testing.T) { + t.Parallel() + var ( statedb = state.NewDatabase(rawdb.NewMemoryDatabase()) state, _ = state.New(common.Hash{}, statedb, nil) @@ -142,6 +146,8 @@ func TestEmptyAccountRange(t *testing.T) { } func TestStorageRangeAt(t *testing.T) { + t.Parallel() + // Create a state where account 0x010000... has a few storage entries. var ( state, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) diff --git a/eth/api_tracer.go b/eth/api_tracer.go index 748280951c4a..c68b762255f4 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -38,6 +38,7 @@ import ( "github.com/ethereum/go-ethereum/eth/tracers" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/trie" @@ -64,7 +65,7 @@ type TraceConfig struct { // StdTraceConfig holds extra parameters to standard-json trace functions. type StdTraceConfig struct { - *vm.LogConfig + vm.LogConfig Reexec *uint64 TxHash common.Hash } @@ -147,7 +148,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl // Ensure we have a valid starting state before doing any work origin := start.NumberU64() - database := state.NewDatabaseWithCache(api.eth.ChainDb(), 16, "") // Chain tracing will probably start at genesis + database := state.NewDatabaseWithConfig(api.eth.ChainDb(), &trie.Config{Cache: 16, Preimages: true}) if number := start.NumberU64(); number > 0 { start = api.eth.blockchain.GetBlock(start.ParentHash(), start.NumberU64()-1) @@ -202,13 +203,11 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl // Fetch and execute the next block trace tasks for task := range tasks { signer := types.MakeSigner(api.eth.blockchain.Config(), task.block.Number()) - + blockCtx := core.NewEVMBlockContext(task.block.Header(), api.eth.blockchain, nil) // Trace all the transactions contained within for i, tx := range task.block.Transactions() { msg, _ := tx.AsMessage(signer) - vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil) - - res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) + res, err := api.traceTx(ctx, msg, blockCtx, task.statedb, config) if err != nil { task.results[i] = &txTraceResult{Error: err.Error()} log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err) @@ -472,17 +471,15 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, if threads > len(txs) { threads = len(txs) } + blockCtx := core.NewEVMBlockContext(block.Header(), api.eth.blockchain, nil) for th := 0; th < threads; th++ { pend.Add(1) go func() { defer pend.Done() - // Fetch and execute the next transaction trace tasks for task := range jobs { msg, _ := txs[task.index].AsMessage(signer) - vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) - - res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) + res, err := api.traceTx(ctx, msg, blockCtx, task.statedb, config) if err != nil { results[task.index] = &txTraceResult{Error: err.Error()} continue @@ -499,9 +496,9 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, // Generate the next state snapshot fast without tracing msg, _ := tx.AsMessage(signer) - vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) + txContext := core.NewEVMTxContext(msg) - vmenv := vm.NewEVM(vmctx, statedb, api.eth.blockchain.Config(), vm.Config{}) + vmenv := vm.NewEVM(blockCtx, txContext, statedb, api.eth.blockchain.Config(), vm.Config{}) if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas())); err != nil { failed = err break @@ -552,34 +549,53 @@ func (api *PrivateDebugAPI) standardTraceBlockToFile(ctx context.Context, block txHash common.Hash ) if config != nil { - if config.LogConfig != nil { - logConfig = *config.LogConfig - } + logConfig = config.LogConfig txHash = config.TxHash } logConfig.Debug = true // Execute transaction, either tracing all or just the requested one var ( - signer = types.MakeSigner(api.eth.blockchain.Config(), block.Number()) - dumps []string + signer = types.MakeSigner(api.eth.blockchain.Config(), block.Number()) + dumps []string + chainConfig = api.eth.blockchain.Config() + vmctx = core.NewEVMBlockContext(block.Header(), api.eth.blockchain, nil) + canon = true ) + // Check if there are any overrides: the caller may wish to enable a future + // fork when executing this block. Note, such overrides are only applicable to the + // actual specified block, not any preceding blocks that we have to go through + // in order to obtain the state. + // Therefore, it's perfectly valid to specify `"futureForkBlock": 0`, to enable `futureFork` + + if config != nil && config.Overrides != nil { + // Copy the config, to not screw up the main config + // Note: the Clique-part is _not_ deep copied + chainConfigCopy := new(params.ChainConfig) + *chainConfigCopy = *chainConfig + chainConfig = chainConfigCopy + if yolov2 := config.LogConfig.Overrides.YoloV2Block; yolov2 != nil { + chainConfig.YoloV2Block = yolov2 + canon = false + } + } for i, tx := range block.Transactions() { // Prepare the trasaction for un-traced execution var ( - msg, _ = tx.AsMessage(signer) - vmctx = core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) - - vmConf vm.Config - dump *os.File - writer *bufio.Writer - err error + msg, _ = tx.AsMessage(signer) + txContext = core.NewEVMTxContext(msg) + vmConf vm.Config + dump *os.File + writer *bufio.Writer + err error ) // If the transaction needs tracing, swap out the configs if tx.Hash() == txHash || txHash == (common.Hash{}) { // Generate a unique temporary file to dump it into prefix := fmt.Sprintf("block_%#x-%d-%#x-", block.Hash().Bytes()[:4], i, tx.Hash().Bytes()[:4]) - + if !canon { + prefix = fmt.Sprintf("%valt-", prefix) + } dump, err = ioutil.TempFile(os.TempDir(), prefix) if err != nil { return nil, err @@ -595,7 +611,7 @@ func (api *PrivateDebugAPI) standardTraceBlockToFile(ctx context.Context, block } } // Execute the transaction and flush any traces to disk - vmenv := vm.NewEVM(vmctx, statedb, api.eth.blockchain.Config(), vmConf) + vmenv := vm.NewEVM(vmctx, txContext, statedb, chainConfig, vmConf) _, err = core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas())) if writer != nil { writer.Flush() @@ -641,7 +657,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (* } // Otherwise try to reexec blocks until we find a state or reach our limit origin := block.NumberU64() - database := state.NewDatabaseWithCache(api.eth.ChainDb(), 16, "") + database := state.NewDatabaseWithConfig(api.eth.ChainDb(), &trie.Config{Cache: 16, Preimages: true}) for i := uint64(0); i < reexec; i++ { block = api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1) @@ -754,18 +770,19 @@ func (api *PrivateDebugAPI) TraceCall(ctx context.Context, args ethapi.CallArgs, // Execute the trace msg := args.ToMessage(api.eth.APIBackend.RPCGasCap()) - vmctx := core.NewEVMContext(msg, header, api.eth.blockchain, nil) + vmctx := core.NewEVMBlockContext(header, api.eth.blockchain, nil) return api.traceTx(ctx, msg, vmctx, statedb, config) } // traceTx configures a new tracer according to the provided configuration, and // executes the given message in the provided environment. The return value will // be tracer dependent. -func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { +func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.BlockContext, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { // Assemble the structured logger or the JavaScript tracer var ( - tracer vm.Tracer - err error + tracer vm.Tracer + err error + txContext = core.NewEVMTxContext(message) ) switch { case config != nil && config.Tracer != nil: @@ -777,7 +794,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v } } // Constuct the JavaScript tracer to execute with - if tracer, err = tracers.New(*config.Tracer); err != nil { + if tracer, err = tracers.New(*config.Tracer, txContext); err != nil { return nil, err } // Handle timeouts and RPC cancellations @@ -795,7 +812,7 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v tracer = vm.NewStructLogger(config.LogConfig) } // Run the transaction with tracing enabled. - vmenv := vm.NewEVM(vmctx, statedb, api.eth.blockchain.Config(), vm.Config{Debug: true, Tracer: tracer}) + vmenv := vm.NewEVM(vmctx, txContext, statedb, api.eth.blockchain.Config(), vm.Config{Debug: true, Tracer: tracer}) result, err := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas())) if err != nil { @@ -825,19 +842,19 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v } // computeTxEnv returns the execution environment of a certain transaction. -func (api *PrivateDebugAPI) computeTxEnv(block *types.Block, txIndex int, reexec uint64) (core.Message, vm.Context, *state.StateDB, error) { +func (api *PrivateDebugAPI) computeTxEnv(block *types.Block, txIndex int, reexec uint64) (core.Message, vm.BlockContext, *state.StateDB, error) { // Create the parent state database parent := api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1) if parent == nil { - return nil, vm.Context{}, nil, fmt.Errorf("parent %#x not found", block.ParentHash()) + return nil, vm.BlockContext{}, nil, fmt.Errorf("parent %#x not found", block.ParentHash()) } statedb, err := api.computeStateDB(parent, reexec) if err != nil { - return nil, vm.Context{}, nil, err + return nil, vm.BlockContext{}, nil, err } if txIndex == 0 && len(block.Transactions()) == 0 { - return nil, vm.Context{}, statedb, nil + return nil, vm.BlockContext{}, statedb, nil } // Recompute transactions up to the target index. @@ -846,18 +863,19 @@ func (api *PrivateDebugAPI) computeTxEnv(block *types.Block, txIndex int, reexec for idx, tx := range block.Transactions() { // Assemble the transaction call message and return if the requested offset msg, _ := tx.AsMessage(signer) - context := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) + txContext := core.NewEVMTxContext(msg) + context := core.NewEVMBlockContext(block.Header(), api.eth.blockchain, nil) if idx == txIndex { return msg, context, statedb, nil } // Not yet the searched for transaction, execute on top of the current state - vmenv := vm.NewEVM(context, statedb, api.eth.blockchain.Config(), vm.Config{}) + vmenv := vm.NewEVM(context, txContext, statedb, api.eth.blockchain.Config(), vm.Config{}) if _, err := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(tx.Gas())); err != nil { - return nil, vm.Context{}, nil, fmt.Errorf("transaction %#x failed: %v", tx.Hash(), err) + return nil, vm.BlockContext{}, nil, fmt.Errorf("transaction %#x failed: %v", tx.Hash(), err) } // Ensure any modifications are committed to the state // Only delete empty objects if EIP158/161 (a.k.a Spurious Dragon) is in effect statedb.Finalise(vmenv.ChainConfig().IsEIP158(block.Number())) } - return nil, vm.Context{}, nil, fmt.Errorf("transaction index %d out of range for block %#x", txIndex, block.Hash()) + return nil, vm.BlockContext{}, nil, fmt.Errorf("transaction index %d out of range for block %#x", txIndex, block.Hash()) } diff --git a/eth/backend.go b/eth/backend.go index 3fd027137c7f..987dee6d5515 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -24,6 +24,7 @@ import ( "runtime" "sync" "sync/atomic" + "time" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/common" @@ -39,6 +40,8 @@ import ( "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/filters" "github.com/ethereum/go-ethereum/eth/gasprice" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/internal/ethapi" @@ -47,7 +50,6 @@ import ( "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" @@ -58,10 +60,11 @@ type Ethereum struct { config *Config // Handlers - txPool *core.TxPool - blockchain *core.BlockChain - protocolManager *ProtocolManager - dialCandidates enode.Iterator + txPool *core.TxPool + blockchain *core.BlockChain + handler *handler + ethDialCandidates enode.Iterator + snapDialCandidates enode.Iterator // DB interfaces chainDb ethdb.Database // Block chain database @@ -144,7 +147,7 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) { if bcVersion != nil { dbVer = fmt.Sprintf("%d", *bcVersion) } - log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkId, "dbversion", dbVer) + log.Info("Initialising Ethereum protocol", "network", config.NetworkId, "dbversion", dbVer) if !config.SkipBcVersionCheck { if bcVersion != nil && *bcVersion > core.BlockChainVersion { @@ -169,6 +172,7 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) { TrieDirtyDisabled: config.NoPruning, TrieTimeLimit: config.TrieTimeout, SnapshotLimit: config.SnapshotCache, + Preimages: config.Preimages, } ) eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, chainConfig, eth.engine, vmConfig, eth.shouldPreserve, &config.TxLookupLimit) @@ -194,7 +198,17 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) { if checkpoint == nil { checkpoint = params.TrustedCheckpoints[genesisHash] } - if eth.protocolManager, err = NewProtocolManager(chainConfig, checkpoint, config.SyncMode, config.NetworkId, eth.eventMux, eth.txPool, eth.engine, eth.blockchain, chainDb, cacheLimit, config.Whitelist); err != nil { + if eth.handler, err = newHandler(&handlerConfig{ + Database: chainDb, + Chain: eth.blockchain, + TxPool: eth.txPool, + Network: config.NetworkId, + Sync: config.SyncMode, + BloomCache: uint64(cacheLimit), + EventMux: eth.eventMux, + Checkpoint: checkpoint, + Whitelist: config.Whitelist, + }); err != nil { return nil, err } eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock) @@ -207,18 +221,34 @@ func New(stack *node.Node, config *Config) (*Ethereum, error) { } eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) - eth.dialCandidates, err = eth.setupDiscovery(&stack.Config().P2P) + eth.ethDialCandidates, err = setupDiscovery(eth.config.EthDiscoveryURLs) + if err != nil { + return nil, err + } + eth.snapDialCandidates, err = setupDiscovery(eth.config.SnapDiscoveryURLs) if err != nil { return nil, err } - // Start the RPC service - eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer, eth.NetVersion()) + eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer) // Register the backend on the node stack.RegisterAPIs(eth.APIs()) stack.RegisterProtocols(eth.Protocols()) stack.RegisterLifecycle(eth) + // Check for unclean shutdown + if uncleanShutdowns, discards, err := rawdb.PushUncleanShutdownMarker(chainDb); err != nil { + log.Error("Could not update unclean-shutdown-marker list", "error", err) + } else { + if discards > 0 { + log.Warn("Old unclean shutdowns found", "count", discards) + } + for _, tstamp := range uncleanShutdowns { + t := time.Unix(int64(tstamp), 0) + log.Warn("Unclean shutdown detected", "booted", t, + "age", common.PrettyAge(t)) + } + } return eth, nil } @@ -295,7 +325,7 @@ func (s *Ethereum) APIs() []rpc.API { }, { Namespace: "eth", Version: "1.0", - Service: downloader.NewPublicDownloaderAPI(s.protocolManager.downloader, s.eventMux), + Service: downloader.NewPublicDownloaderAPI(s.handler.downloader, s.eventMux), Public: true, }, { Namespace: "miner", @@ -458,7 +488,7 @@ func (s *Ethereum) StartMining(threads int) error { } // If mining is started, we can disable the transaction rejection mechanism // introduced to speed sync times. - atomic.StoreUint32(&s.protocolManager.acceptTxs, 1) + atomic.StoreUint32(&s.handler.acceptTxs, 1) go s.miner.Start(eb) } @@ -489,21 +519,17 @@ func (s *Ethereum) EventMux() *event.TypeMux { return s.eventMux } func (s *Ethereum) Engine() consensus.Engine { return s.engine } func (s *Ethereum) ChainDb() ethdb.Database { return s.chainDb } func (s *Ethereum) IsListening() bool { return true } // Always listening -func (s *Ethereum) EthVersion() int { return int(ProtocolVersions[0]) } -func (s *Ethereum) NetVersion() uint64 { return s.networkID } -func (s *Ethereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } -func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.protocolManager.acceptTxs) == 1 } +func (s *Ethereum) Downloader() *downloader.Downloader { return s.handler.downloader } +func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.handler.acceptTxs) == 1 } func (s *Ethereum) ArchiveMode() bool { return s.config.NoPruning } func (s *Ethereum) BloomIndexer() *core.ChainIndexer { return s.bloomIndexer } // Protocols returns all the currently configured // network protocols to start. func (s *Ethereum) Protocols() []p2p.Protocol { - protos := make([]p2p.Protocol, len(ProtocolVersions)) - for i, vsn := range ProtocolVersions { - protos[i] = s.protocolManager.makeProtocol(vsn) - protos[i].Attributes = []enr.Entry{s.currentEthEntry()} - protos[i].DialCandidates = s.dialCandidates + protos := eth.MakeProtocols((*ethHandler)(s.handler), s.networkID, s.ethDialCandidates) + if s.config.SnapshotCache > 0 { + protos = append(protos, snap.MakeProtocols((*snapHandler)(s.handler), s.snapDialCandidates)...) } return protos } @@ -511,7 +537,7 @@ func (s *Ethereum) Protocols() []p2p.Protocol { // Start implements node.Lifecycle, starting all internal goroutines needed by the // Ethereum protocol implementation. func (s *Ethereum) Start() error { - s.startEthEntryUpdate(s.p2pServer.LocalNode()) + eth.StartENRUpdater(s.blockchain, s.p2pServer.LocalNode()) // Start the bloom bits servicing goroutines s.startBloomHandlers(params.BloomBitsBlocks) @@ -525,7 +551,7 @@ func (s *Ethereum) Start() error { maxPeers -= s.config.LightPeers } // Start the networking layer and the light server if requested - s.protocolManager.Start(maxPeers) + s.handler.Start(maxPeers) return nil } @@ -533,7 +559,7 @@ func (s *Ethereum) Start() error { // Ethereum protocol. func (s *Ethereum) Stop() error { // Stop all the peer-related stuff first. - s.protocolManager.Stop() + s.handler.Stop() // Then stop everything else. s.bloomIndexer.Close() @@ -542,7 +568,9 @@ func (s *Ethereum) Stop() error { s.miner.Stop() s.blockchain.Stop() s.engine.Close() + rawdb.PopUncleanShutdownMarker(s.chainDb) s.chainDb.Close() s.eventMux.Stop() + return nil } diff --git a/eth/config.go b/eth/config.go index 0d99c2a3f1e6..77d03e95697d 100644 --- a/eth/config.go +++ b/eth/config.go @@ -115,7 +115,8 @@ type Config struct { // This can be set to list of enrtree:// URLs which will be queried for // for nodes to connect to. - DiscoveryURLs []string + EthDiscoveryURLs []string + SnapDiscoveryURLs []string NoPruning bool // Whether to disable pruning and flush everything to disk NoPrefetch bool // Whether to disable prefetching and only load state on demand @@ -149,6 +150,7 @@ type Config struct { TrieDirtyCache int TrieTimeout time.Duration SnapshotCache int + Preimages bool // Mining options Miner miner.Config diff --git a/eth/discovery.go b/eth/discovery.go index 97d6322ca16d..855ce3b0e1f4 100644 --- a/eth/discovery.go +++ b/eth/discovery.go @@ -19,7 +19,6 @@ package eth import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/forkid" - "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/dnsdisc" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" @@ -60,14 +59,16 @@ func (eth *Ethereum) startEthEntryUpdate(ln *enode.LocalNode) { } func (eth *Ethereum) currentEthEntry() *ethEntry { - return ðEntry{ForkID: forkid.NewID(eth.blockchain)} + return ðEntry{ForkID: forkid.NewID(eth.blockchain.Config(), eth.blockchain.Genesis().Hash(), + eth.blockchain.CurrentHeader().Number.Uint64())} } -// setupDiscovery creates the node discovery source for the eth protocol. -func (eth *Ethereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) { - if cfg.NoDiscovery || len(eth.config.DiscoveryURLs) == 0 { +// setupDiscovery creates the node discovery source for the `eth` and `snap` +// protocols. +func setupDiscovery(urls []string) (enode.Iterator, error) { + if len(urls) == 0 { return nil, nil } client := dnsdisc.NewClient(dnsdisc.Config{}) - return client.NewIterator(eth.config.DiscoveryURLs...) + return client.NewIterator(urls...) } diff --git a/eth/downloader/api.go b/eth/downloader/api.go index 57ff3d71afa8..2024d23deade 100644 --- a/eth/downloader/api.go +++ b/eth/downloader/api.go @@ -20,7 +20,7 @@ import ( "context" "sync" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/rpc" ) diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index c6ca9af0136a..3123598437a4 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" @@ -38,7 +39,6 @@ import ( ) var ( - MaxHashFetch = 512 // Amount of hashes to be fetched per retrieval request MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly @@ -89,7 +89,7 @@ var ( errCancelContentProcessing = errors.New("content processing canceled (requested)") errCanceled = errors.New("syncing canceled (requested)") errNoSyncActive = errors.New("no sync active") - errTooOld = errors.New("peer doesn't speak recent enough protocol version (need version >= 63)") + errTooOld = errors.New("peer doesn't speak recent enough protocol version (need version >= 64)") ) type Downloader struct { @@ -131,20 +131,22 @@ type Downloader struct { ancientLimit uint64 // The maximum block number which can be regarded as ancient data. // Channels - headerCh chan dataPack // [eth/62] Channel receiving inbound block headers - bodyCh chan dataPack // [eth/62] Channel receiving inbound block bodies - receiptCh chan dataPack // [eth/63] Channel receiving inbound receipts - bodyWakeCh chan bool // [eth/62] Channel to signal the block body fetcher of new tasks - receiptWakeCh chan bool // [eth/63] Channel to signal the receipt fetcher of new tasks - headerProcCh chan []*types.Header // [eth/62] Channel to feed the header processor new tasks + headerCh chan dataPack // Channel receiving inbound block headers + bodyCh chan dataPack // Channel receiving inbound block bodies + receiptCh chan dataPack // Channel receiving inbound receipts + bodyWakeCh chan bool // Channel to signal the block body fetcher of new tasks + receiptWakeCh chan bool // Channel to signal the receipt fetcher of new tasks + headerProcCh chan []*types.Header // Channel to feed the header processor new tasks // State sync pivotHeader *types.Header // Pivot block header to dynamically push the syncing state root pivotLock sync.RWMutex // Lock protecting pivot header reads from updates + snapSync bool // Whether to run state sync over the snap protocol + SnapSyncer *snap.Syncer // TODO(karalabe): make private! hack for now stateSyncStart chan *stateSync trackStateReq chan *stateReq - stateCh chan dataPack // [eth/63] Channel receiving inbound node state data + stateCh chan dataPack // Channel receiving inbound node state data // Cancellation and termination cancelPeer string // Identifier of the peer currently being used as the master (cancel on drop) @@ -153,7 +155,7 @@ type Downloader struct { cancelWg sync.WaitGroup // Make sure all fetcher goroutines have exited. quitCh chan struct{} // Quit channel to signal termination - quitLock sync.RWMutex // Lock to prevent double closes + quitLock sync.Mutex // Lock to prevent double closes // Testing hooks syncInitHook func(uint64, uint64) // Method to call upon initiating a new sync run @@ -237,6 +239,7 @@ func New(checkpoint uint64, stateDb ethdb.Database, stateBloom *trie.SyncBloom, headerProcCh: make(chan []*types.Header, 1), quitCh: make(chan struct{}), stateCh: make(chan dataPack), + SnapSyncer: snap.NewSyncer(stateDb, stateBloom), stateSyncStart: make(chan *stateSync), syncStatsState: stateSyncStats{ processed: rawdb.ReadFastTrieProgress(stateDb), @@ -286,19 +289,16 @@ func (d *Downloader) Synchronising() bool { return atomic.LoadInt32(&d.synchronising) > 0 } -// SyncBloomContains tests if the syncbloom filter contains the given hash: -// - false: the bloom definitely does not contain hash -// - true: the bloom maybe contains hash -// -// While the bloom is being initialized (or is closed), all queries will return true. -func (d *Downloader) SyncBloomContains(hash []byte) bool { - return d.stateBloom == nil || d.stateBloom.Contains(hash) -} - // RegisterPeer injects a new download peer into the set of block source to be // used for fetching hashes and blocks from. -func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error { - logger := log.New("peer", id) +func (d *Downloader) RegisterPeer(id string, version uint, peer Peer) error { + var logger log.Logger + if len(id) < 16 { + // Tests use short IDs, don't choke on them + logger = log.New("peer", id) + } else { + logger = log.New("peer", id[:16]) + } logger.Trace("Registering sync peer") if err := d.peers.Register(newPeerConnection(id, version, peer, logger)); err != nil { logger.Error("Failed to register sync peer", "err", err) @@ -310,7 +310,7 @@ func (d *Downloader) RegisterPeer(id string, version int, peer Peer) error { } // RegisterLightPeer injects a light client peer, wrapping it so it appears as a regular peer. -func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) error { +func (d *Downloader) RegisterLightPeer(id string, version uint, peer LightPeer) error { return d.RegisterPeer(id, version, &lightPeerWrapper{peer}) } @@ -319,7 +319,13 @@ func (d *Downloader) RegisterLightPeer(id string, version int, peer LightPeer) e // the queue. func (d *Downloader) UnregisterPeer(id string) error { // Unregister the peer from the active peer set and revoke any fetch tasks - logger := log.New("peer", id) + var logger log.Logger + if len(id) < 16 { + // Tests use short IDs, don't choke on them + logger = log.New("peer", id) + } else { + logger = log.New("peer", id[:16]) + } logger.Trace("Unregistering sync peer") if err := d.peers.Unregister(id); err != nil { logger.Error("Failed to unregister sync peer", "err", err) @@ -381,6 +387,16 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode if mode == FullSync && d.stateBloom != nil { d.stateBloom.Close() } + // If snap sync was requested, create the snap scheduler and switch to fast + // sync mode. Long term we could drop fast sync or merge the two together, + // but until snap becomes prevalent, we should support both. TODO(karalabe). + if mode == SnapSync { + if !d.snapSync { + log.Warn("Enabling snapshot sync prototype") + d.snapSync = true + } + mode = FastSync + } // Reset the queue, peer set and wake channels to clean any internal leftover state d.queue.Reset(blockCacheMaxItems, blockCacheInitialItems) d.peers.Reset() @@ -443,8 +459,8 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I d.mux.Post(DoneEvent{latest}) } }() - if p.version < 63 { - return errTooOld + if p.version < 64 { + return fmt.Errorf("%w, peer version: %d", errTooOld, p.version) } mode := d.getMode() @@ -515,6 +531,8 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I d.ancientLimit = d.checkpoint } else if height > fullMaxForkAncestry+1 { d.ancientLimit = height - fullMaxForkAncestry - 1 + } else { + d.ancientLimit = 0 } frozen, _ := d.stateDB.Ancients() // Ignore the error here since light client can also hit here. @@ -606,9 +624,6 @@ func (d *Downloader) cancel() { func (d *Downloader) Cancel() { d.cancel() d.cancelWg.Wait() - - d.ancientLimit = 0 - log.Debug("Reset ancient limit to zero") } // Terminate interrupts the downloader, canceling all pending operations. @@ -1911,27 +1926,53 @@ func (d *Downloader) commitPivotBlock(result *fetchResult) error { // DeliverHeaders injects a new batch of block headers received from a remote // node into the download schedule. -func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) (err error) { - return d.deliver(id, d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter) +func (d *Downloader) DeliverHeaders(id string, headers []*types.Header) error { + return d.deliver(d.headerCh, &headerPack{id, headers}, headerInMeter, headerDropMeter) } // DeliverBodies injects a new batch of block bodies received from a remote node. -func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) (err error) { - return d.deliver(id, d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter) +func (d *Downloader) DeliverBodies(id string, transactions [][]*types.Transaction, uncles [][]*types.Header) error { + return d.deliver(d.bodyCh, &bodyPack{id, transactions, uncles}, bodyInMeter, bodyDropMeter) } // DeliverReceipts injects a new batch of receipts received from a remote node. -func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) (err error) { - return d.deliver(id, d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter) +func (d *Downloader) DeliverReceipts(id string, receipts [][]*types.Receipt) error { + return d.deliver(d.receiptCh, &receiptPack{id, receipts}, receiptInMeter, receiptDropMeter) } // DeliverNodeData injects a new batch of node state data received from a remote node. -func (d *Downloader) DeliverNodeData(id string, data [][]byte) (err error) { - return d.deliver(id, d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter) +func (d *Downloader) DeliverNodeData(id string, data [][]byte) error { + return d.deliver(d.stateCh, &statePack{id, data}, stateInMeter, stateDropMeter) +} + +// DeliverSnapPacket is invoked from a peer's message handler when it transmits a +// data packet for the local node to consume. +func (d *Downloader) DeliverSnapPacket(peer *snap.Peer, packet snap.Packet) error { + switch packet := packet.(type) { + case *snap.AccountRangePacket: + hashes, accounts, err := packet.Unpack() + if err != nil { + return err + } + return d.SnapSyncer.OnAccounts(peer, packet.ID, hashes, accounts, packet.Proof) + + case *snap.StorageRangesPacket: + hashset, slotset := packet.Unpack() + return d.SnapSyncer.OnStorage(peer, packet.ID, hashset, slotset, packet.Proof) + + case *snap.ByteCodesPacket: + return d.SnapSyncer.OnByteCodes(peer, packet.ID, packet.Codes) + + case *snap.TrieNodesPacket: + return d.SnapSyncer.OnTrieNodes(peer, packet.ID, packet.Nodes) + + default: + return fmt.Errorf("unexpected snap packet type: %T", packet) + } } // deliver injects a new batch of data received from a remote node. -func (d *Downloader) deliver(id string, destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) { +func (d *Downloader) deliver(destCh chan dataPack, packet dataPack, inMeter, dropMeter metrics.Meter) (err error) { // Update the delivery metrics for both good and failed deliveries inMeter.Mark(int64(packet.Items())) defer func() { diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 4bf1e4d469dc..6578275d0c32 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -390,7 +390,7 @@ func (dl *downloadTester) Rollback(hashes []common.Hash) { } // newPeer registers a new block download source into the downloader. -func (dl *downloadTester) newPeer(id string, version int, chain *testChain) error { +func (dl *downloadTester) newPeer(id string, version uint, chain *testChain) error { dl.lock.Lock() defer dl.lock.Unlock() @@ -411,7 +411,6 @@ func (dl *downloadTester) dropPeer(id string) { type downloadTesterPeer struct { dl *downloadTester id string - lock sync.RWMutex chain *testChain missingStates map[common.Hash]bool // State entries that fast sync should not return } @@ -519,8 +518,6 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng // Tests that simple synchronization against a canonical chain works correctly. // In this test common ancestor lookup should be short circuited and not require // binary searching. -func TestCanonicalSynchronisation63Full(t *testing.T) { testCanonicalSynchronisation(t, 63, FullSync) } -func TestCanonicalSynchronisation63Fast(t *testing.T) { testCanonicalSynchronisation(t, 63, FastSync) } func TestCanonicalSynchronisation64Full(t *testing.T) { testCanonicalSynchronisation(t, 64, FullSync) } func TestCanonicalSynchronisation64Fast(t *testing.T) { testCanonicalSynchronisation(t, 64, FastSync) } func TestCanonicalSynchronisation65Full(t *testing.T) { testCanonicalSynchronisation(t, 65, FullSync) } @@ -529,7 +526,7 @@ func TestCanonicalSynchronisation65Light(t *testing.T) { testCanonicalSynchronisation(t, 65, LightSync) } -func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { +func testCanonicalSynchronisation(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -548,14 +545,12 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { // Tests that if a large batch of blocks are being downloaded, it is throttled // until the cached blocks are retrieved. -func TestThrottling63Full(t *testing.T) { testThrottling(t, 63, FullSync) } -func TestThrottling63Fast(t *testing.T) { testThrottling(t, 63, FastSync) } func TestThrottling64Full(t *testing.T) { testThrottling(t, 64, FullSync) } func TestThrottling64Fast(t *testing.T) { testThrottling(t, 64, FastSync) } func TestThrottling65Full(t *testing.T) { testThrottling(t, 65, FullSync) } func TestThrottling65Fast(t *testing.T) { testThrottling(t, 65, FastSync) } -func testThrottling(t *testing.T, protocol int, mode SyncMode) { +func testThrottling(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -570,7 +565,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { <-proceed } // Start a synchronisation concurrently - errc := make(chan error) + errc := make(chan error, 1) go func() { errc <- tester.sync("peer", nil, mode) }() @@ -633,15 +628,13 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { // Tests that simple synchronization against a forked chain works correctly. In // this test common ancestor lookup should *not* be short circuited, and a full // binary search should be executed. -func TestForkedSync63Full(t *testing.T) { testForkedSync(t, 63, FullSync) } -func TestForkedSync63Fast(t *testing.T) { testForkedSync(t, 63, FastSync) } func TestForkedSync64Full(t *testing.T) { testForkedSync(t, 64, FullSync) } func TestForkedSync64Fast(t *testing.T) { testForkedSync(t, 64, FastSync) } func TestForkedSync65Full(t *testing.T) { testForkedSync(t, 65, FullSync) } func TestForkedSync65Fast(t *testing.T) { testForkedSync(t, 65, FastSync) } func TestForkedSync65Light(t *testing.T) { testForkedSync(t, 65, LightSync) } -func testForkedSync(t *testing.T, protocol int, mode SyncMode) { +func testForkedSync(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -666,15 +659,13 @@ func testForkedSync(t *testing.T, protocol int, mode SyncMode) { // Tests that synchronising against a much shorter but much heavyer fork works // corrently and is not dropped. -func TestHeavyForkedSync63Full(t *testing.T) { testHeavyForkedSync(t, 63, FullSync) } -func TestHeavyForkedSync63Fast(t *testing.T) { testHeavyForkedSync(t, 63, FastSync) } func TestHeavyForkedSync64Full(t *testing.T) { testHeavyForkedSync(t, 64, FullSync) } func TestHeavyForkedSync64Fast(t *testing.T) { testHeavyForkedSync(t, 64, FastSync) } func TestHeavyForkedSync65Full(t *testing.T) { testHeavyForkedSync(t, 65, FullSync) } func TestHeavyForkedSync65Fast(t *testing.T) { testHeavyForkedSync(t, 65, FastSync) } func TestHeavyForkedSync65Light(t *testing.T) { testHeavyForkedSync(t, 65, LightSync) } -func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { +func testHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -701,15 +692,13 @@ func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { // Tests that chain forks are contained within a certain interval of the current // chain head, ensuring that malicious peers cannot waste resources by feeding // long dead chains. -func TestBoundedForkedSync63Full(t *testing.T) { testBoundedForkedSync(t, 63, FullSync) } -func TestBoundedForkedSync63Fast(t *testing.T) { testBoundedForkedSync(t, 63, FastSync) } func TestBoundedForkedSync64Full(t *testing.T) { testBoundedForkedSync(t, 64, FullSync) } func TestBoundedForkedSync64Fast(t *testing.T) { testBoundedForkedSync(t, 64, FastSync) } func TestBoundedForkedSync65Full(t *testing.T) { testBoundedForkedSync(t, 65, FullSync) } func TestBoundedForkedSync65Fast(t *testing.T) { testBoundedForkedSync(t, 65, FastSync) } func TestBoundedForkedSync65Light(t *testing.T) { testBoundedForkedSync(t, 65, LightSync) } -func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) { +func testBoundedForkedSync(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -735,15 +724,13 @@ func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) { // Tests that chain forks are contained within a certain interval of the current // chain head for short but heavy forks too. These are a bit special because they // take different ancestor lookup paths. -func TestBoundedHeavyForkedSync63Full(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FullSync) } -func TestBoundedHeavyForkedSync63Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FastSync) } func TestBoundedHeavyForkedSync64Full(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FullSync) } func TestBoundedHeavyForkedSync64Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FastSync) } func TestBoundedHeavyForkedSync65Full(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FullSync) } func TestBoundedHeavyForkedSync65Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 65, FastSync) } func TestBoundedHeavyForkedSync65Light(t *testing.T) { testBoundedHeavyForkedSync(t, 65, LightSync) } -func testBoundedHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { +func testBoundedHeavyForkedSync(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -787,15 +774,13 @@ func TestInactiveDownloader63(t *testing.T) { } // Tests that a canceled download wipes all previously accumulated state. -func TestCancel63Full(t *testing.T) { testCancel(t, 63, FullSync) } -func TestCancel63Fast(t *testing.T) { testCancel(t, 63, FastSync) } func TestCancel64Full(t *testing.T) { testCancel(t, 64, FullSync) } func TestCancel64Fast(t *testing.T) { testCancel(t, 64, FastSync) } func TestCancel65Full(t *testing.T) { testCancel(t, 65, FullSync) } func TestCancel65Fast(t *testing.T) { testCancel(t, 65, FastSync) } func TestCancel65Light(t *testing.T) { testCancel(t, 65, LightSync) } -func testCancel(t *testing.T, protocol int, mode SyncMode) { +func testCancel(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -820,15 +805,13 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) { } // Tests that synchronisation from multiple peers works as intended (multi thread sanity test). -func TestMultiSynchronisation63Full(t *testing.T) { testMultiSynchronisation(t, 63, FullSync) } -func TestMultiSynchronisation63Fast(t *testing.T) { testMultiSynchronisation(t, 63, FastSync) } func TestMultiSynchronisation64Full(t *testing.T) { testMultiSynchronisation(t, 64, FullSync) } func TestMultiSynchronisation64Fast(t *testing.T) { testMultiSynchronisation(t, 64, FastSync) } func TestMultiSynchronisation65Full(t *testing.T) { testMultiSynchronisation(t, 65, FullSync) } func TestMultiSynchronisation65Fast(t *testing.T) { testMultiSynchronisation(t, 65, FastSync) } func TestMultiSynchronisation65Light(t *testing.T) { testMultiSynchronisation(t, 65, LightSync) } -func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { +func testMultiSynchronisation(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -850,15 +833,13 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { // Tests that synchronisations behave well in multi-version protocol environments // and not wreak havoc on other nodes in the network. -func TestMultiProtoSynchronisation63Full(t *testing.T) { testMultiProtoSync(t, 63, FullSync) } -func TestMultiProtoSynchronisation63Fast(t *testing.T) { testMultiProtoSync(t, 63, FastSync) } func TestMultiProtoSynchronisation64Full(t *testing.T) { testMultiProtoSync(t, 64, FullSync) } func TestMultiProtoSynchronisation64Fast(t *testing.T) { testMultiProtoSync(t, 64, FastSync) } func TestMultiProtoSynchronisation65Full(t *testing.T) { testMultiProtoSync(t, 65, FullSync) } func TestMultiProtoSynchronisation65Fast(t *testing.T) { testMultiProtoSync(t, 65, FastSync) } func TestMultiProtoSynchronisation65Light(t *testing.T) { testMultiProtoSync(t, 65, LightSync) } -func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { +func testMultiProtoSync(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -889,15 +870,13 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { // Tests that if a block is empty (e.g. header only), no body request should be // made, and instead the header should be assembled into a whole block in itself. -func TestEmptyShortCircuit63Full(t *testing.T) { testEmptyShortCircuit(t, 63, FullSync) } -func TestEmptyShortCircuit63Fast(t *testing.T) { testEmptyShortCircuit(t, 63, FastSync) } func TestEmptyShortCircuit64Full(t *testing.T) { testEmptyShortCircuit(t, 64, FullSync) } func TestEmptyShortCircuit64Fast(t *testing.T) { testEmptyShortCircuit(t, 64, FastSync) } func TestEmptyShortCircuit65Full(t *testing.T) { testEmptyShortCircuit(t, 65, FullSync) } func TestEmptyShortCircuit65Fast(t *testing.T) { testEmptyShortCircuit(t, 65, FastSync) } func TestEmptyShortCircuit65Light(t *testing.T) { testEmptyShortCircuit(t, 65, LightSync) } -func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { +func testEmptyShortCircuit(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -943,15 +922,13 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { // Tests that headers are enqueued continuously, preventing malicious nodes from // stalling the downloader by feeding gapped header chains. -func TestMissingHeaderAttack63Full(t *testing.T) { testMissingHeaderAttack(t, 63, FullSync) } -func TestMissingHeaderAttack63Fast(t *testing.T) { testMissingHeaderAttack(t, 63, FastSync) } func TestMissingHeaderAttack64Full(t *testing.T) { testMissingHeaderAttack(t, 64, FullSync) } func TestMissingHeaderAttack64Fast(t *testing.T) { testMissingHeaderAttack(t, 64, FastSync) } func TestMissingHeaderAttack65Full(t *testing.T) { testMissingHeaderAttack(t, 65, FullSync) } func TestMissingHeaderAttack65Fast(t *testing.T) { testMissingHeaderAttack(t, 65, FastSync) } func TestMissingHeaderAttack65Light(t *testing.T) { testMissingHeaderAttack(t, 65, LightSync) } -func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { +func testMissingHeaderAttack(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -975,15 +952,13 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { // Tests that if requested headers are shifted (i.e. first is missing), the queue // detects the invalid numbering. -func TestShiftedHeaderAttack63Full(t *testing.T) { testShiftedHeaderAttack(t, 63, FullSync) } -func TestShiftedHeaderAttack63Fast(t *testing.T) { testShiftedHeaderAttack(t, 63, FastSync) } func TestShiftedHeaderAttack64Full(t *testing.T) { testShiftedHeaderAttack(t, 64, FullSync) } func TestShiftedHeaderAttack64Fast(t *testing.T) { testShiftedHeaderAttack(t, 64, FastSync) } func TestShiftedHeaderAttack65Full(t *testing.T) { testShiftedHeaderAttack(t, 65, FullSync) } func TestShiftedHeaderAttack65Fast(t *testing.T) { testShiftedHeaderAttack(t, 65, FastSync) } func TestShiftedHeaderAttack65Light(t *testing.T) { testShiftedHeaderAttack(t, 65, LightSync) } -func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { +func testShiftedHeaderAttack(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -1012,11 +987,10 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { // Tests that upon detecting an invalid header, the recent ones are rolled back // for various failure scenarios. Afterwards a full sync is attempted to make // sure no state was corrupted. -func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) } func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) } func TestInvalidHeaderRollback65Fast(t *testing.T) { testInvalidHeaderRollback(t, 65, FastSync) } -func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { +func testInvalidHeaderRollback(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -1104,15 +1078,13 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { // Tests that a peer advertising a high TD doesn't get to stall the downloader // afterwards by not sending any useful hashes. -func TestHighTDStarvationAttack63Full(t *testing.T) { testHighTDStarvationAttack(t, 63, FullSync) } -func TestHighTDStarvationAttack63Fast(t *testing.T) { testHighTDStarvationAttack(t, 63, FastSync) } func TestHighTDStarvationAttack64Full(t *testing.T) { testHighTDStarvationAttack(t, 64, FullSync) } func TestHighTDStarvationAttack64Fast(t *testing.T) { testHighTDStarvationAttack(t, 64, FastSync) } func TestHighTDStarvationAttack65Full(t *testing.T) { testHighTDStarvationAttack(t, 65, FullSync) } func TestHighTDStarvationAttack65Fast(t *testing.T) { testHighTDStarvationAttack(t, 65, FastSync) } func TestHighTDStarvationAttack65Light(t *testing.T) { testHighTDStarvationAttack(t, 65, LightSync) } -func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { +func testHighTDStarvationAttack(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -1126,11 +1098,10 @@ func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { } // Tests that misbehaving peers are disconnected, whilst behaving ones are not. -func TestBlockHeaderAttackerDropping63(t *testing.T) { testBlockHeaderAttackerDropping(t, 63) } func TestBlockHeaderAttackerDropping64(t *testing.T) { testBlockHeaderAttackerDropping(t, 64) } func TestBlockHeaderAttackerDropping65(t *testing.T) { testBlockHeaderAttackerDropping(t, 65) } -func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { +func testBlockHeaderAttackerDropping(t *testing.T, protocol uint) { t.Parallel() // Define the disconnection requirement for individual hash fetch errors @@ -1180,15 +1151,13 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { // Tests that synchronisation progress (origin block number, current block number // and highest block number) is tracked and updated correctly. -func TestSyncProgress63Full(t *testing.T) { testSyncProgress(t, 63, FullSync) } -func TestSyncProgress63Fast(t *testing.T) { testSyncProgress(t, 63, FastSync) } func TestSyncProgress64Full(t *testing.T) { testSyncProgress(t, 64, FullSync) } func TestSyncProgress64Fast(t *testing.T) { testSyncProgress(t, 64, FastSync) } func TestSyncProgress65Full(t *testing.T) { testSyncProgress(t, 65, FullSync) } func TestSyncProgress65Fast(t *testing.T) { testSyncProgress(t, 65, FastSync) } func TestSyncProgress65Light(t *testing.T) { testSyncProgress(t, 65, LightSync) } -func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { +func testSyncProgress(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -1264,21 +1233,19 @@ func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.Sync // Tests that synchronisation progress (origin block number and highest block // number) is tracked and updated correctly in case of a fork (or manual head // revertal). -func TestForkedSyncProgress63Full(t *testing.T) { testForkedSyncProgress(t, 63, FullSync) } -func TestForkedSyncProgress63Fast(t *testing.T) { testForkedSyncProgress(t, 63, FastSync) } func TestForkedSyncProgress64Full(t *testing.T) { testForkedSyncProgress(t, 64, FullSync) } func TestForkedSyncProgress64Fast(t *testing.T) { testForkedSyncProgress(t, 64, FastSync) } func TestForkedSyncProgress65Full(t *testing.T) { testForkedSyncProgress(t, 65, FullSync) } func TestForkedSyncProgress65Fast(t *testing.T) { testForkedSyncProgress(t, 65, FastSync) } func TestForkedSyncProgress65Light(t *testing.T) { testForkedSyncProgress(t, 65, LightSync) } -func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) { +func testForkedSyncProgress(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() defer tester.terminate() - chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHashFetch) - chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHashFetch) + chainA := testChainForkLightA.shorten(testChainBase.len() + MaxHeaderFetch) + chainB := testChainForkLightB.shorten(testChainBase.len() + MaxHeaderFetch) // Set a sync init hook to catch progress changes starting := make(chan struct{}) @@ -1340,15 +1307,13 @@ func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Tests that if synchronisation is aborted due to some failure, then the progress // origin is not updated in the next sync cycle, as it should be considered the // continuation of the previous sync and not a new instance. -func TestFailedSyncProgress63Full(t *testing.T) { testFailedSyncProgress(t, 63, FullSync) } -func TestFailedSyncProgress63Fast(t *testing.T) { testFailedSyncProgress(t, 63, FastSync) } func TestFailedSyncProgress64Full(t *testing.T) { testFailedSyncProgress(t, 64, FullSync) } func TestFailedSyncProgress64Fast(t *testing.T) { testFailedSyncProgress(t, 64, FastSync) } func TestFailedSyncProgress65Full(t *testing.T) { testFailedSyncProgress(t, 65, FullSync) } func TestFailedSyncProgress65Fast(t *testing.T) { testFailedSyncProgress(t, 65, FastSync) } func TestFailedSyncProgress65Light(t *testing.T) { testFailedSyncProgress(t, 65, LightSync) } -func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { +func testFailedSyncProgress(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -1413,15 +1378,13 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Tests that if an attacker fakes a chain height, after the attack is detected, // the progress height is successfully reduced at the next sync invocation. -func TestFakedSyncProgress63Full(t *testing.T) { testFakedSyncProgress(t, 63, FullSync) } -func TestFakedSyncProgress63Fast(t *testing.T) { testFakedSyncProgress(t, 63, FastSync) } func TestFakedSyncProgress64Full(t *testing.T) { testFakedSyncProgress(t, 64, FullSync) } func TestFakedSyncProgress64Fast(t *testing.T) { testFakedSyncProgress(t, 64, FastSync) } func TestFakedSyncProgress65Full(t *testing.T) { testFakedSyncProgress(t, 65, FullSync) } func TestFakedSyncProgress65Fast(t *testing.T) { testFakedSyncProgress(t, 65, FastSync) } func TestFakedSyncProgress65Light(t *testing.T) { testFakedSyncProgress(t, 65, LightSync) } -func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { +func testFakedSyncProgress(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() tester := newTester() @@ -1490,31 +1453,15 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // This test reproduces an issue where unexpected deliveries would // block indefinitely if they arrived at the right time. -func TestDeliverHeadersHang(t *testing.T) { - t.Parallel() +func TestDeliverHeadersHang64Full(t *testing.T) { testDeliverHeadersHang(t, 64, FullSync) } +func TestDeliverHeadersHang64Fast(t *testing.T) { testDeliverHeadersHang(t, 64, FastSync) } +func TestDeliverHeadersHang65Full(t *testing.T) { testDeliverHeadersHang(t, 65, FullSync) } +func TestDeliverHeadersHang65Fast(t *testing.T) { testDeliverHeadersHang(t, 65, FastSync) } +func TestDeliverHeadersHang65Light(t *testing.T) { testDeliverHeadersHang(t, 65, LightSync) } - testCases := []struct { - protocol int - syncMode SyncMode - }{ - {63, FullSync}, - {63, FastSync}, - {64, FullSync}, - {64, FastSync}, - {64, LightSync}, - {65, FullSync}, - {65, FastSync}, - {65, LightSync}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("protocol %d mode %v", tc.protocol, tc.syncMode), func(t *testing.T) { - t.Parallel() - testDeliverHeadersHang(t, tc.protocol, tc.syncMode) - }) - } -} +func testDeliverHeadersHang(t *testing.T, protocol uint, mode SyncMode) { + t.Parallel() -func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { master := newTester() defer master.terminate() chain := testChainBase.shorten(15) @@ -1601,7 +1548,7 @@ func TestRemoteHeaderRequestSpan(t *testing.T) { {15000, 13006, []int{14823, 14839, 14855, 14871, 14887, 14903, 14919, 14935, 14951, 14967, 14983, 14999}, }, - //Remote is pretty close to us. We don't have to fetch as many + // Remote is pretty close to us. We don't have to fetch as many {1200, 1150, []int{1149, 1154, 1159, 1164, 1169, 1174, 1179, 1184, 1189, 1194, 1199}, }, @@ -1665,15 +1612,13 @@ func TestRemoteHeaderRequestSpan(t *testing.T) { // Tests that peers below a pre-configured checkpoint block are prevented from // being fast-synced from, avoiding potential cheap eclipse attacks. -func TestCheckpointEnforcement63Full(t *testing.T) { testCheckpointEnforcement(t, 63, FullSync) } -func TestCheckpointEnforcement63Fast(t *testing.T) { testCheckpointEnforcement(t, 63, FastSync) } func TestCheckpointEnforcement64Full(t *testing.T) { testCheckpointEnforcement(t, 64, FullSync) } func TestCheckpointEnforcement64Fast(t *testing.T) { testCheckpointEnforcement(t, 64, FastSync) } func TestCheckpointEnforcement65Full(t *testing.T) { testCheckpointEnforcement(t, 65, FullSync) } func TestCheckpointEnforcement65Fast(t *testing.T) { testCheckpointEnforcement(t, 65, FastSync) } func TestCheckpointEnforcement65Light(t *testing.T) { testCheckpointEnforcement(t, 65, LightSync) } -func testCheckpointEnforcement(t *testing.T, protocol int, mode SyncMode) { +func testCheckpointEnforcement(t *testing.T, protocol uint, mode SyncMode) { t.Parallel() // Create a new tester with a particular hard coded checkpoint block diff --git a/eth/downloader/modes.go b/eth/downloader/modes.go index d866ceabce5c..8ea7876a1f76 100644 --- a/eth/downloader/modes.go +++ b/eth/downloader/modes.go @@ -24,7 +24,8 @@ type SyncMode uint32 const ( FullSync SyncMode = iota // Synchronise the entire blockchain history from full blocks - FastSync // Quickly download the headers, full sync only at the chain head + FastSync // Quickly download the headers, full sync only at the chain + SnapSync // Download the chain and the state via compact snashots LightSync // Download only the headers and terminate afterwards ) @@ -39,6 +40,8 @@ func (mode SyncMode) String() string { return "full" case FastSync: return "fast" + case SnapSync: + return "snap" case LightSync: return "light" default: @@ -52,6 +55,8 @@ func (mode SyncMode) MarshalText() ([]byte, error) { return []byte("full"), nil case FastSync: return []byte("fast"), nil + case SnapSync: + return []byte("snap"), nil case LightSync: return []byte("light"), nil default: @@ -65,6 +70,8 @@ func (mode *SyncMode) UnmarshalText(text []byte) error { *mode = FullSync case "fast": *mode = FastSync + case "snap": + *mode = SnapSync case "light": *mode = LightSync default: diff --git a/eth/downloader/peer.go b/eth/downloader/peer.go index c6671436f9e7..ba90bf31cbeb 100644 --- a/eth/downloader/peer.go +++ b/eth/downloader/peer.go @@ -69,7 +69,7 @@ type peerConnection struct { peer Peer - version int // Eth protocol version number to switch strategies + version uint // Eth protocol version number to switch strategies log log.Logger // Contextual logger to add extra infos to peer logs lock sync.RWMutex } @@ -112,7 +112,7 @@ func (w *lightPeerWrapper) RequestNodeData([]common.Hash) error { } // newPeerConnection creates a new downloader peer. -func newPeerConnection(id string, version int, peer Peer, logger log.Logger) *peerConnection { +func newPeerConnection(id string, version uint, peer Peer, logger log.Logger) *peerConnection { return &peerConnection{ id: id, lacking: make(map[common.Hash]struct{}), @@ -457,7 +457,7 @@ func (ps *peerSet) HeaderIdlePeers() ([]*peerConnection, int) { defer p.lock.RUnlock() return p.headerThroughput } - return ps.idlePeers(63, 65, idle, throughput) + return ps.idlePeers(64, 65, idle, throughput) } // BodyIdlePeers retrieves a flat list of all the currently body-idle peers within @@ -471,7 +471,7 @@ func (ps *peerSet) BodyIdlePeers() ([]*peerConnection, int) { defer p.lock.RUnlock() return p.blockThroughput } - return ps.idlePeers(63, 65, idle, throughput) + return ps.idlePeers(64, 65, idle, throughput) } // ReceiptIdlePeers retrieves a flat list of all the currently receipt-idle peers @@ -485,7 +485,7 @@ func (ps *peerSet) ReceiptIdlePeers() ([]*peerConnection, int) { defer p.lock.RUnlock() return p.receiptThroughput } - return ps.idlePeers(63, 65, idle, throughput) + return ps.idlePeers(64, 65, idle, throughput) } // NodeDataIdlePeers retrieves a flat list of all the currently node-data-idle @@ -499,13 +499,13 @@ func (ps *peerSet) NodeDataIdlePeers() ([]*peerConnection, int) { defer p.lock.RUnlock() return p.stateThroughput } - return ps.idlePeers(63, 65, idle, throughput) + return ps.idlePeers(64, 65, idle, throughput) } // idlePeers retrieves a flat list of all currently idle peers satisfying the // protocol version constraints, using the provided function to check idleness. // The resulting set of peers are sorted by their measure throughput. -func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) { +func (ps *peerSet) idlePeers(minProtocol, maxProtocol uint, idleCheck func(*peerConnection) bool, throughput func(*peerConnection) float64) ([]*peerConnection, int) { ps.lock.RLock() defer ps.lock.RUnlock() diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index 745f7c7480f7..2150842f8e82 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -113,24 +113,24 @@ type queue struct { mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching // Headers are "special", they download in batches, supported by a skeleton chain - headerHead common.Hash // [eth/62] Hash of the last queued header to verify order - headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers - headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for - headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable - headerPendPool map[string]*fetchRequest // [eth/62] Currently pending header retrieval operations - headerResults []*types.Header // [eth/62] Result cache accumulating the completed headers - headerProced int // [eth/62] Number of headers already processed from the results - headerOffset uint64 // [eth/62] Number of the first header in the result cache - headerContCh chan bool // [eth/62] Channel to notify when header download finishes + headerHead common.Hash // Hash of the last queued header to verify order + headerTaskPool map[uint64]*types.Header // Pending header retrieval tasks, mapping starting indexes to skeleton headers + headerTaskQueue *prque.Prque // Priority queue of the skeleton indexes to fetch the filling headers for + headerPeerMiss map[string]map[uint64]struct{} // Set of per-peer header batches known to be unavailable + headerPendPool map[string]*fetchRequest // Currently pending header retrieval operations + headerResults []*types.Header // Result cache accumulating the completed headers + headerProced int // Number of headers already processed from the results + headerOffset uint64 // Number of the first header in the result cache + headerContCh chan bool // Channel to notify when header download finishes // All data retrievals below are based on an already assembles header chain - blockTaskPool map[common.Hash]*types.Header // [eth/62] Pending block (body) retrieval tasks, mapping hashes to headers - blockTaskQueue *prque.Prque // [eth/62] Priority queue of the headers to fetch the blocks (bodies) for - blockPendPool map[string]*fetchRequest // [eth/62] Currently pending block (body) retrieval operations + blockTaskPool map[common.Hash]*types.Header // Pending block (body) retrieval tasks, mapping hashes to headers + blockTaskQueue *prque.Prque // Priority queue of the headers to fetch the blocks (bodies) for + blockPendPool map[string]*fetchRequest // Currently pending block (body) retrieval operations - receiptTaskPool map[common.Hash]*types.Header // [eth/63] Pending receipt retrieval tasks, mapping hashes to headers - receiptTaskQueue *prque.Prque // [eth/63] Priority queue of the headers to fetch the receipts for - receiptPendPool map[string]*fetchRequest // [eth/63] Currently pending receipt retrieval operations + receiptTaskPool map[common.Hash]*types.Header // Pending receipt retrieval tasks, mapping hashes to headers + receiptTaskQueue *prque.Prque // Priority queue of the headers to fetch the receipts for + receiptPendPool map[string]*fetchRequest // Currently pending receipt retrieval operations resultCache *resultStore // Downloaded but not yet delivered fetch results resultSize common.StorageSize // Approximate size of a block (exponential moving average) @@ -690,6 +690,13 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh q.lock.Lock() defer q.lock.Unlock() + var logger log.Logger + if len(id) < 16 { + // Tests use short IDs, don't choke on them + logger = log.New("peer", id) + } else { + logger = log.New("peer", id[:16]) + } // Short circuit if the data was never requested request := q.headerPendPool[id] if request == nil { @@ -704,31 +711,34 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh accepted := len(headers) == MaxHeaderFetch if accepted { if headers[0].Number.Uint64() != request.From { - log.Trace("First header broke chain ordering", "peer", id, "number", headers[0].Number, "hash", headers[0].Hash(), request.From) + logger.Trace("First header broke chain ordering", "number", headers[0].Number, "hash", headers[0].Hash(), "expected", request.From) accepted = false } else if headers[len(headers)-1].Hash() != target { - log.Trace("Last header broke skeleton structure ", "peer", id, "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target) + logger.Trace("Last header broke skeleton structure ", "number", headers[len(headers)-1].Number, "hash", headers[len(headers)-1].Hash(), "expected", target) accepted = false } } if accepted { + parentHash := headers[0].Hash() for i, header := range headers[1:] { hash := header.Hash() if want := request.From + 1 + uint64(i); header.Number.Uint64() != want { - log.Warn("Header broke chain ordering", "peer", id, "number", header.Number, "hash", hash, "expected", want) + logger.Warn("Header broke chain ordering", "number", header.Number, "hash", hash, "expected", want) accepted = false break } - if headers[i].Hash() != header.ParentHash { - log.Warn("Header broke chain ancestry", "peer", id, "number", header.Number, "hash", hash) + if parentHash != header.ParentHash { + logger.Warn("Header broke chain ancestry", "number", header.Number, "hash", hash) accepted = false break } + // Set-up parent hash for next round + parentHash = hash } } // If the batch of headers wasn't accepted, mark as unavailable if !accepted { - log.Trace("Skeleton filling not accepted", "peer", id, "from", request.From) + logger.Trace("Skeleton filling not accepted", "from", request.From) miss := q.headerPeerMiss[id] if miss == nil { @@ -755,7 +765,7 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh select { case headerProcCh <- process: - log.Trace("Pre-scheduled new headers", "peer", id, "count", len(process), "from", process[0].Number) + logger.Trace("Pre-scheduled new headers", "count", len(process), "from", process[0].Number) q.headerProced += len(process) default: } @@ -774,7 +784,7 @@ func (q *queue) DeliverBodies(id string, txLists [][]*types.Transaction, uncleLi q.lock.Lock() defer q.lock.Unlock() validate := func(index int, header *types.Header) error { - if types.DeriveSha(types.Transactions(txLists[index]), new(trie.Trie)) != header.TxHash { + if types.DeriveSha(types.Transactions(txLists[index]), trie.NewStackTrie(nil)) != header.TxHash { return errInvalidBody } if types.CalcUncleHash(uncleLists[index]) != header.UncleHash { @@ -799,7 +809,7 @@ func (q *queue) DeliverReceipts(id string, receiptList [][]*types.Receipt) (int, q.lock.Lock() defer q.lock.Unlock() validate := func(index int, header *types.Header) error { - if types.DeriveSha(types.Receipts(receiptList[index]), new(trie.Trie)) != header.ReceiptHash { + if types.DeriveSha(types.Receipts(receiptList[index]), trie.NewStackTrie(nil)) != header.ReceiptHash { return errInvalidReceipt } return nil diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go index 6745aa54ac40..69bd13c2f79b 100644 --- a/eth/downloader/statesync.go +++ b/eth/downloader/statesync.go @@ -101,8 +101,16 @@ func (d *Downloader) runStateSync(s *stateSync) *stateSync { finished []*stateReq // Completed or failed requests timeout = make(chan *stateReq) // Timed out active requests ) - // Run the state sync. log.Trace("State sync starting", "root", s.root) + + defer func() { + // Cancel active request timers on exit. Also set peers to idle so they're + // available for the next sync. + for _, req := range active { + req.timer.Stop() + req.peer.SetNodeDataIdle(int(req.nItems), time.Now()) + } + }() go s.run() defer s.Cancel() @@ -252,8 +260,9 @@ func (d *Downloader) spindownStateSync(active map[string]*stateReq, finished []* type stateSync struct { d *Downloader // Downloader instance to access and manage current peerset - sched *trie.Sync // State trie sync scheduler defining the tasks - keccak hash.Hash // Keccak256 hasher to verify deliveries with + root common.Hash // State root currently being synced + sched *trie.Sync // State trie sync scheduler defining the tasks + keccak hash.Hash // Keccak256 hasher to verify deliveries with trieTasks map[common.Hash]*trieTask // Set of trie node tasks currently queued for retrieval codeTasks map[common.Hash]*codeTask // Set of byte code tasks currently queued for retrieval @@ -268,8 +277,6 @@ type stateSync struct { cancelOnce sync.Once // Ensures cancel only ever gets called once done chan struct{} // Channel to signal termination completion err error // Any error hit during sync (set before completion) - - root common.Hash } // trieTask represents a single trie node download task, containing a set of @@ -290,6 +297,7 @@ type codeTask struct { func newStateSync(d *Downloader, root common.Hash) *stateSync { return &stateSync{ d: d, + root: root, sched: state.NewStateSync(root, d.stateDB, d.stateBloom), keccak: sha3.NewLegacyKeccak256(), trieTasks: make(map[common.Hash]*trieTask), @@ -298,7 +306,6 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync { cancel: make(chan struct{}), done: make(chan struct{}), started: make(chan struct{}), - root: root, } } @@ -306,7 +313,12 @@ func newStateSync(d *Downloader, root common.Hash) *stateSync { // it finishes, and finally notifying any goroutines waiting for the loop to // finish. func (s *stateSync) run() { - s.err = s.loop() + close(s.started) + if s.d.snapSync { + s.err = s.d.SnapSyncer.Sync(s.root, s.cancel) + } else { + s.err = s.loop() + } close(s.done) } @@ -318,7 +330,9 @@ func (s *stateSync) Wait() error { // Cancel cancels the sync and waits until it has shut down. func (s *stateSync) Cancel() error { - s.cancelOnce.Do(func() { close(s.cancel) }) + s.cancelOnce.Do(func() { + close(s.cancel) + }) return s.Wait() } @@ -329,7 +343,6 @@ func (s *stateSync) Cancel() error { // pushed here async. The reason is to decouple processing from data receipt // and timeouts. func (s *stateSync) loop() (err error) { - close(s.started) // Listen for new peer events to assign tasks to them newPeer := make(chan *peerConnection, 1024) peerSub := s.d.peers.SubscribeNewPeers(newPeer) diff --git a/eth/filters/api.go b/eth/filters/api.go index 30d7b71c310c..664f229a0400 100644 --- a/eth/filters/api.go +++ b/eth/filters/api.go @@ -25,7 +25,7 @@ import ( "sync" "time" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" @@ -292,7 +292,7 @@ func (api *PublicFilterAPI) NewFilter(crit FilterCriteria) (rpc.ID, error) { logs := make(chan []*types.Log) logsSub, err := api.events.SubscribeLogs(ethereum.FilterQuery(crit), logs) if err != nil { - return rpc.ID(""), err + return "", err } api.filtersMu.Lock() diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index a105ec51c350..12f037d0f9ae 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -24,7 +24,7 @@ import ( "sync" "time" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index c8d1d43abbd8..b534c1d47fdd 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -25,7 +25,7 @@ import ( "testing" "time" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" diff --git a/eth/gen_config.go b/eth/gen_config.go index 0093439d141b..dd04635eeed2 100644 --- a/eth/gen_config.go +++ b/eth/gen_config.go @@ -20,7 +20,7 @@ func (c Config) MarshalTOML() (interface{}, error) { Genesis *core.Genesis `toml:",omitempty"` NetworkId uint64 SyncMode downloader.SyncMode - DiscoveryURLs []string + EthDiscoveryURLs []string NoPruning bool NoPrefetch bool TxLookupLimit uint64 `toml:",omitempty"` @@ -43,6 +43,7 @@ func (c Config) MarshalTOML() (interface{}, error) { TrieDirtyCache int TrieTimeout time.Duration SnapshotCache int + Preimages bool Miner miner.Config Ethash ethash.Config TxPool core.TxPoolConfig @@ -60,7 +61,7 @@ func (c Config) MarshalTOML() (interface{}, error) { enc.Genesis = c.Genesis enc.NetworkId = c.NetworkId enc.SyncMode = c.SyncMode - enc.DiscoveryURLs = c.DiscoveryURLs + enc.EthDiscoveryURLs = c.EthDiscoveryURLs enc.NoPruning = c.NoPruning enc.NoPrefetch = c.NoPrefetch enc.TxLookupLimit = c.TxLookupLimit @@ -83,6 +84,7 @@ func (c Config) MarshalTOML() (interface{}, error) { enc.TrieDirtyCache = c.TrieDirtyCache enc.TrieTimeout = c.TrieTimeout enc.SnapshotCache = c.SnapshotCache + enc.Preimages = c.Preimages enc.Miner = c.Miner enc.Ethash = c.Ethash enc.TxPool = c.TxPool @@ -104,7 +106,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { Genesis *core.Genesis `toml:",omitempty"` NetworkId *uint64 SyncMode *downloader.SyncMode - DiscoveryURLs []string + EthDiscoveryURLs []string NoPruning *bool NoPrefetch *bool TxLookupLimit *uint64 `toml:",omitempty"` @@ -127,6 +129,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { TrieDirtyCache *int TrieTimeout *time.Duration SnapshotCache *int + Preimages *bool Miner *miner.Config Ethash *ethash.Config TxPool *core.TxPoolConfig @@ -153,8 +156,8 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.SyncMode != nil { c.SyncMode = *dec.SyncMode } - if dec.DiscoveryURLs != nil { - c.DiscoveryURLs = dec.DiscoveryURLs + if dec.EthDiscoveryURLs != nil { + c.EthDiscoveryURLs = dec.EthDiscoveryURLs } if dec.NoPruning != nil { c.NoPruning = *dec.NoPruning @@ -222,6 +225,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.SnapshotCache != nil { c.SnapshotCache = *dec.SnapshotCache } + if dec.Preimages != nil { + c.Preimages = *dec.Preimages + } if dec.Miner != nil { c.Miner = *dec.Miner } diff --git a/eth/handler.go b/eth/handler.go index 7482a2f96ef2..76a429f6d335 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -17,9 +17,7 @@ package eth import ( - "encoding/json" "errors" - "fmt" "math" "math/big" "sync" @@ -27,26 +25,22 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/fetcher" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/eth/protocols/snap" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" - "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) const ( - softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. - estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header - // txChanSize is the size of channel listening to NewTxsEvent. // The number is referenced from the size of tx pool. txChanSize = 4096 @@ -56,26 +50,61 @@ var ( syncChallengeTimeout = 15 * time.Second // Time allowance for a node to reply to the sync progress challenge ) -func errResp(code errCode, format string, v ...interface{}) error { - return fmt.Errorf("%v - %v", code, fmt.Sprintf(format, v...)) +// txPool defines the methods needed from a transaction pool implementation to +// support all the operations needed by the Ethereum chain protocols. +type txPool interface { + // Has returns an indicator whether txpool has a transaction + // cached with the given hash. + Has(hash common.Hash) bool + + // Get retrieves the transaction from local txpool with given + // tx hash. + Get(hash common.Hash) *types.Transaction + + // AddRemotes should add the given transactions to the pool. + AddRemotes([]*types.Transaction) []error + + // Pending should return pending transactions. + // The slice should be modifiable by the caller. + Pending() (map[common.Address]types.Transactions, error) + + // SubscribeNewTxsEvent should return an event subscription of + // NewTxsEvent and send events to the given channel. + SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription } -type ProtocolManager struct { +// handlerConfig is the collection of initialization parameters to create a full +// node network handler. +type handlerConfig struct { + Database ethdb.Database // Database for direct sync insertions + Chain *core.BlockChain // Blockchain to serve data from + TxPool txPool // Transaction pool to propagate from + Network uint64 // Network identifier to adfvertise + Sync downloader.SyncMode // Whether to fast or full sync + BloomCache uint64 // Megabytes to alloc for fast sync bloom + EventMux *event.TypeMux // Legacy event mux, deprecate for `feed` + Checkpoint *params.TrustedCheckpoint // Hard coded checkpoint for sync challenges + Whitelist map[uint64]common.Hash // Hard coded whitelist for sync challenged +} + +type handler struct { networkID uint64 forkFilter forkid.Filter // Fork ID filter, constant across the lifetime of the node fastSync uint32 // Flag whether fast sync is enabled (gets disabled if we already have blocks) + snapSync uint32 // Flag whether fast sync should operate on top of the snap protocol acceptTxs uint32 // Flag whether we're considered synchronised (enables transaction processing) checkpointNumber uint64 // Block number for the sync progress validator to cross reference checkpointHash common.Hash // Block hash for the sync progress validator to cross reference - txpool txPool - blockchain *core.BlockChain - chaindb ethdb.Database - maxPeers int + database ethdb.Database + txpool txPool + chain *core.BlockChain + maxPeers int downloader *downloader.Downloader + stateBloom *trie.SyncBloom blockFetcher *fetcher.BlockFetcher txFetcher *fetcher.TxFetcher peers *peerSet @@ -94,29 +123,27 @@ type ProtocolManager struct { chainSync *chainSyncer wg sync.WaitGroup peerWG sync.WaitGroup - - // Test fields or hooks - broadcastTxAnnouncesOnly bool // Testing field, disable transaction propagation } -// NewProtocolManager returns a new Ethereum sub protocol manager. The Ethereum sub protocol manages peers capable -// with the Ethereum network. -func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCheckpoint, mode downloader.SyncMode, networkID uint64, mux *event.TypeMux, txpool txPool, engine consensus.Engine, blockchain *core.BlockChain, chaindb ethdb.Database, cacheLimit int, whitelist map[uint64]common.Hash) (*ProtocolManager, error) { +// newHandler returns a handler for all Ethereum chain management protocol. +func newHandler(config *handlerConfig) (*handler, error) { // Create the protocol manager with the base fields - manager := &ProtocolManager{ - networkID: networkID, - forkFilter: forkid.NewFilter(blockchain), - eventMux: mux, - txpool: txpool, - blockchain: blockchain, - chaindb: chaindb, + if config.EventMux == nil { + config.EventMux = new(event.TypeMux) // Nicety initialization for tests + } + h := &handler{ + networkID: config.Network, + forkFilter: forkid.NewFilter(config.Chain), + eventMux: config.EventMux, + database: config.Database, + txpool: config.TxPool, + chain: config.Chain, peers: newPeerSet(), - whitelist: whitelist, + whitelist: config.Whitelist, txsyncCh: make(chan *txsync), quitSync: make(chan struct{}), } - - if mode == downloader.FullSync { + if config.Sync == downloader.FullSync { // The database seems empty as the current block is the genesis. Yet the fast // block is ahead, so fast sync was enabled for this node at a certain point. // The scenarios where this can happen is @@ -125,42 +152,42 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh // * the last fast sync is not finished while user specifies a full sync this // time. But we don't have any recent state for full sync. // In these cases however it's safe to reenable fast sync. - fullBlock, fastBlock := blockchain.CurrentBlock(), blockchain.CurrentFastBlock() + fullBlock, fastBlock := h.chain.CurrentBlock(), h.chain.CurrentFastBlock() if fullBlock.NumberU64() == 0 && fastBlock.NumberU64() > 0 { - manager.fastSync = uint32(1) + h.fastSync = uint32(1) log.Warn("Switch sync mode from full sync to fast sync") } } else { - if blockchain.CurrentBlock().NumberU64() > 0 { + if h.chain.CurrentBlock().NumberU64() > 0 { // Print warning log if database is not empty to run fast sync. log.Warn("Switch sync mode from fast sync to full sync") } else { // If fast sync was requested and our database is empty, grant it - manager.fastSync = uint32(1) + h.fastSync = uint32(1) + if config.Sync == downloader.SnapSync { + h.snapSync = uint32(1) + } } } - // If we have trusted checkpoints, enforce them on the chain - if checkpoint != nil { - manager.checkpointNumber = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1 - manager.checkpointHash = checkpoint.SectionHead + if config.Checkpoint != nil { + h.checkpointNumber = (config.Checkpoint.SectionIndex+1)*params.CHTFrequency - 1 + h.checkpointHash = config.Checkpoint.SectionHead } - // Construct the downloader (long sync) and its backing state bloom if fast // sync is requested. The downloader is responsible for deallocating the state // bloom when it's done. - var stateBloom *trie.SyncBloom - if atomic.LoadUint32(&manager.fastSync) == 1 { - stateBloom = trie.NewSyncBloom(uint64(cacheLimit), chaindb) + if atomic.LoadUint32(&h.fastSync) == 1 { + h.stateBloom = trie.NewSyncBloom(config.BloomCache, config.Database) } - manager.downloader = downloader.New(manager.checkpointNumber, chaindb, stateBloom, manager.eventMux, blockchain, nil, manager.removePeer) + h.downloader = downloader.New(h.checkpointNumber, config.Database, h.stateBloom, h.eventMux, h.chain, nil, h.removePeer) // Construct the fetcher (short sync) validator := func(header *types.Header) error { - return engine.VerifyHeader(blockchain, header, true) + return h.chain.Engine().VerifyHeader(h.chain, header, true) } heighter := func() uint64 { - return blockchain.CurrentBlock().NumberU64() + return h.chain.CurrentBlock().NumberU64() } inserter := func(blocks types.Blocks) (int, error) { // If sync hasn't reached the checkpoint yet, deny importing weird blocks. @@ -169,7 +196,7 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh // the propagated block if the head is too old. Unfortunately there is a corner // case when starting new networks, where the genesis might be ancient (0 unix) // which would prevent full nodes from accepting it. - if manager.blockchain.CurrentBlock().NumberU64() < manager.checkpointNumber { + if h.chain.CurrentBlock().NumberU64() < h.checkpointNumber { log.Warn("Unsynced yet, discarded propagated block", "number", blocks[0].Number(), "hash", blocks[0].Hash()) return 0, nil } @@ -178,179 +205,88 @@ func NewProtocolManager(config *params.ChainConfig, checkpoint *params.TrustedCh // accept each others' blocks until a restart. Unfortunately we haven't figured // out a way yet where nodes can decide unilaterally whether the network is new // or not. This should be fixed if we figure out a solution. - if atomic.LoadUint32(&manager.fastSync) == 1 { + if atomic.LoadUint32(&h.fastSync) == 1 { log.Warn("Fast syncing, discarded propagated block", "number", blocks[0].Number(), "hash", blocks[0].Hash()) return 0, nil } - n, err := manager.blockchain.InsertChain(blocks) + n, err := h.chain.InsertChain(blocks) if err == nil { - atomic.StoreUint32(&manager.acceptTxs, 1) // Mark initial sync done on any fetcher import + atomic.StoreUint32(&h.acceptTxs, 1) // Mark initial sync done on any fetcher import } return n, err } - manager.blockFetcher = fetcher.NewBlockFetcher(false, nil, blockchain.GetBlockByHash, validator, manager.BroadcastBlock, heighter, nil, inserter, manager.removePeer) + h.blockFetcher = fetcher.NewBlockFetcher(false, nil, h.chain.GetBlockByHash, validator, h.BroadcastBlock, heighter, nil, inserter, h.removePeer) fetchTx := func(peer string, hashes []common.Hash) error { - p := manager.peers.Peer(peer) + p := h.peers.ethPeer(peer) if p == nil { return errors.New("unknown peer") } return p.RequestTxs(hashes) } - manager.txFetcher = fetcher.NewTxFetcher(txpool.Has, txpool.AddRemotes, fetchTx) - - manager.chainSync = newChainSyncer(manager) - - return manager, nil -} - -func (pm *ProtocolManager) makeProtocol(version uint) p2p.Protocol { - length, ok := protocolLengths[version] - if !ok { - panic("makeProtocol for unknown version") - } - - return p2p.Protocol{ - Name: protocolName, - Version: version, - Length: length, - Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { - return pm.runPeer(pm.newPeer(int(version), p, rw, pm.txpool.Get)) - }, - NodeInfo: func() interface{} { - return pm.NodeInfo() - }, - PeerInfo: func(id enode.ID) interface{} { - if p := pm.peers.Peer(fmt.Sprintf("%x", id[:8])); p != nil { - return p.Info() - } - return nil - }, - } + h.txFetcher = fetcher.NewTxFetcher(h.txpool.Has, h.txpool.AddRemotes, fetchTx) + h.chainSync = newChainSyncer(h) + return h, nil } -func (pm *ProtocolManager) removePeer(id string) { - // Short circuit if the peer was already removed - peer := pm.peers.Peer(id) - if peer == nil { - return - } - log.Debug("Removing Ethereum peer", "peer", id) - - // Unregister the peer from the downloader and Ethereum peer set - pm.downloader.UnregisterPeer(id) - pm.txFetcher.Drop(id) - - if err := pm.peers.Unregister(id); err != nil { - log.Error("Peer removal failed", "peer", id, "err", err) - } - // Hard disconnect at the networking layer - if peer != nil { - peer.Peer.Disconnect(p2p.DiscUselessPeer) - } -} - -func (pm *ProtocolManager) Start(maxPeers int) { - pm.maxPeers = maxPeers - - // broadcast transactions - pm.wg.Add(1) - pm.txsCh = make(chan core.NewTxsEvent, txChanSize) - pm.txsSub = pm.txpool.SubscribeNewTxsEvent(pm.txsCh) - go pm.txBroadcastLoop() - - // broadcast mined blocks - pm.wg.Add(1) - pm.minedBlockSub = pm.eventMux.Subscribe(core.NewMinedBlockEvent{}) - go pm.minedBroadcastLoop() - - // start sync handlers - pm.wg.Add(2) - go pm.chainSync.loop() - go pm.txsyncLoop64() // TODO(karalabe): Legacy initial tx echange, drop with eth/64. -} - -func (pm *ProtocolManager) Stop() { - pm.txsSub.Unsubscribe() // quits txBroadcastLoop - pm.minedBlockSub.Unsubscribe() // quits blockBroadcastLoop - - // Quit chainSync and txsync64. - // After this is done, no new peers will be accepted. - close(pm.quitSync) - pm.wg.Wait() - - // Disconnect existing sessions. - // This also closes the gate for any new registrations on the peer set. - // sessions which are already established but not added to pm.peers yet - // will exit when they try to register. - pm.peers.Close() - pm.peerWG.Wait() - - log.Info("Ethereum protocol stopped") -} - -func (pm *ProtocolManager) newPeer(pv int, p *p2p.Peer, rw p2p.MsgReadWriter, getPooledTx func(hash common.Hash) *types.Transaction) *peer { - return newPeer(pv, p, rw, getPooledTx) -} - -func (pm *ProtocolManager) runPeer(p *peer) error { - if !pm.chainSync.handlePeerEvent(p) { +// runEthPeer +func (h *handler) runEthPeer(peer *eth.Peer, handler eth.Handler) error { + if !h.chainSync.handlePeerEvent(peer) { return p2p.DiscQuitting } - pm.peerWG.Add(1) - defer pm.peerWG.Done() - return pm.handle(p) -} - -// handle is the callback invoked to manage the life cycle of an eth peer. When -// this function terminates, the peer is disconnected. -func (pm *ProtocolManager) handle(p *peer) error { - // Ignore maxPeers if this is a trusted peer - if pm.peers.Len() >= pm.maxPeers && !p.Peer.Info().Network.Trusted { - return p2p.DiscTooManyPeers - } - p.Log().Debug("Ethereum peer connected", "name", p.Name()) + h.peerWG.Add(1) + defer h.peerWG.Done() // Execute the Ethereum handshake var ( - genesis = pm.blockchain.Genesis() - head = pm.blockchain.CurrentHeader() + genesis = h.chain.Genesis() + head = h.chain.CurrentHeader() hash = head.Hash() number = head.Number.Uint64() - td = pm.blockchain.GetTd(hash, number) + td = h.chain.GetTd(hash, number) ) - if err := p.Handshake(pm.networkID, td, hash, genesis.Hash(), forkid.NewID(pm.blockchain), pm.forkFilter); err != nil { - p.Log().Debug("Ethereum handshake failed", "err", err) + forkID := forkid.NewID(h.chain.Config(), h.chain.Genesis().Hash(), h.chain.CurrentHeader().Number.Uint64()) + if err := peer.Handshake(h.networkID, td, hash, genesis.Hash(), forkID, h.forkFilter); err != nil { + peer.Log().Debug("Ethereum handshake failed", "err", err) return err } + // Ignore maxPeers if this is a trusted peer + if h.peers.Len() >= h.maxPeers && !peer.Peer.Info().Network.Trusted { + return p2p.DiscTooManyPeers + } + peer.Log().Debug("Ethereum peer connected", "name", peer.Name()) // Register the peer locally - if err := pm.peers.Register(p, pm.removePeer); err != nil { - p.Log().Error("Ethereum peer registration failed", "err", err) + if err := h.peers.registerEthPeer(peer); err != nil { + peer.Log().Error("Ethereum peer registration failed", "err", err) return err } - defer pm.removePeer(p.id) + defer h.removePeer(peer.ID()) + p := h.peers.ethPeer(peer.ID()) + if p == nil { + return errors.New("peer dropped during handling") + } // Register the peer in the downloader. If the downloader considers it banned, we disconnect - if err := pm.downloader.RegisterPeer(p.id, p.version, p); err != nil { + if err := h.downloader.RegisterPeer(peer.ID(), peer.Version(), peer); err != nil { return err } - pm.chainSync.handlePeerEvent(p) + h.chainSync.handlePeerEvent(peer) // Propagate existing transactions. new transactions appearing // after this will be sent via broadcasts. - pm.syncTransactions(p) + h.syncTransactions(peer) // If we have a trusted CHT, reject all peers below that (avoid fast sync eclipse) - if pm.checkpointHash != (common.Hash{}) { + if h.checkpointHash != (common.Hash{}) { // Request the peer's checkpoint header for chain height/weight validation - if err := p.RequestHeadersByNumber(pm.checkpointNumber, 1, 0, false); err != nil { + if err := peer.RequestHeadersByNumber(h.checkpointNumber, 1, 0, false); err != nil { return err } // Start a timer to disconnect if the peer doesn't reply in time p.syncDrop = time.AfterFunc(syncChallengeTimeout, func() { - p.Log().Warn("Checkpoint challenge timed out, dropping", "addr", p.RemoteAddr(), "type", p.Name()) - pm.removePeer(p.id) + peer.Log().Warn("Checkpoint challenge timed out, dropping", "addr", peer.RemoteAddr(), "type", peer.Name()) + h.removePeer(peer.ID()) }) // Make sure it's cleaned up if the peer dies off defer func() { @@ -361,474 +297,115 @@ func (pm *ProtocolManager) handle(p *peer) error { }() } // If we have any explicit whitelist block hashes, request them - for number := range pm.whitelist { - if err := p.RequestHeadersByNumber(number, 1, 0, false); err != nil { + for number := range h.whitelist { + if err := peer.RequestHeadersByNumber(number, 1, 0, false); err != nil { return err } } // Handle incoming messages until the connection is torn down - for { - if err := pm.handleMsg(p); err != nil { - p.Log().Debug("Ethereum message handling failed", "err", err) - return err - } - } + return handler(peer) } -// handleMsg is invoked whenever an inbound message is received from a remote -// peer. The remote connection is torn down upon returning any error. -func (pm *ProtocolManager) handleMsg(p *peer) error { - // Read the next message from the remote peer, and ensure it's fully consumed - msg, err := p.rw.ReadMsg() - if err != nil { +// runSnapPeer +func (h *handler) runSnapPeer(peer *snap.Peer, handler snap.Handler) error { + h.peerWG.Add(1) + defer h.peerWG.Done() + + // Register the peer locally + if err := h.peers.registerSnapPeer(peer); err != nil { + peer.Log().Error("Snapshot peer registration failed", "err", err) return err } - if msg.Size > protocolMaxMsgSize { - return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize) - } - defer msg.Discard() - - // Handle the message depending on its contents - switch { - case msg.Code == StatusMsg: - // Status messages should never arrive after the handshake - return errResp(ErrExtraStatusMsg, "uncontrolled status message") - - // Block header query, collect the requested headers and reply - case msg.Code == GetBlockHeadersMsg: - // Decode the complex header query - var query getBlockHeadersData - if err := msg.Decode(&query); err != nil { - return errResp(ErrDecode, "%v: %v", msg, err) - } - hashMode := query.Origin.Hash != (common.Hash{}) - first := true - maxNonCanonical := uint64(100) - - // Gather headers until the fetch or network limits is reached - var ( - bytes common.StorageSize - headers []*types.Header - unknown bool - ) - for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && len(headers) < downloader.MaxHeaderFetch { - // Retrieve the next header satisfying the query - var origin *types.Header - if hashMode { - if first { - first = false - origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash) - if origin != nil { - query.Origin.Number = origin.Number.Uint64() - } - } else { - origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number) - } - } else { - origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number) - } - if origin == nil { - break - } - headers = append(headers, origin) - bytes += estHeaderRlpSize - - // Advance to the next header of the query - switch { - case hashMode && query.Reverse: - // Hash based traversal towards the genesis block - ancestor := query.Skip + 1 - if ancestor == 0 { - unknown = true - } else { - query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) - unknown = (query.Origin.Hash == common.Hash{}) - } - case hashMode && !query.Reverse: - // Hash based traversal towards the leaf block - var ( - current = origin.Number.Uint64() - next = current + query.Skip + 1 - ) - if next <= current { - infos, _ := json.MarshalIndent(p.Peer.Info(), "", " ") - p.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) - unknown = true - } else { - if header := pm.blockchain.GetHeaderByNumber(next); header != nil { - nextHash := header.Hash() - expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) - if expOldHash == query.Origin.Hash { - query.Origin.Hash, query.Origin.Number = nextHash, next - } else { - unknown = true - } - } else { - unknown = true - } - } - case query.Reverse: - // Number based traversal towards the genesis block - if query.Origin.Number >= query.Skip+1 { - query.Origin.Number -= query.Skip + 1 - } else { - unknown = true - } - - case !query.Reverse: - // Number based traversal towards the leaf block - query.Origin.Number += query.Skip + 1 - } - } - return p.SendBlockHeaders(headers) + defer h.removePeer(peer.ID()) - case msg.Code == BlockHeadersMsg: - // A batch of headers arrived to one of our previous requests - var headers []*types.Header - if err := msg.Decode(&headers); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // If no headers were received, but we're expencting a checkpoint header, consider it that - if len(headers) == 0 && p.syncDrop != nil { - // Stop the timer either way, decide later to drop or not - p.syncDrop.Stop() - p.syncDrop = nil - - // If we're doing a fast sync, we must enforce the checkpoint block to avoid - // eclipse attacks. Unsynced nodes are welcome to connect after we're done - // joining the network - if atomic.LoadUint32(&pm.fastSync) == 1 { - p.Log().Warn("Dropping unsynced node during fast sync", "addr", p.RemoteAddr(), "type", p.Name()) - return errors.New("unsynced node cannot serve fast sync") - } - } - // Filter out any explicitly requested headers, deliver the rest to the downloader - filter := len(headers) == 1 - if filter { - // If it's a potential sync progress check, validate the content and advertised chain weight - if p.syncDrop != nil && headers[0].Number.Uint64() == pm.checkpointNumber { - // Disable the sync drop timer - p.syncDrop.Stop() - p.syncDrop = nil - - // Validate the header and either drop the peer or continue - if headers[0].Hash() != pm.checkpointHash { - return errors.New("checkpoint hash mismatch") - } - return nil - } - // Otherwise if it's a whitelisted block, validate against the set - if want, ok := pm.whitelist[headers[0].Number.Uint64()]; ok { - if hash := headers[0].Hash(); want != hash { - p.Log().Info("Whitelist mismatch, dropping peer", "number", headers[0].Number.Uint64(), "hash", hash, "want", want) - return errors.New("whitelist block mismatch") - } - p.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", want) - } - // Irrelevant of the fork checks, send the header to the fetcher just in case - headers = pm.blockFetcher.FilterHeaders(p.id, headers, time.Now()) - } - if len(headers) > 0 || !filter { - err := pm.downloader.DeliverHeaders(p.id, headers) - if err != nil { - log.Debug("Failed to deliver headers", "err", err) - } - } - - case msg.Code == GetBlockBodiesMsg: - // Decode the retrieval message - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) - if _, err := msgStream.List(); err != nil { - return err - } - // Gather blocks until the fetch or network limits is reached - var ( - hash common.Hash - bytes int - bodies []rlp.RawValue - ) - for bytes < softResponseLimit && len(bodies) < downloader.MaxBlockFetch { - // Retrieve the hash of the next block - if err := msgStream.Decode(&hash); err == rlp.EOL { - break - } else if err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Retrieve the requested block body, stopping if enough was found - if data := pm.blockchain.GetBodyRLP(hash); len(data) != 0 { - bodies = append(bodies, data) - bytes += len(data) - } - } - return p.SendBlockBodiesRLP(bodies) - - case msg.Code == BlockBodiesMsg: - // A batch of block bodies arrived to one of our previous requests - var request blockBodiesData - if err := msg.Decode(&request); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Deliver them all to the downloader for queuing - transactions := make([][]*types.Transaction, len(request)) - uncles := make([][]*types.Header, len(request)) - - for i, body := range request { - transactions[i] = body.Transactions - uncles[i] = body.Uncles - } - // Filter out any explicitly requested bodies, deliver the rest to the downloader - filter := len(transactions) > 0 || len(uncles) > 0 - if filter { - transactions, uncles = pm.blockFetcher.FilterBodies(p.id, transactions, uncles, time.Now()) - } - if len(transactions) > 0 || len(uncles) > 0 || !filter { - err := pm.downloader.DeliverBodies(p.id, transactions, uncles) - if err != nil { - log.Debug("Failed to deliver bodies", "err", err) - } - } + if err := h.downloader.SnapSyncer.Register(peer); err != nil { + return err + } + // Handle incoming messages until the connection is torn down + return handler(peer) +} - case p.version >= eth63 && msg.Code == GetNodeDataMsg: - // Decode the retrieval message - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) - if _, err := msgStream.List(); err != nil { - return err - } - // Gather state data until the fetch or network limits is reached - var ( - hash common.Hash - bytes int - data [][]byte - ) - for bytes < softResponseLimit && len(data) < downloader.MaxStateFetch { - // Retrieve the hash of the next state entry - if err := msgStream.Decode(&hash); err == rlp.EOL { - break - } else if err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Retrieve the requested state entry, stopping if enough was found - // todo now the code and trienode is mixed in the protocol level, - // separate these two types. - if !pm.downloader.SyncBloomContains(hash[:]) { - // Only lookup the trie node if there's chance that we actually have it - continue - } - entry, err := pm.blockchain.TrieNode(hash) - if len(entry) == 0 || err != nil { - // Read the contract code with prefix only to save unnecessary lookups. - entry, err = pm.blockchain.ContractCodeWithPrefix(hash) - } - if err == nil && len(entry) > 0 { - data = append(data, entry) - bytes += len(entry) - } - } - return p.SendNodeData(data) +func (h *handler) removePeer(id string) { + // Remove the eth peer if it exists + eth := h.peers.ethPeer(id) + if eth != nil { + log.Debug("Removing Ethereum peer", "peer", id) + h.downloader.UnregisterPeer(id) + h.txFetcher.Drop(id) - case p.version >= eth63 && msg.Code == NodeDataMsg: - // A batch of node state data arrived to one of our previous requests - var data [][]byte - if err := msg.Decode(&data); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) + if err := h.peers.unregisterEthPeer(id); err != nil { + log.Error("Peer removal failed", "peer", id, "err", err) } - // Deliver all to the downloader - if err := pm.downloader.DeliverNodeData(p.id, data); err != nil { - log.Debug("Failed to deliver node state data", "err", err) + } + // Remove the snap peer if it exists + snap := h.peers.snapPeer(id) + if snap != nil { + log.Debug("Removing Snapshot peer", "peer", id) + h.downloader.SnapSyncer.Unregister(id) + if err := h.peers.unregisterSnapPeer(id); err != nil { + log.Error("Peer removal failed", "peer", id, "err", err) } + } + // Hard disconnect at the networking layer + if eth != nil { + eth.Peer.Disconnect(p2p.DiscUselessPeer) + } + if snap != nil { + snap.Peer.Disconnect(p2p.DiscUselessPeer) + } +} - case p.version >= eth63 && msg.Code == GetReceiptsMsg: - // Decode the retrieval message - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) - if _, err := msgStream.List(); err != nil { - return err - } - // Gather state data until the fetch or network limits is reached - var ( - hash common.Hash - bytes int - receipts []rlp.RawValue - ) - for bytes < softResponseLimit && len(receipts) < downloader.MaxReceiptFetch { - // Retrieve the hash of the next block - if err := msgStream.Decode(&hash); err == rlp.EOL { - break - } else if err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Retrieve the requested block's receipts, skipping if unknown to us - results := pm.blockchain.GetReceiptsByHash(hash) - if results == nil { - if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { - continue - } - } - // If known, encode and queue for response packet - if encoded, err := rlp.EncodeToBytes(results); err != nil { - log.Error("Failed to encode receipt", "err", err) - } else { - receipts = append(receipts, encoded) - bytes += len(encoded) - } - } - return p.SendReceiptsRLP(receipts) +func (h *handler) Start(maxPeers int) { + h.maxPeers = maxPeers - case p.version >= eth63 && msg.Code == ReceiptsMsg: - // A batch of receipts arrived to one of our previous requests - var receipts [][]*types.Receipt - if err := msg.Decode(&receipts); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Deliver all to the downloader - if err := pm.downloader.DeliverReceipts(p.id, receipts); err != nil { - log.Debug("Failed to deliver receipts", "err", err) - } + // broadcast transactions + h.wg.Add(1) + h.txsCh = make(chan core.NewTxsEvent, txChanSize) + h.txsSub = h.txpool.SubscribeNewTxsEvent(h.txsCh) + go h.txBroadcastLoop() - case msg.Code == NewBlockHashesMsg: - var announces newBlockHashesData - if err := msg.Decode(&announces); err != nil { - return errResp(ErrDecode, "%v: %v", msg, err) - } - // Mark the hashes as present at the remote node - for _, block := range announces { - p.MarkBlock(block.Hash) - } - // Schedule all the unknown hashes for retrieval - unknown := make(newBlockHashesData, 0, len(announces)) - for _, block := range announces { - if !pm.blockchain.HasBlock(block.Hash, block.Number) { - unknown = append(unknown, block) - } - } - for _, block := range unknown { - pm.blockFetcher.Notify(p.id, block.Hash, block.Number, time.Now(), p.RequestOneHeader, p.RequestBodies) - } + // broadcast mined blocks + h.wg.Add(1) + h.minedBlockSub = h.eventMux.Subscribe(core.NewMinedBlockEvent{}) + go h.minedBroadcastLoop() - case msg.Code == NewBlockMsg: - // Retrieve and decode the propagated block - var request newBlockData - if err := msg.Decode(&request); err != nil { - return errResp(ErrDecode, "%v: %v", msg, err) - } - if hash := types.CalcUncleHash(request.Block.Uncles()); hash != request.Block.UncleHash() { - log.Warn("Propagated block has invalid uncles", "have", hash, "exp", request.Block.UncleHash()) - break // TODO(karalabe): return error eventually, but wait a few releases - } - if hash := types.DeriveSha(request.Block.Transactions(), new(trie.Trie)); hash != request.Block.TxHash() { - log.Warn("Propagated block has invalid body", "have", hash, "exp", request.Block.TxHash()) - break // TODO(karalabe): return error eventually, but wait a few releases - } - if err := request.sanityCheck(); err != nil { - return err - } - request.Block.ReceivedAt = msg.ReceivedAt - request.Block.ReceivedFrom = p - - // Mark the peer as owning the block and schedule it for import - p.MarkBlock(request.Block.Hash()) - pm.blockFetcher.Enqueue(p.id, request.Block) - - // Assuming the block is importable by the peer, but possibly not yet done so, - // calculate the head hash and TD that the peer truly must have. - var ( - trueHead = request.Block.ParentHash() - trueTD = new(big.Int).Sub(request.TD, request.Block.Difficulty()) - ) - // Update the peer's total difficulty if better than the previous - if _, td := p.Head(); trueTD.Cmp(td) > 0 { - p.SetHead(trueHead, trueTD) - pm.chainSync.handlePeerEvent(p) - } + // start sync handlers + h.wg.Add(2) + go h.chainSync.loop() + go h.txsyncLoop64() // TODO(karalabe): Legacy initial tx echange, drop with eth/64. +} - case msg.Code == NewPooledTransactionHashesMsg && p.version >= eth65: - // New transaction announcement arrived, make sure we have - // a valid and fresh chain to handle them - if atomic.LoadUint32(&pm.acceptTxs) == 0 { - break - } - var hashes []common.Hash - if err := msg.Decode(&hashes); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Schedule all the unknown hashes for retrieval - for _, hash := range hashes { - p.MarkTransaction(hash) - } - pm.txFetcher.Notify(p.id, hashes) +func (h *handler) Stop() { + h.txsSub.Unsubscribe() // quits txBroadcastLoop + h.minedBlockSub.Unsubscribe() // quits blockBroadcastLoop - case msg.Code == GetPooledTransactionsMsg && p.version >= eth65: - // Decode the retrieval message - msgStream := rlp.NewStream(msg.Payload, uint64(msg.Size)) - if _, err := msgStream.List(); err != nil { - return err - } - // Gather transactions until the fetch or network limits is reached - var ( - hash common.Hash - bytes int - hashes []common.Hash - txs []rlp.RawValue - ) - for bytes < softResponseLimit { - // Retrieve the hash of the next block - if err := msgStream.Decode(&hash); err == rlp.EOL { - break - } else if err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - // Retrieve the requested transaction, skipping if unknown to us - tx := pm.txpool.Get(hash) - if tx == nil { - continue - } - // If known, encode and queue for response packet - if encoded, err := rlp.EncodeToBytes(tx); err != nil { - log.Error("Failed to encode transaction", "err", err) - } else { - hashes = append(hashes, hash) - txs = append(txs, encoded) - bytes += len(encoded) - } - } - return p.SendPooledTransactionsRLP(hashes, txs) + // Quit chainSync and txsync64. + // After this is done, no new peers will be accepted. + close(h.quitSync) + h.wg.Wait() - case msg.Code == TransactionMsg || (msg.Code == PooledTransactionsMsg && p.version >= eth65): - // Transactions arrived, make sure we have a valid and fresh chain to handle them - if atomic.LoadUint32(&pm.acceptTxs) == 0 { - break - } - // Transactions can be processed, parse all of them and deliver to the pool - var txs []*types.Transaction - if err := msg.Decode(&txs); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - for i, tx := range txs { - // Validate and mark the remote transaction - if tx == nil { - return errResp(ErrDecode, "transaction %d is nil", i) - } - p.MarkTransaction(tx.Hash()) - } - pm.txFetcher.Enqueue(p.id, txs, msg.Code == PooledTransactionsMsg) + // Disconnect existing sessions. + // This also closes the gate for any new registrations on the peer set. + // sessions which are already established but not added to h.peers yet + // will exit when they try to register. + h.peers.close() + h.peerWG.Wait() - default: - return errResp(ErrInvalidMsgCode, "%v", msg.Code) - } - return nil + log.Info("Ethereum protocol stopped") } // BroadcastBlock will either propagate a block to a subset of its peers, or // will only announce its availability (depending what's requested). -func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) { +func (h *handler) BroadcastBlock(block *types.Block, propagate bool) { hash := block.Hash() - peers := pm.peers.PeersWithoutBlock(hash) + peers := h.peers.ethPeersWithoutBlock(hash) // If propagation is requested, send to a subset of the peer if propagate { // Calculate the TD of the block (it's not imported yet, so block.Td is not valid) var td *big.Int - if parent := pm.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1); parent != nil { - td = new(big.Int).Add(block.Difficulty(), pm.blockchain.GetTd(block.ParentHash(), block.NumberU64()-1)) + if parent := h.chain.GetBlock(block.ParentHash(), block.NumberU64()-1); parent != nil { + td = new(big.Int).Add(block.Difficulty(), h.chain.GetTd(block.ParentHash(), block.NumberU64()-1)) } else { log.Error("Propagating dangling block", "number", block.Number(), "hash", hash) return @@ -842,7 +419,7 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) { return } // Otherwise if the block is indeed in out own chain, announce it - if pm.blockchain.HasBlock(hash, block.NumberU64()) { + if h.chain.HasBlock(hash, block.NumberU64()) { for _, peer := range peers { peer.AsyncSendNewBlockHash(block) } @@ -852,15 +429,15 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) { // BroadcastTransactions will propagate a batch of transactions to all peers which are not known to // already have the given transaction. -func (pm *ProtocolManager) BroadcastTransactions(txs types.Transactions, propagate bool) { +func (h *handler) BroadcastTransactions(txs types.Transactions, propagate bool) { var ( - txset = make(map[*peer][]common.Hash) - annos = make(map[*peer][]common.Hash) + txset = make(map[*ethPeer][]common.Hash) + annos = make(map[*ethPeer][]common.Hash) ) // Broadcast transactions to a batch of peers not knowing about it if propagate { for _, tx := range txs { - peers := pm.peers.PeersWithoutTx(tx.Hash()) + peers := h.peers.ethPeersWithoutTransacion(tx.Hash()) // Send the block to a subset of our peers transfer := peers[:int(math.Sqrt(float64(len(peers))))] @@ -876,13 +453,13 @@ func (pm *ProtocolManager) BroadcastTransactions(txs types.Transactions, propaga } // Otherwise only broadcast the announcement to peers for _, tx := range txs { - peers := pm.peers.PeersWithoutTx(tx.Hash()) + peers := h.peers.ethPeersWithoutTransacion(tx.Hash()) for _, peer := range peers { annos[peer] = append(annos[peer], tx.Hash()) } } for peer, hashes := range annos { - if peer.version >= eth65 { + if peer.Version() >= eth.ETH65 { peer.AsyncSendPooledTransactionHashes(hashes) } else { peer.AsyncSendTransactions(hashes) @@ -891,56 +468,29 @@ func (pm *ProtocolManager) BroadcastTransactions(txs types.Transactions, propaga } // minedBroadcastLoop sends mined blocks to connected peers. -func (pm *ProtocolManager) minedBroadcastLoop() { - defer pm.wg.Done() +func (h *handler) minedBroadcastLoop() { + defer h.wg.Done() - for obj := range pm.minedBlockSub.Chan() { + for obj := range h.minedBlockSub.Chan() { if ev, ok := obj.Data.(core.NewMinedBlockEvent); ok { - pm.BroadcastBlock(ev.Block, true) // First propagate block to peers - pm.BroadcastBlock(ev.Block, false) // Only then announce to the rest + h.BroadcastBlock(ev.Block, true) // First propagate block to peers + h.BroadcastBlock(ev.Block, false) // Only then announce to the rest } } } // txBroadcastLoop announces new transactions to connected peers. -func (pm *ProtocolManager) txBroadcastLoop() { - defer pm.wg.Done() +func (h *handler) txBroadcastLoop() { + defer h.wg.Done() for { select { - case event := <-pm.txsCh: - // For testing purpose only, disable propagation - if pm.broadcastTxAnnouncesOnly { - pm.BroadcastTransactions(event.Txs, false) - continue - } - pm.BroadcastTransactions(event.Txs, true) // First propagate transactions to peers - pm.BroadcastTransactions(event.Txs, false) // Only then announce to the rest + case event := <-h.txsCh: + h.BroadcastTransactions(event.Txs, true) // First propagate transactions to peers + h.BroadcastTransactions(event.Txs, false) // Only then announce to the rest - case <-pm.txsSub.Err(): + case <-h.txsSub.Err(): return } } } - -// NodeInfo represents a short summary of the Ethereum sub-protocol metadata -// known about the host peer. -type NodeInfo struct { - Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4) - Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain - Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block - Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules - Head common.Hash `json:"head"` // SHA3 hash of the host's best owned block -} - -// NodeInfo retrieves some protocol metadata about the running host node. -func (pm *ProtocolManager) NodeInfo() *NodeInfo { - currentBlock := pm.blockchain.CurrentBlock() - return &NodeInfo{ - Network: pm.networkID, - Difficulty: pm.blockchain.GetTd(currentBlock.Hash(), currentBlock.NumberU64()), - Genesis: pm.blockchain.Genesis().Hash(), - Config: pm.blockchain.Config(), - Head: currentBlock.Hash(), - } -} diff --git a/eth/handler_eth.go b/eth/handler_eth.go new file mode 100644 index 000000000000..84bdac659a5a --- /dev/null +++ b/eth/handler_eth.go @@ -0,0 +1,218 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "fmt" + "math/big" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/trie" +) + +// ethHandler implements the eth.Backend interface to handle the various network +// packets that are sent as replies or broadcasts. +type ethHandler handler + +func (h *ethHandler) Chain() *core.BlockChain { return h.chain } +func (h *ethHandler) StateBloom() *trie.SyncBloom { return h.stateBloom } +func (h *ethHandler) TxPool() eth.TxPool { return h.txpool } + +// RunPeer is invoked when a peer joins on the `eth` protocol. +func (h *ethHandler) RunPeer(peer *eth.Peer, hand eth.Handler) error { + return (*handler)(h).runEthPeer(peer, hand) +} + +// PeerInfo retrieves all known `eth` information about a peer. +func (h *ethHandler) PeerInfo(id enode.ID) interface{} { + if p := h.peers.ethPeer(id.String()); p != nil { + return p.info() + } + return nil +} + +// AcceptTxs retrieves whether transaction processing is enabled on the node +// or if inbound transactions should simply be dropped. +func (h *ethHandler) AcceptTxs() bool { + return atomic.LoadUint32(&h.acceptTxs) == 1 +} + +// Handle is invoked from a peer's message handler when it receives a new remote +// message that the handler couldn't consume and serve itself. +func (h *ethHandler) Handle(peer *eth.Peer, packet eth.Packet) error { + // Consume any broadcasts and announces, forwarding the rest to the downloader + switch packet := packet.(type) { + case *eth.BlockHeadersPacket: + return h.handleHeaders(peer, *packet) + + case *eth.BlockBodiesPacket: + txset, uncleset := packet.Unpack() + return h.handleBodies(peer, txset, uncleset) + + case *eth.NodeDataPacket: + if err := h.downloader.DeliverNodeData(peer.ID(), *packet); err != nil { + log.Debug("Failed to deliver node state data", "err", err) + } + return nil + + case *eth.ReceiptsPacket: + if err := h.downloader.DeliverReceipts(peer.ID(), *packet); err != nil { + log.Debug("Failed to deliver receipts", "err", err) + } + return nil + + case *eth.NewBlockHashesPacket: + hashes, numbers := packet.Unpack() + return h.handleBlockAnnounces(peer, hashes, numbers) + + case *eth.NewBlockPacket: + return h.handleBlockBroadcast(peer, packet.Block, packet.TD) + + case *eth.NewPooledTransactionHashesPacket: + return h.txFetcher.Notify(peer.ID(), *packet) + + case *eth.TransactionsPacket: + return h.txFetcher.Enqueue(peer.ID(), *packet, false) + + case *eth.PooledTransactionsPacket: + return h.txFetcher.Enqueue(peer.ID(), *packet, true) + + default: + return fmt.Errorf("unexpected eth packet type: %T", packet) + } +} + +// handleHeaders is invoked from a peer's message handler when it transmits a batch +// of headers for the local node to process. +func (h *ethHandler) handleHeaders(peer *eth.Peer, headers []*types.Header) error { + p := h.peers.ethPeer(peer.ID()) + if p == nil { + return errors.New("unregistered during callback") + } + // If no headers were received, but we're expencting a checkpoint header, consider it that + if len(headers) == 0 && p.syncDrop != nil { + // Stop the timer either way, decide later to drop or not + p.syncDrop.Stop() + p.syncDrop = nil + + // If we're doing a fast (or snap) sync, we must enforce the checkpoint block to avoid + // eclipse attacks. Unsynced nodes are welcome to connect after we're done + // joining the network + if atomic.LoadUint32(&h.fastSync) == 1 { + peer.Log().Warn("Dropping unsynced node during sync", "addr", peer.RemoteAddr(), "type", peer.Name()) + return errors.New("unsynced node cannot serve sync") + } + } + // Filter out any explicitly requested headers, deliver the rest to the downloader + filter := len(headers) == 1 + if filter { + // If it's a potential sync progress check, validate the content and advertised chain weight + if p.syncDrop != nil && headers[0].Number.Uint64() == h.checkpointNumber { + // Disable the sync drop timer + p.syncDrop.Stop() + p.syncDrop = nil + + // Validate the header and either drop the peer or continue + if headers[0].Hash() != h.checkpointHash { + return errors.New("checkpoint hash mismatch") + } + return nil + } + // Otherwise if it's a whitelisted block, validate against the set + if want, ok := h.whitelist[headers[0].Number.Uint64()]; ok { + if hash := headers[0].Hash(); want != hash { + peer.Log().Info("Whitelist mismatch, dropping peer", "number", headers[0].Number.Uint64(), "hash", hash, "want", want) + return errors.New("whitelist block mismatch") + } + peer.Log().Debug("Whitelist block verified", "number", headers[0].Number.Uint64(), "hash", want) + } + // Irrelevant of the fork checks, send the header to the fetcher just in case + headers = h.blockFetcher.FilterHeaders(peer.ID(), headers, time.Now()) + } + if len(headers) > 0 || !filter { + err := h.downloader.DeliverHeaders(peer.ID(), headers) + if err != nil { + log.Debug("Failed to deliver headers", "err", err) + } + } + return nil +} + +// handleBodies is invoked from a peer's message handler when it transmits a batch +// of block bodies for the local node to process. +func (h *ethHandler) handleBodies(peer *eth.Peer, txs [][]*types.Transaction, uncles [][]*types.Header) error { + // Filter out any explicitly requested bodies, deliver the rest to the downloader + filter := len(txs) > 0 || len(uncles) > 0 + if filter { + txs, uncles = h.blockFetcher.FilterBodies(peer.ID(), txs, uncles, time.Now()) + } + if len(txs) > 0 || len(uncles) > 0 || !filter { + err := h.downloader.DeliverBodies(peer.ID(), txs, uncles) + if err != nil { + log.Debug("Failed to deliver bodies", "err", err) + } + } + return nil +} + +// handleBlockAnnounces is invoked from a peer's message handler when it transmits a +// batch of block announcements for the local node to process. +func (h *ethHandler) handleBlockAnnounces(peer *eth.Peer, hashes []common.Hash, numbers []uint64) error { + // Schedule all the unknown hashes for retrieval + var ( + unknownHashes = make([]common.Hash, 0, len(hashes)) + unknownNumbers = make([]uint64, 0, len(numbers)) + ) + for i := 0; i < len(hashes); i++ { + if !h.chain.HasBlock(hashes[i], numbers[i]) { + unknownHashes = append(unknownHashes, hashes[i]) + unknownNumbers = append(unknownNumbers, numbers[i]) + } + } + for i := 0; i < len(unknownHashes); i++ { + h.blockFetcher.Notify(peer.ID(), unknownHashes[i], unknownNumbers[i], time.Now(), peer.RequestOneHeader, peer.RequestBodies) + } + return nil +} + +// handleBlockBroadcast is invoked from a peer's message handler when it transmits a +// block broadcast for the local node to process. +func (h *ethHandler) handleBlockBroadcast(peer *eth.Peer, block *types.Block, td *big.Int) error { + // Schedule the block for import + h.blockFetcher.Enqueue(peer.ID(), block) + + // Assuming the block is importable by the peer, but possibly not yet done so, + // calculate the head hash and TD that the peer truly must have. + var ( + trueHead = block.ParentHash() + trueTD = new(big.Int).Sub(td, block.Difficulty()) + ) + // Update the peer's total difficulty if better than the previous + if _, td := peer.Head(); trueTD.Cmp(td) > 0 { + peer.SetHead(trueHead, trueTD) + h.chainSync.handlePeerEvent(peer) + } + return nil +} diff --git a/eth/handler_eth_test.go b/eth/handler_eth_test.go new file mode 100644 index 000000000000..0e5c0c90eef5 --- /dev/null +++ b/eth/handler_eth_test.go @@ -0,0 +1,740 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "fmt" + "math/big" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" +) + +// testEthHandler is a mock event handler to listen for inbound network requests +// on the `eth` protocol and convert them into a more easily testable form. +type testEthHandler struct { + blockBroadcasts event.Feed + txAnnounces event.Feed + txBroadcasts event.Feed +} + +func (h *testEthHandler) Chain() *core.BlockChain { panic("no backing chain") } +func (h *testEthHandler) StateBloom() *trie.SyncBloom { panic("no backing state bloom") } +func (h *testEthHandler) TxPool() eth.TxPool { panic("no backing tx pool") } +func (h *testEthHandler) AcceptTxs() bool { return true } +func (h *testEthHandler) RunPeer(*eth.Peer, eth.Handler) error { panic("not used in tests") } +func (h *testEthHandler) PeerInfo(enode.ID) interface{} { panic("not used in tests") } + +func (h *testEthHandler) Handle(peer *eth.Peer, packet eth.Packet) error { + switch packet := packet.(type) { + case *eth.NewBlockPacket: + h.blockBroadcasts.Send(packet.Block) + return nil + + case *eth.NewPooledTransactionHashesPacket: + h.txAnnounces.Send(([]common.Hash)(*packet)) + return nil + + case *eth.TransactionsPacket: + h.txBroadcasts.Send(([]*types.Transaction)(*packet)) + return nil + + case *eth.PooledTransactionsPacket: + h.txBroadcasts.Send(([]*types.Transaction)(*packet)) + return nil + + default: + panic(fmt.Sprintf("unexpected eth packet type in tests: %T", packet)) + } +} + +// Tests that peers are correctly accepted (or rejected) based on the advertised +// fork IDs in the protocol handshake. +func TestForkIDSplit64(t *testing.T) { testForkIDSplit(t, 64) } +func TestForkIDSplit65(t *testing.T) { testForkIDSplit(t, 65) } + +func testForkIDSplit(t *testing.T, protocol uint) { + t.Parallel() + + var ( + engine = ethash.NewFaker() + + configNoFork = ¶ms.ChainConfig{HomesteadBlock: big.NewInt(1)} + configProFork = ¶ms.ChainConfig{ + HomesteadBlock: big.NewInt(1), + EIP150Block: big.NewInt(2), + EIP155Block: big.NewInt(2), + EIP158Block: big.NewInt(2), + ByzantiumBlock: big.NewInt(3), + } + dbNoFork = rawdb.NewMemoryDatabase() + dbProFork = rawdb.NewMemoryDatabase() + + gspecNoFork = &core.Genesis{Config: configNoFork} + gspecProFork = &core.Genesis{Config: configProFork} + + genesisNoFork = gspecNoFork.MustCommit(dbNoFork) + genesisProFork = gspecProFork.MustCommit(dbProFork) + + chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil, nil) + chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil, nil) + + blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil) + blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil) + + ethNoFork, _ = newHandler(&handlerConfig{ + Database: dbNoFork, + Chain: chainNoFork, + TxPool: newTestTxPool(), + Network: 1, + Sync: downloader.FullSync, + BloomCache: 1, + }) + ethProFork, _ = newHandler(&handlerConfig{ + Database: dbProFork, + Chain: chainProFork, + TxPool: newTestTxPool(), + Network: 1, + Sync: downloader.FullSync, + BloomCache: 1, + }) + ) + ethNoFork.Start(1000) + ethProFork.Start(1000) + + // Clean up everything after ourselves + defer chainNoFork.Stop() + defer chainProFork.Stop() + + defer ethNoFork.Stop() + defer ethProFork.Stop() + + // Both nodes should allow the other to connect (same genesis, next fork is the same) + p2pNoFork, p2pProFork := p2p.MsgPipe() + defer p2pNoFork.Close() + defer p2pProFork.Close() + + peerNoFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) + peerProFork := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) + defer peerNoFork.Close() + defer peerProFork.Close() + + errc := make(chan error, 2) + go func(errc chan error) { + errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil }) + }(errc) + go func(errc chan error) { + errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil }) + }(errc) + + for i := 0; i < 2; i++ { + select { + case err := <-errc: + if err != nil { + t.Fatalf("frontier nofork <-> profork failed: %v", err) + } + case <-time.After(250 * time.Millisecond): + t.Fatalf("frontier nofork <-> profork handler timeout") + } + } + // Progress into Homestead. Fork's match, so we don't care what the future holds + chainNoFork.InsertChain(blocksNoFork[:1]) + chainProFork.InsertChain(blocksProFork[:1]) + + p2pNoFork, p2pProFork = p2p.MsgPipe() + defer p2pNoFork.Close() + defer p2pProFork.Close() + + peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) + peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) + defer peerNoFork.Close() + defer peerProFork.Close() + + errc = make(chan error, 2) + go func(errc chan error) { + errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil }) + }(errc) + go func(errc chan error) { + errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil }) + }(errc) + + for i := 0; i < 2; i++ { + select { + case err := <-errc: + if err != nil { + t.Fatalf("homestead nofork <-> profork failed: %v", err) + } + case <-time.After(250 * time.Millisecond): + t.Fatalf("homestead nofork <-> profork handler timeout") + } + } + // Progress into Spurious. Forks mismatch, signalling differing chains, reject + chainNoFork.InsertChain(blocksNoFork[1:2]) + chainProFork.InsertChain(blocksProFork[1:2]) + + p2pNoFork, p2pProFork = p2p.MsgPipe() + defer p2pNoFork.Close() + defer p2pProFork.Close() + + peerNoFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) + peerProFork = eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) + defer peerNoFork.Close() + defer peerProFork.Close() + + errc = make(chan error, 2) + go func(errc chan error) { + errc <- ethNoFork.runEthPeer(peerProFork, func(peer *eth.Peer) error { return nil }) + }(errc) + go func(errc chan error) { + errc <- ethProFork.runEthPeer(peerNoFork, func(peer *eth.Peer) error { return nil }) + }(errc) + + var successes int + for i := 0; i < 2; i++ { + select { + case err := <-errc: + if err == nil { + successes++ + if successes == 2 { // Only one side disconnects + t.Fatalf("fork ID rejection didn't happen") + } + } + case <-time.After(250 * time.Millisecond): + t.Fatalf("split peers not rejected") + } + } +} + +// Tests that received transactions are added to the local pool. +func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) } +func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, 65) } + +func testRecvTransactions(t *testing.T, protocol uint) { + t.Parallel() + + // Create a message handler, configure it to accept transactions and watch them + handler := newTestHandler() + defer handler.close() + + handler.handler.acceptTxs = 1 // mark synced to accept transactions + + txs := make(chan core.NewTxsEvent) + sub := handler.txpool.SubscribeNewTxsEvent(txs) + defer sub.Unsubscribe() + + // Create a source peer to send messages through and a sink handler to receive them + p2pSrc, p2pSink := p2p.MsgPipe() + defer p2pSrc.Close() + defer p2pSink.Close() + + src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool) + sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool) + defer src.Close() + defer sink.Close() + + go handler.handler.runEthPeer(sink, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(handler.handler), peer) + }) + // Run the handshake locally to avoid spinning up a source handler + var ( + genesis = handler.chain.Genesis() + head = handler.chain.CurrentBlock() + td = handler.chain.GetTd(head.Hash(), head.NumberU64()) + ) + if err := src.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil { + t.Fatalf("failed to run protocol handshake") + } + // Send the transaction to the sink and verify that it's added to the tx pool + tx := types.NewTransaction(0, common.Address{}, big.NewInt(0), 100000, big.NewInt(0), nil) + tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey) + + if err := src.SendTransactions([]*types.Transaction{tx}); err != nil { + t.Fatalf("failed to send transaction: %v", err) + } + select { + case event := <-txs: + if len(event.Txs) != 1 { + t.Errorf("wrong number of added transactions: got %d, want 1", len(event.Txs)) + } else if event.Txs[0].Hash() != tx.Hash() { + t.Errorf("added wrong tx hash: got %v, want %v", event.Txs[0].Hash(), tx.Hash()) + } + case <-time.After(2 * time.Second): + t.Errorf("no NewTxsEvent received within 2 seconds") + } +} + +// This test checks that pending transactions are sent. +func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) } +func TestSendTransactions65(t *testing.T) { testSendTransactions(t, 65) } + +func testSendTransactions(t *testing.T, protocol uint) { + t.Parallel() + + // Create a message handler and fill the pool with big transactions + handler := newTestHandler() + defer handler.close() + + insert := make([]*types.Transaction, 100) + for nonce := range insert { + tx := types.NewTransaction(uint64(nonce), common.Address{}, big.NewInt(0), 100000, big.NewInt(0), make([]byte, txsyncPackSize/10)) + tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey) + + insert[nonce] = tx + } + go handler.txpool.AddRemotes(insert) // Need goroutine to not block on feed + time.Sleep(250 * time.Millisecond) // Wait until tx events get out of the system (can't use events, tx broadcaster races with peer join) + + // Create a source handler to send messages through and a sink peer to receive them + p2pSrc, p2pSink := p2p.MsgPipe() + defer p2pSrc.Close() + defer p2pSink.Close() + + src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, handler.txpool) + sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, handler.txpool) + defer src.Close() + defer sink.Close() + + go handler.handler.runEthPeer(src, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(handler.handler), peer) + }) + // Run the handshake locally to avoid spinning up a source handler + var ( + genesis = handler.chain.Genesis() + head = handler.chain.CurrentBlock() + td = handler.chain.GetTd(head.Hash(), head.NumberU64()) + ) + if err := sink.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil { + t.Fatalf("failed to run protocol handshake") + } + // After the handshake completes, the source handler should stream the sink + // the transactions, subscribe to all inbound network events + backend := new(testEthHandler) + + anns := make(chan []common.Hash) + annSub := backend.txAnnounces.Subscribe(anns) + defer annSub.Unsubscribe() + + bcasts := make(chan []*types.Transaction) + bcastSub := backend.txBroadcasts.Subscribe(bcasts) + defer bcastSub.Unsubscribe() + + go eth.Handle(backend, sink) + + // Make sure we get all the transactions on the correct channels + seen := make(map[common.Hash]struct{}) + for len(seen) < len(insert) { + switch protocol { + case 63, 64: + select { + case <-anns: + t.Errorf("tx announce received on pre eth/65") + case txs := <-bcasts: + for _, tx := range txs { + if _, ok := seen[tx.Hash()]; ok { + t.Errorf("duplicate transaction announced: %x", tx.Hash()) + } + seen[tx.Hash()] = struct{}{} + } + } + case 65: + select { + case hashes := <-anns: + for _, hash := range hashes { + if _, ok := seen[hash]; ok { + t.Errorf("duplicate transaction announced: %x", hash) + } + seen[hash] = struct{}{} + } + case <-bcasts: + t.Errorf("initial tx broadcast received on post eth/65") + } + + default: + panic("unsupported protocol, please extend test") + } + } + for _, tx := range insert { + if _, ok := seen[tx.Hash()]; !ok { + t.Errorf("missing transaction: %x", tx.Hash()) + } + } +} + +// Tests that transactions get propagated to all attached peers, either via direct +// broadcasts or via announcements/retrievals. +func TestTransactionPropagation64(t *testing.T) { testTransactionPropagation(t, 64) } +func TestTransactionPropagation65(t *testing.T) { testTransactionPropagation(t, 65) } + +func testTransactionPropagation(t *testing.T, protocol uint) { + t.Parallel() + + // Create a source handler to send transactions from and a number of sinks + // to receive them. We need multiple sinks since a one-to-one peering would + // broadcast all transactions without announcement. + source := newTestHandler() + defer source.close() + + sinks := make([]*testHandler, 10) + for i := 0; i < len(sinks); i++ { + sinks[i] = newTestHandler() + defer sinks[i].close() + + sinks[i].handler.acceptTxs = 1 // mark synced to accept transactions + } + // Interconnect all the sink handlers with the source handler + for i, sink := range sinks { + sink := sink // Closure for gorotuine below + + sourcePipe, sinkPipe := p2p.MsgPipe() + defer sourcePipe.Close() + defer sinkPipe.Close() + + sourcePeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, source.txpool) + sinkPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, sink.txpool) + defer sourcePeer.Close() + defer sinkPeer.Close() + + go source.handler.runEthPeer(sourcePeer, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(source.handler), peer) + }) + go sink.handler.runEthPeer(sinkPeer, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(sink.handler), peer) + }) + } + // Subscribe to all the transaction pools + txChs := make([]chan core.NewTxsEvent, len(sinks)) + for i := 0; i < len(sinks); i++ { + txChs[i] = make(chan core.NewTxsEvent, 1024) + + sub := sinks[i].txpool.SubscribeNewTxsEvent(txChs[i]) + defer sub.Unsubscribe() + } + // Fill the source pool with transactions and wait for them at the sinks + txs := make([]*types.Transaction, 1024) + for nonce := range txs { + tx := types.NewTransaction(uint64(nonce), common.Address{}, big.NewInt(0), 100000, big.NewInt(0), nil) + tx, _ = types.SignTx(tx, types.HomesteadSigner{}, testKey) + + txs[nonce] = tx + } + source.txpool.AddRemotes(txs) + + // Iterate through all the sinks and ensure they all got the transactions + for i := range sinks { + for arrived := 0; arrived < len(txs); { + select { + case event := <-txChs[i]: + arrived += len(event.Txs) + case <-time.NewTimer(time.Second).C: + t.Errorf("sink %d: transaction propagation timed out: have %d, want %d", i, arrived, len(txs)) + } + } + } +} + +// Tests that post eth protocol handshake, clients perform a mutual checkpoint +// challenge to validate each other's chains. Hash mismatches, or missing ones +// during a fast sync should lead to the peer getting dropped. +func TestCheckpointChallenge(t *testing.T) { + tests := []struct { + syncmode downloader.SyncMode + checkpoint bool + timeout bool + empty bool + match bool + drop bool + }{ + // If checkpointing is not enabled locally, don't challenge and don't drop + {downloader.FullSync, false, false, false, false, false}, + {downloader.FastSync, false, false, false, false, false}, + + // If checkpointing is enabled locally and remote response is empty, only drop during fast sync + {downloader.FullSync, true, false, true, false, false}, + {downloader.FastSync, true, false, true, false, true}, // Special case, fast sync, unsynced peer + + // If checkpointing is enabled locally and remote response mismatches, always drop + {downloader.FullSync, true, false, false, false, true}, + {downloader.FastSync, true, false, false, false, true}, + + // If checkpointing is enabled locally and remote response matches, never drop + {downloader.FullSync, true, false, false, true, false}, + {downloader.FastSync, true, false, false, true, false}, + + // If checkpointing is enabled locally and remote times out, always drop + {downloader.FullSync, true, true, false, true, true}, + {downloader.FastSync, true, true, false, true, true}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("sync %v checkpoint %v timeout %v empty %v match %v", tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match), func(t *testing.T) { + testCheckpointChallenge(t, tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match, tt.drop) + }) + } +} + +func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) { + // Reduce the checkpoint handshake challenge timeout + defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout) + syncChallengeTimeout = 250 * time.Millisecond + + // Create a test handler and inject a CHT into it. The injection is a bit + // ugly, but it beats creating everything manually just to avoid reaching + // into the internals a bit. + handler := newTestHandler() + defer handler.close() + + if syncmode == downloader.FastSync { + atomic.StoreUint32(&handler.handler.fastSync, 1) + } else { + atomic.StoreUint32(&handler.handler.fastSync, 0) + } + var response *types.Header + if checkpoint { + number := (uint64(rand.Intn(500))+1)*params.CHTFrequency - 1 + response = &types.Header{Number: big.NewInt(int64(number)), Extra: []byte("valid")} + + handler.handler.checkpointNumber = number + handler.handler.checkpointHash = response.Hash() + } + // Create a challenger peer and a challenged one + p2pLocal, p2pRemote := p2p.MsgPipe() + defer p2pLocal.Close() + defer p2pRemote.Close() + + local := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{1}, "", nil), p2pLocal, handler.txpool) + remote := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{2}, "", nil), p2pRemote, handler.txpool) + defer local.Close() + defer remote.Close() + + go handler.handler.runEthPeer(local, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(handler.handler), peer) + }) + // Run the handshake locally to avoid spinning up a remote handler + var ( + genesis = handler.chain.Genesis() + head = handler.chain.CurrentBlock() + td = handler.chain.GetTd(head.Hash(), head.NumberU64()) + ) + if err := remote.Handshake(1, td, head.Hash(), genesis.Hash(), forkid.NewIDWithChain(handler.chain), forkid.NewFilter(handler.chain)); err != nil { + t.Fatalf("failed to run protocol handshake") + } + // Connect a new peer and check that we receive the checkpoint challenge + if checkpoint { + if err := remote.ExpectRequestHeadersByNumber(response.Number.Uint64(), 1, 0, false); err != nil { + t.Fatalf("challenge mismatch: %v", err) + } + // Create a block to reply to the challenge if no timeout is simulated + if !timeout { + if empty { + if err := remote.SendBlockHeaders([]*types.Header{}); err != nil { + t.Fatalf("failed to answer challenge: %v", err) + } + } else if match { + if err := remote.SendBlockHeaders([]*types.Header{response}); err != nil { + t.Fatalf("failed to answer challenge: %v", err) + } + } else { + if err := remote.SendBlockHeaders([]*types.Header{{Number: response.Number}}); err != nil { + t.Fatalf("failed to answer challenge: %v", err) + } + } + } + } + // Wait until the test timeout passes to ensure proper cleanup + time.Sleep(syncChallengeTimeout + 300*time.Millisecond) + + // Verify that the remote peer is maintained or dropped + if drop { + if peers := handler.handler.peers.Len(); peers != 0 { + t.Fatalf("peer count mismatch: have %d, want %d", peers, 0) + } + } else { + if peers := handler.handler.peers.Len(); peers != 1 { + t.Fatalf("peer count mismatch: have %d, want %d", peers, 1) + } + } +} + +// Tests that blocks are broadcast to a sqrt number of peers only. +func TestBroadcastBlock1Peer(t *testing.T) { testBroadcastBlock(t, 1, 1) } +func TestBroadcastBlock2Peers(t *testing.T) { testBroadcastBlock(t, 2, 1) } +func TestBroadcastBlock3Peers(t *testing.T) { testBroadcastBlock(t, 3, 1) } +func TestBroadcastBlock4Peers(t *testing.T) { testBroadcastBlock(t, 4, 2) } +func TestBroadcastBlock5Peers(t *testing.T) { testBroadcastBlock(t, 5, 2) } +func TestBroadcastBlock8Peers(t *testing.T) { testBroadcastBlock(t, 9, 3) } +func TestBroadcastBlock12Peers(t *testing.T) { testBroadcastBlock(t, 12, 3) } +func TestBroadcastBlock16Peers(t *testing.T) { testBroadcastBlock(t, 16, 4) } +func TestBroadcastBloc26Peers(t *testing.T) { testBroadcastBlock(t, 26, 5) } +func TestBroadcastBlock100Peers(t *testing.T) { testBroadcastBlock(t, 100, 10) } + +func testBroadcastBlock(t *testing.T, peers, bcasts int) { + t.Parallel() + + // Create a source handler to broadcast blocks from and a number of sinks + // to receive them. + source := newTestHandlerWithBlocks(1) + defer source.close() + + sinks := make([]*testEthHandler, peers) + for i := 0; i < len(sinks); i++ { + sinks[i] = new(testEthHandler) + } + // Interconnect all the sink handlers with the source handler + var ( + genesis = source.chain.Genesis() + td = source.chain.GetTd(genesis.Hash(), genesis.NumberU64()) + ) + for i, sink := range sinks { + sink := sink // Closure for gorotuine below + + sourcePipe, sinkPipe := p2p.MsgPipe() + defer sourcePipe.Close() + defer sinkPipe.Close() + + sourcePeer := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{byte(i)}, "", nil), sourcePipe, nil) + sinkPeer := eth.NewPeer(eth.ETH64, p2p.NewPeer(enode.ID{0}, "", nil), sinkPipe, nil) + defer sourcePeer.Close() + defer sinkPeer.Close() + + go source.handler.runEthPeer(sourcePeer, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(source.handler), peer) + }) + if err := sinkPeer.Handshake(1, td, genesis.Hash(), genesis.Hash(), forkid.NewIDWithChain(source.chain), forkid.NewFilter(source.chain)); err != nil { + t.Fatalf("failed to run protocol handshake") + } + go eth.Handle(sink, sinkPeer) + } + // Subscribe to all the transaction pools + blockChs := make([]chan *types.Block, len(sinks)) + for i := 0; i < len(sinks); i++ { + blockChs[i] = make(chan *types.Block, 1) + defer close(blockChs[i]) + + sub := sinks[i].blockBroadcasts.Subscribe(blockChs[i]) + defer sub.Unsubscribe() + } + // Initiate a block propagation across the peers + time.Sleep(100 * time.Millisecond) + source.handler.BroadcastBlock(source.chain.CurrentBlock(), true) + + // Iterate through all the sinks and ensure the correct number got the block + done := make(chan struct{}, peers) + for _, ch := range blockChs { + ch := ch + go func() { + <-ch + done <- struct{}{} + }() + } + var received int + for { + select { + case <-done: + received++ + + case <-time.After(100 * time.Millisecond): + if received != bcasts { + t.Errorf("broadcast count mismatch: have %d, want %d", received, bcasts) + } + return + } + } +} + +// Tests that a propagated malformed block (uncles or transactions don't match +// with the hashes in the header) gets discarded and not broadcast forward. +func TestBroadcastMalformedBlock64(t *testing.T) { testBroadcastMalformedBlock(t, 64) } +func TestBroadcastMalformedBlock65(t *testing.T) { testBroadcastMalformedBlock(t, 65) } + +func testBroadcastMalformedBlock(t *testing.T, protocol uint) { + t.Parallel() + + // Create a source handler to broadcast blocks from and a number of sinks + // to receive them. + source := newTestHandlerWithBlocks(1) + defer source.close() + + // Create a source handler to send messages through and a sink peer to receive them + p2pSrc, p2pSink := p2p.MsgPipe() + defer p2pSrc.Close() + defer p2pSink.Close() + + src := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), p2pSrc, source.txpool) + sink := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), p2pSink, source.txpool) + defer src.Close() + defer sink.Close() + + go source.handler.runEthPeer(src, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(source.handler), peer) + }) + // Run the handshake locally to avoid spinning up a sink handler + var ( + genesis = source.chain.Genesis() + td = source.chain.GetTd(genesis.Hash(), genesis.NumberU64()) + ) + if err := sink.Handshake(1, td, genesis.Hash(), genesis.Hash(), forkid.NewIDWithChain(source.chain), forkid.NewFilter(source.chain)); err != nil { + t.Fatalf("failed to run protocol handshake") + } + // After the handshake completes, the source handler should stream the sink + // the blocks, subscribe to inbound network events + backend := new(testEthHandler) + + blocks := make(chan *types.Block, 1) + sub := backend.blockBroadcasts.Subscribe(blocks) + defer sub.Unsubscribe() + + go eth.Handle(backend, sink) + + // Create various combinations of malformed blocks + head := source.chain.CurrentBlock() + + malformedUncles := head.Header() + malformedUncles.UncleHash[0]++ + malformedTransactions := head.Header() + malformedTransactions.TxHash[0]++ + malformedEverything := head.Header() + malformedEverything.UncleHash[0]++ + malformedEverything.TxHash[0]++ + + // Try to broadcast all malformations and ensure they all get discarded + for _, header := range []*types.Header{malformedUncles, malformedTransactions, malformedEverything} { + block := types.NewBlockWithHeader(header).WithBody(head.Transactions(), head.Uncles()) + if err := src.SendNewBlock(block, big.NewInt(131136)); err != nil { + t.Fatalf("failed to broadcast block: %v", err) + } + select { + case <-blocks: + t.Fatalf("malformed block forwarded") + case <-time.After(100 * time.Millisecond): + } + } +} diff --git a/eth/handler_snap.go b/eth/handler_snap.go new file mode 100644 index 000000000000..25975bf60b68 --- /dev/null +++ b/eth/handler_snap.go @@ -0,0 +1,48 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/eth/protocols/snap" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// snapHandler implements the snap.Backend interface to handle the various network +// packets that are sent as replies or broadcasts. +type snapHandler handler + +func (h *snapHandler) Chain() *core.BlockChain { return h.chain } + +// RunPeer is invoked when a peer joins on the `snap` protocol. +func (h *snapHandler) RunPeer(peer *snap.Peer, hand snap.Handler) error { + return (*handler)(h).runSnapPeer(peer, hand) +} + +// PeerInfo retrieves all known `snap` information about a peer. +func (h *snapHandler) PeerInfo(id enode.ID) interface{} { + if p := h.peers.snapPeer(id.String()); p != nil { + return p.info() + } + return nil +} + +// Handle is invoked from a peer's message handler when it receives a new remote +// message that the handler couldn't consume and serve itself. +func (h *snapHandler) Handle(peer *snap.Peer, packet snap.Packet) error { + return h.downloader.DeliverSnapPacket(peer, packet) +} diff --git a/eth/handler_test.go b/eth/handler_test.go index fc6c6f27454a..a90ef5c348aa 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -17,678 +17,154 @@ package eth import ( - "fmt" - "math" "math/big" - "math/rand" - "testing" - "time" + "sort" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus/ethash" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/params" ) -// Tests that block headers can be retrieved from a remote chain based on user queries. -func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } -func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) } +var ( + // testKey is a private key to use for funding a tester account. + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") -func testGetBlockHeaders(t *testing.T, protocol int) { - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) - peer, _ := newTestPeer("peer", protocol, pm, true) - defer peer.close() + // testAddr is the Ethereum address of the tester account. + testAddr = crypto.PubkeyToAddress(testKey.PublicKey) +) - // Create a "random" unknown hash for testing - var unknown common.Hash - for i := range unknown { - unknown[i] = byte(i) - } - // Create a batch of tests for various scenarios - limit := uint64(downloader.MaxHeaderFetch) - tests := []struct { - query *getBlockHeadersData // The query to execute for header retrieval - expect []common.Hash // The hashes of the block whose headers are expected - }{ - // A single random block should be retrievable by hash and number too - { - &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, - []common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 1}, - []common.Hash{pm.blockchain.GetBlockByNumber(limit / 2).Hash()}, - }, - // Multiple headers should be retrievable in both directions - { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(limit / 2).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 + 1).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 + 2).Hash(), - }, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(limit / 2).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 - 1).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 - 2).Hash(), - }, - }, - // Multiple headers with skip lists should be retrievable - { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(limit / 2).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 + 4).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 + 8).Hash(), - }, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(limit / 2).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 - 4).Hash(), - pm.blockchain.GetBlockByNumber(limit/2 - 8).Hash(), - }, - }, - // The chain endpoints should be retrievable - { - &getBlockHeadersData{Origin: hashOrNumber{Number: 0}, Amount: 1}, - []common.Hash{pm.blockchain.GetBlockByNumber(0).Hash()}, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64()}, Amount: 1}, - []common.Hash{pm.blockchain.CurrentBlock().Hash()}, - }, - // Ensure protocol limits are honored - { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, - pm.blockchain.GetBlockHashesFromHash(pm.blockchain.CurrentBlock().Hash(), limit), - }, - // Check that requesting more than available is handled gracefully - { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(), - pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64()).Hash(), - }, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(4).Hash(), - pm.blockchain.GetBlockByNumber(0).Hash(), - }, - }, - // Check that requesting more than available is handled gracefully, even if mid skip - { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 4).Hash(), - pm.blockchain.GetBlockByNumber(pm.blockchain.CurrentBlock().NumberU64() - 1).Hash(), - }, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(4).Hash(), - pm.blockchain.GetBlockByNumber(1).Hash(), - }, - }, - // Check a corner case where requesting more can iterate past the endpoints - { - &getBlockHeadersData{Origin: hashOrNumber{Number: 2}, Amount: 5, Reverse: true}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(2).Hash(), - pm.blockchain.GetBlockByNumber(1).Hash(), - pm.blockchain.GetBlockByNumber(0).Hash(), - }, - }, - // Check a corner case where skipping overflow loops back into the chain start - { - &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(3).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64 - 1}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(3).Hash(), - }, - }, - // Check a corner case where skipping overflow loops back to the same header - { - &getBlockHeadersData{Origin: hashOrNumber{Hash: pm.blockchain.GetBlockByNumber(1).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64}, - []common.Hash{ - pm.blockchain.GetBlockByNumber(1).Hash(), - }, - }, - // Check that non existing headers aren't returned - { - &getBlockHeadersData{Origin: hashOrNumber{Hash: unknown}, Amount: 1}, - []common.Hash{}, - }, { - &getBlockHeadersData{Origin: hashOrNumber{Number: pm.blockchain.CurrentBlock().NumberU64() + 1}, Amount: 1}, - []common.Hash{}, - }, - } - // Run each of the tests and verify the results against the chain - for i, tt := range tests { - // Collect the headers to expect in the response - headers := []*types.Header{} - for _, hash := range tt.expect { - headers = append(headers, pm.blockchain.GetBlockByHash(hash).Header()) - } - // Send the hash request and verify the response - p2p.Send(peer.app, 0x03, tt.query) - if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil { - t.Errorf("test %d: headers mismatch: %v", i, err) - } - // If the test used number origins, repeat with hashes as the too - if tt.query.Origin.Hash == (common.Hash{}) { - if origin := pm.blockchain.GetBlockByNumber(tt.query.Origin.Number); origin != nil { - tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0 +// testTxPool is a mock transaction pool that blindly accepts all transactions. +// Its goal is to get around setting up a valid statedb for the balance and nonce +// checks. +type testTxPool struct { + pool map[common.Hash]*types.Transaction // Hash map of collected transactions - p2p.Send(peer.app, 0x03, tt.query) - if err := p2p.ExpectMsg(peer.app, 0x04, headers); err != nil { - t.Errorf("test %d: headers mismatch: %v", i, err) - } - } - } - } + txFeed event.Feed // Notification feed to allow waiting for inclusion + lock sync.RWMutex // Protects the transaction pool } -// Tests that block contents can be retrieved from a remote chain based on their hashes. -func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) } -func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) } - -func testGetBlockBodies(t *testing.T, protocol int) { - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil) - peer, _ := newTestPeer("peer", protocol, pm, true) - defer peer.close() - - // Create a batch of tests for various scenarios - limit := downloader.MaxBlockFetch - tests := []struct { - random int // Number of blocks to fetch randomly from the chain - explicit []common.Hash // Explicitly requested blocks - available []bool // Availability of explicitly requested blocks - expected int // Total number of existing blocks to expect - }{ - {1, nil, nil, 1}, // A single random block should be retrievable - {10, nil, nil, 10}, // Multiple random blocks should be retrievable - {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable - {limit + 1, nil, nil, limit}, // No more than the possible block count should be returned - {0, []common.Hash{pm.blockchain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable - {0, []common.Hash{pm.blockchain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable - {0, []common.Hash{{}}, []bool{false}, 0}, // A non existent block should not be returned - - // Existing and non-existing blocks interleaved should not cause problems - {0, []common.Hash{ - {}, - pm.blockchain.GetBlockByNumber(1).Hash(), - {}, - pm.blockchain.GetBlockByNumber(10).Hash(), - {}, - pm.blockchain.GetBlockByNumber(100).Hash(), - {}, - }, []bool{false, true, false, true, false, true, false}, 3}, - } - // Run each of the tests and verify the results against the chain - for i, tt := range tests { - // Collect the hashes to request, and the response to expect - hashes, seen := []common.Hash{}, make(map[int64]bool) - bodies := []*blockBody{} - - for j := 0; j < tt.random; j++ { - for { - num := rand.Int63n(int64(pm.blockchain.CurrentBlock().NumberU64())) - if !seen[num] { - seen[num] = true - - block := pm.blockchain.GetBlockByNumber(uint64(num)) - hashes = append(hashes, block.Hash()) - if len(bodies) < tt.expected { - bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) - } - break - } - } - } - for j, hash := range tt.explicit { - hashes = append(hashes, hash) - if tt.available[j] && len(bodies) < tt.expected { - block := pm.blockchain.GetBlockByHash(hash) - bodies = append(bodies, &blockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) - } - } - // Send the hash request and verify the response - p2p.Send(peer.app, 0x05, hashes) - if err := p2p.ExpectMsg(peer.app, 0x06, bodies); err != nil { - t.Errorf("test %d: bodies mismatch: %v", i, err) - } +// newTestTxPool creates a mock transaction pool. +func newTestTxPool() *testTxPool { + return &testTxPool{ + pool: make(map[common.Hash]*types.Transaction), } } -// Tests that the node state database can be retrieved based on hashes. -func TestGetNodeData63(t *testing.T) { testGetNodeData(t, 63) } -func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) } +// Has returns an indicator whether txpool has a transaction +// cached with the given hash. +func (p *testTxPool) Has(hash common.Hash) bool { + p.lock.Lock() + defer p.lock.Unlock() -func testGetNodeData(t *testing.T, protocol int) { - // Define three accounts to simulate transactions with - acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") - acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") - acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey) - acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey) - - signer := types.HomesteadSigner{} - // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test) - generator := func(i int, block *core.BlockGen) { - switch i { - case 0: - // In block 1, the test bank sends account #1 some ether. - tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey) - block.AddTx(tx) - case 1: - // In block 2, the test bank sends some more ether to account #1. - // acc1Addr passes it on to account #2. - tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey) - tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key) - block.AddTx(tx1) - block.AddTx(tx2) - case 2: - // Block 3 is empty but was mined by account #2. - block.SetCoinbase(acc2Addr) - block.SetExtra([]byte("yeehaw")) - case 3: - // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). - b2 := block.PrevBlock(1).Header() - b2.Extra = []byte("foo") - block.AddUncle(b2) - b3 := block.PrevBlock(2).Header() - b3.Extra = []byte("foo") - block.AddUncle(b3) - } - } - // Assemble the test environment - pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) - peer, _ := newTestPeer("peer", protocol, pm, true) - defer peer.close() - - // Fetch for now the entire chain db - hashes := []common.Hash{} - - it := db.NewIterator(nil, nil) - for it.Next() { - if key := it.Key(); len(key) == common.HashLength { - hashes = append(hashes, common.BytesToHash(key)) - } - } - it.Release() - - p2p.Send(peer.app, 0x0d, hashes) - msg, err := peer.app.ReadMsg() - if err != nil { - t.Fatalf("failed to read node data response: %v", err) - } - if msg.Code != 0x0e { - t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, 0x0c) - } - var data [][]byte - if err := msg.Decode(&data); err != nil { - t.Fatalf("failed to decode response node data: %v", err) - } - // Verify that all hashes correspond to the requested data, and reconstruct a state tree - for i, want := range hashes { - if hash := crypto.Keccak256Hash(data[i]); hash != want { - t.Errorf("data hash mismatch: have %x, want %x", hash, want) - } - } - statedb := rawdb.NewMemoryDatabase() - for i := 0; i < len(data); i++ { - statedb.Put(hashes[i].Bytes(), data[i]) - } - accounts := []common.Address{testBank, acc1Addr, acc2Addr} - for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { - trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil) - - for j, acc := range accounts { - state, _ := pm.blockchain.State() - bw := state.GetBalance(acc) - bh := trie.GetBalance(acc) - - if (bw != nil && bh == nil) || (bw == nil && bh != nil) { - t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw) - } - if bw != nil && bh != nil && bw.Cmp(bw) != 0 { - t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw) - } - } - } + return p.pool[hash] != nil } -// Tests that the transaction receipts can be retrieved based on hashes. -func TestGetReceipt63(t *testing.T) { testGetReceipt(t, 63) } -func TestGetReceipt64(t *testing.T) { testGetReceipt(t, 64) } +// Get retrieves the transaction from local txpool with given +// tx hash. +func (p *testTxPool) Get(hash common.Hash) *types.Transaction { + p.lock.Lock() + defer p.lock.Unlock() -func testGetReceipt(t *testing.T, protocol int) { - // Define three accounts to simulate transactions with - acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") - acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") - acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey) - acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey) - - signer := types.HomesteadSigner{} - // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test) - generator := func(i int, block *core.BlockGen) { - switch i { - case 0: - // In block 1, the test bank sends account #1 some ether. - tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testBankKey) - block.AddTx(tx) - case 1: - // In block 2, the test bank sends some more ether to account #1. - // acc1Addr passes it on to account #2. - tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testBank), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testBankKey) - tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key) - block.AddTx(tx1) - block.AddTx(tx2) - case 2: - // Block 3 is empty but was mined by account #2. - block.SetCoinbase(acc2Addr) - block.SetExtra([]byte("yeehaw")) - case 3: - // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). - b2 := block.PrevBlock(1).Header() - b2.Extra = []byte("foo") - block.AddUncle(b2) - b3 := block.PrevBlock(2).Header() - b3.Extra = []byte("foo") - block.AddUncle(b3) - } - } - // Assemble the test environment - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil) - peer, _ := newTestPeer("peer", protocol, pm, true) - defer peer.close() - - // Collect the hashes to request, and the response to expect - hashes, receipts := []common.Hash{}, []types.Receipts{} - for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { - block := pm.blockchain.GetBlockByNumber(i) - - hashes = append(hashes, block.Hash()) - receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash())) - } - // Send the hash request and verify the response - p2p.Send(peer.app, 0x0f, hashes) - if err := p2p.ExpectMsg(peer.app, 0x10, receipts); err != nil { - t.Errorf("receipts mismatch: %v", err) - } + return p.pool[hash] } -// Tests that post eth protocol handshake, clients perform a mutual checkpoint -// challenge to validate each other's chains. Hash mismatches, or missing ones -// during a fast sync should lead to the peer getting dropped. -func TestCheckpointChallenge(t *testing.T) { - tests := []struct { - syncmode downloader.SyncMode - checkpoint bool - timeout bool - empty bool - match bool - drop bool - }{ - // If checkpointing is not enabled locally, don't challenge and don't drop - {downloader.FullSync, false, false, false, false, false}, - {downloader.FastSync, false, false, false, false, false}, - - // If checkpointing is enabled locally and remote response is empty, only drop during fast sync - {downloader.FullSync, true, false, true, false, false}, - {downloader.FastSync, true, false, true, false, true}, // Special case, fast sync, unsynced peer - - // If checkpointing is enabled locally and remote response mismatches, always drop - {downloader.FullSync, true, false, false, false, true}, - {downloader.FastSync, true, false, false, false, true}, - - // If checkpointing is enabled locally and remote response matches, never drop - {downloader.FullSync, true, false, false, true, false}, - {downloader.FastSync, true, false, false, true, false}, +// AddRemotes appends a batch of transactions to the pool, and notifies any +// listeners if the addition channel is non nil +func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error { + p.lock.Lock() + defer p.lock.Unlock() - // If checkpointing is enabled locally and remote times out, always drop - {downloader.FullSync, true, true, false, true, true}, - {downloader.FastSync, true, true, false, true, true}, - } - for _, tt := range tests { - t.Run(fmt.Sprintf("sync %v checkpoint %v timeout %v empty %v match %v", tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match), func(t *testing.T) { - testCheckpointChallenge(t, tt.syncmode, tt.checkpoint, tt.timeout, tt.empty, tt.match, tt.drop) - }) + for _, tx := range txs { + p.pool[tx.Hash()] = tx } + p.txFeed.Send(core.NewTxsEvent{Txs: txs}) + return make([]error, len(txs)) } -func testCheckpointChallenge(t *testing.T, syncmode downloader.SyncMode, checkpoint bool, timeout bool, empty bool, match bool, drop bool) { - // Reduce the checkpoint handshake challenge timeout - defer func(old time.Duration) { syncChallengeTimeout = old }(syncChallengeTimeout) - syncChallengeTimeout = 250 * time.Millisecond - - // Initialize a chain and generate a fake CHT if checkpointing is enabled - var ( - db = rawdb.NewMemoryDatabase() - config = new(params.ChainConfig) - ) - (&core.Genesis{Config: config}).MustCommit(db) // Commit genesis block - // If checkpointing is enabled, create and inject a fake CHT and the corresponding - // chllenge response. - var response *types.Header - var cht *params.TrustedCheckpoint - if checkpoint { - index := uint64(rand.Intn(500)) - number := (index+1)*params.CHTFrequency - 1 - response = &types.Header{Number: big.NewInt(int64(number)), Extra: []byte("valid")} +// Pending returns all the transactions known to the pool +func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) { + p.lock.RLock() + defer p.lock.RUnlock() - cht = ¶ms.TrustedCheckpoint{ - SectionIndex: index, - SectionHead: response.Hash(), - } + batches := make(map[common.Address]types.Transactions) + for _, tx := range p.pool { + from, _ := types.Sender(types.HomesteadSigner{}, tx) + batches[from] = append(batches[from], tx) } - // Create a checkpoint aware protocol manager - blockchain, err := core.NewBlockChain(db, nil, config, ethash.NewFaker(), vm.Config{}, nil, nil) - if err != nil { - t.Fatalf("failed to create new blockchain: %v", err) - } - pm, err := NewProtocolManager(config, cht, syncmode, DefaultConfig.NetworkId, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, ethash.NewFaker(), blockchain, db, 1, nil) - if err != nil { - t.Fatalf("failed to start test protocol manager: %v", err) - } - pm.Start(1000) - defer pm.Stop() - - // Connect a new peer and check that we receive the checkpoint challenge - peer, _ := newTestPeer("peer", eth63, pm, true) - defer peer.close() - - if checkpoint { - challenge := &getBlockHeadersData{ - Origin: hashOrNumber{Number: response.Number.Uint64()}, - Amount: 1, - Skip: 0, - Reverse: false, - } - if err := p2p.ExpectMsg(peer.app, GetBlockHeadersMsg, challenge); err != nil { - t.Fatalf("challenge mismatch: %v", err) - } - // Create a block to reply to the challenge if no timeout is simulated - if !timeout { - if empty { - if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{}); err != nil { - t.Fatalf("failed to answer challenge: %v", err) - } - } else if match { - if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{response}); err != nil { - t.Fatalf("failed to answer challenge: %v", err) - } - } else { - if err := p2p.Send(peer.app, BlockHeadersMsg, []*types.Header{{Number: response.Number}}); err != nil { - t.Fatalf("failed to answer challenge: %v", err) - } - } - } - } - // Wait until the test timeout passes to ensure proper cleanup - time.Sleep(syncChallengeTimeout + 300*time.Millisecond) - - // Verify that the remote peer is maintained or dropped - if drop { - if peers := pm.peers.Len(); peers != 0 { - t.Fatalf("peer count mismatch: have %d, want %d", peers, 0) - } - } else { - if peers := pm.peers.Len(); peers != 1 { - t.Fatalf("peer count mismatch: have %d, want %d", peers, 1) - } + for _, batch := range batches { + sort.Sort(types.TxByNonce(batch)) } + return batches, nil } -func TestBroadcastBlock(t *testing.T) { - var tests = []struct { - totalPeers int - broadcastExpected int - }{ - {1, 1}, - {2, 1}, - {3, 1}, - {4, 2}, - {5, 2}, - {9, 3}, - {12, 3}, - {16, 4}, - {26, 5}, - {100, 10}, - } - for _, test := range tests { - testBroadcastBlock(t, test.totalPeers, test.broadcastExpected) - } +// SubscribeNewTxsEvent should return an event subscription of NewTxsEvent and +// send events to the given channel. +func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { + return p.txFeed.Subscribe(ch) } -func testBroadcastBlock(t *testing.T, totalPeers, broadcastExpected int) { - var ( - evmux = new(event.TypeMux) - pow = ethash.NewFaker() - db = rawdb.NewMemoryDatabase() - config = ¶ms.ChainConfig{} - gspec = &core.Genesis{Config: config} - genesis = gspec.MustCommit(db) - ) - blockchain, err := core.NewBlockChain(db, nil, config, pow, vm.Config{}, nil, nil) - if err != nil { - t.Fatalf("failed to create new blockchain: %v", err) - } - pm, err := NewProtocolManager(config, nil, downloader.FullSync, DefaultConfig.NetworkId, evmux, &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, pow, blockchain, db, 1, nil) - if err != nil { - t.Fatalf("failed to start test protocol manager: %v", err) - } - pm.Start(1000) - defer pm.Stop() - var peers []*testPeer - for i := 0; i < totalPeers; i++ { - peer, _ := newTestPeer(fmt.Sprintf("peer %d", i), eth63, pm, true) - defer peer.close() - - peers = append(peers, peer) - } - chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, gen *core.BlockGen) {}) - pm.BroadcastBlock(chain[0], true /*propagate*/) - - errCh := make(chan error, totalPeers) - doneCh := make(chan struct{}, totalPeers) - for _, peer := range peers { - go func(p *testPeer) { - if err := p2p.ExpectMsg(p.app, NewBlockMsg, &newBlockData{Block: chain[0], TD: big.NewInt(131136)}); err != nil { - errCh <- err - } else { - doneCh <- struct{}{} - } - }(peer) - } - var received int - for { - select { - case <-doneCh: - received++ - if received > broadcastExpected { - // We can bail early here - t.Errorf("broadcast count mismatch: have %d > want %d", received, broadcastExpected) - return - } - case <-time.After(2 * time.Second): - if received != broadcastExpected { - t.Errorf("broadcast count mismatch: have %d, want %d", received, broadcastExpected) - } - return - case err = <-errCh: - t.Fatalf("broadcast failed: %v", err) - } - } +// testHandler is a live implementation of the Ethereum protocol handler, just +// preinitialized with some sane testing defaults and the transaction pool mocked +// out. +type testHandler struct { + db ethdb.Database + chain *core.BlockChain + txpool *testTxPool + handler *handler +} +// newTestHandler creates a new handler for testing purposes with no blocks. +func newTestHandler() *testHandler { + return newTestHandlerWithBlocks(0) } -// Tests that a propagated malformed block (uncles or transactions don't match -// with the hashes in the header) gets discarded and not broadcast forward. -func TestBroadcastMalformedBlock(t *testing.T) { - // Create a live node to test propagation with - var ( - engine = ethash.NewFaker() - db = rawdb.NewMemoryDatabase() - config = ¶ms.ChainConfig{} - gspec = &core.Genesis{Config: config} - genesis = gspec.MustCommit(db) - ) - blockchain, err := core.NewBlockChain(db, nil, config, engine, vm.Config{}, nil, nil) - if err != nil { - t.Fatalf("failed to create new blockchain: %v", err) +// newTestHandlerWithBlocks creates a new handler for testing purposes, with a +// given number of initial blocks. +func newTestHandlerWithBlocks(blocks int) *testHandler { + // Create a database pre-initialize with a genesis block + db := rawdb.NewMemoryDatabase() + (&core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}}, + }).MustCommit(db) + + chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil) + + bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, nil) + if _, err := chain.InsertChain(bs); err != nil { + panic(err) + } + txpool := newTestTxPool() + + handler, _ := newHandler(&handlerConfig{ + Database: db, + Chain: chain, + TxPool: txpool, + Network: 1, + Sync: downloader.FastSync, + BloomCache: 1, + }) + handler.Start(1000) + + return &testHandler{ + db: db, + chain: chain, + txpool: txpool, + handler: handler, } - pm, err := NewProtocolManager(config, nil, downloader.FullSync, DefaultConfig.NetworkId, new(event.TypeMux), new(testTxPool), engine, blockchain, db, 1, nil) - if err != nil { - t.Fatalf("failed to start test protocol manager: %v", err) - } - pm.Start(2) - defer pm.Stop() - - // Create two peers, one to send the malformed block with and one to check - // propagation - source, _ := newTestPeer("source", eth63, pm, true) - defer source.close() - - sink, _ := newTestPeer("sink", eth63, pm, true) - defer sink.close() - - // Create various combinations of malformed blocks - chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 1, func(i int, gen *core.BlockGen) {}) - - malformedUncles := chain[0].Header() - malformedUncles.UncleHash[0]++ - malformedTransactions := chain[0].Header() - malformedTransactions.TxHash[0]++ - malformedEverything := chain[0].Header() - malformedEverything.UncleHash[0]++ - malformedEverything.TxHash[0]++ +} - // Keep listening to broadcasts and notify if any arrives - notify := make(chan struct{}, 1) - go func() { - if _, err := sink.app.ReadMsg(); err == nil { - notify <- struct{}{} - } - }() - // Try to broadcast all malformations and ensure they all get discarded - for _, header := range []*types.Header{malformedUncles, malformedTransactions, malformedEverything} { - block := types.NewBlockWithHeader(header).WithBody(chain[0].Transactions(), chain[0].Uncles()) - if err := p2p.Send(source.app, NewBlockMsg, []interface{}{block, big.NewInt(131136)}); err != nil { - t.Fatalf("failed to broadcast block: %v", err) - } - select { - case <-notify: - t.Fatalf("malformed block forwarded") - case <-time.After(100 * time.Millisecond): - } - } +// close tears down the handler and all its internal constructs. +func (b *testHandler) close() { + b.handler.Stop() + b.chain.Stop() } diff --git a/eth/helper_test.go b/eth/helper_test.go deleted file mode 100644 index 65effcc16537..000000000000 --- a/eth/helper_test.go +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -// This file contains some shares testing functionality, common to multiple -// different files and modules being tested. - -package eth - -import ( - "crypto/ecdsa" - "crypto/rand" - "fmt" - "math/big" - "sort" - "sync" - "testing" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/consensus/ethash" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/forkid" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/eth/downloader" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/params" -) - -var ( - testBankKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - testBank = crypto.PubkeyToAddress(testBankKey.PublicKey) -) - -// newTestProtocolManager creates a new protocol manager for testing purposes, -// with the given number of blocks already known, and potential notification -// channels for different events. -func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, ethdb.Database, error) { - var ( - evmux = new(event.TypeMux) - engine = ethash.NewFaker() - db = rawdb.NewMemoryDatabase() - gspec = &core.Genesis{ - Config: params.TestChainConfig, - Alloc: core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}}, - } - genesis = gspec.MustCommit(db) - blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil) - ) - chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator) - if _, err := blockchain.InsertChain(chain); err != nil { - panic(err) - } - pm, err := NewProtocolManager(gspec.Config, nil, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx, pool: make(map[common.Hash]*types.Transaction)}, engine, blockchain, db, 1, nil) - if err != nil { - return nil, nil, err - } - pm.Start(1000) - return pm, db, nil -} - -// newTestProtocolManagerMust creates a new protocol manager for testing purposes, -// with the given number of blocks already known, and potential notification -// channels for different events. In case of an error, the constructor force- -// fails the test. -func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, ethdb.Database) { - pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx) - if err != nil { - t.Fatalf("Failed to create protocol manager: %v", err) - } - return pm, db -} - -// testTxPool is a fake, helper transaction pool for testing purposes -type testTxPool struct { - txFeed event.Feed - pool map[common.Hash]*types.Transaction // Hash map of collected transactions - added chan<- []*types.Transaction // Notification channel for new transactions - - lock sync.RWMutex // Protects the transaction pool -} - -// Has returns an indicator whether txpool has a transaction -// cached with the given hash. -func (p *testTxPool) Has(hash common.Hash) bool { - p.lock.Lock() - defer p.lock.Unlock() - - return p.pool[hash] != nil -} - -// Get retrieves the transaction from local txpool with given -// tx hash. -func (p *testTxPool) Get(hash common.Hash) *types.Transaction { - p.lock.Lock() - defer p.lock.Unlock() - - return p.pool[hash] -} - -// AddRemotes appends a batch of transactions to the pool, and notifies any -// listeners if the addition channel is non nil -func (p *testTxPool) AddRemotes(txs []*types.Transaction) []error { - p.lock.Lock() - defer p.lock.Unlock() - - for _, tx := range txs { - p.pool[tx.Hash()] = tx - } - if p.added != nil { - p.added <- txs - } - p.txFeed.Send(core.NewTxsEvent{Txs: txs}) - return make([]error, len(txs)) -} - -// Pending returns all the transactions known to the pool -func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) { - p.lock.RLock() - defer p.lock.RUnlock() - - batches := make(map[common.Address]types.Transactions) - for _, tx := range p.pool { - from, _ := types.Sender(types.HomesteadSigner{}, tx) - batches[from] = append(batches[from], tx) - } - for _, batch := range batches { - sort.Sort(types.TxByNonce(batch)) - } - return batches, nil -} - -func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { - return p.txFeed.Subscribe(ch) -} - -// newTestTransaction create a new dummy transaction. -func newTestTransaction(from *ecdsa.PrivateKey, nonce uint64, datasize int) *types.Transaction { - tx := types.NewTransaction(nonce, common.Address{}, big.NewInt(0), 100000, big.NewInt(0), make([]byte, datasize)) - tx, _ = types.SignTx(tx, types.HomesteadSigner{}, from) - return tx -} - -// testPeer is a simulated peer to allow testing direct network calls. -type testPeer struct { - net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging - app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side - *peer -} - -// newTestPeer creates a new peer registered at the given protocol manager. -func newTestPeer(name string, version int, pm *ProtocolManager, shake bool) (*testPeer, <-chan error) { - // Create a message pipe to communicate through - app, net := p2p.MsgPipe() - - // Start the peer on a new thread - var id enode.ID - rand.Read(id[:]) - peer := pm.newPeer(version, p2p.NewPeer(id, name, nil), net, pm.txpool.Get) - errc := make(chan error, 1) - go func() { errc <- pm.runPeer(peer) }() - tp := &testPeer{app: app, net: net, peer: peer} - - // Execute any implicitly requested handshakes and return - if shake { - var ( - genesis = pm.blockchain.Genesis() - head = pm.blockchain.CurrentHeader() - td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) - ) - tp.handshake(nil, td, head.Hash(), genesis.Hash(), forkid.NewID(pm.blockchain), forkid.NewFilter(pm.blockchain)) - } - return tp, errc -} - -// handshake simulates a trivial handshake that expects the same state from the -// remote side as we are simulating locally. -func (p *testPeer) handshake(t *testing.T, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) { - var msg interface{} - switch { - case p.version == eth63: - msg = &statusData63{ - ProtocolVersion: uint32(p.version), - NetworkId: DefaultConfig.NetworkId, - TD: td, - CurrentBlock: head, - GenesisBlock: genesis, - } - case p.version >= eth64: - msg = &statusData{ - ProtocolVersion: uint32(p.version), - NetworkID: DefaultConfig.NetworkId, - TD: td, - Head: head, - Genesis: genesis, - ForkID: forkID, - } - default: - panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) - } - if err := p2p.ExpectMsg(p.app, StatusMsg, msg); err != nil { - t.Fatalf("status recv: %v", err) - } - if err := p2p.Send(p.app, StatusMsg, msg); err != nil { - t.Fatalf("status send: %v", err) - } -} - -// close terminates the local side of the peer, notifying the remote protocol -// manager of termination. -func (p *testPeer) close() { - p.app.Close() -} diff --git a/eth/peer.go b/eth/peer.go index 21b82a19c547..6970c8afd3f7 100644 --- a/eth/peer.go +++ b/eth/peer.go @@ -17,806 +17,58 @@ package eth import ( - "errors" - "fmt" "math/big" "sync" "time" - mapset "github.com/deckarep/golang-set" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/forkid" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/eth/protocols/snap" ) -var ( - errClosed = errors.New("peer set is closed") - errAlreadyRegistered = errors.New("peer is already registered") - errNotRegistered = errors.New("peer is not registered") -) - -const ( - maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) - maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS) - - // maxQueuedTxs is the maximum number of transactions to queue up before dropping - // older broadcasts. - maxQueuedTxs = 4096 - - // maxQueuedTxAnns is the maximum number of transaction announcements to queue up - // before dropping older announcements. - maxQueuedTxAnns = 4096 - - // maxQueuedBlocks is the maximum number of block propagations to queue up before - // dropping broadcasts. There's not much point in queueing stale blocks, so a few - // that might cover uncles should be enough. - maxQueuedBlocks = 4 - - // maxQueuedBlockAnns is the maximum number of block announcements to queue up before - // dropping broadcasts. Similarly to block propagations, there's no point to queue - // above some healthy uncle limit, so use that. - maxQueuedBlockAnns = 4 - - handshakeTimeout = 5 * time.Second -) - -// max is a helper function which returns the larger of the two given integers. -func max(a, b int) int { - if a > b { - return a - } - return b -} - -// PeerInfo represents a short summary of the Ethereum sub-protocol metadata known +// ethPeerInfo represents a short summary of the `eth` sub-protocol metadata known // about a connected peer. -type PeerInfo struct { - Version int `json:"version"` // Ethereum protocol version negotiated +type ethPeerInfo struct { + Version uint `json:"version"` // Ethereum protocol version negotiated Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain - Head string `json:"head"` // SHA3 hash of the peer's best owned block -} - -// propEvent is a block propagation, waiting for its turn in the broadcast queue. -type propEvent struct { - block *types.Block - td *big.Int -} - -type peer struct { - id string - - *p2p.Peer - rw p2p.MsgReadWriter - - version int // Protocol version negotiated - syncDrop *time.Timer // Timed connection dropper if sync progress isn't validated in time - - head common.Hash - td *big.Int - lock sync.RWMutex - - knownBlocks mapset.Set // Set of block hashes known to be known by this peer - queuedBlocks chan *propEvent // Queue of blocks to broadcast to the peer - queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer - - knownTxs mapset.Set // Set of transaction hashes known to be known by this peer - txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests - txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests - getPooledTx func(common.Hash) *types.Transaction // Callback used to retrieve transaction from txpool - - term chan struct{} // Termination channel to stop the broadcaster + Head string `json:"head"` // Hex hash of the peer's best owned block } -func newPeer(version int, p *p2p.Peer, rw p2p.MsgReadWriter, getPooledTx func(hash common.Hash) *types.Transaction) *peer { - return &peer{ - Peer: p, - rw: rw, - version: version, - id: fmt.Sprintf("%x", p.ID().Bytes()[:8]), - knownTxs: mapset.NewSet(), - knownBlocks: mapset.NewSet(), - queuedBlocks: make(chan *propEvent, maxQueuedBlocks), - queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns), - txBroadcast: make(chan []common.Hash), - txAnnounce: make(chan []common.Hash), - getPooledTx: getPooledTx, - term: make(chan struct{}), - } -} - -// broadcastBlocks is a write loop that multiplexes blocks and block accouncements -// to the remote peer. The goal is to have an async writer that does not lock up -// node internals and at the same time rate limits queued data. -func (p *peer) broadcastBlocks(removePeer func(string)) { - for { - select { - case prop := <-p.queuedBlocks: - if err := p.SendNewBlock(prop.block, prop.td); err != nil { - removePeer(p.id) - return - } - p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td) - - case block := <-p.queuedBlockAnns: - if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil { - removePeer(p.id) - return - } - p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash()) - - case <-p.term: - return - } - } -} - -// broadcastTransactions is a write loop that schedules transaction broadcasts -// to the remote peer. The goal is to have an async writer that does not lock up -// node internals and at the same time rate limits queued data. -func (p *peer) broadcastTransactions(removePeer func(string)) { - var ( - queue []common.Hash // Queue of hashes to broadcast as full transactions - done chan struct{} // Non-nil if background broadcaster is running - fail = make(chan error, 1) // Channel used to receive network error - ) - for { - // If there's no in-flight broadcast running, check if a new one is needed - if done == nil && len(queue) > 0 { - // Pile transaction until we reach our allowed network limit - var ( - hashes []common.Hash - txs []*types.Transaction - size common.StorageSize - ) - for i := 0; i < len(queue) && size < txsyncPackSize; i++ { - if tx := p.getPooledTx(queue[i]); tx != nil { - txs = append(txs, tx) - size += tx.Size() - } - hashes = append(hashes, queue[i]) - } - queue = queue[:copy(queue, queue[len(hashes):])] - - // If there's anything available to transfer, fire up an async writer - if len(txs) > 0 { - done = make(chan struct{}) - go func() { - if err := p.sendTransactions(txs); err != nil { - fail <- err - return - } - close(done) - p.Log().Trace("Sent transactions", "count", len(txs)) - }() - } - } - // Transfer goroutine may or may not have been started, listen for events - select { - case hashes := <-p.txBroadcast: - // New batch of transactions to be broadcast, queue them (with cap) - queue = append(queue, hashes...) - if len(queue) > maxQueuedTxs { - // Fancy copy and resize to ensure buffer doesn't grow indefinitely - queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])] - } - - case <-done: - done = nil - - case <-fail: - removePeer(p.id) - return +// ethPeer is a wrapper around eth.Peer to maintain a few extra metadata. +type ethPeer struct { + *eth.Peer - case <-p.term: - return - } - } + syncDrop *time.Timer // Connection dropper if `eth` sync progress isn't validated in time + lock sync.RWMutex // Mutex protecting the internal fields } -// announceTransactions is a write loop that schedules transaction broadcasts -// to the remote peer. The goal is to have an async writer that does not lock up -// node internals and at the same time rate limits queued data. -func (p *peer) announceTransactions(removePeer func(string)) { - var ( - queue []common.Hash // Queue of hashes to announce as transaction stubs - done chan struct{} // Non-nil if background announcer is running - fail = make(chan error, 1) // Channel used to receive network error - ) - for { - // If there's no in-flight announce running, check if a new one is needed - if done == nil && len(queue) > 0 { - // Pile transaction hashes until we reach our allowed network limit - var ( - hashes []common.Hash - pending []common.Hash - size common.StorageSize - ) - for i := 0; i < len(queue) && size < txsyncPackSize; i++ { - if p.getPooledTx(queue[i]) != nil { - pending = append(pending, queue[i]) - size += common.HashLength - } - hashes = append(hashes, queue[i]) - } - queue = queue[:copy(queue, queue[len(hashes):])] - - // If there's anything available to transfer, fire up an async writer - if len(pending) > 0 { - done = make(chan struct{}) - go func() { - if err := p.sendPooledTransactionHashes(pending); err != nil { - fail <- err - return - } - close(done) - p.Log().Trace("Sent transaction announcements", "count", len(pending)) - }() - } - } - // Transfer goroutine may or may not have been started, listen for events - select { - case hashes := <-p.txAnnounce: - // New batch of transactions to be broadcast, queue them (with cap) - queue = append(queue, hashes...) - if len(queue) > maxQueuedTxAnns { - // Fancy copy and resize to ensure buffer doesn't grow indefinitely - queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxAnns:])] - } - - case <-done: - done = nil - - case <-fail: - removePeer(p.id) - return - - case <-p.term: - return - } - } -} - -// close signals the broadcast goroutine to terminate. -func (p *peer) close() { - close(p.term) -} - -// Info gathers and returns a collection of metadata known about a peer. -func (p *peer) Info() *PeerInfo { +// info gathers and returns some `eth` protocol metadata known about a peer. +func (p *ethPeer) info() *ethPeerInfo { hash, td := p.Head() - return &PeerInfo{ - Version: p.version, + return ðPeerInfo{ + Version: p.Version(), Difficulty: td, Head: hash.Hex(), } } -// Head retrieves a copy of the current head hash and total difficulty of the -// peer. -func (p *peer) Head() (hash common.Hash, td *big.Int) { - p.lock.RLock() - defer p.lock.RUnlock() - - copy(hash[:], p.head[:]) - return hash, new(big.Int).Set(p.td) -} - -// SetHead updates the head hash and total difficulty of the peer. -func (p *peer) SetHead(hash common.Hash, td *big.Int) { - p.lock.Lock() - defer p.lock.Unlock() - - copy(p.head[:], hash[:]) - p.td.Set(td) -} - -// MarkBlock marks a block as known for the peer, ensuring that the block will -// never be propagated to this particular peer. -func (p *peer) MarkBlock(hash common.Hash) { - // If we reached the memory allowance, drop a previously known block hash - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } - p.knownBlocks.Add(hash) -} - -// MarkTransaction marks a transaction as known for the peer, ensuring that it -// will never be propagated to this particular peer. -func (p *peer) MarkTransaction(hash common.Hash) { - // If we reached the memory allowance, drop a previously known transaction hash - for p.knownTxs.Cardinality() >= maxKnownTxs { - p.knownTxs.Pop() - } - p.knownTxs.Add(hash) -} - -// SendTransactions64 sends transactions to the peer and includes the hashes -// in its transaction hash set for future reference. -// -// This method is legacy support for initial transaction exchange in eth/64 and -// prior. For eth/65 and higher use SendPooledTransactionHashes. -func (p *peer) SendTransactions64(txs types.Transactions) error { - return p.sendTransactions(txs) -} - -// sendTransactions sends transactions to the peer and includes the hashes -// in its transaction hash set for future reference. -// -// This method is a helper used by the async transaction sender. Don't call it -// directly as the queueing (memory) and transmission (bandwidth) costs should -// not be managed directly. -func (p *peer) sendTransactions(txs types.Transactions) error { - // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) { - p.knownTxs.Pop() - } - for _, tx := range txs { - p.knownTxs.Add(tx.Hash()) - } - return p2p.Send(p.rw, TransactionMsg, txs) -} - -// AsyncSendTransactions queues a list of transactions (by hash) to eventually -// propagate to a remote peer. The number of pending sends are capped (new ones -// will force old sends to be dropped) -func (p *peer) AsyncSendTransactions(hashes []common.Hash) { - select { - case p.txBroadcast <- hashes: - // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } - case <-p.term: - p.Log().Debug("Dropping transaction propagation", "count", len(hashes)) - } -} - -// sendPooledTransactionHashes sends transaction hashes to the peer and includes -// them in its transaction hash set for future reference. -// -// This method is a helper used by the async transaction announcer. Don't call it -// directly as the queueing (memory) and transmission (bandwidth) costs should -// not be managed directly. -func (p *peer) sendPooledTransactionHashes(hashes []common.Hash) error { - // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } - return p2p.Send(p.rw, NewPooledTransactionHashesMsg, hashes) -} - -// AsyncSendPooledTransactionHashes queues a list of transactions hashes to eventually -// announce to a remote peer. The number of pending sends are capped (new ones -// will force old sends to be dropped) -func (p *peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) { - select { - case p.txAnnounce <- hashes: - // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } - case <-p.term: - p.Log().Debug("Dropping transaction announcement", "count", len(hashes)) - } -} - -// SendPooledTransactionsRLP sends requested transactions to the peer and adds the -// hashes in its transaction hash set for future reference. -// -// Note, the method assumes the hashes are correct and correspond to the list of -// transactions being sent. -func (p *peer) SendPooledTransactionsRLP(hashes []common.Hash, txs []rlp.RawValue) error { - // Mark all the transactions as known, but ensure we don't overflow our limits - for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { - p.knownTxs.Pop() - } - for _, hash := range hashes { - p.knownTxs.Add(hash) - } - return p2p.Send(p.rw, PooledTransactionsMsg, txs) -} - -// SendNewBlockHashes announces the availability of a number of blocks through -// a hash notification. -func (p *peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error { - // Mark all the block hashes as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) { - p.knownBlocks.Pop() - } - for _, hash := range hashes { - p.knownBlocks.Add(hash) - } - request := make(newBlockHashesData, len(hashes)) - for i := 0; i < len(hashes); i++ { - request[i].Hash = hashes[i] - request[i].Number = numbers[i] - } - return p2p.Send(p.rw, NewBlockHashesMsg, request) -} - -// AsyncSendNewBlockHash queues the availability of a block for propagation to a -// remote peer. If the peer's broadcast queue is full, the event is silently -// dropped. -func (p *peer) AsyncSendNewBlockHash(block *types.Block) { - select { - case p.queuedBlockAnns <- block: - // Mark all the block hash as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } - p.knownBlocks.Add(block.Hash()) - default: - p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash()) - } -} - -// SendNewBlock propagates an entire block to a remote peer. -func (p *peer) SendNewBlock(block *types.Block, td *big.Int) error { - // Mark all the block hash as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } - p.knownBlocks.Add(block.Hash()) - return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td}) -} - -// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If -// the peer's broadcast queue is full, the event is silently dropped. -func (p *peer) AsyncSendNewBlock(block *types.Block, td *big.Int) { - select { - case p.queuedBlocks <- &propEvent{block: block, td: td}: - // Mark all the block hash as known, but ensure we don't overflow our limits - for p.knownBlocks.Cardinality() >= maxKnownBlocks { - p.knownBlocks.Pop() - } - p.knownBlocks.Add(block.Hash()) - default: - p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash()) - } -} - -// SendBlockHeaders sends a batch of block headers to the remote peer. -func (p *peer) SendBlockHeaders(headers []*types.Header) error { - return p2p.Send(p.rw, BlockHeadersMsg, headers) -} - -// SendBlockBodies sends a batch of block contents to the remote peer. -func (p *peer) SendBlockBodies(bodies []*blockBody) error { - return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies)) -} - -// SendBlockBodiesRLP sends a batch of block contents to the remote peer from -// an already RLP encoded format. -func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error { - return p2p.Send(p.rw, BlockBodiesMsg, bodies) -} - -// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the -// hashes requested. -func (p *peer) SendNodeData(data [][]byte) error { - return p2p.Send(p.rw, NodeDataMsg, data) -} - -// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the -// ones requested from an already RLP encoded format. -func (p *peer) SendReceiptsRLP(receipts []rlp.RawValue) error { - return p2p.Send(p.rw, ReceiptsMsg, receipts) -} - -// RequestOneHeader is a wrapper around the header query functions to fetch a -// single header. It is used solely by the fetcher. -func (p *peer) RequestOneHeader(hash common.Hash) error { - p.Log().Debug("Fetching single header", "hash", hash) - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) -} - -// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the -// specified header query, based on the hash of an origin block. -func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { - p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse) - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) -} - -// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the -// specified header query, based on the number of an origin block. -func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { - p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse) - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) -} - -// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes -// specified. -func (p *peer) RequestBodies(hashes []common.Hash) error { - p.Log().Debug("Fetching batch of block bodies", "count", len(hashes)) - return p2p.Send(p.rw, GetBlockBodiesMsg, hashes) -} - -// RequestNodeData fetches a batch of arbitrary data from a node's known state -// data, corresponding to the specified hashes. -func (p *peer) RequestNodeData(hashes []common.Hash) error { - p.Log().Debug("Fetching batch of state data", "count", len(hashes)) - return p2p.Send(p.rw, GetNodeDataMsg, hashes) -} - -// RequestReceipts fetches a batch of transaction receipts from a remote node. -func (p *peer) RequestReceipts(hashes []common.Hash) error { - p.Log().Debug("Fetching batch of receipts", "count", len(hashes)) - return p2p.Send(p.rw, GetReceiptsMsg, hashes) -} - -// RequestTxs fetches a batch of transactions from a remote node. -func (p *peer) RequestTxs(hashes []common.Hash) error { - p.Log().Debug("Fetching batch of transactions", "count", len(hashes)) - return p2p.Send(p.rw, GetPooledTransactionsMsg, hashes) -} - -// Handshake executes the eth protocol handshake, negotiating version number, -// network IDs, difficulties, head and genesis blocks. -func (p *peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error { - // Send out own handshake in a new thread - errc := make(chan error, 2) - - var ( - status63 statusData63 // safe to read after two values have been received from errc - status statusData // safe to read after two values have been received from errc - ) - go func() { - switch { - case p.version == eth63: - errc <- p2p.Send(p.rw, StatusMsg, &statusData63{ - ProtocolVersion: uint32(p.version), - NetworkId: network, - TD: td, - CurrentBlock: head, - GenesisBlock: genesis, - }) - case p.version >= eth64: - errc <- p2p.Send(p.rw, StatusMsg, &statusData{ - ProtocolVersion: uint32(p.version), - NetworkID: network, - TD: td, - Head: head, - Genesis: genesis, - ForkID: forkID, - }) - default: - panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) - } - }() - go func() { - switch { - case p.version == eth63: - errc <- p.readStatusLegacy(network, &status63, genesis) - case p.version >= eth64: - errc <- p.readStatus(network, &status, genesis, forkFilter) - default: - panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) - } - }() - timeout := time.NewTimer(handshakeTimeout) - defer timeout.Stop() - for i := 0; i < 2; i++ { - select { - case err := <-errc: - if err != nil { - return err - } - case <-timeout.C: - return p2p.DiscReadTimeout - } - } - switch { - case p.version == eth63: - p.td, p.head = status63.TD, status63.CurrentBlock - case p.version >= eth64: - p.td, p.head = status.TD, status.Head - default: - panic(fmt.Sprintf("unsupported eth protocol version: %d", p.version)) - } - return nil -} - -func (p *peer) readStatusLegacy(network uint64, status *statusData63, genesis common.Hash) error { - msg, err := p.rw.ReadMsg() - if err != nil { - return err - } - if msg.Code != StatusMsg { - return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) - } - if msg.Size > protocolMaxMsgSize { - return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize) - } - // Decode the handshake and make sure everything matches - if err := msg.Decode(&status); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - if status.GenesisBlock != genesis { - return errResp(ErrGenesisMismatch, "%x (!= %x)", status.GenesisBlock[:8], genesis[:8]) - } - if status.NetworkId != network { - return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkId, network) - } - if int(status.ProtocolVersion) != p.version { - return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version) - } - return nil -} - -func (p *peer) readStatus(network uint64, status *statusData, genesis common.Hash, forkFilter forkid.Filter) error { - msg, err := p.rw.ReadMsg() - if err != nil { - return err - } - if msg.Code != StatusMsg { - return errResp(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) - } - if msg.Size > protocolMaxMsgSize { - return errResp(ErrMsgTooLarge, "%v > %v", msg.Size, protocolMaxMsgSize) - } - // Decode the handshake and make sure everything matches - if err := msg.Decode(&status); err != nil { - return errResp(ErrDecode, "msg %v: %v", msg, err) - } - if status.NetworkID != network { - return errResp(ErrNetworkIDMismatch, "%d (!= %d)", status.NetworkID, network) - } - if int(status.ProtocolVersion) != p.version { - return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, p.version) - } - if status.Genesis != genesis { - return errResp(ErrGenesisMismatch, "%x (!= %x)", status.Genesis, genesis) - } - if err := forkFilter(status.ForkID); err != nil { - return errResp(ErrForkIDRejected, "%v", err) - } - return nil -} - -// String implements fmt.Stringer. -func (p *peer) String() string { - return fmt.Sprintf("Peer %s [%s]", p.id, - fmt.Sprintf("eth/%2d", p.version), - ) -} - -// peerSet represents the collection of active peers currently participating in -// the Ethereum sub-protocol. -type peerSet struct { - peers map[string]*peer - lock sync.RWMutex - closed bool -} - -// newPeerSet creates a new peer set to track the active participants. -func newPeerSet() *peerSet { - return &peerSet{ - peers: make(map[string]*peer), - } -} - -// Register injects a new peer into the working set, or returns an error if the -// peer is already known. If a new peer it registered, its broadcast loop is also -// started. -func (ps *peerSet) Register(p *peer, removePeer func(string)) error { - ps.lock.Lock() - defer ps.lock.Unlock() - - if ps.closed { - return errClosed - } - if _, ok := ps.peers[p.id]; ok { - return errAlreadyRegistered - } - ps.peers[p.id] = p - - go p.broadcastBlocks(removePeer) - go p.broadcastTransactions(removePeer) - if p.version >= eth65 { - go p.announceTransactions(removePeer) - } - return nil -} - -// Unregister removes a remote peer from the active set, disabling any further -// actions to/from that particular entity. -func (ps *peerSet) Unregister(id string) error { - ps.lock.Lock() - defer ps.lock.Unlock() - - p, ok := ps.peers[id] - if !ok { - return errNotRegistered - } - delete(ps.peers, id) - p.close() - - return nil -} - -// Peer retrieves the registered peer with the given id. -func (ps *peerSet) Peer(id string) *peer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - return ps.peers[id] -} - -// Len returns if the current number of peers in the set. -func (ps *peerSet) Len() int { - ps.lock.RLock() - defer ps.lock.RUnlock() - - return len(ps.peers) -} - -// PeersWithoutBlock retrieves a list of peers that do not have a given block in -// their set of known hashes. -func (ps *peerSet) PeersWithoutBlock(hash common.Hash) []*peer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - list := make([]*peer, 0, len(ps.peers)) - for _, p := range ps.peers { - if !p.knownBlocks.Contains(hash) { - list = append(list, p) - } - } - return list -} - -// PeersWithoutTx retrieves a list of peers that do not have a given transaction -// in their set of known hashes. -func (ps *peerSet) PeersWithoutTx(hash common.Hash) []*peer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - list := make([]*peer, 0, len(ps.peers)) - for _, p := range ps.peers { - if !p.knownTxs.Contains(hash) { - list = append(list, p) - } - } - return list +// snapPeerInfo represents a short summary of the `snap` sub-protocol metadata known +// about a connected peer. +type snapPeerInfo struct { + Version uint `json:"version"` // Snapshot protocol version negotiated } -// BestPeer retrieves the known peer with the currently highest total difficulty. -func (ps *peerSet) BestPeer() *peer { - ps.lock.RLock() - defer ps.lock.RUnlock() +// snapPeer is a wrapper around snap.Peer to maintain a few extra metadata. +type snapPeer struct { + *snap.Peer - var ( - bestPeer *peer - bestTd *big.Int - ) - for _, p := range ps.peers { - if _, td := p.Head(); bestPeer == nil || td.Cmp(bestTd) > 0 { - bestPeer, bestTd = p, td - } - } - return bestPeer + ethDrop *time.Timer // Connection dropper if `eth` doesn't connect in time + lock sync.RWMutex // Mutex protecting the internal fields } -// Close disconnects all peers. -// No new peers can be registered after Close has returned. -func (ps *peerSet) Close() { - ps.lock.Lock() - defer ps.lock.Unlock() - - for _, p := range ps.peers { - p.Disconnect(p2p.DiscQuitting) +// info gathers and returns some `snap` protocol metadata known about a peer. +func (p *snapPeer) info() *snapPeerInfo { + return &snapPeerInfo{ + Version: p.Version(), } - ps.closed = true } diff --git a/eth/peerset.go b/eth/peerset.go new file mode 100644 index 000000000000..9b584ec3203d --- /dev/null +++ b/eth/peerset.go @@ -0,0 +1,301 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "math/big" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/eth/protocols/eth" + "github.com/ethereum/go-ethereum/eth/protocols/snap" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/p2p" +) + +var ( + // errPeerSetClosed is returned if a peer is attempted to be added or removed + // from the peer set after it has been terminated. + errPeerSetClosed = errors.New("peerset closed") + + // errPeerAlreadyRegistered is returned if a peer is attempted to be added + // to the peer set, but one with the same id already exists. + errPeerAlreadyRegistered = errors.New("peer already registered") + + // errPeerNotRegistered is returned if a peer is attempted to be removed from + // a peer set, but no peer with the given id exists. + errPeerNotRegistered = errors.New("peer not registered") + + // ethConnectTimeout is the `snap` timeout for `eth` to connect too. + ethConnectTimeout = 3 * time.Second +) + +// peerSet represents the collection of active peers currently participating in +// the `eth` or `snap` protocols. +type peerSet struct { + ethPeers map[string]*ethPeer // Peers connected on the `eth` protocol + snapPeers map[string]*snapPeer // Peers connected on the `snap` protocol + + ethJoinFeed event.Feed // Events when an `eth` peer successfully joins + ethDropFeed event.Feed // Events when an `eth` peer gets dropped + snapJoinFeed event.Feed // Events when a `snap` peer joins on both `eth` and `snap` + snapDropFeed event.Feed // Events when a `snap` peer gets dropped (only if fully joined) + + scope event.SubscriptionScope // Subscription group to unsubscribe everyone at once + + lock sync.RWMutex + closed bool +} + +// newPeerSet creates a new peer set to track the active participants. +func newPeerSet() *peerSet { + return &peerSet{ + ethPeers: make(map[string]*ethPeer), + snapPeers: make(map[string]*snapPeer), + } +} + +// subscribeEthJoin registers a subscription for peers joining (and completing +// the handshake) on the `eth` protocol. +func (ps *peerSet) subscribeEthJoin(ch chan<- *eth.Peer) event.Subscription { + return ps.scope.Track(ps.ethJoinFeed.Subscribe(ch)) +} + +// subscribeEthDrop registers a subscription for peers being dropped from the +// `eth` protocol. +func (ps *peerSet) subscribeEthDrop(ch chan<- *eth.Peer) event.Subscription { + return ps.scope.Track(ps.ethDropFeed.Subscribe(ch)) +} + +// subscribeSnapJoin registers a subscription for peers joining (and completing +// the `eth` join) on the `snap` protocol. +func (ps *peerSet) subscribeSnapJoin(ch chan<- *snap.Peer) event.Subscription { + return ps.scope.Track(ps.snapJoinFeed.Subscribe(ch)) +} + +// subscribeSnapDrop registers a subscription for peers being dropped from the +// `snap` protocol. +func (ps *peerSet) subscribeSnapDrop(ch chan<- *snap.Peer) event.Subscription { + return ps.scope.Track(ps.snapDropFeed.Subscribe(ch)) +} + +// registerEthPeer injects a new `eth` peer into the working set, or returns an +// error if the peer is already known. The peer is announced on the `eth` join +// feed and if it completes a pending `snap` peer, also on that feed. +func (ps *peerSet) registerEthPeer(peer *eth.Peer) error { + ps.lock.Lock() + if ps.closed { + ps.lock.Unlock() + return errPeerSetClosed + } + id := peer.ID() + if _, ok := ps.ethPeers[id]; ok { + ps.lock.Unlock() + return errPeerAlreadyRegistered + } + ps.ethPeers[id] = ðPeer{Peer: peer} + + snap, ok := ps.snapPeers[id] + ps.lock.Unlock() + + if ok { + // Previously dangling `snap` peer, stop it's timer since `eth` connected + snap.lock.Lock() + if snap.ethDrop != nil { + snap.ethDrop.Stop() + snap.ethDrop = nil + } + snap.lock.Unlock() + } + ps.ethJoinFeed.Send(peer) + if ok { + ps.snapJoinFeed.Send(snap.Peer) + } + return nil +} + +// unregisterEthPeer removes a remote peer from the active set, disabling any further +// actions to/from that particular entity. The drop is announced on the `eth` drop +// feed and also on the `snap` feed if the eth/snap duality was broken just now. +func (ps *peerSet) unregisterEthPeer(id string) error { + ps.lock.Lock() + eth, ok := ps.ethPeers[id] + if !ok { + ps.lock.Unlock() + return errPeerNotRegistered + } + delete(ps.ethPeers, id) + + snap, ok := ps.snapPeers[id] + ps.lock.Unlock() + + ps.ethDropFeed.Send(eth) + if ok { + ps.snapDropFeed.Send(snap) + } + return nil +} + +// registerSnapPeer injects a new `snap` peer into the working set, or returns +// an error if the peer is already known. The peer is announced on the `snap` +// join feed if it completes an existing `eth` peer. +// +// If the peer isn't yet connected on `eth` and fails to do so within a given +// amount of time, it is dropped. This enforces that `snap` is an extension to +// `eth`, not a standalone leeching protocol. +func (ps *peerSet) registerSnapPeer(peer *snap.Peer) error { + ps.lock.Lock() + if ps.closed { + ps.lock.Unlock() + return errPeerSetClosed + } + id := peer.ID() + if _, ok := ps.snapPeers[id]; ok { + ps.lock.Unlock() + return errPeerAlreadyRegistered + } + ps.snapPeers[id] = &snapPeer{Peer: peer} + + _, ok := ps.ethPeers[id] + if !ok { + // Dangling `snap` peer, start a timer to drop if `eth` doesn't connect + ps.snapPeers[id].ethDrop = time.AfterFunc(ethConnectTimeout, func() { + peer.Log().Warn("Snapshot peer missing eth, dropping", "addr", peer.RemoteAddr(), "type", peer.Name()) + peer.Disconnect(p2p.DiscUselessPeer) + }) + } + ps.lock.Unlock() + + if ok { + ps.snapJoinFeed.Send(peer) + } + return nil +} + +// unregisterSnapPeer removes a remote peer from the active set, disabling any +// further actions to/from that particular entity. The drop is announced on the +// `snap` drop feed. +func (ps *peerSet) unregisterSnapPeer(id string) error { + ps.lock.Lock() + peer, ok := ps.snapPeers[id] + if !ok { + ps.lock.Unlock() + return errPeerNotRegistered + } + delete(ps.snapPeers, id) + ps.lock.Unlock() + + peer.lock.Lock() + if peer.ethDrop != nil { + peer.ethDrop.Stop() + peer.ethDrop = nil + } + peer.lock.Unlock() + + ps.snapDropFeed.Send(peer) + return nil +} + +// ethPeer retrieves the registered `eth` peer with the given id. +func (ps *peerSet) ethPeer(id string) *ethPeer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.ethPeers[id] +} + +// snapPeer retrieves the registered `snap` peer with the given id. +func (ps *peerSet) snapPeer(id string) *snapPeer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.snapPeers[id] +} + +// ethPeersWithoutBlock retrieves a list of `eth` peers that do not have a given +// block in their set of known hashes so it might be propagated to them. +func (ps *peerSet) ethPeersWithoutBlock(hash common.Hash) []*ethPeer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + list := make([]*ethPeer, 0, len(ps.ethPeers)) + for _, p := range ps.ethPeers { + if !p.KnownBlock(hash) { + list = append(list, p) + } + } + return list +} + +// ethPeersWithoutTransacion retrieves a list of `eth` peers that do not have a +// given transaction in their set of known hashes. +func (ps *peerSet) ethPeersWithoutTransacion(hash common.Hash) []*ethPeer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + list := make([]*ethPeer, 0, len(ps.ethPeers)) + for _, p := range ps.ethPeers { + if !p.KnownTransaction(hash) { + list = append(list, p) + } + } + return list +} + +// Len returns if the current number of `eth` peers in the set. Since the `snap` +// peers are tied to the existnce of an `eth` connection, that will always be a +// subset of `eth`. +func (ps *peerSet) Len() int { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return len(ps.ethPeers) +} + +// ethPeerWithHighestTD retrieves the known peer with the currently highest total +// difficulty. +func (ps *peerSet) ethPeerWithHighestTD() *eth.Peer { + ps.lock.RLock() + defer ps.lock.RUnlock() + + var ( + bestPeer *eth.Peer + bestTd *big.Int + ) + for _, p := range ps.ethPeers { + if _, td := p.Head(); bestPeer == nil || td.Cmp(bestTd) > 0 { + bestPeer, bestTd = p.Peer, td + } + } + return bestPeer +} + +// close disconnects all peers. +func (ps *peerSet) close() { + ps.lock.Lock() + defer ps.lock.Unlock() + + for _, p := range ps.ethPeers { + p.Disconnect(p2p.DiscQuitting) + } + for _, p := range ps.snapPeers { + p.Disconnect(p2p.DiscQuitting) + } + ps.closed = true +} diff --git a/eth/protocol.go b/eth/protocol.go deleted file mode 100644 index dc75d6b31a76..000000000000 --- a/eth/protocol.go +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package eth - -import ( - "fmt" - "io" - "math/big" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/forkid" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/rlp" -) - -// Constants to match up protocol versions and messages -const ( - eth63 = 63 - eth64 = 64 - eth65 = 65 -) - -// protocolName is the official short name of the protocol used during capability negotiation. -const protocolName = "eth" - -// ProtocolVersions are the supported versions of the eth protocol (first is primary). -var ProtocolVersions = []uint{eth65, eth64, eth63} - -// protocolLengths are the number of implemented message corresponding to different protocol versions. -var protocolLengths = map[uint]uint64{eth65: 17, eth64: 17, eth63: 17} - -const protocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message - -// eth protocol message codes -const ( - StatusMsg = 0x00 - NewBlockHashesMsg = 0x01 - TransactionMsg = 0x02 - GetBlockHeadersMsg = 0x03 - BlockHeadersMsg = 0x04 - GetBlockBodiesMsg = 0x05 - BlockBodiesMsg = 0x06 - NewBlockMsg = 0x07 - GetNodeDataMsg = 0x0d - NodeDataMsg = 0x0e - GetReceiptsMsg = 0x0f - ReceiptsMsg = 0x10 - - // New protocol message codes introduced in eth65 - // - // Previously these message ids were used by some legacy and unsupported - // eth protocols, reown them here. - NewPooledTransactionHashesMsg = 0x08 - GetPooledTransactionsMsg = 0x09 - PooledTransactionsMsg = 0x0a -) - -type errCode int - -const ( - ErrMsgTooLarge = iota - ErrDecode - ErrInvalidMsgCode - ErrProtocolVersionMismatch - ErrNetworkIDMismatch - ErrGenesisMismatch - ErrForkIDRejected - ErrNoStatusMsg - ErrExtraStatusMsg -) - -func (e errCode) String() string { - return errorToString[int(e)] -} - -// XXX change once legacy code is out -var errorToString = map[int]string{ - ErrMsgTooLarge: "Message too long", - ErrDecode: "Invalid message", - ErrInvalidMsgCode: "Invalid message code", - ErrProtocolVersionMismatch: "Protocol version mismatch", - ErrNetworkIDMismatch: "Network ID mismatch", - ErrGenesisMismatch: "Genesis mismatch", - ErrForkIDRejected: "Fork ID rejected", - ErrNoStatusMsg: "No status message", - ErrExtraStatusMsg: "Extra status message", -} - -type txPool interface { - // Has returns an indicator whether txpool has a transaction - // cached with the given hash. - Has(hash common.Hash) bool - - // Get retrieves the transaction from local txpool with given - // tx hash. - Get(hash common.Hash) *types.Transaction - - // AddRemotes should add the given transactions to the pool. - AddRemotes([]*types.Transaction) []error - - // Pending should return pending transactions. - // The slice should be modifiable by the caller. - Pending() (map[common.Address]types.Transactions, error) - - // SubscribeNewTxsEvent should return an event subscription of - // NewTxsEvent and send events to the given channel. - SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription -} - -// statusData63 is the network packet for the status message for eth/63. -type statusData63 struct { - ProtocolVersion uint32 - NetworkId uint64 - TD *big.Int - CurrentBlock common.Hash - GenesisBlock common.Hash -} - -// statusData is the network packet for the status message for eth/64 and later. -type statusData struct { - ProtocolVersion uint32 - NetworkID uint64 - TD *big.Int - Head common.Hash - Genesis common.Hash - ForkID forkid.ID -} - -// newBlockHashesData is the network packet for the block announcements. -type newBlockHashesData []struct { - Hash common.Hash // Hash of one particular block being announced - Number uint64 // Number of one particular block being announced -} - -// getBlockHeadersData represents a block header query. -type getBlockHeadersData struct { - Origin hashOrNumber // Block from which to retrieve headers - Amount uint64 // Maximum number of headers to retrieve - Skip uint64 // Blocks to skip between consecutive headers - Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) -} - -// hashOrNumber is a combined field for specifying an origin block. -type hashOrNumber struct { - Hash common.Hash // Block hash from which to retrieve headers (excludes Number) - Number uint64 // Block hash from which to retrieve headers (excludes Hash) -} - -// EncodeRLP is a specialized encoder for hashOrNumber to encode only one of the -// two contained union fields. -func (hn *hashOrNumber) EncodeRLP(w io.Writer) error { - if hn.Hash == (common.Hash{}) { - return rlp.Encode(w, hn.Number) - } - if hn.Number != 0 { - return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number) - } - return rlp.Encode(w, hn.Hash) -} - -// DecodeRLP is a specialized decoder for hashOrNumber to decode the contents -// into either a block hash or a block number. -func (hn *hashOrNumber) DecodeRLP(s *rlp.Stream) error { - _, size, _ := s.Kind() - origin, err := s.Raw() - if err == nil { - switch { - case size == 32: - err = rlp.DecodeBytes(origin, &hn.Hash) - case size <= 8: - err = rlp.DecodeBytes(origin, &hn.Number) - default: - err = fmt.Errorf("invalid input size %d for origin", size) - } - } - return err -} - -// newBlockData is the network packet for the block propagation message. -type newBlockData struct { - Block *types.Block - TD *big.Int -} - -// sanityCheck verifies that the values are reasonable, as a DoS protection -func (request *newBlockData) sanityCheck() error { - if err := request.Block.SanityCheck(); err != nil { - return err - } - //TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times - // larger, it will still fit within 100 bits - if tdlen := request.TD.BitLen(); tdlen > 100 { - return fmt.Errorf("too large block TD: bitlen %d", tdlen) - } - return nil -} - -// blockBody represents the data content of a single block. -type blockBody struct { - Transactions []*types.Transaction // Transactions contained within a block - Uncles []*types.Header // Uncles contained within a block -} - -// blockBodiesData is the network packet for block content distribution. -type blockBodiesData []*blockBody diff --git a/eth/protocol_test.go b/eth/protocol_test.go deleted file mode 100644 index fc916a2263c9..000000000000 --- a/eth/protocol_test.go +++ /dev/null @@ -1,459 +0,0 @@ -// Copyright 2014 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package eth - -import ( - "fmt" - "math/big" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/consensus/ethash" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/forkid" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/core/vm" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/eth/downloader" - "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/params" - "github.com/ethereum/go-ethereum/rlp" -) - -func init() { - // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) -} - -var testAccount, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - -// Tests that handshake failures are detected and reported correctly. -func TestStatusMsgErrors63(t *testing.T) { - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) - var ( - genesis = pm.blockchain.Genesis() - head = pm.blockchain.CurrentHeader() - td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) - ) - defer pm.Stop() - - tests := []struct { - code uint64 - data interface{} - wantError error - }{ - { - code: TransactionMsg, data: []interface{}{}, - wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"), - }, - { - code: StatusMsg, data: statusData63{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash()}, - wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 63), - }, - { - code: StatusMsg, data: statusData63{63, 999, td, head.Hash(), genesis.Hash()}, - wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId), - }, - { - code: StatusMsg, data: statusData63{63, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}}, - wantError: errResp(ErrGenesisMismatch, "0300000000000000 (!= %x)", genesis.Hash().Bytes()[:8]), - }, - } - for i, test := range tests { - p, errc := newTestPeer("peer", 63, pm, false) - // The send call might hang until reset because - // the protocol might not read the payload. - go p2p.Send(p.app, test.code, test.data) - - select { - case err := <-errc: - if err == nil { - t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError) - } else if err.Error() != test.wantError.Error() { - t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError) - } - case <-time.After(2 * time.Second): - t.Errorf("protocol did not shut down within 2 seconds") - } - p.close() - } -} - -func TestStatusMsgErrors64(t *testing.T) { - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) - var ( - genesis = pm.blockchain.Genesis() - head = pm.blockchain.CurrentHeader() - td = pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) - forkID = forkid.NewID(pm.blockchain) - ) - defer pm.Stop() - - tests := []struct { - code uint64 - data interface{} - wantError error - }{ - { - code: TransactionMsg, data: []interface{}{}, - wantError: errResp(ErrNoStatusMsg, "first msg has code 2 (!= 0)"), - }, - { - code: StatusMsg, data: statusData{10, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkID}, - wantError: errResp(ErrProtocolVersionMismatch, "10 (!= %d)", 64), - }, - { - code: StatusMsg, data: statusData{64, 999, td, head.Hash(), genesis.Hash(), forkID}, - wantError: errResp(ErrNetworkIDMismatch, "999 (!= %d)", DefaultConfig.NetworkId), - }, - { - code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), common.Hash{3}, forkID}, - wantError: errResp(ErrGenesisMismatch, "0300000000000000000000000000000000000000000000000000000000000000 (!= %x)", genesis.Hash()), - }, - { - code: StatusMsg, data: statusData{64, DefaultConfig.NetworkId, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}}, - wantError: errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()), - }, - } - for i, test := range tests { - p, errc := newTestPeer("peer", 64, pm, false) - // The send call might hang until reset because - // the protocol might not read the payload. - go p2p.Send(p.app, test.code, test.data) - - select { - case err := <-errc: - if err == nil { - t.Errorf("test %d: protocol returned nil error, want %q", i, test.wantError) - } else if err.Error() != test.wantError.Error() { - t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.wantError) - } - case <-time.After(2 * time.Second): - t.Errorf("protocol did not shut down within 2 seconds") - } - p.close() - } -} - -func TestForkIDSplit(t *testing.T) { - var ( - engine = ethash.NewFaker() - - configNoFork = ¶ms.ChainConfig{HomesteadBlock: big.NewInt(1)} - configProFork = ¶ms.ChainConfig{ - HomesteadBlock: big.NewInt(1), - EIP150Block: big.NewInt(2), - EIP155Block: big.NewInt(2), - EIP158Block: big.NewInt(2), - ByzantiumBlock: big.NewInt(3), - } - dbNoFork = rawdb.NewMemoryDatabase() - dbProFork = rawdb.NewMemoryDatabase() - - gspecNoFork = &core.Genesis{Config: configNoFork} - gspecProFork = &core.Genesis{Config: configProFork} - - genesisNoFork = gspecNoFork.MustCommit(dbNoFork) - genesisProFork = gspecProFork.MustCommit(dbProFork) - - chainNoFork, _ = core.NewBlockChain(dbNoFork, nil, configNoFork, engine, vm.Config{}, nil, nil) - chainProFork, _ = core.NewBlockChain(dbProFork, nil, configProFork, engine, vm.Config{}, nil, nil) - - blocksNoFork, _ = core.GenerateChain(configNoFork, genesisNoFork, engine, dbNoFork, 2, nil) - blocksProFork, _ = core.GenerateChain(configProFork, genesisProFork, engine, dbProFork, 2, nil) - - ethNoFork, _ = NewProtocolManager(configNoFork, nil, downloader.FullSync, 1, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, engine, chainNoFork, dbNoFork, 1, nil) - ethProFork, _ = NewProtocolManager(configProFork, nil, downloader.FullSync, 1, new(event.TypeMux), &testTxPool{pool: make(map[common.Hash]*types.Transaction)}, engine, chainProFork, dbProFork, 1, nil) - ) - ethNoFork.Start(1000) - ethProFork.Start(1000) - - // Both nodes should allow the other to connect (same genesis, next fork is the same) - p2pNoFork, p2pProFork := p2p.MsgPipe() - peerNoFork := newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) - peerProFork := newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) - - errc := make(chan error, 2) - go func() { errc <- ethNoFork.handle(peerProFork) }() - go func() { errc <- ethProFork.handle(peerNoFork) }() - - select { - case err := <-errc: - t.Fatalf("frontier nofork <-> profork failed: %v", err) - case <-time.After(250 * time.Millisecond): - p2pNoFork.Close() - p2pProFork.Close() - } - // Progress into Homestead. Fork's match, so we don't care what the future holds - chainNoFork.InsertChain(blocksNoFork[:1]) - chainProFork.InsertChain(blocksProFork[:1]) - - p2pNoFork, p2pProFork = p2p.MsgPipe() - peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) - peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) - - errc = make(chan error, 2) - go func() { errc <- ethNoFork.handle(peerProFork) }() - go func() { errc <- ethProFork.handle(peerNoFork) }() - - select { - case err := <-errc: - t.Fatalf("homestead nofork <-> profork failed: %v", err) - case <-time.After(250 * time.Millisecond): - p2pNoFork.Close() - p2pProFork.Close() - } - // Progress into Spurious. Forks mismatch, signalling differing chains, reject - chainNoFork.InsertChain(blocksNoFork[1:2]) - chainProFork.InsertChain(blocksProFork[1:2]) - - p2pNoFork, p2pProFork = p2p.MsgPipe() - peerNoFork = newPeer(64, p2p.NewPeer(enode.ID{1}, "", nil), p2pNoFork, nil) - peerProFork = newPeer(64, p2p.NewPeer(enode.ID{2}, "", nil), p2pProFork, nil) - - errc = make(chan error, 2) - go func() { errc <- ethNoFork.handle(peerProFork) }() - go func() { errc <- ethProFork.handle(peerNoFork) }() - - select { - case err := <-errc: - if want := errResp(ErrForkIDRejected, forkid.ErrLocalIncompatibleOrStale.Error()); err.Error() != want.Error() { - t.Fatalf("fork ID rejection error mismatch: have %v, want %v", err, want) - } - case <-time.After(250 * time.Millisecond): - t.Fatalf("split peers not rejected") - } -} - -// This test checks that received transactions are added to the local pool. -func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) } -func TestRecvTransactions64(t *testing.T) { testRecvTransactions(t, 64) } -func TestRecvTransactions65(t *testing.T) { testRecvTransactions(t, 65) } - -func testRecvTransactions(t *testing.T, protocol int) { - txAdded := make(chan []*types.Transaction) - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded) - pm.acceptTxs = 1 // mark synced to accept transactions - p, _ := newTestPeer("peer", protocol, pm, true) - defer pm.Stop() - defer p.close() - - tx := newTestTransaction(testAccount, 0, 0) - if err := p2p.Send(p.app, TransactionMsg, []interface{}{tx}); err != nil { - t.Fatalf("send error: %v", err) - } - select { - case added := <-txAdded: - if len(added) != 1 { - t.Errorf("wrong number of added transactions: got %d, want 1", len(added)) - } else if added[0].Hash() != tx.Hash() { - t.Errorf("added wrong tx hash: got %v, want %v", added[0].Hash(), tx.Hash()) - } - case <-time.After(2 * time.Second): - t.Errorf("no NewTxsEvent received within 2 seconds") - } -} - -// This test checks that pending transactions are sent. -func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) } -func TestSendTransactions64(t *testing.T) { testSendTransactions(t, 64) } -func TestSendTransactions65(t *testing.T) { testSendTransactions(t, 65) } - -func testSendTransactions(t *testing.T, protocol int) { - pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) - defer pm.Stop() - - // Fill the pool with big transactions (use a subscription to wait until all - // the transactions are announced to avoid spurious events causing extra - // broadcasts). - const txsize = txsyncPackSize / 10 - alltxs := make([]*types.Transaction, 100) - for nonce := range alltxs { - alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), txsize) - } - pm.txpool.AddRemotes(alltxs) - time.Sleep(100 * time.Millisecond) // Wait until new tx even gets out of the system (lame) - - // Connect several peers. They should all receive the pending transactions. - var wg sync.WaitGroup - checktxs := func(p *testPeer) { - defer wg.Done() - defer p.close() - seen := make(map[common.Hash]bool) - for _, tx := range alltxs { - seen[tx.Hash()] = false - } - for n := 0; n < len(alltxs) && !t.Failed(); { - var forAllHashes func(callback func(hash common.Hash)) - switch protocol { - case 63: - fallthrough - case 64: - msg, err := p.app.ReadMsg() - if err != nil { - t.Errorf("%v: read error: %v", p.Peer, err) - continue - } else if msg.Code != TransactionMsg { - t.Errorf("%v: got code %d, want TxMsg", p.Peer, msg.Code) - continue - } - var txs []*types.Transaction - if err := msg.Decode(&txs); err != nil { - t.Errorf("%v: %v", p.Peer, err) - continue - } - forAllHashes = func(callback func(hash common.Hash)) { - for _, tx := range txs { - callback(tx.Hash()) - } - } - case 65: - msg, err := p.app.ReadMsg() - if err != nil { - t.Errorf("%v: read error: %v", p.Peer, err) - continue - } else if msg.Code != NewPooledTransactionHashesMsg { - t.Errorf("%v: got code %d, want NewPooledTransactionHashesMsg", p.Peer, msg.Code) - continue - } - var hashes []common.Hash - if err := msg.Decode(&hashes); err != nil { - t.Errorf("%v: %v", p.Peer, err) - continue - } - forAllHashes = func(callback func(hash common.Hash)) { - for _, h := range hashes { - callback(h) - } - } - } - forAllHashes(func(hash common.Hash) { - seentx, want := seen[hash] - if seentx { - t.Errorf("%v: got tx more than once: %x", p.Peer, hash) - } - if !want { - t.Errorf("%v: got unexpected tx: %x", p.Peer, hash) - } - seen[hash] = true - n++ - }) - } - } - for i := 0; i < 3; i++ { - p, _ := newTestPeer(fmt.Sprintf("peer #%d", i), protocol, pm, true) - wg.Add(1) - go checktxs(p) - } - wg.Wait() -} - -func TestTransactionPropagation(t *testing.T) { testSyncTransaction(t, true) } -func TestTransactionAnnouncement(t *testing.T) { testSyncTransaction(t, false) } - -func testSyncTransaction(t *testing.T, propagtion bool) { - // Create a protocol manager for transaction fetcher and sender - pmFetcher, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil) - defer pmFetcher.Stop() - pmSender, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil) - pmSender.broadcastTxAnnouncesOnly = !propagtion - defer pmSender.Stop() - - // Sync up the two peers - io1, io2 := p2p.MsgPipe() - - go pmSender.handle(pmSender.newPeer(65, p2p.NewPeer(enode.ID{}, "sender", nil), io2, pmSender.txpool.Get)) - go pmFetcher.handle(pmFetcher.newPeer(65, p2p.NewPeer(enode.ID{}, "fetcher", nil), io1, pmFetcher.txpool.Get)) - - time.Sleep(250 * time.Millisecond) - pmFetcher.doSync(peerToSyncOp(downloader.FullSync, pmFetcher.peers.BestPeer())) - atomic.StoreUint32(&pmFetcher.acceptTxs, 1) - - newTxs := make(chan core.NewTxsEvent, 1024) - sub := pmFetcher.txpool.SubscribeNewTxsEvent(newTxs) - defer sub.Unsubscribe() - - // Fill the pool with new transactions - alltxs := make([]*types.Transaction, 1024) - for nonce := range alltxs { - alltxs[nonce] = newTestTransaction(testAccount, uint64(nonce), 0) - } - pmSender.txpool.AddRemotes(alltxs) - - var got int -loop: - for { - select { - case ev := <-newTxs: - got += len(ev.Txs) - if got == 1024 { - break loop - } - case <-time.NewTimer(time.Second).C: - t.Fatal("Failed to retrieve all transaction") - } - } -} - -// Tests that the custom union field encoder and decoder works correctly. -func TestGetBlockHeadersDataEncodeDecode(t *testing.T) { - // Create a "random" hash for testing - var hash common.Hash - for i := range hash { - hash[i] = byte(i) - } - // Assemble some table driven tests - tests := []struct { - packet *getBlockHeadersData - fail bool - }{ - // Providing the origin as either a hash or a number should both work - {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}}}, - {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}}}, - - // Providing arbitrary query field should also work - {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}}, - {fail: false, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}}, - - // Providing both the origin hash and origin number must fail - {fail: true, packet: &getBlockHeadersData{Origin: hashOrNumber{Hash: hash, Number: 314}}}, - } - // Iterate over each of the tests and try to encode and then decode - for i, tt := range tests { - bytes, err := rlp.EncodeToBytes(tt.packet) - if err != nil && !tt.fail { - t.Fatalf("test %d: failed to encode packet: %v", i, err) - } else if err == nil && tt.fail { - t.Fatalf("test %d: encode should have failed", i) - } - if !tt.fail { - packet := new(getBlockHeadersData) - if err := rlp.DecodeBytes(bytes, packet); err != nil { - t.Fatalf("test %d: failed to decode packet: %v", i, err) - } - if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount || - packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse { - t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet) - } - } - } -} diff --git a/eth/protocols/eth/broadcast.go b/eth/protocols/eth/broadcast.go new file mode 100644 index 000000000000..2349398fae68 --- /dev/null +++ b/eth/protocols/eth/broadcast.go @@ -0,0 +1,195 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +const ( + // This is the target size for the packs of transactions or announcements. A + // pack can get larger than this if a single transactions exceeds this size. + maxTxPacketSize = 100 * 1024 +) + +// blockPropagation is a block propagation event, waiting for its turn in the +// broadcast queue. +type blockPropagation struct { + block *types.Block + td *big.Int +} + +// broadcastBlocks is a write loop that multiplexes blocks and block accouncements +// to the remote peer. The goal is to have an async writer that does not lock up +// node internals and at the same time rate limits queued data. +func (p *Peer) broadcastBlocks() { + for { + select { + case prop := <-p.queuedBlocks: + if err := p.SendNewBlock(prop.block, prop.td); err != nil { + return + } + p.Log().Trace("Propagated block", "number", prop.block.Number(), "hash", prop.block.Hash(), "td", prop.td) + + case block := <-p.queuedBlockAnns: + if err := p.SendNewBlockHashes([]common.Hash{block.Hash()}, []uint64{block.NumberU64()}); err != nil { + return + } + p.Log().Trace("Announced block", "number", block.Number(), "hash", block.Hash()) + + case <-p.term: + return + } + } +} + +// broadcastTransactions is a write loop that schedules transaction broadcasts +// to the remote peer. The goal is to have an async writer that does not lock up +// node internals and at the same time rate limits queued data. +func (p *Peer) broadcastTransactions() { + var ( + queue []common.Hash // Queue of hashes to broadcast as full transactions + done chan struct{} // Non-nil if background broadcaster is running + fail = make(chan error, 1) // Channel used to receive network error + failed bool // Flag whether a send failed, discard everything onward + ) + for { + // If there's no in-flight broadcast running, check if a new one is needed + if done == nil && len(queue) > 0 { + // Pile transaction until we reach our allowed network limit + var ( + hashes []common.Hash + txs []*types.Transaction + size common.StorageSize + ) + for i := 0; i < len(queue) && size < maxTxPacketSize; i++ { + if tx := p.txpool.Get(queue[i]); tx != nil { + txs = append(txs, tx) + size += tx.Size() + } + hashes = append(hashes, queue[i]) + } + queue = queue[:copy(queue, queue[len(hashes):])] + + // If there's anything available to transfer, fire up an async writer + if len(txs) > 0 { + done = make(chan struct{}) + go func() { + if err := p.SendTransactions(txs); err != nil { + fail <- err + return + } + close(done) + p.Log().Trace("Sent transactions", "count", len(txs)) + }() + } + } + // Transfer goroutine may or may not have been started, listen for events + select { + case hashes := <-p.txBroadcast: + // If the connection failed, discard all transaction events + if failed { + continue + } + // New batch of transactions to be broadcast, queue them (with cap) + queue = append(queue, hashes...) + if len(queue) > maxQueuedTxs { + // Fancy copy and resize to ensure buffer doesn't grow indefinitely + queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])] + } + + case <-done: + done = nil + + case <-fail: + failed = true + + case <-p.term: + return + } + } +} + +// announceTransactions is a write loop that schedules transaction broadcasts +// to the remote peer. The goal is to have an async writer that does not lock up +// node internals and at the same time rate limits queued data. +func (p *Peer) announceTransactions() { + var ( + queue []common.Hash // Queue of hashes to announce as transaction stubs + done chan struct{} // Non-nil if background announcer is running + fail = make(chan error, 1) // Channel used to receive network error + failed bool // Flag whether a send failed, discard everything onward + ) + for { + // If there's no in-flight announce running, check if a new one is needed + if done == nil && len(queue) > 0 { + // Pile transaction hashes until we reach our allowed network limit + var ( + hashes []common.Hash + pending []common.Hash + size common.StorageSize + ) + for i := 0; i < len(queue) && size < maxTxPacketSize; i++ { + if p.txpool.Get(queue[i]) != nil { + pending = append(pending, queue[i]) + size += common.HashLength + } + hashes = append(hashes, queue[i]) + } + queue = queue[:copy(queue, queue[len(hashes):])] + + // If there's anything available to transfer, fire up an async writer + if len(pending) > 0 { + done = make(chan struct{}) + go func() { + if err := p.sendPooledTransactionHashes(pending); err != nil { + fail <- err + return + } + close(done) + p.Log().Trace("Sent transaction announcements", "count", len(pending)) + }() + } + } + // Transfer goroutine may or may not have been started, listen for events + select { + case hashes := <-p.txAnnounce: + // If the connection failed, discard all transaction events + if failed { + continue + } + // New batch of transactions to be broadcast, queue them (with cap) + queue = append(queue, hashes...) + if len(queue) > maxQueuedTxAnns { + // Fancy copy and resize to ensure buffer doesn't grow indefinitely + queue = queue[:copy(queue, queue[len(queue)-maxQueuedTxs:])] + } + + case <-done: + done = nil + + case <-fail: + failed = true + + case <-p.term: + return + } + } +} diff --git a/eth/protocols/eth/discovery.go b/eth/protocols/eth/discovery.go new file mode 100644 index 000000000000..025479b423ef --- /dev/null +++ b/eth/protocols/eth/discovery.go @@ -0,0 +1,65 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" +) + +// enrEntry is the ENR entry which advertises `eth` protocol on the discovery. +type enrEntry struct { + ForkID forkid.ID // Fork identifier per EIP-2124 + + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` +} + +// ENRKey implements enr.Entry. +func (e enrEntry) ENRKey() string { + return "eth" +} + +// StartENRUpdater starts the `eth` ENR updater loop, which listens for chain +// head events and updates the requested node record whenever a fork is passed. +func StartENRUpdater(chain *core.BlockChain, ln *enode.LocalNode) { + var newHead = make(chan core.ChainHeadEvent, 10) + sub := chain.SubscribeChainHeadEvent(newHead) + + go func() { + defer sub.Unsubscribe() + for { + select { + case <-newHead: + ln.Set(currentENREntry(chain)) + case <-sub.Err(): + // Would be nice to sync with Stop, but there is no + // good way to do that. + return + } + } + }() +} + +// currentENREntry constructs an `eth` ENR entry based on the current state of the chain. +func currentENREntry(chain *core.BlockChain) *enrEntry { + return &enrEntry{ + ForkID: forkid.NewID(chain.Config(), chain.Genesis().Hash(), chain.CurrentHeader().Number.Uint64()), + } +} diff --git a/eth/protocols/eth/handler.go b/eth/protocols/eth/handler.go new file mode 100644 index 000000000000..25ddcd93ecfc --- /dev/null +++ b/eth/protocols/eth/handler.go @@ -0,0 +1,512 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "encoding/json" + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + // softResponseLimit is the target maximum size of replies to data retrievals. + softResponseLimit = 2 * 1024 * 1024 + + // estHeaderSize is the approximate size of an RLP encoded block header. + estHeaderSize = 500 + + // maxHeadersServe is the maximum number of block headers to serve. This number + // is there to limit the number of disk lookups. + maxHeadersServe = 1024 + + // maxBodiesServe is the maximum number of block bodies to serve. This number + // is mostly there to limit the number of disk lookups. With 24KB block sizes + // nowadays, the practical limit will always be softResponseLimit. + maxBodiesServe = 1024 + + // maxNodeDataServe is the maximum number of state trie nodes to serve. This + // number is there to limit the number of disk lookups. + maxNodeDataServe = 1024 + + // maxReceiptsServe is the maximum number of block receipts to serve. This + // number is mostly there to limit the number of disk lookups. With block + // containing 200+ transactions nowadays, the practical limit will always + // be softResponseLimit. + maxReceiptsServe = 1024 +) + +// Handler is a callback to invoke from an outside runner after the boilerplate +// exchanges have passed. +type Handler func(peer *Peer) error + +// Backend defines the data retrieval methods to serve remote requests and the +// callback methods to invoke on remote deliveries. +type Backend interface { + // Chain retrieves the blockchain object to serve data. + Chain() *core.BlockChain + + // StateBloom retrieves the bloom filter - if any - for state trie nodes. + StateBloom() *trie.SyncBloom + + // TxPool retrieves the transaction pool object to serve data. + TxPool() TxPool + + // AcceptTxs retrieves whether transaction processing is enabled on the node + // or if inbound transactions should simply be dropped. + AcceptTxs() bool + + // RunPeer is invoked when a peer joins on the `eth` protocol. The handler + // should do any peer maintenance work, handshakes and validations. If all + // is passed, control should be given back to the `handler` to process the + // inbound messages going forward. + RunPeer(peer *Peer, handler Handler) error + + // PeerInfo retrieves all known `eth` information about a peer. + PeerInfo(id enode.ID) interface{} + + // Handle is a callback to be invoked when a data packet is received from + // the remote peer. Only packets not consumed by the protocol handler will + // be forwarded to the backend. + Handle(peer *Peer, packet Packet) error +} + +// TxPool defines the methods needed by the protocol handler to serve transactions. +type TxPool interface { + // Get retrieves the the transaction from the local txpool with the given hash. + Get(hash common.Hash) *types.Transaction +} + +// MakeProtocols constructs the P2P protocol definitions for `eth`. +func MakeProtocols(backend Backend, network uint64, dnsdisc enode.Iterator) []p2p.Protocol { + protocols := make([]p2p.Protocol, len(protocolVersions)) + for i, version := range protocolVersions { + version := version // Closure + + protocols[i] = p2p.Protocol{ + Name: protocolName, + Version: version, + Length: protocolLengths[version], + Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + peer := NewPeer(version, p, rw, backend.TxPool()) + defer peer.Close() + + return backend.RunPeer(peer, func(peer *Peer) error { + return Handle(backend, peer) + }) + }, + NodeInfo: func() interface{} { + return nodeInfo(backend.Chain(), network) + }, + PeerInfo: func(id enode.ID) interface{} { + return backend.PeerInfo(id) + }, + Attributes: []enr.Entry{currentENREntry(backend.Chain())}, + DialCandidates: dnsdisc, + } + } + return protocols +} + +// NodeInfo represents a short summary of the `eth` sub-protocol metadata +// known about the host peer. +type NodeInfo struct { + Network uint64 `json:"network"` // Ethereum network ID (1=Frontier, 2=Morden, Ropsten=3, Rinkeby=4) + Difficulty *big.Int `json:"difficulty"` // Total difficulty of the host's blockchain + Genesis common.Hash `json:"genesis"` // SHA3 hash of the host's genesis block + Config *params.ChainConfig `json:"config"` // Chain configuration for the fork rules + Head common.Hash `json:"head"` // Hex hash of the host's best owned block +} + +// nodeInfo retrieves some `eth` protocol metadata about the running host node. +func nodeInfo(chain *core.BlockChain, network uint64) *NodeInfo { + head := chain.CurrentBlock() + return &NodeInfo{ + Network: network, + Difficulty: chain.GetTd(head.Hash(), head.NumberU64()), + Genesis: chain.Genesis().Hash(), + Config: chain.Config(), + Head: head.Hash(), + } +} + +// Handle is invoked whenever an `eth` connection is made that successfully passes +// the protocol handshake. This method will keep processing messages until the +// connection is torn down. +func Handle(backend Backend, peer *Peer) error { + for { + if err := handleMessage(backend, peer); err != nil { + peer.Log().Debug("Message handling failed in `eth`", "err", err) + return err + } + } +} + +// handleMessage is invoked whenever an inbound message is received from a remote +// peer. The remote connection is torn down upon returning any error. +func handleMessage(backend Backend, peer *Peer) error { + // Read the next message from the remote peer, and ensure it's fully consumed + msg, err := peer.rw.ReadMsg() + if err != nil { + return err + } + if msg.Size > maxMessageSize { + return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize) + } + defer msg.Discard() + + // Handle the message depending on its contents + switch { + case msg.Code == StatusMsg: + // Status messages should never arrive after the handshake + return fmt.Errorf("%w: uncontrolled status message", errExtraStatusMsg) + + // Block header query, collect the requested headers and reply + case msg.Code == GetBlockHeadersMsg: + // Decode the complex header query + var query GetBlockHeadersPacket + if err := msg.Decode(&query); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + hashMode := query.Origin.Hash != (common.Hash{}) + first := true + maxNonCanonical := uint64(100) + + // Gather headers until the fetch or network limits is reached + var ( + bytes common.StorageSize + headers []*types.Header + unknown bool + lookups int + ) + for !unknown && len(headers) < int(query.Amount) && bytes < softResponseLimit && + len(headers) < maxHeadersServe && lookups < 2*maxHeadersServe { + lookups++ + // Retrieve the next header satisfying the query + var origin *types.Header + if hashMode { + if first { + first = false + origin = backend.Chain().GetHeaderByHash(query.Origin.Hash) + if origin != nil { + query.Origin.Number = origin.Number.Uint64() + } + } else { + origin = backend.Chain().GetHeader(query.Origin.Hash, query.Origin.Number) + } + } else { + origin = backend.Chain().GetHeaderByNumber(query.Origin.Number) + } + if origin == nil { + break + } + headers = append(headers, origin) + bytes += estHeaderSize + + // Advance to the next header of the query + switch { + case hashMode && query.Reverse: + // Hash based traversal towards the genesis block + ancestor := query.Skip + 1 + if ancestor == 0 { + unknown = true + } else { + query.Origin.Hash, query.Origin.Number = backend.Chain().GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical) + unknown = (query.Origin.Hash == common.Hash{}) + } + case hashMode && !query.Reverse: + // Hash based traversal towards the leaf block + var ( + current = origin.Number.Uint64() + next = current + query.Skip + 1 + ) + if next <= current { + infos, _ := json.MarshalIndent(peer.Peer.Info(), "", " ") + peer.Log().Warn("GetBlockHeaders skip overflow attack", "current", current, "skip", query.Skip, "next", next, "attacker", infos) + unknown = true + } else { + if header := backend.Chain().GetHeaderByNumber(next); header != nil { + nextHash := header.Hash() + expOldHash, _ := backend.Chain().GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical) + if expOldHash == query.Origin.Hash { + query.Origin.Hash, query.Origin.Number = nextHash, next + } else { + unknown = true + } + } else { + unknown = true + } + } + case query.Reverse: + // Number based traversal towards the genesis block + if query.Origin.Number >= query.Skip+1 { + query.Origin.Number -= query.Skip + 1 + } else { + unknown = true + } + + case !query.Reverse: + // Number based traversal towards the leaf block + query.Origin.Number += query.Skip + 1 + } + } + return peer.SendBlockHeaders(headers) + + case msg.Code == BlockHeadersMsg: + // A batch of headers arrived to one of our previous requests + res := new(BlockHeadersPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return backend.Handle(peer, res) + + case msg.Code == GetBlockBodiesMsg: + // Decode the block body retrieval message + var query GetBlockBodiesPacket + if err := msg.Decode(&query); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Gather blocks until the fetch or network limits is reached + var ( + bytes int + bodies []rlp.RawValue + ) + for lookups, hash := range query { + if bytes >= softResponseLimit || len(bodies) >= maxBodiesServe || + lookups >= 2*maxBodiesServe { + break + } + if data := backend.Chain().GetBodyRLP(hash); len(data) != 0 { + bodies = append(bodies, data) + bytes += len(data) + } + } + return peer.SendBlockBodiesRLP(bodies) + + case msg.Code == BlockBodiesMsg: + // A batch of block bodies arrived to one of our previous requests + res := new(BlockBodiesPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return backend.Handle(peer, res) + + case msg.Code == GetNodeDataMsg: + // Decode the trie node data retrieval message + var query GetNodeDataPacket + if err := msg.Decode(&query); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + bytes int + nodes [][]byte + ) + for lookups, hash := range query { + if bytes >= softResponseLimit || len(nodes) >= maxNodeDataServe || + lookups >= 2*maxNodeDataServe { + break + } + // Retrieve the requested state entry + if bloom := backend.StateBloom(); bloom != nil && !bloom.Contains(hash[:]) { + // Only lookup the trie node if there's chance that we actually have it + continue + } + entry, err := backend.Chain().TrieNode(hash) + if len(entry) == 0 || err != nil { + // Read the contract code with prefix only to save unnecessary lookups. + entry, err = backend.Chain().ContractCodeWithPrefix(hash) + } + if err == nil && len(entry) > 0 { + nodes = append(nodes, entry) + bytes += len(entry) + } + } + return peer.SendNodeData(nodes) + + case msg.Code == NodeDataMsg: + // A batch of node state data arrived to one of our previous requests + res := new(NodeDataPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return backend.Handle(peer, res) + + case msg.Code == GetReceiptsMsg: + // Decode the block receipts retrieval message + var query GetReceiptsPacket + if err := msg.Decode(&query); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Gather state data until the fetch or network limits is reached + var ( + bytes int + receipts []rlp.RawValue + ) + for lookups, hash := range query { + if bytes >= softResponseLimit || len(receipts) >= maxReceiptsServe || + lookups >= 2*maxReceiptsServe { + break + } + // Retrieve the requested block's receipts + results := backend.Chain().GetReceiptsByHash(hash) + if results == nil { + if header := backend.Chain().GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash { + continue + } + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(results); err != nil { + log.Error("Failed to encode receipt", "err", err) + } else { + receipts = append(receipts, encoded) + bytes += len(encoded) + } + } + return peer.SendReceiptsRLP(receipts) + + case msg.Code == ReceiptsMsg: + // A batch of receipts arrived to one of our previous requests + res := new(ReceiptsPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return backend.Handle(peer, res) + + case msg.Code == NewBlockHashesMsg: + // A batch of new block announcements just arrived + ann := new(NewBlockHashesPacket) + if err := msg.Decode(ann); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Mark the hashes as present at the remote node + for _, block := range *ann { + peer.markBlock(block.Hash) + } + // Deliver them all to the backend for queuing + return backend.Handle(peer, ann) + + case msg.Code == NewBlockMsg: + // Retrieve and decode the propagated block + ann := new(NewBlockPacket) + if err := msg.Decode(ann); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + if hash := types.CalcUncleHash(ann.Block.Uncles()); hash != ann.Block.UncleHash() { + log.Warn("Propagated block has invalid uncles", "have", hash, "exp", ann.Block.UncleHash()) + break // TODO(karalabe): return error eventually, but wait a few releases + } + if hash := types.DeriveSha(ann.Block.Transactions(), trie.NewStackTrie(nil)); hash != ann.Block.TxHash() { + log.Warn("Propagated block has invalid body", "have", hash, "exp", ann.Block.TxHash()) + break // TODO(karalabe): return error eventually, but wait a few releases + } + if err := ann.sanityCheck(); err != nil { + return err + } + ann.Block.ReceivedAt = msg.ReceivedAt + ann.Block.ReceivedFrom = peer + + // Mark the peer as owning the block + peer.markBlock(ann.Block.Hash()) + + return backend.Handle(peer, ann) + + case msg.Code == NewPooledTransactionHashesMsg && peer.version >= ETH65: + // New transaction announcement arrived, make sure we have + // a valid and fresh chain to handle them + if !backend.AcceptTxs() { + break + } + ann := new(NewPooledTransactionHashesPacket) + if err := msg.Decode(ann); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Schedule all the unknown hashes for retrieval + for _, hash := range *ann { + peer.markTransaction(hash) + } + return backend.Handle(peer, ann) + + case msg.Code == GetPooledTransactionsMsg && peer.version >= ETH65: + // Decode the pooled transactions retrieval message + var query GetPooledTransactionsPacket + if err := msg.Decode(&query); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Gather transactions until the fetch or network limits is reached + var ( + bytes int + hashes []common.Hash + txs []rlp.RawValue + ) + for _, hash := range query { + if bytes >= softResponseLimit { + break + } + // Retrieve the requested transaction, skipping if unknown to us + tx := backend.TxPool().Get(hash) + if tx == nil { + continue + } + // If known, encode and queue for response packet + if encoded, err := rlp.EncodeToBytes(tx); err != nil { + log.Error("Failed to encode transaction", "err", err) + } else { + hashes = append(hashes, hash) + txs = append(txs, encoded) + bytes += len(encoded) + } + } + return peer.SendPooledTransactionsRLP(hashes, txs) + + case msg.Code == TransactionsMsg || (msg.Code == PooledTransactionsMsg && peer.version >= ETH65): + // Transactions arrived, make sure we have a valid and fresh chain to handle them + if !backend.AcceptTxs() { + break + } + // Transactions can be processed, parse all of them and deliver to the pool + var txs []*types.Transaction + if err := msg.Decode(&txs); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + for i, tx := range txs { + // Validate and mark the remote transaction + if tx == nil { + return fmt.Errorf("%w: transaction %d is nil", errDecode, i) + } + peer.markTransaction(tx.Hash()) + } + if msg.Code == PooledTransactionsMsg { + return backend.Handle(peer, (*PooledTransactionsPacket)(&txs)) + } + return backend.Handle(peer, (*TransactionsPacket)(&txs)) + + default: + return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) + } + return nil +} diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go new file mode 100644 index 000000000000..30beae931b65 --- /dev/null +++ b/eth/protocols/eth/handler_test.go @@ -0,0 +1,519 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "math" + "math/big" + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" +) + +var ( + // testKey is a private key to use for funding a tester account. + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + + // testAddr is the Ethereum address of the tester account. + testAddr = crypto.PubkeyToAddress(testKey.PublicKey) +) + +// testBackend is a mock implementation of the live Ethereum message handler. Its +// purpose is to allow testing the request/reply workflows and wire serialization +// in the `eth` protocol without actually doing any data processing. +type testBackend struct { + db ethdb.Database + chain *core.BlockChain + txpool *core.TxPool +} + +// newTestBackend creates an empty chain and wraps it into a mock backend. +func newTestBackend(blocks int) *testBackend { + return newTestBackendWithGenerator(blocks, nil) +} + +// newTestBackend creates a chain with a number of explicitly defined blocks and +// wraps it into a mock backend. +func newTestBackendWithGenerator(blocks int, generator func(int, *core.BlockGen)) *testBackend { + // Create a database pre-initialize with a genesis block + db := rawdb.NewMemoryDatabase() + (&core.Genesis{ + Config: params.TestChainConfig, + Alloc: core.GenesisAlloc{testAddr: {Balance: big.NewInt(1000000)}}, + }).MustCommit(db) + + chain, _ := core.NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{}, nil, nil) + + bs, _ := core.GenerateChain(params.TestChainConfig, chain.Genesis(), ethash.NewFaker(), db, blocks, generator) + if _, err := chain.InsertChain(bs); err != nil { + panic(err) + } + txconfig := core.DefaultTxPoolConfig + txconfig.Journal = "" // Don't litter the disk with test journals + + return &testBackend{ + db: db, + chain: chain, + txpool: core.NewTxPool(txconfig, params.TestChainConfig, chain), + } +} + +// close tears down the transaction pool and chain behind the mock backend. +func (b *testBackend) close() { + b.txpool.Stop() + b.chain.Stop() +} + +func (b *testBackend) Chain() *core.BlockChain { return b.chain } +func (b *testBackend) StateBloom() *trie.SyncBloom { return nil } +func (b *testBackend) TxPool() TxPool { return b.txpool } + +func (b *testBackend) RunPeer(peer *Peer, handler Handler) error { + // Normally the backend would do peer mainentance and handshakes. All that + // is omitted and we will just give control back to the handler. + return handler(peer) +} +func (b *testBackend) PeerInfo(enode.ID) interface{} { panic("not implemented") } + +func (b *testBackend) AcceptTxs() bool { + panic("data processing tests should be done in the handler package") +} +func (b *testBackend) Handle(*Peer, Packet) error { + panic("data processing tests should be done in the handler package") +} + +// Tests that block headers can be retrieved from a remote chain based on user queries. +func TestGetBlockHeaders64(t *testing.T) { testGetBlockHeaders(t, 64) } +func TestGetBlockHeaders65(t *testing.T) { testGetBlockHeaders(t, 65) } + +func testGetBlockHeaders(t *testing.T, protocol uint) { + t.Parallel() + + backend := newTestBackend(maxHeadersServe + 15) + defer backend.close() + + peer, _ := newTestPeer("peer", protocol, backend) + defer peer.close() + + // Create a "random" unknown hash for testing + var unknown common.Hash + for i := range unknown { + unknown[i] = byte(i) + } + // Create a batch of tests for various scenarios + limit := uint64(maxHeadersServe) + tests := []struct { + query *GetBlockHeadersPacket // The query to execute for header retrieval + expect []common.Hash // The hashes of the block whose headers are expected + }{ + // A single random block should be retrievable by hash and number too + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(limit / 2).Hash()}, Amount: 1}, + []common.Hash{backend.chain.GetBlockByNumber(limit / 2).Hash()}, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 1}, + []common.Hash{backend.chain.GetBlockByNumber(limit / 2).Hash()}, + }, + // Multiple headers should be retrievable in both directions + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 3}, + []common.Hash{ + backend.chain.GetBlockByNumber(limit / 2).Hash(), + backend.chain.GetBlockByNumber(limit/2 + 1).Hash(), + backend.chain.GetBlockByNumber(limit/2 + 2).Hash(), + }, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Amount: 3, Reverse: true}, + []common.Hash{ + backend.chain.GetBlockByNumber(limit / 2).Hash(), + backend.chain.GetBlockByNumber(limit/2 - 1).Hash(), + backend.chain.GetBlockByNumber(limit/2 - 2).Hash(), + }, + }, + // Multiple headers with skip lists should be retrievable + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3}, + []common.Hash{ + backend.chain.GetBlockByNumber(limit / 2).Hash(), + backend.chain.GetBlockByNumber(limit/2 + 4).Hash(), + backend.chain.GetBlockByNumber(limit/2 + 8).Hash(), + }, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: limit / 2}, Skip: 3, Amount: 3, Reverse: true}, + []common.Hash{ + backend.chain.GetBlockByNumber(limit / 2).Hash(), + backend.chain.GetBlockByNumber(limit/2 - 4).Hash(), + backend.chain.GetBlockByNumber(limit/2 - 8).Hash(), + }, + }, + // The chain endpoints should be retrievable + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 0}, Amount: 1}, + []common.Hash{backend.chain.GetBlockByNumber(0).Hash()}, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64()}, Amount: 1}, + []common.Hash{backend.chain.CurrentBlock().Hash()}, + }, + // Ensure protocol limits are honored + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 1}, Amount: limit + 10, Reverse: true}, + backend.chain.GetBlockHashesFromHash(backend.chain.CurrentBlock().Hash(), limit), + }, + // Check that requesting more than available is handled gracefully + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 4}, Skip: 3, Amount: 3}, + []common.Hash{ + backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 4).Hash(), + backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64()).Hash(), + }, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 4}, Skip: 3, Amount: 3, Reverse: true}, + []common.Hash{ + backend.chain.GetBlockByNumber(4).Hash(), + backend.chain.GetBlockByNumber(0).Hash(), + }, + }, + // Check that requesting more than available is handled gracefully, even if mid skip + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() - 4}, Skip: 2, Amount: 3}, + []common.Hash{ + backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 4).Hash(), + backend.chain.GetBlockByNumber(backend.chain.CurrentBlock().NumberU64() - 1).Hash(), + }, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 4}, Skip: 2, Amount: 3, Reverse: true}, + []common.Hash{ + backend.chain.GetBlockByNumber(4).Hash(), + backend.chain.GetBlockByNumber(1).Hash(), + }, + }, + // Check a corner case where requesting more can iterate past the endpoints + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 2}, Amount: 5, Reverse: true}, + []common.Hash{ + backend.chain.GetBlockByNumber(2).Hash(), + backend.chain.GetBlockByNumber(1).Hash(), + backend.chain.GetBlockByNumber(0).Hash(), + }, + }, + // Check a corner case where skipping overflow loops back into the chain start + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(3).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64 - 1}, + []common.Hash{ + backend.chain.GetBlockByNumber(3).Hash(), + }, + }, + // Check a corner case where skipping overflow loops back to the same header + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: backend.chain.GetBlockByNumber(1).Hash()}, Amount: 2, Reverse: false, Skip: math.MaxUint64}, + []common.Hash{ + backend.chain.GetBlockByNumber(1).Hash(), + }, + }, + // Check that non existing headers aren't returned + { + &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: unknown}, Amount: 1}, + []common.Hash{}, + }, { + &GetBlockHeadersPacket{Origin: HashOrNumber{Number: backend.chain.CurrentBlock().NumberU64() + 1}, Amount: 1}, + []common.Hash{}, + }, + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Collect the headers to expect in the response + var headers []*types.Header + for _, hash := range tt.expect { + headers = append(headers, backend.chain.GetBlockByHash(hash).Header()) + } + // Send the hash request and verify the response + p2p.Send(peer.app, GetBlockHeadersMsg, tt.query) + if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, headers); err != nil { + t.Errorf("test %d: headers mismatch: %v", i, err) + } + // If the test used number origins, repeat with hashes as the too + if tt.query.Origin.Hash == (common.Hash{}) { + if origin := backend.chain.GetBlockByNumber(tt.query.Origin.Number); origin != nil { + tt.query.Origin.Hash, tt.query.Origin.Number = origin.Hash(), 0 + + p2p.Send(peer.app, GetBlockHeadersMsg, tt.query) + if err := p2p.ExpectMsg(peer.app, BlockHeadersMsg, headers); err != nil { + t.Errorf("test %d: headers mismatch: %v", i, err) + } + } + } + } +} + +// Tests that block contents can be retrieved from a remote chain based on their hashes. +func TestGetBlockBodies64(t *testing.T) { testGetBlockBodies(t, 64) } +func TestGetBlockBodies65(t *testing.T) { testGetBlockBodies(t, 65) } + +func testGetBlockBodies(t *testing.T, protocol uint) { + t.Parallel() + + backend := newTestBackend(maxBodiesServe + 15) + defer backend.close() + + peer, _ := newTestPeer("peer", protocol, backend) + defer peer.close() + + // Create a batch of tests for various scenarios + limit := maxBodiesServe + tests := []struct { + random int // Number of blocks to fetch randomly from the chain + explicit []common.Hash // Explicitly requested blocks + available []bool // Availability of explicitly requested blocks + expected int // Total number of existing blocks to expect + }{ + {1, nil, nil, 1}, // A single random block should be retrievable + {10, nil, nil, 10}, // Multiple random blocks should be retrievable + {limit, nil, nil, limit}, // The maximum possible blocks should be retrievable + {limit + 1, nil, nil, limit}, // No more than the possible block count should be returned + {0, []common.Hash{backend.chain.Genesis().Hash()}, []bool{true}, 1}, // The genesis block should be retrievable + {0, []common.Hash{backend.chain.CurrentBlock().Hash()}, []bool{true}, 1}, // The chains head block should be retrievable + {0, []common.Hash{{}}, []bool{false}, 0}, // A non existent block should not be returned + + // Existing and non-existing blocks interleaved should not cause problems + {0, []common.Hash{ + {}, + backend.chain.GetBlockByNumber(1).Hash(), + {}, + backend.chain.GetBlockByNumber(10).Hash(), + {}, + backend.chain.GetBlockByNumber(100).Hash(), + {}, + }, []bool{false, true, false, true, false, true, false}, 3}, + } + // Run each of the tests and verify the results against the chain + for i, tt := range tests { + // Collect the hashes to request, and the response to expectva + var ( + hashes []common.Hash + bodies []*BlockBody + seen = make(map[int64]bool) + ) + for j := 0; j < tt.random; j++ { + for { + num := rand.Int63n(int64(backend.chain.CurrentBlock().NumberU64())) + if !seen[num] { + seen[num] = true + + block := backend.chain.GetBlockByNumber(uint64(num)) + hashes = append(hashes, block.Hash()) + if len(bodies) < tt.expected { + bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) + } + break + } + } + } + for j, hash := range tt.explicit { + hashes = append(hashes, hash) + if tt.available[j] && len(bodies) < tt.expected { + block := backend.chain.GetBlockByHash(hash) + bodies = append(bodies, &BlockBody{Transactions: block.Transactions(), Uncles: block.Uncles()}) + } + } + // Send the hash request and verify the response + p2p.Send(peer.app, GetBlockBodiesMsg, hashes) + if err := p2p.ExpectMsg(peer.app, BlockBodiesMsg, bodies); err != nil { + t.Errorf("test %d: bodies mismatch: %v", i, err) + } + } +} + +// Tests that the state trie nodes can be retrieved based on hashes. +func TestGetNodeData64(t *testing.T) { testGetNodeData(t, 64) } +func TestGetNodeData65(t *testing.T) { testGetNodeData(t, 65) } + +func testGetNodeData(t *testing.T, protocol uint) { + t.Parallel() + + // Define three accounts to simulate transactions with + acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") + acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey) + acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey) + + signer := types.HomesteadSigner{} + // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test) + generator := func(i int, block *core.BlockGen) { + switch i { + case 0: + // In block 1, the test bank sends account #1 some ether. + tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testKey) + block.AddTx(tx) + case 1: + // In block 2, the test bank sends some more ether to account #1. + // acc1Addr passes it on to account #2. + tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey) + tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key) + block.AddTx(tx1) + block.AddTx(tx2) + case 2: + // Block 3 is empty but was mined by account #2. + block.SetCoinbase(acc2Addr) + block.SetExtra([]byte("yeehaw")) + case 3: + // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). + b2 := block.PrevBlock(1).Header() + b2.Extra = []byte("foo") + block.AddUncle(b2) + b3 := block.PrevBlock(2).Header() + b3.Extra = []byte("foo") + block.AddUncle(b3) + } + } + // Assemble the test environment + backend := newTestBackendWithGenerator(4, generator) + defer backend.close() + + peer, _ := newTestPeer("peer", protocol, backend) + defer peer.close() + + // Fetch for now the entire chain db + var hashes []common.Hash + + it := backend.db.NewIterator(nil, nil) + for it.Next() { + if key := it.Key(); len(key) == common.HashLength { + hashes = append(hashes, common.BytesToHash(key)) + } + } + it.Release() + + p2p.Send(peer.app, GetNodeDataMsg, hashes) + msg, err := peer.app.ReadMsg() + if err != nil { + t.Fatalf("failed to read node data response: %v", err) + } + if msg.Code != NodeDataMsg { + t.Fatalf("response packet code mismatch: have %x, want %x", msg.Code, NodeDataMsg) + } + var data [][]byte + if err := msg.Decode(&data); err != nil { + t.Fatalf("failed to decode response node data: %v", err) + } + // Verify that all hashes correspond to the requested data, and reconstruct a state tree + for i, want := range hashes { + if hash := crypto.Keccak256Hash(data[i]); hash != want { + t.Errorf("data hash mismatch: have %x, want %x", hash, want) + } + } + statedb := rawdb.NewMemoryDatabase() + for i := 0; i < len(data); i++ { + statedb.Put(hashes[i].Bytes(), data[i]) + } + accounts := []common.Address{testAddr, acc1Addr, acc2Addr} + for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ { + trie, _ := state.New(backend.chain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb), nil) + + for j, acc := range accounts { + state, _ := backend.chain.State() + bw := state.GetBalance(acc) + bh := trie.GetBalance(acc) + + if (bw != nil && bh == nil) || (bw == nil && bh != nil) { + t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw) + } + if bw != nil && bh != nil && bw.Cmp(bw) != 0 { + t.Errorf("test %d, account %d: balance mismatch: have %v, want %v", i, j, bh, bw) + } + } + } +} + +// Tests that the transaction receipts can be retrieved based on hashes. +func TestGetBlockReceipts64(t *testing.T) { testGetBlockReceipts(t, 64) } +func TestGetBlockReceipts65(t *testing.T) { testGetBlockReceipts(t, 65) } + +func testGetBlockReceipts(t *testing.T, protocol uint) { + t.Parallel() + + // Define three accounts to simulate transactions with + acc1Key, _ := crypto.HexToECDSA("8a1f9a8f95be41cd7ccb6168179afb4504aefe388d1e14474d32c45c72ce7b7a") + acc2Key, _ := crypto.HexToECDSA("49a7b37aa6f6645917e7b807e9d1c00d4fa71f18343b0d4122a4d2df64dd6fee") + acc1Addr := crypto.PubkeyToAddress(acc1Key.PublicKey) + acc2Addr := crypto.PubkeyToAddress(acc2Key.PublicKey) + + signer := types.HomesteadSigner{} + // Create a chain generator with some simple transactions (blatantly stolen from @fjl/chain_markets_test) + generator := func(i int, block *core.BlockGen) { + switch i { + case 0: + // In block 1, the test bank sends account #1 some ether. + tx, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(10000), params.TxGas, nil, nil), signer, testKey) + block.AddTx(tx) + case 1: + // In block 2, the test bank sends some more ether to account #1. + // acc1Addr passes it on to account #2. + tx1, _ := types.SignTx(types.NewTransaction(block.TxNonce(testAddr), acc1Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey) + tx2, _ := types.SignTx(types.NewTransaction(block.TxNonce(acc1Addr), acc2Addr, big.NewInt(1000), params.TxGas, nil, nil), signer, acc1Key) + block.AddTx(tx1) + block.AddTx(tx2) + case 2: + // Block 3 is empty but was mined by account #2. + block.SetCoinbase(acc2Addr) + block.SetExtra([]byte("yeehaw")) + case 3: + // Block 4 includes blocks 2 and 3 as uncle headers (with modified extra data). + b2 := block.PrevBlock(1).Header() + b2.Extra = []byte("foo") + block.AddUncle(b2) + b3 := block.PrevBlock(2).Header() + b3.Extra = []byte("foo") + block.AddUncle(b3) + } + } + // Assemble the test environment + backend := newTestBackendWithGenerator(4, generator) + defer backend.close() + + peer, _ := newTestPeer("peer", protocol, backend) + defer peer.close() + + // Collect the hashes to request, and the response to expect + var ( + hashes []common.Hash + receipts []types.Receipts + ) + for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ { + block := backend.chain.GetBlockByNumber(i) + + hashes = append(hashes, block.Hash()) + receipts = append(receipts, backend.chain.GetReceiptsByHash(block.Hash())) + } + // Send the hash request and verify the response + p2p.Send(peer.app, GetReceiptsMsg, hashes) + if err := p2p.ExpectMsg(peer.app, ReceiptsMsg, receipts); err != nil { + t.Errorf("receipts mismatch: %v", err) + } +} diff --git a/eth/protocols/eth/handshake.go b/eth/protocols/eth/handshake.go new file mode 100644 index 000000000000..57a4e0bc3470 --- /dev/null +++ b/eth/protocols/eth/handshake.go @@ -0,0 +1,107 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "fmt" + "math/big" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/p2p" +) + +const ( + // handshakeTimeout is the maximum allowed time for the `eth` handshake to + // complete before dropping the connection.= as malicious. + handshakeTimeout = 5 * time.Second +) + +// Handshake executes the eth protocol handshake, negotiating version number, +// network IDs, difficulties, head and genesis blocks. +func (p *Peer) Handshake(network uint64, td *big.Int, head common.Hash, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter) error { + // Send out own handshake in a new thread + errc := make(chan error, 2) + + var status StatusPacket // safe to read after two values have been received from errc + + go func() { + errc <- p2p.Send(p.rw, StatusMsg, &StatusPacket{ + ProtocolVersion: uint32(p.version), + NetworkID: network, + TD: td, + Head: head, + Genesis: genesis, + ForkID: forkID, + }) + }() + go func() { + errc <- p.readStatus(network, &status, genesis, forkFilter) + }() + timeout := time.NewTimer(handshakeTimeout) + defer timeout.Stop() + for i := 0; i < 2; i++ { + select { + case err := <-errc: + if err != nil { + return err + } + case <-timeout.C: + return p2p.DiscReadTimeout + } + } + p.td, p.head = status.TD, status.Head + + // TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times + // larger, it will still fit within 100 bits + if tdlen := p.td.BitLen(); tdlen > 100 { + return fmt.Errorf("too large total difficulty: bitlen %d", tdlen) + } + return nil +} + +// readStatus reads the remote handshake message. +func (p *Peer) readStatus(network uint64, status *StatusPacket, genesis common.Hash, forkFilter forkid.Filter) error { + msg, err := p.rw.ReadMsg() + if err != nil { + return err + } + if msg.Code != StatusMsg { + return fmt.Errorf("%w: first msg has code %x (!= %x)", errNoStatusMsg, msg.Code, StatusMsg) + } + if msg.Size > maxMessageSize { + return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize) + } + // Decode the handshake and make sure everything matches + if err := msg.Decode(&status); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + if status.NetworkID != network { + return fmt.Errorf("%w: %d (!= %d)", errNetworkIDMismatch, status.NetworkID, network) + } + if uint(status.ProtocolVersion) != p.version { + return fmt.Errorf("%w: %d (!= %d)", errProtocolVersionMismatch, status.ProtocolVersion, p.version) + } + if status.Genesis != genesis { + return fmt.Errorf("%w: %x (!= %x)", errGenesisMismatch, status.Genesis, genesis) + } + if err := forkFilter(status.ForkID); err != nil { + return fmt.Errorf("%w: %v", errForkIDRejected, err) + } + return nil +} diff --git a/eth/protocols/eth/handshake_test.go b/eth/protocols/eth/handshake_test.go new file mode 100644 index 000000000000..65f9a000648d --- /dev/null +++ b/eth/protocols/eth/handshake_test.go @@ -0,0 +1,91 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// Tests that handshake failures are detected and reported correctly. +func TestHandshake64(t *testing.T) { testHandshake(t, 64) } +func TestHandshake65(t *testing.T) { testHandshake(t, 65) } + +func testHandshake(t *testing.T, protocol uint) { + t.Parallel() + + // Create a test backend only to have some valid genesis chain + backend := newTestBackend(3) + defer backend.close() + + var ( + genesis = backend.chain.Genesis() + head = backend.chain.CurrentBlock() + td = backend.chain.GetTd(head.Hash(), head.NumberU64()) + forkID = forkid.NewID(backend.chain.Config(), backend.chain.Genesis().Hash(), backend.chain.CurrentHeader().Number.Uint64()) + ) + tests := []struct { + code uint64 + data interface{} + want error + }{ + { + code: TransactionsMsg, data: []interface{}{}, + want: errNoStatusMsg, + }, + { + code: StatusMsg, data: StatusPacket{10, 1, td, head.Hash(), genesis.Hash(), forkID}, + want: errProtocolVersionMismatch, + }, + { + code: StatusMsg, data: StatusPacket{uint32(protocol), 999, td, head.Hash(), genesis.Hash(), forkID}, + want: errNetworkIDMismatch, + }, + { + code: StatusMsg, data: StatusPacket{uint32(protocol), 1, td, head.Hash(), common.Hash{3}, forkID}, + want: errGenesisMismatch, + }, + { + code: StatusMsg, data: StatusPacket{uint32(protocol), 1, td, head.Hash(), genesis.Hash(), forkid.ID{Hash: [4]byte{0x00, 0x01, 0x02, 0x03}}}, + want: errForkIDRejected, + }, + } + for i, test := range tests { + // Create the two peers to shake with each other + app, net := p2p.MsgPipe() + defer app.Close() + defer net.Close() + + peer := NewPeer(protocol, p2p.NewPeer(enode.ID{}, "peer", nil), net, nil) + defer peer.Close() + + // Send the junk test with one peer, check the handshake failure + go p2p.Send(app, test.code, test.data) + + err := peer.Handshake(1, td, head.Hash(), genesis.Hash(), forkID, forkid.NewFilter(backend.chain)) + if err == nil { + t.Errorf("test %d: protocol returned nil error, want %q", i, test.want) + } else if !errors.Is(err, test.want) { + t.Errorf("test %d: wrong error: got %q, want %q", i, err, test.want) + } + } +} diff --git a/eth/protocols/eth/peer.go b/eth/protocols/eth/peer.go new file mode 100644 index 000000000000..735ef78ce71b --- /dev/null +++ b/eth/protocols/eth/peer.go @@ -0,0 +1,429 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "math/big" + "sync" + + mapset "github.com/deckarep/golang-set" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + // maxKnownTxs is the maximum transactions hashes to keep in the known list + // before starting to randomly evict them. + maxKnownTxs = 32768 + + // maxKnownBlocks is the maximum block hashes to keep in the known list + // before starting to randomly evict them. + maxKnownBlocks = 1024 + + // maxQueuedTxs is the maximum number of transactions to queue up before dropping + // older broadcasts. + maxQueuedTxs = 4096 + + // maxQueuedTxAnns is the maximum number of transaction announcements to queue up + // before dropping older announcements. + maxQueuedTxAnns = 4096 + + // maxQueuedBlocks is the maximum number of block propagations to queue up before + // dropping broadcasts. There's not much point in queueing stale blocks, so a few + // that might cover uncles should be enough. + maxQueuedBlocks = 4 + + // maxQueuedBlockAnns is the maximum number of block announcements to queue up before + // dropping broadcasts. Similarly to block propagations, there's no point to queue + // above some healthy uncle limit, so use that. + maxQueuedBlockAnns = 4 +) + +// max is a helper function which returns the larger of the two given integers. +func max(a, b int) int { + if a > b { + return a + } + return b +} + +// Peer is a collection of relevant information we have about a `eth` peer. +type Peer struct { + id string // Unique ID for the peer, cached + + *p2p.Peer // The embedded P2P package peer + rw p2p.MsgReadWriter // Input/output streams for snap + version uint // Protocol version negotiated + + head common.Hash // Latest advertised head block hash + td *big.Int // Latest advertised head block total difficulty + + knownBlocks mapset.Set // Set of block hashes known to be known by this peer + queuedBlocks chan *blockPropagation // Queue of blocks to broadcast to the peer + queuedBlockAnns chan *types.Block // Queue of blocks to announce to the peer + + txpool TxPool // Transaction pool used by the broadcasters for liveness checks + knownTxs mapset.Set // Set of transaction hashes known to be known by this peer + txBroadcast chan []common.Hash // Channel used to queue transaction propagation requests + txAnnounce chan []common.Hash // Channel used to queue transaction announcement requests + + term chan struct{} // Termination channel to stop the broadcasters + lock sync.RWMutex // Mutex protecting the internal fields +} + +// NewPeer create a wrapper for a network connection and negotiated protocol +// version. +func NewPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter, txpool TxPool) *Peer { + peer := &Peer{ + id: p.ID().String(), + Peer: p, + rw: rw, + version: version, + knownTxs: mapset.NewSet(), + knownBlocks: mapset.NewSet(), + queuedBlocks: make(chan *blockPropagation, maxQueuedBlocks), + queuedBlockAnns: make(chan *types.Block, maxQueuedBlockAnns), + txBroadcast: make(chan []common.Hash), + txAnnounce: make(chan []common.Hash), + txpool: txpool, + term: make(chan struct{}), + } + // Start up all the broadcasters + go peer.broadcastBlocks() + go peer.broadcastTransactions() + if version >= ETH65 { + go peer.announceTransactions() + } + return peer +} + +// Close signals the broadcast goroutine to terminate. Only ever call this if +// you created the peer yourself via NewPeer. Otherwise let whoever created it +// clean it up! +func (p *Peer) Close() { + close(p.term) +} + +// ID retrieves the peer's unique identifier. +func (p *Peer) ID() string { + return p.id +} + +// Version retrieves the peer's negoatiated `eth` protocol version. +func (p *Peer) Version() uint { + return p.version +} + +// Head retrieves the current head hash and total difficulty of the peer. +func (p *Peer) Head() (hash common.Hash, td *big.Int) { + p.lock.RLock() + defer p.lock.RUnlock() + + copy(hash[:], p.head[:]) + return hash, new(big.Int).Set(p.td) +} + +// SetHead updates the head hash and total difficulty of the peer. +func (p *Peer) SetHead(hash common.Hash, td *big.Int) { + p.lock.Lock() + defer p.lock.Unlock() + + copy(p.head[:], hash[:]) + p.td.Set(td) +} + +// KnownBlock returns whether peer is known to already have a block. +func (p *Peer) KnownBlock(hash common.Hash) bool { + return p.knownBlocks.Contains(hash) +} + +// KnownTransaction returns whether peer is known to already have a transaction. +func (p *Peer) KnownTransaction(hash common.Hash) bool { + return p.knownTxs.Contains(hash) +} + +// markBlock marks a block as known for the peer, ensuring that the block will +// never be propagated to this particular peer. +func (p *Peer) markBlock(hash common.Hash) { + // If we reached the memory allowance, drop a previously known block hash + for p.knownBlocks.Cardinality() >= maxKnownBlocks { + p.knownBlocks.Pop() + } + p.knownBlocks.Add(hash) +} + +// markTransaction marks a transaction as known for the peer, ensuring that it +// will never be propagated to this particular peer. +func (p *Peer) markTransaction(hash common.Hash) { + // If we reached the memory allowance, drop a previously known transaction hash + for p.knownTxs.Cardinality() >= maxKnownTxs { + p.knownTxs.Pop() + } + p.knownTxs.Add(hash) +} + +// SendTransactions sends transactions to the peer and includes the hashes +// in its transaction hash set for future reference. +// +// This method is a helper used by the async transaction sender. Don't call it +// directly as the queueing (memory) and transmission (bandwidth) costs should +// not be managed directly. +// +// The reasons this is public is to allow packages using this protocol to write +// tests that directly send messages without having to do the asyn queueing. +func (p *Peer) SendTransactions(txs types.Transactions) error { + // Mark all the transactions as known, but ensure we don't overflow our limits + for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(txs)) { + p.knownTxs.Pop() + } + for _, tx := range txs { + p.knownTxs.Add(tx.Hash()) + } + return p2p.Send(p.rw, TransactionsMsg, txs) +} + +// AsyncSendTransactions queues a list of transactions (by hash) to eventually +// propagate to a remote peer. The number of pending sends are capped (new ones +// will force old sends to be dropped) +func (p *Peer) AsyncSendTransactions(hashes []common.Hash) { + select { + case p.txBroadcast <- hashes: + // Mark all the transactions as known, but ensure we don't overflow our limits + for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { + p.knownTxs.Pop() + } + for _, hash := range hashes { + p.knownTxs.Add(hash) + } + case <-p.term: + p.Log().Debug("Dropping transaction propagation", "count", len(hashes)) + } +} + +// sendPooledTransactionHashes sends transaction hashes to the peer and includes +// them in its transaction hash set for future reference. +// +// This method is a helper used by the async transaction announcer. Don't call it +// directly as the queueing (memory) and transmission (bandwidth) costs should +// not be managed directly. +func (p *Peer) sendPooledTransactionHashes(hashes []common.Hash) error { + // Mark all the transactions as known, but ensure we don't overflow our limits + for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { + p.knownTxs.Pop() + } + for _, hash := range hashes { + p.knownTxs.Add(hash) + } + return p2p.Send(p.rw, NewPooledTransactionHashesMsg, NewPooledTransactionHashesPacket(hashes)) +} + +// AsyncSendPooledTransactionHashes queues a list of transactions hashes to eventually +// announce to a remote peer. The number of pending sends are capped (new ones +// will force old sends to be dropped) +func (p *Peer) AsyncSendPooledTransactionHashes(hashes []common.Hash) { + select { + case p.txAnnounce <- hashes: + // Mark all the transactions as known, but ensure we don't overflow our limits + for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { + p.knownTxs.Pop() + } + for _, hash := range hashes { + p.knownTxs.Add(hash) + } + case <-p.term: + p.Log().Debug("Dropping transaction announcement", "count", len(hashes)) + } +} + +// SendPooledTransactionsRLP sends requested transactions to the peer and adds the +// hashes in its transaction hash set for future reference. +// +// Note, the method assumes the hashes are correct and correspond to the list of +// transactions being sent. +func (p *Peer) SendPooledTransactionsRLP(hashes []common.Hash, txs []rlp.RawValue) error { + // Mark all the transactions as known, but ensure we don't overflow our limits + for p.knownTxs.Cardinality() > max(0, maxKnownTxs-len(hashes)) { + p.knownTxs.Pop() + } + for _, hash := range hashes { + p.knownTxs.Add(hash) + } + return p2p.Send(p.rw, PooledTransactionsMsg, txs) // Not packed into PooledTransactionsPacket to avoid RLP decoding +} + +// SendNewBlockHashes announces the availability of a number of blocks through +// a hash notification. +func (p *Peer) SendNewBlockHashes(hashes []common.Hash, numbers []uint64) error { + // Mark all the block hashes as known, but ensure we don't overflow our limits + for p.knownBlocks.Cardinality() > max(0, maxKnownBlocks-len(hashes)) { + p.knownBlocks.Pop() + } + for _, hash := range hashes { + p.knownBlocks.Add(hash) + } + request := make(NewBlockHashesPacket, len(hashes)) + for i := 0; i < len(hashes); i++ { + request[i].Hash = hashes[i] + request[i].Number = numbers[i] + } + return p2p.Send(p.rw, NewBlockHashesMsg, request) +} + +// AsyncSendNewBlockHash queues the availability of a block for propagation to a +// remote peer. If the peer's broadcast queue is full, the event is silently +// dropped. +func (p *Peer) AsyncSendNewBlockHash(block *types.Block) { + select { + case p.queuedBlockAnns <- block: + // Mark all the block hash as known, but ensure we don't overflow our limits + for p.knownBlocks.Cardinality() >= maxKnownBlocks { + p.knownBlocks.Pop() + } + p.knownBlocks.Add(block.Hash()) + default: + p.Log().Debug("Dropping block announcement", "number", block.NumberU64(), "hash", block.Hash()) + } +} + +// SendNewBlock propagates an entire block to a remote peer. +func (p *Peer) SendNewBlock(block *types.Block, td *big.Int) error { + // Mark all the block hash as known, but ensure we don't overflow our limits + for p.knownBlocks.Cardinality() >= maxKnownBlocks { + p.knownBlocks.Pop() + } + p.knownBlocks.Add(block.Hash()) + return p2p.Send(p.rw, NewBlockMsg, &NewBlockPacket{block, td}) +} + +// AsyncSendNewBlock queues an entire block for propagation to a remote peer. If +// the peer's broadcast queue is full, the event is silently dropped. +func (p *Peer) AsyncSendNewBlock(block *types.Block, td *big.Int) { + select { + case p.queuedBlocks <- &blockPropagation{block: block, td: td}: + // Mark all the block hash as known, but ensure we don't overflow our limits + for p.knownBlocks.Cardinality() >= maxKnownBlocks { + p.knownBlocks.Pop() + } + p.knownBlocks.Add(block.Hash()) + default: + p.Log().Debug("Dropping block propagation", "number", block.NumberU64(), "hash", block.Hash()) + } +} + +// SendBlockHeaders sends a batch of block headers to the remote peer. +func (p *Peer) SendBlockHeaders(headers []*types.Header) error { + return p2p.Send(p.rw, BlockHeadersMsg, BlockHeadersPacket(headers)) +} + +// SendBlockBodies sends a batch of block contents to the remote peer. +func (p *Peer) SendBlockBodies(bodies []*BlockBody) error { + return p2p.Send(p.rw, BlockBodiesMsg, BlockBodiesPacket(bodies)) +} + +// SendBlockBodiesRLP sends a batch of block contents to the remote peer from +// an already RLP encoded format. +func (p *Peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error { + return p2p.Send(p.rw, BlockBodiesMsg, bodies) // Not packed into BlockBodiesPacket to avoid RLP decoding +} + +// SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the +// hashes requested. +func (p *Peer) SendNodeData(data [][]byte) error { + return p2p.Send(p.rw, NodeDataMsg, NodeDataPacket(data)) +} + +// SendReceiptsRLP sends a batch of transaction receipts, corresponding to the +// ones requested from an already RLP encoded format. +func (p *Peer) SendReceiptsRLP(receipts []rlp.RawValue) error { + return p2p.Send(p.rw, ReceiptsMsg, receipts) // Not packed into ReceiptsPacket to avoid RLP decoding +} + +// RequestOneHeader is a wrapper around the header query functions to fetch a +// single header. It is used solely by the fetcher. +func (p *Peer) RequestOneHeader(hash common.Hash) error { + p.Log().Debug("Fetching single header", "hash", hash) + return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{ + Origin: HashOrNumber{Hash: hash}, + Amount: uint64(1), + Skip: uint64(0), + Reverse: false, + }) +} + +// RequestHeadersByHash fetches a batch of blocks' headers corresponding to the +// specified header query, based on the hash of an origin block. +func (p *Peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { + p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse) + return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{ + Origin: HashOrNumber{Hash: origin}, + Amount: uint64(amount), + Skip: uint64(skip), + Reverse: reverse, + }) +} + +// RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the +// specified header query, based on the number of an origin block. +func (p *Peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { + p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse) + return p2p.Send(p.rw, GetBlockHeadersMsg, &GetBlockHeadersPacket{ + Origin: HashOrNumber{Number: origin}, + Amount: uint64(amount), + Skip: uint64(skip), + Reverse: reverse, + }) +} + +// ExpectRequestHeadersByNumber is a testing method to mirror the recipient side +// of the RequestHeadersByNumber operation. +func (p *Peer) ExpectRequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { + req := &GetBlockHeadersPacket{ + Origin: HashOrNumber{Number: origin}, + Amount: uint64(amount), + Skip: uint64(skip), + Reverse: reverse, + } + return p2p.ExpectMsg(p.rw, GetBlockHeadersMsg, req) +} + +// RequestBodies fetches a batch of blocks' bodies corresponding to the hashes +// specified. +func (p *Peer) RequestBodies(hashes []common.Hash) error { + p.Log().Debug("Fetching batch of block bodies", "count", len(hashes)) + return p2p.Send(p.rw, GetBlockBodiesMsg, GetBlockBodiesPacket(hashes)) +} + +// RequestNodeData fetches a batch of arbitrary data from a node's known state +// data, corresponding to the specified hashes. +func (p *Peer) RequestNodeData(hashes []common.Hash) error { + p.Log().Debug("Fetching batch of state data", "count", len(hashes)) + return p2p.Send(p.rw, GetNodeDataMsg, GetNodeDataPacket(hashes)) +} + +// RequestReceipts fetches a batch of transaction receipts from a remote node. +func (p *Peer) RequestReceipts(hashes []common.Hash) error { + p.Log().Debug("Fetching batch of receipts", "count", len(hashes)) + return p2p.Send(p.rw, GetReceiptsMsg, GetReceiptsPacket(hashes)) +} + +// RequestTxs fetches a batch of transactions from a remote node. +func (p *Peer) RequestTxs(hashes []common.Hash) error { + p.Log().Debug("Fetching batch of transactions", "count", len(hashes)) + return p2p.Send(p.rw, GetPooledTransactionsMsg, GetPooledTransactionsPacket(hashes)) +} diff --git a/eth/protocols/eth/peer_test.go b/eth/protocols/eth/peer_test.go new file mode 100644 index 000000000000..70e9959f82fd --- /dev/null +++ b/eth/protocols/eth/peer_test.go @@ -0,0 +1,61 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// This file contains some shares testing functionality, common to multiple +// different files and modules being tested. + +package eth + +import ( + "crypto/rand" + + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// testPeer is a simulated peer to allow testing direct network calls. +type testPeer struct { + *Peer + + net p2p.MsgReadWriter // Network layer reader/writer to simulate remote messaging + app *p2p.MsgPipeRW // Application layer reader/writer to simulate the local side +} + +// newTestPeer creates a new peer registered at the given data backend. +func newTestPeer(name string, version uint, backend Backend) (*testPeer, <-chan error) { + // Create a message pipe to communicate through + app, net := p2p.MsgPipe() + + // Start the peer on a new thread + var id enode.ID + rand.Read(id[:]) + + peer := NewPeer(version, p2p.NewPeer(id, name, nil), net, backend.TxPool()) + errc := make(chan error, 1) + go func() { + errc <- backend.RunPeer(peer, func(peer *Peer) error { + return Handle(backend, peer) + }) + }() + return &testPeer{app: app, net: net, Peer: peer}, errc +} + +// close terminates the local side of the peer, notifying the remote protocol +// manager of termination. +func (p *testPeer) close() { + p.Peer.Close() + p.app.Close() +} diff --git a/eth/protocols/eth/protocol.go b/eth/protocols/eth/protocol.go new file mode 100644 index 000000000000..63d3494ec48a --- /dev/null +++ b/eth/protocols/eth/protocol.go @@ -0,0 +1,279 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "fmt" + "io" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" +) + +// Constants to match up protocol versions and messages +const ( + ETH64 = 64 + ETH65 = 65 +) + +// protocolName is the official short name of the `eth` protocol used during +// devp2p capability negotiation. +const protocolName = "eth" + +// protocolVersions are the supported versions of the `eth` protocol (first +// is primary). +var protocolVersions = []uint{ETH65, ETH64} + +// protocolLengths are the number of implemented message corresponding to +// different protocol versions. +var protocolLengths = map[uint]uint64{ETH65: 17, ETH64: 17} + +// maxMessageSize is the maximum cap on the size of a protocol message. +const maxMessageSize = 10 * 1024 * 1024 + +const ( + // Protocol messages in eth/64 + StatusMsg = 0x00 + NewBlockHashesMsg = 0x01 + TransactionsMsg = 0x02 + GetBlockHeadersMsg = 0x03 + BlockHeadersMsg = 0x04 + GetBlockBodiesMsg = 0x05 + BlockBodiesMsg = 0x06 + NewBlockMsg = 0x07 + GetNodeDataMsg = 0x0d + NodeDataMsg = 0x0e + GetReceiptsMsg = 0x0f + ReceiptsMsg = 0x10 + + // Protocol messages overloaded in eth/65 + NewPooledTransactionHashesMsg = 0x08 + GetPooledTransactionsMsg = 0x09 + PooledTransactionsMsg = 0x0a +) + +var ( + errNoStatusMsg = errors.New("no status message") + errMsgTooLarge = errors.New("message too long") + errDecode = errors.New("invalid message") + errInvalidMsgCode = errors.New("invalid message code") + errProtocolVersionMismatch = errors.New("protocol version mismatch") + errNetworkIDMismatch = errors.New("network ID mismatch") + errGenesisMismatch = errors.New("genesis mismatch") + errForkIDRejected = errors.New("fork ID rejected") + errExtraStatusMsg = errors.New("extra status message") +) + +// Packet represents a p2p message in the `eth` protocol. +type Packet interface { + Name() string // Name returns a string corresponding to the message type. + Kind() byte // Kind returns the message type. +} + +// StatusPacket is the network packet for the status message for eth/64 and later. +type StatusPacket struct { + ProtocolVersion uint32 + NetworkID uint64 + TD *big.Int + Head common.Hash + Genesis common.Hash + ForkID forkid.ID +} + +// NewBlockHashesPacket is the network packet for the block announcements. +type NewBlockHashesPacket []struct { + Hash common.Hash // Hash of one particular block being announced + Number uint64 // Number of one particular block being announced +} + +// Unpack retrieves the block hashes and numbers from the announcement packet +// and returns them in a split flat format that's more consistent with the +// internal data structures. +func (p *NewBlockHashesPacket) Unpack() ([]common.Hash, []uint64) { + var ( + hashes = make([]common.Hash, len(*p)) + numbers = make([]uint64, len(*p)) + ) + for i, body := range *p { + hashes[i], numbers[i] = body.Hash, body.Number + } + return hashes, numbers +} + +// TransactionsPacket is the network packet for broadcasting new transactions. +type TransactionsPacket []*types.Transaction + +// GetBlockHeadersPacket represents a block header query. +type GetBlockHeadersPacket struct { + Origin HashOrNumber // Block from which to retrieve headers + Amount uint64 // Maximum number of headers to retrieve + Skip uint64 // Blocks to skip between consecutive headers + Reverse bool // Query direction (false = rising towards latest, true = falling towards genesis) +} + +// HashOrNumber is a combined field for specifying an origin block. +type HashOrNumber struct { + Hash common.Hash // Block hash from which to retrieve headers (excludes Number) + Number uint64 // Block hash from which to retrieve headers (excludes Hash) +} + +// EncodeRLP is a specialized encoder for HashOrNumber to encode only one of the +// two contained union fields. +func (hn *HashOrNumber) EncodeRLP(w io.Writer) error { + if hn.Hash == (common.Hash{}) { + return rlp.Encode(w, hn.Number) + } + if hn.Number != 0 { + return fmt.Errorf("both origin hash (%x) and number (%d) provided", hn.Hash, hn.Number) + } + return rlp.Encode(w, hn.Hash) +} + +// DecodeRLP is a specialized decoder for HashOrNumber to decode the contents +// into either a block hash or a block number. +func (hn *HashOrNumber) DecodeRLP(s *rlp.Stream) error { + _, size, _ := s.Kind() + origin, err := s.Raw() + if err == nil { + switch { + case size == 32: + err = rlp.DecodeBytes(origin, &hn.Hash) + case size <= 8: + err = rlp.DecodeBytes(origin, &hn.Number) + default: + err = fmt.Errorf("invalid input size %d for origin", size) + } + } + return err +} + +// BlockHeadersPacket represents a block header response. +type BlockHeadersPacket []*types.Header + +// NewBlockPacket is the network packet for the block propagation message. +type NewBlockPacket struct { + Block *types.Block + TD *big.Int +} + +// sanityCheck verifies that the values are reasonable, as a DoS protection +func (request *NewBlockPacket) sanityCheck() error { + if err := request.Block.SanityCheck(); err != nil { + return err + } + //TD at mainnet block #7753254 is 76 bits. If it becomes 100 million times + // larger, it will still fit within 100 bits + if tdlen := request.TD.BitLen(); tdlen > 100 { + return fmt.Errorf("too large block TD: bitlen %d", tdlen) + } + return nil +} + +// GetBlockBodiesPacket represents a block body query. +type GetBlockBodiesPacket []common.Hash + +// BlockBodiesPacket is the network packet for block content distribution. +type BlockBodiesPacket []*BlockBody + +// BlockBody represents the data content of a single block. +type BlockBody struct { + Transactions []*types.Transaction // Transactions contained within a block + Uncles []*types.Header // Uncles contained within a block +} + +// Unpack retrieves the transactions and uncles from the range packet and returns +// them in a split flat format that's more consistent with the internal data structures. +func (p *BlockBodiesPacket) Unpack() ([][]*types.Transaction, [][]*types.Header) { + var ( + txset = make([][]*types.Transaction, len(*p)) + uncleset = make([][]*types.Header, len(*p)) + ) + for i, body := range *p { + txset[i], uncleset[i] = body.Transactions, body.Uncles + } + return txset, uncleset +} + +// GetNodeDataPacket represents a trie node data query. +type GetNodeDataPacket []common.Hash + +// NodeDataPacket is the network packet for trie node data distribution. +type NodeDataPacket [][]byte + +// GetReceiptsPacket represents a block receipts query. +type GetReceiptsPacket []common.Hash + +// ReceiptsPacket is the network packet for block receipts distribution. +type ReceiptsPacket [][]*types.Receipt + +// NewPooledTransactionHashesPacket represents a transaction announcement packet. +type NewPooledTransactionHashesPacket []common.Hash + +// GetPooledTransactionsPacket represents a transaction query. +type GetPooledTransactionsPacket []common.Hash + +// PooledTransactionsPacket is the network packet for transaction distribution. +type PooledTransactionsPacket []*types.Transaction + +func (*StatusPacket) Name() string { return "Status" } +func (*StatusPacket) Kind() byte { return StatusMsg } + +func (*NewBlockHashesPacket) Name() string { return "NewBlockHashes" } +func (*NewBlockHashesPacket) Kind() byte { return NewBlockHashesMsg } + +func (*TransactionsPacket) Name() string { return "Transactions" } +func (*TransactionsPacket) Kind() byte { return TransactionsMsg } + +func (*GetBlockHeadersPacket) Name() string { return "GetBlockHeaders" } +func (*GetBlockHeadersPacket) Kind() byte { return GetBlockHeadersMsg } + +func (*BlockHeadersPacket) Name() string { return "BlockHeaders" } +func (*BlockHeadersPacket) Kind() byte { return BlockHeadersMsg } + +func (*GetBlockBodiesPacket) Name() string { return "GetBlockBodies" } +func (*GetBlockBodiesPacket) Kind() byte { return GetBlockBodiesMsg } + +func (*BlockBodiesPacket) Name() string { return "BlockBodies" } +func (*BlockBodiesPacket) Kind() byte { return BlockBodiesMsg } + +func (*NewBlockPacket) Name() string { return "NewBlock" } +func (*NewBlockPacket) Kind() byte { return NewBlockMsg } + +func (*GetNodeDataPacket) Name() string { return "GetNodeData" } +func (*GetNodeDataPacket) Kind() byte { return GetNodeDataMsg } + +func (*NodeDataPacket) Name() string { return "NodeData" } +func (*NodeDataPacket) Kind() byte { return NodeDataMsg } + +func (*GetReceiptsPacket) Name() string { return "GetReceipts" } +func (*GetReceiptsPacket) Kind() byte { return GetReceiptsMsg } + +func (*ReceiptsPacket) Name() string { return "Receipts" } +func (*ReceiptsPacket) Kind() byte { return ReceiptsMsg } + +func (*NewPooledTransactionHashesPacket) Name() string { return "NewPooledTransactionHashes" } +func (*NewPooledTransactionHashesPacket) Kind() byte { return NewPooledTransactionHashesMsg } + +func (*GetPooledTransactionsPacket) Name() string { return "GetPooledTransactions" } +func (*GetPooledTransactionsPacket) Kind() byte { return GetPooledTransactionsMsg } + +func (*PooledTransactionsPacket) Name() string { return "PooledTransactions" } +func (*PooledTransactionsPacket) Kind() byte { return PooledTransactionsMsg } diff --git a/eth/protocols/eth/protocol_test.go b/eth/protocols/eth/protocol_test.go new file mode 100644 index 000000000000..056ea56480cc --- /dev/null +++ b/eth/protocols/eth/protocol_test.go @@ -0,0 +1,68 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) + +// Tests that the custom union field encoder and decoder works correctly. +func TestGetBlockHeadersDataEncodeDecode(t *testing.T) { + // Create a "random" hash for testing + var hash common.Hash + for i := range hash { + hash[i] = byte(i) + } + // Assemble some table driven tests + tests := []struct { + packet *GetBlockHeadersPacket + fail bool + }{ + // Providing the origin as either a hash or a number should both work + {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 314}}}, + {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash}}}, + + // Providing arbitrary query field should also work + {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Number: 314}, Amount: 314, Skip: 1, Reverse: true}}, + {fail: false, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash}, Amount: 314, Skip: 1, Reverse: true}}, + + // Providing both the origin hash and origin number must fail + {fail: true, packet: &GetBlockHeadersPacket{Origin: HashOrNumber{Hash: hash, Number: 314}}}, + } + // Iterate over each of the tests and try to encode and then decode + for i, tt := range tests { + bytes, err := rlp.EncodeToBytes(tt.packet) + if err != nil && !tt.fail { + t.Fatalf("test %d: failed to encode packet: %v", i, err) + } else if err == nil && tt.fail { + t.Fatalf("test %d: encode should have failed", i) + } + if !tt.fail { + packet := new(GetBlockHeadersPacket) + if err := rlp.DecodeBytes(bytes, packet); err != nil { + t.Fatalf("test %d: failed to decode packet: %v", i, err) + } + if packet.Origin.Hash != tt.packet.Origin.Hash || packet.Origin.Number != tt.packet.Origin.Number || packet.Amount != tt.packet.Amount || + packet.Skip != tt.packet.Skip || packet.Reverse != tt.packet.Reverse { + t.Fatalf("test %d: encode decode mismatch: have %+v, want %+v", i, packet, tt.packet) + } + } + } +} diff --git a/eth/protocols/snap/discovery.go b/eth/protocols/snap/discovery.go new file mode 100644 index 000000000000..684ec7e632da --- /dev/null +++ b/eth/protocols/snap/discovery.go @@ -0,0 +1,32 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "github.com/ethereum/go-ethereum/rlp" +) + +// enrEntry is the ENR entry which advertises `snap` protocol on the discovery. +type enrEntry struct { + // Ignore additional fields (for forward compatibility). + Rest []rlp.RawValue `rlp:"tail"` +} + +// ENRKey implements enr.Entry. +func (e enrEntry) ENRKey() string { + return "snap" +} diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go new file mode 100644 index 000000000000..36322e648bc9 --- /dev/null +++ b/eth/protocols/snap/handler.go @@ -0,0 +1,490 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "bytes" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" +) + +const ( + // softResponseLimit is the target maximum size of replies to data retrievals. + softResponseLimit = 2 * 1024 * 1024 + + // maxCodeLookups is the maximum number of bytecodes to serve. This number is + // there to limit the number of disk lookups. + maxCodeLookups = 1024 + + // stateLookupSlack defines the ratio by how much a state response can exceed + // the requested limit in order to try and avoid breaking up contracts into + // multiple packages and proving them. + stateLookupSlack = 0.1 + + // maxTrieNodeLookups is the maximum number of state trie nodes to serve. This + // number is there to limit the number of disk lookups. + maxTrieNodeLookups = 1024 +) + +// Handler is a callback to invoke from an outside runner after the boilerplate +// exchanges have passed. +type Handler func(peer *Peer) error + +// Backend defines the data retrieval methods to serve remote requests and the +// callback methods to invoke on remote deliveries. +type Backend interface { + // Chain retrieves the blockchain object to serve data. + Chain() *core.BlockChain + + // RunPeer is invoked when a peer joins on the `eth` protocol. The handler + // should do any peer maintenance work, handshakes and validations. If all + // is passed, control should be given back to the `handler` to process the + // inbound messages going forward. + RunPeer(peer *Peer, handler Handler) error + + // PeerInfo retrieves all known `snap` information about a peer. + PeerInfo(id enode.ID) interface{} + + // Handle is a callback to be invoked when a data packet is received from + // the remote peer. Only packets not consumed by the protocol handler will + // be forwarded to the backend. + Handle(peer *Peer, packet Packet) error +} + +// MakeProtocols constructs the P2P protocol definitions for `snap`. +func MakeProtocols(backend Backend, dnsdisc enode.Iterator) []p2p.Protocol { + protocols := make([]p2p.Protocol, len(protocolVersions)) + for i, version := range protocolVersions { + version := version // Closure + + protocols[i] = p2p.Protocol{ + Name: protocolName, + Version: version, + Length: protocolLengths[version], + Run: func(p *p2p.Peer, rw p2p.MsgReadWriter) error { + return backend.RunPeer(newPeer(version, p, rw), func(peer *Peer) error { + return handle(backend, peer) + }) + }, + NodeInfo: func() interface{} { + return nodeInfo(backend.Chain()) + }, + PeerInfo: func(id enode.ID) interface{} { + return backend.PeerInfo(id) + }, + Attributes: []enr.Entry{&enrEntry{}}, + DialCandidates: dnsdisc, + } + } + return protocols +} + +// handle is the callback invoked to manage the life cycle of a `snap` peer. +// When this function terminates, the peer is disconnected. +func handle(backend Backend, peer *Peer) error { + for { + if err := handleMessage(backend, peer); err != nil { + peer.Log().Debug("Message handling failed in `snap`", "err", err) + return err + } + } +} + +// handleMessage is invoked whenever an inbound message is received from a +// remote peer on the `spap` protocol. The remote connection is torn down upon +// returning any error. +func handleMessage(backend Backend, peer *Peer) error { + // Read the next message from the remote peer, and ensure it's fully consumed + msg, err := peer.rw.ReadMsg() + if err != nil { + return err + } + if msg.Size > maxMessageSize { + return fmt.Errorf("%w: %v > %v", errMsgTooLarge, msg.Size, maxMessageSize) + } + defer msg.Discard() + + // Handle the message depending on its contents + switch { + case msg.Code == GetAccountRangeMsg: + // Decode the account retrieval request + var req GetAccountRangePacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Retrieve the requested state and bail out if non existent + tr, err := trie.New(req.Root, backend.Chain().StateCache().TrieDB()) + if err != nil { + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID}) + } + it, err := backend.Chain().Snapshots().AccountIterator(req.Root, req.Origin) + if err != nil { + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID}) + } + // Iterate over the requested range and pile accounts up + var ( + accounts []*AccountData + size uint64 + last common.Hash + ) + for it.Next() && size < req.Bytes { + hash, account := it.Hash(), common.CopyBytes(it.Account()) + + // Track the returned interval for the Merkle proofs + last = hash + + // Assemble the reply item + size += uint64(common.HashLength + len(account)) + accounts = append(accounts, &AccountData{ + Hash: hash, + Body: account, + }) + // If we've exceeded the request threshold, abort + if bytes.Compare(hash[:], req.Limit[:]) >= 0 { + break + } + } + it.Release() + + // Generate the Merkle proofs for the first and last account + proof := light.NewNodeSet() + if err := tr.Prove(req.Origin[:], 0, proof); err != nil { + log.Warn("Failed to prove account range", "origin", req.Origin, "err", err) + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID}) + } + if last != (common.Hash{}) { + if err := tr.Prove(last[:], 0, proof); err != nil { + log.Warn("Failed to prove account range", "last", last, "err", err) + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ID: req.ID}) + } + } + var proofs [][]byte + for _, blob := range proof.NodeList() { + proofs = append(proofs, blob) + } + // Send back anything accumulated + return p2p.Send(peer.rw, AccountRangeMsg, &AccountRangePacket{ + ID: req.ID, + Accounts: accounts, + Proof: proofs, + }) + + case msg.Code == AccountRangeMsg: + // A range of accounts arrived to one of our previous requests + res := new(AccountRangePacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Ensure the range is monotonically increasing + for i := 1; i < len(res.Accounts); i++ { + if bytes.Compare(res.Accounts[i-1].Hash[:], res.Accounts[i].Hash[:]) >= 0 { + return fmt.Errorf("accounts not monotonically increasing: #%d [%x] vs #%d [%x]", i-1, res.Accounts[i-1].Hash[:], i, res.Accounts[i].Hash[:]) + } + } + return backend.Handle(peer, res) + + case msg.Code == GetStorageRangesMsg: + // Decode the storage retrieval request + var req GetStorageRangesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // TODO(karalabe): Do we want to enforce > 0 accounts and 1 account if origin is set? + // TODO(karalabe): - Logging locally is not ideal as remote faulst annoy the local user + // TODO(karalabe): - Dropping the remote peer is less flexible wrt client bugs (slow is better than non-functional) + + // Calculate the hard limit at which to abort, even if mid storage trie + hardLimit := uint64(float64(req.Bytes) * (1 + stateLookupSlack)) + + // Retrieve storage ranges until the packet limit is reached + var ( + slots [][]*StorageData + proofs [][]byte + size uint64 + ) + for _, account := range req.Accounts { + // If we've exceeded the requested data limit, abort without opening + // a new storage range (that we'd need to prove due to exceeded size) + if size >= req.Bytes { + break + } + // The first account might start from a different origin and end sooner + var origin common.Hash + if len(req.Origin) > 0 { + origin, req.Origin = common.BytesToHash(req.Origin), nil + } + var limit = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if len(req.Limit) > 0 { + limit, req.Limit = common.BytesToHash(req.Limit), nil + } + // Retrieve the requested state and bail out if non existent + it, err := backend.Chain().Snapshots().StorageIterator(req.Root, account, origin) + if err != nil { + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID}) + } + // Iterate over the requested range and pile slots up + var ( + storage []*StorageData + last common.Hash + ) + for it.Next() && size < hardLimit { + hash, slot := it.Hash(), common.CopyBytes(it.Slot()) + + // Track the returned interval for the Merkle proofs + last = hash + + // Assemble the reply item + size += uint64(common.HashLength + len(slot)) + storage = append(storage, &StorageData{ + Hash: hash, + Body: slot, + }) + // If we've exceeded the request threshold, abort + if bytes.Compare(hash[:], limit[:]) >= 0 { + break + } + } + slots = append(slots, storage) + it.Release() + + // Generate the Merkle proofs for the first and last storage slot, but + // only if the response was capped. If the entire storage trie included + // in the response, no need for any proofs. + if origin != (common.Hash{}) || size >= hardLimit { + // Request started at a non-zero hash or was capped prematurely, add + // the endpoint Merkle proofs + accTrie, err := trie.New(req.Root, backend.Chain().StateCache().TrieDB()) + if err != nil { + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID}) + } + var acc state.Account + if err := rlp.DecodeBytes(accTrie.Get(account[:]), &acc); err != nil { + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID}) + } + stTrie, err := trie.New(acc.Root, backend.Chain().StateCache().TrieDB()) + if err != nil { + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID}) + } + proof := light.NewNodeSet() + if err := stTrie.Prove(origin[:], 0, proof); err != nil { + log.Warn("Failed to prove storage range", "origin", req.Origin, "err", err) + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID}) + } + if last != (common.Hash{}) { + if err := stTrie.Prove(last[:], 0, proof); err != nil { + log.Warn("Failed to prove storage range", "last", last, "err", err) + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ID: req.ID}) + } + } + for _, blob := range proof.NodeList() { + proofs = append(proofs, blob) + } + // Proof terminates the reply as proofs are only added if a node + // refuses to serve more data (exception when a contract fetch is + // finishing, but that's that). + break + } + } + // Send back anything accumulated + return p2p.Send(peer.rw, StorageRangesMsg, &StorageRangesPacket{ + ID: req.ID, + Slots: slots, + Proof: proofs, + }) + + case msg.Code == StorageRangesMsg: + // A range of storage slots arrived to one of our previous requests + res := new(StorageRangesPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + // Ensure the ranges ae monotonically increasing + for i, slots := range res.Slots { + for j := 1; j < len(slots); j++ { + if bytes.Compare(slots[j-1].Hash[:], slots[j].Hash[:]) >= 0 { + return fmt.Errorf("storage slots not monotonically increasing for account #%d: #%d [%x] vs #%d [%x]", i, j-1, slots[j-1].Hash[:], j, slots[j].Hash[:]) + } + } + } + return backend.Handle(peer, res) + + case msg.Code == GetByteCodesMsg: + // Decode bytecode retrieval request + var req GetByteCodesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + if len(req.Hashes) > maxCodeLookups { + req.Hashes = req.Hashes[:maxCodeLookups] + } + // Retrieve bytecodes until the packet size limit is reached + var ( + codes [][]byte + bytes uint64 + ) + for _, hash := range req.Hashes { + if hash == emptyCode { + // Peers should not request the empty code, but if they do, at + // least sent them back a correct response without db lookups + codes = append(codes, []byte{}) + } else if blob, err := backend.Chain().ContractCode(hash); err == nil { + codes = append(codes, blob) + bytes += uint64(len(blob)) + } + if bytes > req.Bytes { + break + } + } + // Send back anything accumulated + return p2p.Send(peer.rw, ByteCodesMsg, &ByteCodesPacket{ + ID: req.ID, + Codes: codes, + }) + + case msg.Code == ByteCodesMsg: + // A batch of byte codes arrived to one of our previous requests + res := new(ByteCodesPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return backend.Handle(peer, res) + + case msg.Code == GetTrieNodesMsg: + // Decode trie node retrieval request + var req GetTrieNodesPacket + if err := msg.Decode(&req); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + if req.Bytes > softResponseLimit { + req.Bytes = softResponseLimit + } + // Make sure we have the state associated with the request + triedb := backend.Chain().StateCache().TrieDB() + + accTrie, err := trie.NewSecure(req.Root, triedb) + if err != nil { + // We don't have the requested state available, bail out + return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ID: req.ID}) + } + snap := backend.Chain().Snapshots().Snapshot(req.Root) + if snap == nil { + // We don't have the requested state snapshotted yet, bail out. + // In reality we could still serve using the account and storage + // tries only, but let's protect the node a bit while it's doing + // snapshot generation. + return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ID: req.ID}) + } + // Retrieve trie nodes until the packet size limit is reached + var ( + nodes [][]byte + bytes uint64 + loads int // Trie hash expansions to cound database reads + ) + for _, pathset := range req.Paths { + switch len(pathset) { + case 0: + // Ensure we penalize invalid requests + return fmt.Errorf("%w: zero-item pathset requested", errBadRequest) + + case 1: + // If we're only retrieving an account trie node, fetch it directly + blob, resolved, err := accTrie.TryGetNode(pathset[0]) + loads += resolved // always account database reads, even for failures + if err != nil { + break + } + nodes = append(nodes, blob) + bytes += uint64(len(blob)) + + default: + // Storage slots requested, open the storage trie and retrieve from there + account, err := snap.Account(common.BytesToHash(pathset[0])) + loads++ // always account database reads, even for failures + if err != nil { + break + } + stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb) + loads++ // always account database reads, even for failures + if err != nil { + break + } + for _, path := range pathset[1:] { + blob, resolved, err := stTrie.TryGetNode(path) + loads += resolved // always account database reads, even for failures + if err != nil { + break + } + nodes = append(nodes, blob) + bytes += uint64(len(blob)) + + // Sanity check limits to avoid DoS on the store trie loads + if bytes > req.Bytes || loads > maxTrieNodeLookups { + break + } + } + } + // Abort request processing if we've exceeded our limits + if bytes > req.Bytes || loads > maxTrieNodeLookups { + break + } + } + // Send back anything accumulated + return p2p.Send(peer.rw, TrieNodesMsg, &TrieNodesPacket{ + ID: req.ID, + Nodes: nodes, + }) + + case msg.Code == TrieNodesMsg: + // A batch of trie nodes arrived to one of our previous requests + res := new(TrieNodesPacket) + if err := msg.Decode(res); err != nil { + return fmt.Errorf("%w: message %v: %v", errDecode, msg, err) + } + return backend.Handle(peer, res) + + default: + return fmt.Errorf("%w: %v", errInvalidMsgCode, msg.Code) + } +} + +// NodeInfo represents a short summary of the `snap` sub-protocol metadata +// known about the host peer. +type NodeInfo struct{} + +// nodeInfo retrieves some `snap` protocol metadata about the running host node. +func nodeInfo(chain *core.BlockChain) *NodeInfo { + return &NodeInfo{} +} diff --git a/eth/protocols/snap/peer.go b/eth/protocols/snap/peer.go new file mode 100644 index 000000000000..73eaaadd0929 --- /dev/null +++ b/eth/protocols/snap/peer.go @@ -0,0 +1,111 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p" +) + +// Peer is a collection of relevant information we have about a `snap` peer. +type Peer struct { + id string // Unique ID for the peer, cached + + *p2p.Peer // The embedded P2P package peer + rw p2p.MsgReadWriter // Input/output streams for snap + version uint // Protocol version negotiated + + logger log.Logger // Contextual logger with the peer id injected +} + +// newPeer create a wrapper for a network connection and negotiated protocol +// version. +func newPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) *Peer { + id := p.ID().String() + return &Peer{ + id: id, + Peer: p, + rw: rw, + version: version, + logger: log.New("peer", id[:8]), + } +} + +// ID retrieves the peer's unique identifier. +func (p *Peer) ID() string { + return p.id +} + +// Version retrieves the peer's negoatiated `snap` protocol version. +func (p *Peer) Version() uint { + return p.version +} + +// RequestAccountRange fetches a batch of accounts rooted in a specific account +// trie, starting with the origin. +func (p *Peer) RequestAccountRange(id uint64, root common.Hash, origin, limit common.Hash, bytes uint64) error { + p.logger.Trace("Fetching range of accounts", "reqid", id, "root", root, "origin", origin, "limit", limit, "bytes", common.StorageSize(bytes)) + return p2p.Send(p.rw, GetAccountRangeMsg, &GetAccountRangePacket{ + ID: id, + Root: root, + Origin: origin, + Limit: limit, + Bytes: bytes, + }) +} + +// RequestStorageRange fetches a batch of storage slots belonging to one or more +// accounts. If slots from only one accout is requested, an origin marker may also +// be used to retrieve from there. +func (p *Peer) RequestStorageRanges(id uint64, root common.Hash, accounts []common.Hash, origin, limit []byte, bytes uint64) error { + if len(accounts) == 1 && origin != nil { + p.logger.Trace("Fetching range of large storage slots", "reqid", id, "root", root, "account", accounts[0], "origin", common.BytesToHash(origin), "limit", common.BytesToHash(limit), "bytes", common.StorageSize(bytes)) + } else { + p.logger.Trace("Fetching ranges of small storage slots", "reqid", id, "root", root, "accounts", len(accounts), "first", accounts[0], "bytes", common.StorageSize(bytes)) + } + return p2p.Send(p.rw, GetStorageRangesMsg, &GetStorageRangesPacket{ + ID: id, + Root: root, + Accounts: accounts, + Origin: origin, + Limit: limit, + Bytes: bytes, + }) +} + +// RequestByteCodes fetches a batch of bytecodes by hash. +func (p *Peer) RequestByteCodes(id uint64, hashes []common.Hash, bytes uint64) error { + p.logger.Trace("Fetching set of byte codes", "reqid", id, "hashes", len(hashes), "bytes", common.StorageSize(bytes)) + return p2p.Send(p.rw, GetByteCodesMsg, &GetByteCodesPacket{ + ID: id, + Hashes: hashes, + Bytes: bytes, + }) +} + +// RequestTrieNodes fetches a batch of account or storage trie nodes rooted in +// a specificstate trie. +func (p *Peer) RequestTrieNodes(id uint64, root common.Hash, paths []TrieNodePathSet, bytes uint64) error { + p.logger.Trace("Fetching set of trie nodes", "reqid", id, "root", root, "pathsets", len(paths), "bytes", common.StorageSize(bytes)) + return p2p.Send(p.rw, GetTrieNodesMsg, &GetTrieNodesPacket{ + ID: id, + Root: root, + Paths: paths, + Bytes: bytes, + }) +} diff --git a/eth/protocols/snap/protocol.go b/eth/protocols/snap/protocol.go new file mode 100644 index 000000000000..a1e4349691b4 --- /dev/null +++ b/eth/protocols/snap/protocol.go @@ -0,0 +1,218 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/rlp" +) + +// Constants to match up protocol versions and messages +const ( + snap1 = 1 +) + +// protocolName is the official short name of the `snap` protocol used during +// devp2p capability negotiation. +const protocolName = "snap" + +// protocolVersions are the supported versions of the `snap` protocol (first +// is primary). +var protocolVersions = []uint{snap1} + +// protocolLengths are the number of implemented message corresponding to +// different protocol versions. +var protocolLengths = map[uint]uint64{snap1: 8} + +// maxMessageSize is the maximum cap on the size of a protocol message. +const maxMessageSize = 10 * 1024 * 1024 + +const ( + GetAccountRangeMsg = 0x00 + AccountRangeMsg = 0x01 + GetStorageRangesMsg = 0x02 + StorageRangesMsg = 0x03 + GetByteCodesMsg = 0x04 + ByteCodesMsg = 0x05 + GetTrieNodesMsg = 0x06 + TrieNodesMsg = 0x07 +) + +var ( + errMsgTooLarge = errors.New("message too long") + errDecode = errors.New("invalid message") + errInvalidMsgCode = errors.New("invalid message code") + errBadRequest = errors.New("bad request") +) + +// Packet represents a p2p message in the `snap` protocol. +type Packet interface { + Name() string // Name returns a string corresponding to the message type. + Kind() byte // Kind returns the message type. +} + +// GetAccountRangePacket represents an account query. +type GetAccountRangePacket struct { + ID uint64 // Request ID to match up responses with + Root common.Hash // Root hash of the account trie to serve + Origin common.Hash // Hash of the first account to retrieve + Limit common.Hash // Hash of the last account to retrieve + Bytes uint64 // Soft limit at which to stop returning data +} + +// AccountRangePacket represents an account query response. +type AccountRangePacket struct { + ID uint64 // ID of the request this is a response for + Accounts []*AccountData // List of consecutive accounts from the trie + Proof [][]byte // List of trie nodes proving the account range +} + +// AccountData represents a single account in a query response. +type AccountData struct { + Hash common.Hash // Hash of the account + Body rlp.RawValue // Account body in slim format +} + +// Unpack retrieves the accounts from the range packet and converts from slim +// wire representation to consensus format. The returned data is RLP encoded +// since it's expected to be serialized to disk without further interpretation. +// +// Note, this method does a round of RLP decoding and reencoding, so only use it +// once and cache the results if need be. Ideally discard the packet afterwards +// to not double the memory use. +func (p *AccountRangePacket) Unpack() ([]common.Hash, [][]byte, error) { + var ( + hashes = make([]common.Hash, len(p.Accounts)) + accounts = make([][]byte, len(p.Accounts)) + ) + for i, acc := range p.Accounts { + val, err := snapshot.FullAccountRLP(acc.Body) + if err != nil { + return nil, nil, fmt.Errorf("invalid account %x: %v", acc.Body, err) + } + hashes[i], accounts[i] = acc.Hash, val + } + return hashes, accounts, nil +} + +// GetStorageRangesPacket represents an storage slot query. +type GetStorageRangesPacket struct { + ID uint64 // Request ID to match up responses with + Root common.Hash // Root hash of the account trie to serve + Accounts []common.Hash // Account hashes of the storage tries to serve + Origin []byte // Hash of the first storage slot to retrieve (large contract mode) + Limit []byte // Hash of the last storage slot to retrieve (large contract mode) + Bytes uint64 // Soft limit at which to stop returning data +} + +// StorageRangesPacket represents a storage slot query response. +type StorageRangesPacket struct { + ID uint64 // ID of the request this is a response for + Slots [][]*StorageData // Lists of consecutive storage slots for the requested accounts + Proof [][]byte // Merkle proofs for the *last* slot range, if it's incomplete +} + +// StorageData represents a single storage slot in a query response. +type StorageData struct { + Hash common.Hash // Hash of the storage slot + Body []byte // Data content of the slot +} + +// Unpack retrieves the storage slots from the range packet and returns them in +// a split flat format that's more consistent with the internal data structures. +func (p *StorageRangesPacket) Unpack() ([][]common.Hash, [][][]byte) { + var ( + hashset = make([][]common.Hash, len(p.Slots)) + slotset = make([][][]byte, len(p.Slots)) + ) + for i, slots := range p.Slots { + hashset[i] = make([]common.Hash, len(slots)) + slotset[i] = make([][]byte, len(slots)) + for j, slot := range slots { + hashset[i][j] = slot.Hash + slotset[i][j] = slot.Body + } + } + return hashset, slotset +} + +// GetByteCodesPacket represents a contract bytecode query. +type GetByteCodesPacket struct { + ID uint64 // Request ID to match up responses with + Hashes []common.Hash // Code hashes to retrieve the code for + Bytes uint64 // Soft limit at which to stop returning data +} + +// ByteCodesPacket represents a contract bytecode query response. +type ByteCodesPacket struct { + ID uint64 // ID of the request this is a response for + Codes [][]byte // Requested contract bytecodes +} + +// GetTrieNodesPacket represents a state trie node query. +type GetTrieNodesPacket struct { + ID uint64 // Request ID to match up responses with + Root common.Hash // Root hash of the account trie to serve + Paths []TrieNodePathSet // Trie node hashes to retrieve the nodes for + Bytes uint64 // Soft limit at which to stop returning data +} + +// TrieNodePathSet is a list of trie node paths to retrieve. A naive way to +// represent trie nodes would be a simple list of `account || storage` path +// segments concatenated, but that would be very wasteful on the network. +// +// Instead, this array special cases the first element as the path in the +// account trie and the remaining elements as paths in the storage trie. To +// address an account node, the slice should have a length of 1 consisting +// of only the account path. There's no need to be able to address both an +// account node and a storage node in the same request as it cannot happen +// that a slot is accessed before the account path is fully expanded. +type TrieNodePathSet [][]byte + +// TrieNodesPacket represents a state trie node query response. +type TrieNodesPacket struct { + ID uint64 // ID of the request this is a response for + Nodes [][]byte // Requested state trie nodes +} + +func (*GetAccountRangePacket) Name() string { return "GetAccountRange" } +func (*GetAccountRangePacket) Kind() byte { return GetAccountRangeMsg } + +func (*AccountRangePacket) Name() string { return "AccountRange" } +func (*AccountRangePacket) Kind() byte { return AccountRangeMsg } + +func (*GetStorageRangesPacket) Name() string { return "GetStorageRanges" } +func (*GetStorageRangesPacket) Kind() byte { return GetStorageRangesMsg } + +func (*StorageRangesPacket) Name() string { return "StorageRanges" } +func (*StorageRangesPacket) Kind() byte { return StorageRangesMsg } + +func (*GetByteCodesPacket) Name() string { return "GetByteCodes" } +func (*GetByteCodesPacket) Kind() byte { return GetByteCodesMsg } + +func (*ByteCodesPacket) Name() string { return "ByteCodes" } +func (*ByteCodesPacket) Kind() byte { return ByteCodesMsg } + +func (*GetTrieNodesPacket) Name() string { return "GetTrieNodes" } +func (*GetTrieNodesPacket) Kind() byte { return GetTrieNodesMsg } + +func (*TrieNodesPacket) Name() string { return "TrieNodes" } +func (*TrieNodesPacket) Kind() byte { return TrieNodesMsg } diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go new file mode 100644 index 000000000000..679b328283b0 --- /dev/null +++ b/eth/protocols/snap/sync.go @@ -0,0 +1,2481 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package snap + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "math/big" + "math/rand" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/light" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/trie" + "golang.org/x/crypto/sha3" +) + +var ( + // emptyRoot is the known root hash of an empty trie. + emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") + + // emptyCode is the known hash of the empty EVM bytecode. + emptyCode = crypto.Keccak256Hash(nil) +) + +const ( + // maxRequestSize is the maximum number of bytes to request from a remote peer. + maxRequestSize = 512 * 1024 + + // maxStorageSetRequestCountis th maximum number of contracts to request the + // storage of in a single query. If this number is too low, we're not filling + // responses fully and waste round trip times. If it's too high, we're capping + // responses and waste bandwidth. + maxStorageSetRequestCount = maxRequestSize / 1024 + + // maxCodeRequestCount is the maximum number of bytecode blobs to request in a + // single query. If this number is too low, we're not filling responses fully + // and waste round trip times. If it's too high, we're capping responses and + // waste bandwidth. + // + // Depoyed bytecodes are currently capped at 24KB, so the minimum request + // size should be maxRequestSize / 24K. Assuming that most contracts do not + // come close to that, requesting 4x should be a good approximation. + maxCodeRequestCount = maxRequestSize / (24 * 1024) * 4 + + // maxTrieRequestCount is the maximum number of trie node blobs to request in + // a single query. If this number is too low, we're not filling responses fully + // and waste round trip times. If it's too high, we're capping responses and + // waste bandwidth. + maxTrieRequestCount = 512 + + // requestTimeout is the maximum time a peer is allowed to spend on serving + // a single network request. + requestTimeout = 10 * time.Second // TODO(karalabe): Make it dynamic ala fast-sync? + + // accountConcurrency is the number of chunks to split the account trie into + // to allow concurrent retrievals. + accountConcurrency = 16 + + // storageConcurrency is the number of chunks to split the a large contract + // storage trie into to allow concurrent retrievals. + storageConcurrency = 16 +) + +// accountRequest tracks a pending account range request to ensure responses are +// to actual requests and to validate any security constraints. +// +// Concurrency note: account requests and responses are handled concurrently from +// the main runloop to allow Merkle proof verifications on the peer's thread and +// to drop on invalid response. The request struct must contain all the data to +// construct the response without accessing runloop internals (i.e. task). That +// is only included to allow the runloop to match a response to the task being +// synced without having yet another set of maps. +type accountRequest struct { + peer string // Peer to which this request is assigned + id uint64 // Request ID of this request + + cancel chan struct{} // Channel to track sync cancellation + timeout *time.Timer // Timer to track delivery timeout + stale chan struct{} // Channel to signal the request was dropped + + origin common.Hash // First account requested to allow continuation checks + limit common.Hash // Last account requested to allow non-overlapping chunking + + task *accountTask // Task which this request is filling (only access fields through the runloop!!) +} + +// accountResponse is an already Merkle-verified remote response to an account +// range request. It contains the subtrie for the requested account range and +// the database that's going to be filled with the internal nodes on commit. +type accountResponse struct { + task *accountTask // Task which this request is filling + + hashes []common.Hash // Account hashes in the returned range + accounts []*state.Account // Expanded accounts in the returned range + + nodes ethdb.KeyValueStore // Database containing the reconstructed trie nodes + trie *trie.Trie // Reconstructed trie to reject incomplete account paths + + bounds map[common.Hash]struct{} // Boundary nodes to avoid persisting incomplete accounts + overflow *light.NodeSet // Overflow nodes to avoid persisting across chunk boundaries + + cont bool // Whether the account range has a continuation +} + +// bytecodeRequest tracks a pending bytecode request to ensure responses are to +// actual requests and to validate any security constraints. +// +// Concurrency note: bytecode requests and responses are handled concurrently from +// the main runloop to allow Keccak256 hash verifications on the peer's thread and +// to drop on invalid response. The request struct must contain all the data to +// construct the response without accessing runloop internals (i.e. task). That +// is only included to allow the runloop to match a response to the task being +// synced without having yet another set of maps. +type bytecodeRequest struct { + peer string // Peer to which this request is assigned + id uint64 // Request ID of this request + + cancel chan struct{} // Channel to track sync cancellation + timeout *time.Timer // Timer to track delivery timeout + stale chan struct{} // Channel to signal the request was dropped + + hashes []common.Hash // Bytecode hashes to validate responses + task *accountTask // Task which this request is filling (only access fields through the runloop!!) +} + +// bytecodeResponse is an already verified remote response to a bytecode request. +type bytecodeResponse struct { + task *accountTask // Task which this request is filling + + hashes []common.Hash // Hashes of the bytecode to avoid double hashing + codes [][]byte // Actual bytecodes to store into the database (nil = missing) +} + +// storageRequest tracks a pending storage ranges request to ensure responses are +// to actual requests and to validate any security constraints. +// +// Concurrency note: storage requests and responses are handled concurrently from +// the main runloop to allow Merkel proof verifications on the peer's thread and +// to drop on invalid response. The request struct must contain all the data to +// construct the response without accessing runloop internals (i.e. tasks). That +// is only included to allow the runloop to match a response to the task being +// synced without having yet another set of maps. +type storageRequest struct { + peer string // Peer to which this request is assigned + id uint64 // Request ID of this request + + cancel chan struct{} // Channel to track sync cancellation + timeout *time.Timer // Timer to track delivery timeout + stale chan struct{} // Channel to signal the request was dropped + + accounts []common.Hash // Account hashes to validate responses + roots []common.Hash // Storage roots to validate responses + + origin common.Hash // First storage slot requested to allow continuation checks + limit common.Hash // Last storage slot requested to allow non-overlapping chunking + + mainTask *accountTask // Task which this response belongs to (only access fields through the runloop!!) + subTask *storageTask // Task which this response is filling (only access fields through the runloop!!) +} + +// storageResponse is an already Merkle-verified remote response to a storage +// range request. It contains the subtries for the requested storage ranges and +// the databases that's going to be filled with the internal nodes on commit. +type storageResponse struct { + mainTask *accountTask // Task which this response belongs to + subTask *storageTask // Task which this response is filling + + accounts []common.Hash // Account hashes requested, may be only partially filled + roots []common.Hash // Storage roots requested, may be only partially filled + + hashes [][]common.Hash // Storage slot hashes in the returned range + slots [][][]byte // Storage slot values in the returned range + nodes []ethdb.KeyValueStore // Database containing the reconstructed trie nodes + tries []*trie.Trie // Reconstructed tries to reject overflown slots + + // Fields relevant for the last account only + bounds map[common.Hash]struct{} // Boundary nodes to avoid persisting (incomplete) + overflow *light.NodeSet // Overflow nodes to avoid persisting across chunk boundaries + cont bool // Whether the last storage range has a continuation +} + +// trienodeHealRequest tracks a pending state trie request to ensure responses +// are to actual requests and to validate any security constraints. +// +// Concurrency note: trie node requests and responses are handled concurrently from +// the main runloop to allow Keccak256 hash verifications on the peer's thread and +// to drop on invalid response. The request struct must contain all the data to +// construct the response without accessing runloop internals (i.e. task). That +// is only included to allow the runloop to match a response to the task being +// synced without having yet another set of maps. +type trienodeHealRequest struct { + peer string // Peer to which this request is assigned + id uint64 // Request ID of this request + + cancel chan struct{} // Channel to track sync cancellation + timeout *time.Timer // Timer to track delivery timeout + stale chan struct{} // Channel to signal the request was dropped + + hashes []common.Hash // Trie node hashes to validate responses + paths []trie.SyncPath // Trie node paths requested for rescheduling + + task *healTask // Task which this request is filling (only access fields through the runloop!!) +} + +// trienodeHealResponse is an already verified remote response to a trie node request. +type trienodeHealResponse struct { + task *healTask // Task which this request is filling + + hashes []common.Hash // Hashes of the trie nodes to avoid double hashing + paths []trie.SyncPath // Trie node paths requested for rescheduling missing ones + nodes [][]byte // Actual trie nodes to store into the database (nil = missing) +} + +// bytecodeHealRequest tracks a pending bytecode request to ensure responses are to +// actual requests and to validate any security constraints. +// +// Concurrency note: bytecode requests and responses are handled concurrently from +// the main runloop to allow Keccak256 hash verifications on the peer's thread and +// to drop on invalid response. The request struct must contain all the data to +// construct the response without accessing runloop internals (i.e. task). That +// is only included to allow the runloop to match a response to the task being +// synced without having yet another set of maps. +type bytecodeHealRequest struct { + peer string // Peer to which this request is assigned + id uint64 // Request ID of this request + + cancel chan struct{} // Channel to track sync cancellation + timeout *time.Timer // Timer to track delivery timeout + stale chan struct{} // Channel to signal the request was dropped + + hashes []common.Hash // Bytecode hashes to validate responses + task *healTask // Task which this request is filling (only access fields through the runloop!!) +} + +// bytecodeHealResponse is an already verified remote response to a bytecode request. +type bytecodeHealResponse struct { + task *healTask // Task which this request is filling + + hashes []common.Hash // Hashes of the bytecode to avoid double hashing + codes [][]byte // Actual bytecodes to store into the database (nil = missing) +} + +// accountTask represents the sync task for a chunk of the account snapshot. +type accountTask struct { + // These fields get serialized to leveldb on shutdown + Next common.Hash // Next account to sync in this interval + Last common.Hash // Last account to sync in this interval + SubTasks map[common.Hash][]*storageTask // Storage intervals needing fetching for large contracts + + // These fields are internals used during runtime + req *accountRequest // Pending request to fill this task + res *accountResponse // Validate response filling this task + pend int // Number of pending subtasks for this round + + needCode []bool // Flags whether the filling accounts need code retrieval + needState []bool // Flags whether the filling accounts need storage retrieval + needHeal []bool // Flags whether the filling accounts's state was chunked and need healing + + codeTasks map[common.Hash]struct{} // Code hashes that need retrieval + stateTasks map[common.Hash]common.Hash // Account hashes->roots that need full state retrieval + + done bool // Flag whether the task can be removed +} + +// storageTask represents the sync task for a chunk of the storage snapshot. +type storageTask struct { + Next common.Hash // Next account to sync in this interval + Last common.Hash // Last account to sync in this interval + + // These fields are internals used during runtime + root common.Hash // Storage root hash for this instance + req *storageRequest // Pending request to fill this task + done bool // Flag whether the task can be removed +} + +// healTask represents the sync task for healing the snap-synced chunk boundaries. +type healTask struct { + scheduler *trie.Sync // State trie sync scheduler defining the tasks + + trieTasks map[common.Hash]trie.SyncPath // Set of trie node tasks currently queued for retrieval + codeTasks map[common.Hash]struct{} // Set of byte code tasks currently queued for retrieval +} + +// syncProgress is a database entry to allow suspending and resuming a snapshot state +// sync. Opposed to full and fast sync, there is no way to restart a suspended +// snap sync without prior knowledge of the suspension point. +type syncProgress struct { + Tasks []*accountTask // The suspended account tasks (contract tasks within) + + // Status report during syncing phase + AccountSynced uint64 // Number of accounts downloaded + AccountBytes common.StorageSize // Number of account trie bytes persisted to disk + BytecodeSynced uint64 // Number of bytecodes downloaded + BytecodeBytes common.StorageSize // Number of bytecode bytes downloaded + StorageSynced uint64 // Number of storage slots downloaded + StorageBytes common.StorageSize // Number of storage trie bytes persisted to disk + + // Status report during healing phase + TrienodeHealSynced uint64 // Number of state trie nodes downloaded + TrienodeHealBytes common.StorageSize // Number of state trie bytes persisted to disk + TrienodeHealDups uint64 // Number of state trie nodes already processed + TrienodeHealNops uint64 // Number of state trie nodes not requested + BytecodeHealSynced uint64 // Number of bytecodes downloaded + BytecodeHealBytes common.StorageSize // Number of bytecodes persisted to disk + BytecodeHealDups uint64 // Number of bytecodes already processed + BytecodeHealNops uint64 // Number of bytecodes not requested +} + +// Syncer is an Ethereum account and storage trie syncer based on snapshots and +// the snap protocol. It's purpose is to download all the accounts and storage +// slots from remote peers and reassemble chunks of the state trie, on top of +// which a state sync can be run to fix any gaps / overlaps. +// +// Every network request has a variety of failure events: +// - The peer disconnects after task assignment, failing to send the request +// - The peer disconnects after sending the request, before delivering on it +// - The peer remains connected, but does not deliver a response in time +// - The peer delivers a stale response after a previous timeout +// - The peer delivers a refusal to serve the requested state +type Syncer struct { + db ethdb.KeyValueStore // Database to store the trie nodes into (and dedup) + bloom *trie.SyncBloom // Bloom filter to deduplicate nodes for state fixup + + root common.Hash // Current state trie root being synced + tasks []*accountTask // Current account task set being synced + healer *healTask // Current state healing task being executed + update chan struct{} // Notification channel for possible sync progression + + peers map[string]*Peer // Currently active peers to download from + peerJoin *event.Feed // Event feed to react to peers joining + peerDrop *event.Feed // Event feed to react to peers dropping + + // Request tracking during syncing phase + statelessPeers map[string]struct{} // Peers that failed to deliver state data + accountIdlers map[string]struct{} // Peers that aren't serving account requests + bytecodeIdlers map[string]struct{} // Peers that aren't serving bytecode requests + storageIdlers map[string]struct{} // Peers that aren't serving storage requests + + accountReqs map[uint64]*accountRequest // Account requests currently running + bytecodeReqs map[uint64]*bytecodeRequest // Bytecode requests currently running + storageReqs map[uint64]*storageRequest // Storage requests currently running + + accountReqFails chan *accountRequest // Failed account range requests to revert + bytecodeReqFails chan *bytecodeRequest // Failed bytecode requests to revert + storageReqFails chan *storageRequest // Failed storage requests to revert + + accountResps chan *accountResponse // Account sub-tries to integrate into the database + bytecodeResps chan *bytecodeResponse // Bytecodes to integrate into the database + storageResps chan *storageResponse // Storage sub-tries to integrate into the database + + accountSynced uint64 // Number of accounts downloaded + accountBytes common.StorageSize // Number of account trie bytes persisted to disk + bytecodeSynced uint64 // Number of bytecodes downloaded + bytecodeBytes common.StorageSize // Number of bytecode bytes downloaded + storageSynced uint64 // Number of storage slots downloaded + storageBytes common.StorageSize // Number of storage trie bytes persisted to disk + + // Request tracking during healing phase + trienodeHealIdlers map[string]struct{} // Peers that aren't serving trie node requests + bytecodeHealIdlers map[string]struct{} // Peers that aren't serving bytecode requests + + trienodeHealReqs map[uint64]*trienodeHealRequest // Trie node requests currently running + bytecodeHealReqs map[uint64]*bytecodeHealRequest // Bytecode requests currently running + + trienodeHealReqFails chan *trienodeHealRequest // Failed trienode requests to revert + bytecodeHealReqFails chan *bytecodeHealRequest // Failed bytecode requests to revert + + trienodeHealResps chan *trienodeHealResponse // Trie nodes to integrate into the database + bytecodeHealResps chan *bytecodeHealResponse // Bytecodes to integrate into the database + + trienodeHealSynced uint64 // Number of state trie nodes downloaded + trienodeHealBytes common.StorageSize // Number of state trie bytes persisted to disk + trienodeHealDups uint64 // Number of state trie nodes already processed + trienodeHealNops uint64 // Number of state trie nodes not requested + bytecodeHealSynced uint64 // Number of bytecodes downloaded + bytecodeHealBytes common.StorageSize // Number of bytecodes persisted to disk + bytecodeHealDups uint64 // Number of bytecodes already processed + bytecodeHealNops uint64 // Number of bytecodes not requested + + startTime time.Time // Time instance when snapshot sync started + startAcc common.Hash // Account hash where sync started from + logTime time.Time // Time instance when status was last reported + + pend sync.WaitGroup // Tracks network request goroutines for graceful shutdown + lock sync.RWMutex // Protects fields that can change outside of sync (peers, reqs, root) +} + +func NewSyncer(db ethdb.KeyValueStore, bloom *trie.SyncBloom) *Syncer { + return &Syncer{ + db: db, + bloom: bloom, + + peers: make(map[string]*Peer), + peerJoin: new(event.Feed), + peerDrop: new(event.Feed), + update: make(chan struct{}, 1), + + accountIdlers: make(map[string]struct{}), + storageIdlers: make(map[string]struct{}), + bytecodeIdlers: make(map[string]struct{}), + + accountReqs: make(map[uint64]*accountRequest), + storageReqs: make(map[uint64]*storageRequest), + bytecodeReqs: make(map[uint64]*bytecodeRequest), + accountReqFails: make(chan *accountRequest), + storageReqFails: make(chan *storageRequest), + bytecodeReqFails: make(chan *bytecodeRequest), + accountResps: make(chan *accountResponse), + storageResps: make(chan *storageResponse), + bytecodeResps: make(chan *bytecodeResponse), + + trienodeHealIdlers: make(map[string]struct{}), + bytecodeHealIdlers: make(map[string]struct{}), + + trienodeHealReqs: make(map[uint64]*trienodeHealRequest), + bytecodeHealReqs: make(map[uint64]*bytecodeHealRequest), + trienodeHealReqFails: make(chan *trienodeHealRequest), + bytecodeHealReqFails: make(chan *bytecodeHealRequest), + trienodeHealResps: make(chan *trienodeHealResponse), + bytecodeHealResps: make(chan *bytecodeHealResponse), + } +} + +// Register injects a new data source into the syncer's peerset. +func (s *Syncer) Register(peer *Peer) error { + // Make sure the peer is not registered yet + s.lock.Lock() + if _, ok := s.peers[peer.id]; ok { + log.Error("Snap peer already registered", "id", peer.id) + + s.lock.Unlock() + return errors.New("already registered") + } + s.peers[peer.id] = peer + + // Mark the peer as idle, even if no sync is running + s.accountIdlers[peer.id] = struct{}{} + s.storageIdlers[peer.id] = struct{}{} + s.bytecodeIdlers[peer.id] = struct{}{} + s.trienodeHealIdlers[peer.id] = struct{}{} + s.bytecodeHealIdlers[peer.id] = struct{}{} + s.lock.Unlock() + + // Notify any active syncs that a new peer can be assigned data + s.peerJoin.Send(peer.id) + return nil +} + +// Unregister injects a new data source into the syncer's peerset. +func (s *Syncer) Unregister(id string) error { + // Remove all traces of the peer from the registry + s.lock.Lock() + if _, ok := s.peers[id]; !ok { + log.Error("Snap peer not registered", "id", id) + + s.lock.Unlock() + return errors.New("not registered") + } + delete(s.peers, id) + + // Remove status markers, even if no sync is running + delete(s.statelessPeers, id) + + delete(s.accountIdlers, id) + delete(s.storageIdlers, id) + delete(s.bytecodeIdlers, id) + delete(s.trienodeHealIdlers, id) + delete(s.bytecodeHealIdlers, id) + s.lock.Unlock() + + // Notify any active syncs that pending requests need to be reverted + s.peerDrop.Send(id) + return nil +} + +// Sync starts (or resumes a previous) sync cycle to iterate over an state trie +// with the given root and reconstruct the nodes based on the snapshot leaves. +// Previously downloaded segments will not be redownloaded of fixed, rather any +// errors will be healed after the leaves are fully accumulated. +func (s *Syncer) Sync(root common.Hash, cancel chan struct{}) error { + // Move the trie root from any previous value, revert stateless markers for + // any peers and initialize the syncer if it was not yet run + s.lock.Lock() + s.root = root + s.healer = &healTask{ + scheduler: state.NewStateSync(root, s.db, s.bloom), + trieTasks: make(map[common.Hash]trie.SyncPath), + codeTasks: make(map[common.Hash]struct{}), + } + s.statelessPeers = make(map[string]struct{}) + s.lock.Unlock() + + if s.startTime == (time.Time{}) { + s.startTime = time.Now() + } + // Retrieve the previous sync status from LevelDB and abort if already synced + s.loadSyncStatus() + if len(s.tasks) == 0 && s.healer.scheduler.Pending() == 0 { + log.Debug("Snapshot sync already completed") + return nil + } + defer func() { // Persist any progress, independent of failure + for _, task := range s.tasks { + s.forwardAccountTask(task) + } + s.cleanAccountTasks() + s.saveSyncStatus() + }() + + log.Debug("Starting snapshot sync cycle", "root", root) + defer s.report(true) + + // Whether sync completed or not, disregard any future packets + defer func() { + log.Debug("Terminating snapshot sync cycle", "root", root) + s.lock.Lock() + s.accountReqs = make(map[uint64]*accountRequest) + s.storageReqs = make(map[uint64]*storageRequest) + s.bytecodeReqs = make(map[uint64]*bytecodeRequest) + s.trienodeHealReqs = make(map[uint64]*trienodeHealRequest) + s.bytecodeHealReqs = make(map[uint64]*bytecodeHealRequest) + s.lock.Unlock() + }() + // Keep scheduling sync tasks + peerJoin := make(chan string, 16) + peerJoinSub := s.peerJoin.Subscribe(peerJoin) + defer peerJoinSub.Unsubscribe() + + peerDrop := make(chan string, 16) + peerDropSub := s.peerDrop.Subscribe(peerDrop) + defer peerDropSub.Unsubscribe() + + for { + // Remove all completed tasks and terminate sync if everything's done + s.cleanStorageTasks() + s.cleanAccountTasks() + if len(s.tasks) == 0 && s.healer.scheduler.Pending() == 0 { + return nil + } + // Assign all the data retrieval tasks to any free peers + s.assignAccountTasks(cancel) + s.assignBytecodeTasks(cancel) + s.assignStorageTasks(cancel) + if len(s.tasks) == 0 { + // Sync phase done, run heal phase + s.assignTrienodeHealTasks(cancel) + s.assignBytecodeHealTasks(cancel) + } + // Wait for something to happen + select { + case <-s.update: + // Something happened (new peer, delivery, timeout), recheck tasks + case <-peerJoin: + // A new peer joined, try to schedule it new tasks + case id := <-peerDrop: + s.revertRequests(id) + case <-cancel: + return nil + + case req := <-s.accountReqFails: + s.revertAccountRequest(req) + case req := <-s.bytecodeReqFails: + s.revertBytecodeRequest(req) + case req := <-s.storageReqFails: + s.revertStorageRequest(req) + case req := <-s.trienodeHealReqFails: + s.revertTrienodeHealRequest(req) + case req := <-s.bytecodeHealReqFails: + s.revertBytecodeHealRequest(req) + + case res := <-s.accountResps: + s.processAccountResponse(res) + case res := <-s.bytecodeResps: + s.processBytecodeResponse(res) + case res := <-s.storageResps: + s.processStorageResponse(res) + case res := <-s.trienodeHealResps: + s.processTrienodeHealResponse(res) + case res := <-s.bytecodeHealResps: + s.processBytecodeHealResponse(res) + } + // Report stats if something meaningful happened + s.report(false) + } +} + +// loadSyncStatus retrieves a previously aborted sync status from the database, +// or generates a fresh one if none is available. +func (s *Syncer) loadSyncStatus() { + var progress syncProgress + + if status := rawdb.ReadSanpshotSyncStatus(s.db); status != nil { + if err := json.Unmarshal(status, &progress); err != nil { + log.Error("Failed to decode snap sync status", "err", err) + } else { + for _, task := range progress.Tasks { + log.Debug("Scheduled account sync task", "from", task.Next, "last", task.Last) + } + s.tasks = progress.Tasks + + s.accountSynced = progress.AccountSynced + s.accountBytes = progress.AccountBytes + s.bytecodeSynced = progress.BytecodeSynced + s.bytecodeBytes = progress.BytecodeBytes + s.storageSynced = progress.StorageSynced + s.storageBytes = progress.StorageBytes + + s.trienodeHealSynced = progress.TrienodeHealSynced + s.trienodeHealBytes = progress.TrienodeHealBytes + s.bytecodeHealSynced = progress.BytecodeHealSynced + s.bytecodeHealBytes = progress.BytecodeHealBytes + return + } + } + // Either we've failed to decode the previus state, or there was none. + // Start a fresh sync by chunking up the account range and scheduling + // them for retrieval. + s.tasks = nil + s.accountSynced, s.accountBytes = 0, 0 + s.bytecodeSynced, s.bytecodeBytes = 0, 0 + s.storageSynced, s.storageBytes = 0, 0 + s.trienodeHealSynced, s.trienodeHealBytes = 0, 0 + s.bytecodeHealSynced, s.bytecodeHealBytes = 0, 0 + + var next common.Hash + step := new(big.Int).Sub( + new(big.Int).Div( + new(big.Int).Exp(common.Big2, common.Big256, nil), + big.NewInt(accountConcurrency), + ), common.Big1, + ) + for i := 0; i < accountConcurrency; i++ { + last := common.BigToHash(new(big.Int).Add(next.Big(), step)) + if i == accountConcurrency-1 { + // Make sure we don't overflow if the step is not a proper divisor + last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + } + s.tasks = append(s.tasks, &accountTask{ + Next: next, + Last: last, + SubTasks: make(map[common.Hash][]*storageTask), + }) + log.Debug("Created account sync task", "from", next, "last", last) + next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) + } +} + +// saveSyncStatus marshals the remaining sync tasks into leveldb. +func (s *Syncer) saveSyncStatus() { + progress := &syncProgress{ + Tasks: s.tasks, + AccountSynced: s.accountSynced, + AccountBytes: s.accountBytes, + BytecodeSynced: s.bytecodeSynced, + BytecodeBytes: s.bytecodeBytes, + StorageSynced: s.storageSynced, + StorageBytes: s.storageBytes, + TrienodeHealSynced: s.trienodeHealSynced, + TrienodeHealBytes: s.trienodeHealBytes, + BytecodeHealSynced: s.bytecodeHealSynced, + BytecodeHealBytes: s.bytecodeHealBytes, + } + status, err := json.Marshal(progress) + if err != nil { + panic(err) // This can only fail during implementation + } + rawdb.WriteSnapshotSyncStatus(s.db, status) +} + +// cleanAccountTasks removes account range retrieval tasks that have already been +// completed. +func (s *Syncer) cleanAccountTasks() { + for i := 0; i < len(s.tasks); i++ { + if s.tasks[i].done { + s.tasks = append(s.tasks[:i], s.tasks[i+1:]...) + i-- + } + } +} + +// cleanStorageTasks iterates over all the account tasks and storage sub-tasks +// within, cleaning any that have been completed. +func (s *Syncer) cleanStorageTasks() { + for _, task := range s.tasks { + for account, subtasks := range task.SubTasks { + // Remove storage range retrieval tasks that completed + for j := 0; j < len(subtasks); j++ { + if subtasks[j].done { + subtasks = append(subtasks[:j], subtasks[j+1:]...) + j-- + } + } + if len(subtasks) > 0 { + task.SubTasks[account] = subtasks + continue + } + // If all storage chunks are done, mark the account as done too + for j, hash := range task.res.hashes { + if hash == account { + task.needState[j] = false + } + } + delete(task.SubTasks, account) + task.pend-- + + // If this was the last pending task, forward the account task + if task.pend == 0 { + s.forwardAccountTask(task) + } + } + } +} + +// assignAccountTasks attempts to match idle peers to pending account range +// retrievals. +func (s *Syncer) assignAccountTasks(cancel chan struct{}) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there are no idle peers, short circuit assignment + if len(s.accountIdlers) == 0 { + return + } + // Iterate over all the tasks and try to find a pending one + for _, task := range s.tasks { + // Skip any tasks already filling + if task.req != nil || task.res != nil { + continue + } + // Task pending retrieval, try to find an idle peer. If no such peer + // exists, we probably assigned tasks for all (or they are stateless). + // Abort the entire assignment mechanism. + var idle string + for id := range s.accountIdlers { + // If the peer rejected a query in this sync cycle, don't bother asking + // again for anything, it's either out of sync or already pruned + if _, ok := s.statelessPeers[id]; ok { + continue + } + idle = id + break + } + if idle == "" { + return + } + // Matched a pending task to an idle peer, allocate a unique request id + var reqid uint64 + for { + reqid = uint64(rand.Int63()) + if reqid == 0 { + continue + } + if _, ok := s.accountReqs[reqid]; ok { + continue + } + break + } + // Generate the network query and send it to the peer + req := &accountRequest{ + peer: idle, + id: reqid, + cancel: cancel, + stale: make(chan struct{}), + origin: task.Next, + limit: task.Last, + task: task, + } + req.timeout = time.AfterFunc(requestTimeout, func() { + log.Debug("Account range request timed out") + select { + case s.accountReqFails <- req: + default: + } + }) + s.accountReqs[reqid] = req + delete(s.accountIdlers, idle) + + s.pend.Add(1) + go func(peer *Peer, root common.Hash) { + defer s.pend.Done() + + // Attempt to send the remote request and revert if it fails + if err := peer.RequestAccountRange(reqid, root, req.origin, req.limit, maxRequestSize); err != nil { + peer.Log().Debug("Failed to request account range", "err", err) + select { + case s.accountReqFails <- req: + default: + } + } + // Request successfully sent, start a + }(s.peers[idle], s.root) // We're in the lock, peers[id] surely exists + + // Inject the request into the task to block further assignments + task.req = req + } +} + +// assignBytecodeTasks attempts to match idle peers to pending code retrievals. +func (s *Syncer) assignBytecodeTasks(cancel chan struct{}) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there are no idle peers, short circuit assignment + if len(s.bytecodeIdlers) == 0 { + return + } + // Iterate over all the tasks and try to find a pending one + for _, task := range s.tasks { + // Skip any tasks not in the bytecode retrieval phase + if task.res == nil { + continue + } + // Skip tasks that are already retrieving (or done with) all codes + if len(task.codeTasks) == 0 { + continue + } + // Task pending retrieval, try to find an idle peer. If no such peer + // exists, we probably assigned tasks for all (or they are stateless). + // Abort the entire assignment mechanism. + var idle string + for id := range s.bytecodeIdlers { + // If the peer rejected a query in this sync cycle, don't bother asking + // again for anything, it's either out of sync or already pruned + if _, ok := s.statelessPeers[id]; ok { + continue + } + idle = id + break + } + if idle == "" { + return + } + // Matched a pending task to an idle peer, allocate a unique request id + var reqid uint64 + for { + reqid = uint64(rand.Int63()) + if reqid == 0 { + continue + } + if _, ok := s.bytecodeReqs[reqid]; ok { + continue + } + break + } + // Generate the network query and send it to the peer + hashes := make([]common.Hash, 0, maxCodeRequestCount) + for hash := range task.codeTasks { + delete(task.codeTasks, hash) + hashes = append(hashes, hash) + if len(hashes) >= maxCodeRequestCount { + break + } + } + req := &bytecodeRequest{ + peer: idle, + id: reqid, + cancel: cancel, + stale: make(chan struct{}), + hashes: hashes, + task: task, + } + req.timeout = time.AfterFunc(requestTimeout, func() { + log.Debug("Bytecode request timed out") + select { + case s.bytecodeReqFails <- req: + default: + } + }) + s.bytecodeReqs[reqid] = req + delete(s.bytecodeIdlers, idle) + + s.pend.Add(1) + go func(peer *Peer) { + defer s.pend.Done() + + // Attempt to send the remote request and revert if it fails + if err := peer.RequestByteCodes(reqid, hashes, maxRequestSize); err != nil { + log.Debug("Failed to request bytecodes", "err", err) + select { + case s.bytecodeReqFails <- req: + default: + } + } + // Request successfully sent, start a + }(s.peers[idle]) // We're in the lock, peers[id] surely exists + } +} + +// assignStorageTasks attempts to match idle peers to pending storage range +// retrievals. +func (s *Syncer) assignStorageTasks(cancel chan struct{}) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there are no idle peers, short circuit assignment + if len(s.storageIdlers) == 0 { + return + } + // Iterate over all the tasks and try to find a pending one + for _, task := range s.tasks { + // Skip any tasks not in the storage retrieval phase + if task.res == nil { + continue + } + // Skip tasks that are already retrieving (or done with) all small states + if len(task.SubTasks) == 0 && len(task.stateTasks) == 0 { + continue + } + // Task pending retrieval, try to find an idle peer. If no such peer + // exists, we probably assigned tasks for all (or they are stateless). + // Abort the entire assignment mechanism. + var idle string + for id := range s.storageIdlers { + // If the peer rejected a query in this sync cycle, don't bother asking + // again for anything, it's either out of sync or already pruned + if _, ok := s.statelessPeers[id]; ok { + continue + } + idle = id + break + } + if idle == "" { + return + } + // Matched a pending task to an idle peer, allocate a unique request id + var reqid uint64 + for { + reqid = uint64(rand.Int63()) + if reqid == 0 { + continue + } + if _, ok := s.storageReqs[reqid]; ok { + continue + } + break + } + // Generate the network query and send it to the peer. If there are + // large contract tasks pending, complete those before diving into + // even more new contracts. + var ( + accounts = make([]common.Hash, 0, maxStorageSetRequestCount) + roots = make([]common.Hash, 0, maxStorageSetRequestCount) + subtask *storageTask + ) + for account, subtasks := range task.SubTasks { + for _, st := range subtasks { + // Skip any subtasks already filling + if st.req != nil { + continue + } + // Found an incomplete storage chunk, schedule it + accounts = append(accounts, account) + roots = append(roots, st.root) + + subtask = st + break // Large contract chunks are downloaded individually + } + if subtask != nil { + break // Large contract chunks are downloaded individually + } + } + if subtask == nil { + // No large contract required retrieval, but small ones available + for acccount, root := range task.stateTasks { + delete(task.stateTasks, acccount) + + accounts = append(accounts, acccount) + roots = append(roots, root) + + if len(accounts) >= maxStorageSetRequestCount { + break + } + } + } + // If nothing was found, it means this task is actually already fully + // retrieving, but large contracts are hard to detect. Skip to the next. + if len(accounts) == 0 { + continue + } + req := &storageRequest{ + peer: idle, + id: reqid, + cancel: cancel, + stale: make(chan struct{}), + accounts: accounts, + roots: roots, + mainTask: task, + subTask: subtask, + } + if subtask != nil { + req.origin = subtask.Next + req.limit = subtask.Last + } + req.timeout = time.AfterFunc(requestTimeout, func() { + log.Debug("Storage request timed out") + select { + case s.storageReqFails <- req: + default: + } + }) + s.storageReqs[reqid] = req + delete(s.storageIdlers, idle) + + s.pend.Add(1) + go func(peer *Peer, root common.Hash) { + defer s.pend.Done() + + // Attempt to send the remote request and revert if it fails + var origin, limit []byte + if subtask != nil { + origin, limit = req.origin[:], req.limit[:] + } + if err := peer.RequestStorageRanges(reqid, root, accounts, origin, limit, maxRequestSize); err != nil { + log.Debug("Failed to request storage", "err", err) + select { + case s.storageReqFails <- req: + default: + } + } + // Request successfully sent, start a + }(s.peers[idle], s.root) // We're in the lock, peers[id] surely exists + + // Inject the request into the subtask to block further assignments + if subtask != nil { + subtask.req = req + } + } +} + +// assignTrienodeHealTasks attempts to match idle peers to trie node requests to +// heal any trie errors caused by the snap sync's chunked retrieval model. +func (s *Syncer) assignTrienodeHealTasks(cancel chan struct{}) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there are no idle peers, short circuit assignment + if len(s.trienodeHealIdlers) == 0 { + return + } + // Iterate over pending tasks and try to find a peer to retrieve with + for len(s.healer.trieTasks) > 0 || s.healer.scheduler.Pending() > 0 { + // If there are not enough trie tasks queued to fully assign, fill the + // queue from the state sync scheduler. The trie synced schedules these + // together with bytecodes, so we need to queue them combined. + var ( + have = len(s.healer.trieTasks) + len(s.healer.codeTasks) + want = maxTrieRequestCount + maxCodeRequestCount + ) + if have < want { + nodes, paths, codes := s.healer.scheduler.Missing(want - have) + for i, hash := range nodes { + s.healer.trieTasks[hash] = paths[i] + } + for _, hash := range codes { + s.healer.codeTasks[hash] = struct{}{} + } + } + // If all the heal tasks are bytecodes or already downloading, bail + if len(s.healer.trieTasks) == 0 { + return + } + // Task pending retrieval, try to find an idle peer. If no such peer + // exists, we probably assigned tasks for all (or they are stateless). + // Abort the entire assignment mechanism. + var idle string + for id := range s.trienodeHealIdlers { + // If the peer rejected a query in this sync cycle, don't bother asking + // again for anything, it's either out of sync or already pruned + if _, ok := s.statelessPeers[id]; ok { + continue + } + idle = id + break + } + if idle == "" { + return + } + // Matched a pending task to an idle peer, allocate a unique request id + var reqid uint64 + for { + reqid = uint64(rand.Int63()) + if reqid == 0 { + continue + } + if _, ok := s.trienodeHealReqs[reqid]; ok { + continue + } + break + } + // Generate the network query and send it to the peer + var ( + hashes = make([]common.Hash, 0, maxTrieRequestCount) + paths = make([]trie.SyncPath, 0, maxTrieRequestCount) + pathsets = make([]TrieNodePathSet, 0, maxTrieRequestCount) + ) + for hash, pathset := range s.healer.trieTasks { + delete(s.healer.trieTasks, hash) + + hashes = append(hashes, hash) + paths = append(paths, pathset) + pathsets = append(pathsets, [][]byte(pathset)) // TODO(karalabe): group requests by account hash + + if len(hashes) >= maxTrieRequestCount { + break + } + } + req := &trienodeHealRequest{ + peer: idle, + id: reqid, + cancel: cancel, + stale: make(chan struct{}), + hashes: hashes, + paths: paths, + task: s.healer, + } + req.timeout = time.AfterFunc(requestTimeout, func() { + log.Debug("Trienode heal request timed out") + select { + case s.trienodeHealReqFails <- req: + default: + } + }) + s.trienodeHealReqs[reqid] = req + delete(s.trienodeHealIdlers, idle) + + s.pend.Add(1) + go func(peer *Peer, root common.Hash) { + defer s.pend.Done() + + // Attempt to send the remote request and revert if it fails + if err := peer.RequestTrieNodes(reqid, root, pathsets, maxRequestSize); err != nil { + log.Debug("Failed to request trienode healers", "err", err) + select { + case s.trienodeHealReqFails <- req: + default: + } + } + // Request successfully sent, start a + }(s.peers[idle], s.root) // We're in the lock, peers[id] surely exists + } +} + +// assignBytecodeHealTasks attempts to match idle peers to bytecode requests to +// heal any trie errors caused by the snap sync's chunked retrieval model. +func (s *Syncer) assignBytecodeHealTasks(cancel chan struct{}) { + s.lock.Lock() + defer s.lock.Unlock() + + // If there are no idle peers, short circuit assignment + if len(s.bytecodeHealIdlers) == 0 { + return + } + // Iterate over pending tasks and try to find a peer to retrieve with + for len(s.healer.codeTasks) > 0 || s.healer.scheduler.Pending() > 0 { + // If there are not enough trie tasks queued to fully assign, fill the + // queue from the state sync scheduler. The trie synced schedules these + // together with trie nodes, so we need to queue them combined. + var ( + have = len(s.healer.trieTasks) + len(s.healer.codeTasks) + want = maxTrieRequestCount + maxCodeRequestCount + ) + if have < want { + nodes, paths, codes := s.healer.scheduler.Missing(want - have) + for i, hash := range nodes { + s.healer.trieTasks[hash] = paths[i] + } + for _, hash := range codes { + s.healer.codeTasks[hash] = struct{}{} + } + } + // If all the heal tasks are trienodes or already downloading, bail + if len(s.healer.codeTasks) == 0 { + return + } + // Task pending retrieval, try to find an idle peer. If no such peer + // exists, we probably assigned tasks for all (or they are stateless). + // Abort the entire assignment mechanism. + var idle string + for id := range s.bytecodeHealIdlers { + // If the peer rejected a query in this sync cycle, don't bother asking + // again for anything, it's either out of sync or already pruned + if _, ok := s.statelessPeers[id]; ok { + continue + } + idle = id + break + } + if idle == "" { + return + } + // Matched a pending task to an idle peer, allocate a unique request id + var reqid uint64 + for { + reqid = uint64(rand.Int63()) + if reqid == 0 { + continue + } + if _, ok := s.bytecodeHealReqs[reqid]; ok { + continue + } + break + } + // Generate the network query and send it to the peer + hashes := make([]common.Hash, 0, maxCodeRequestCount) + for hash := range s.healer.codeTasks { + delete(s.healer.codeTasks, hash) + + hashes = append(hashes, hash) + if len(hashes) >= maxCodeRequestCount { + break + } + } + req := &bytecodeHealRequest{ + peer: idle, + id: reqid, + cancel: cancel, + stale: make(chan struct{}), + hashes: hashes, + task: s.healer, + } + req.timeout = time.AfterFunc(requestTimeout, func() { + log.Debug("Bytecode heal request timed out") + select { + case s.bytecodeHealReqFails <- req: + default: + } + }) + s.bytecodeHealReqs[reqid] = req + delete(s.bytecodeHealIdlers, idle) + + s.pend.Add(1) + go func(peer *Peer) { + defer s.pend.Done() + + // Attempt to send the remote request and revert if it fails + if err := peer.RequestByteCodes(reqid, hashes, maxRequestSize); err != nil { + log.Debug("Failed to request bytecode healers", "err", err) + select { + case s.bytecodeHealReqFails <- req: + default: + } + } + // Request successfully sent, start a + }(s.peers[idle]) // We're in the lock, peers[id] surely exists + } +} + +// revertRequests locates all the currently pending reuqests from a particular +// peer and reverts them, rescheduling for others to fulfill. +func (s *Syncer) revertRequests(peer string) { + // Gather the requests first, revertals need the lock too + s.lock.Lock() + var accountReqs []*accountRequest + for _, req := range s.accountReqs { + if req.peer == peer { + accountReqs = append(accountReqs, req) + } + } + var bytecodeReqs []*bytecodeRequest + for _, req := range s.bytecodeReqs { + if req.peer == peer { + bytecodeReqs = append(bytecodeReqs, req) + } + } + var storageReqs []*storageRequest + for _, req := range s.storageReqs { + if req.peer == peer { + storageReqs = append(storageReqs, req) + } + } + var trienodeHealReqs []*trienodeHealRequest + for _, req := range s.trienodeHealReqs { + if req.peer == peer { + trienodeHealReqs = append(trienodeHealReqs, req) + } + } + var bytecodeHealReqs []*bytecodeHealRequest + for _, req := range s.bytecodeHealReqs { + if req.peer == peer { + bytecodeHealReqs = append(bytecodeHealReqs, req) + } + } + s.lock.Unlock() + + // Revert all the requests matching the peer + for _, req := range accountReqs { + s.revertAccountRequest(req) + } + for _, req := range bytecodeReqs { + s.revertBytecodeRequest(req) + } + for _, req := range storageReqs { + s.revertStorageRequest(req) + } + for _, req := range trienodeHealReqs { + s.revertTrienodeHealRequest(req) + } + for _, req := range bytecodeHealReqs { + s.revertBytecodeHealRequest(req) + } +} + +// revertAccountRequest cleans up an account range request and returns all failed +// retrieval tasks to the scheduler for reassignment. +func (s *Syncer) revertAccountRequest(req *accountRequest) { + log.Trace("Reverting account request", "peer", req.peer, "reqid", req.id) + select { + case <-req.stale: + log.Trace("Account request already reverted", "peer", req.peer, "reqid", req.id) + return + default: + } + close(req.stale) + + // Remove the request from the tracked set + s.lock.Lock() + delete(s.accountReqs, req.id) + s.lock.Unlock() + + // If there's a timeout timer still running, abort it and mark the account + // task as not-pending, ready for resheduling + req.timeout.Stop() + if req.task.req == req { + req.task.req = nil + } +} + +// revertBytecodeRequest cleans up an bytecode request and returns all failed +// retrieval tasks to the scheduler for reassignment. +func (s *Syncer) revertBytecodeRequest(req *bytecodeRequest) { + log.Trace("Reverting bytecode request", "peer", req.peer) + select { + case <-req.stale: + log.Trace("Bytecode request already reverted", "peer", req.peer, "reqid", req.id) + return + default: + } + close(req.stale) + + // Remove the request from the tracked set + s.lock.Lock() + delete(s.bytecodeReqs, req.id) + s.lock.Unlock() + + // If there's a timeout timer still running, abort it and mark the code + // retrievals as not-pending, ready for resheduling + req.timeout.Stop() + for _, hash := range req.hashes { + req.task.codeTasks[hash] = struct{}{} + } +} + +// revertStorageRequest cleans up a storage range request and returns all failed +// retrieval tasks to the scheduler for reassignment. +func (s *Syncer) revertStorageRequest(req *storageRequest) { + log.Trace("Reverting storage request", "peer", req.peer) + select { + case <-req.stale: + log.Trace("Storage request already reverted", "peer", req.peer, "reqid", req.id) + return + default: + } + close(req.stale) + + // Remove the request from the tracked set + s.lock.Lock() + delete(s.storageReqs, req.id) + s.lock.Unlock() + + // If there's a timeout timer still running, abort it and mark the storage + // task as not-pending, ready for resheduling + req.timeout.Stop() + if req.subTask != nil { + req.subTask.req = nil + } else { + for i, account := range req.accounts { + req.mainTask.stateTasks[account] = req.roots[i] + } + } +} + +// revertTrienodeHealRequest cleans up an trienode heal request and returns all +// failed retrieval tasks to the scheduler for reassignment. +func (s *Syncer) revertTrienodeHealRequest(req *trienodeHealRequest) { + log.Trace("Reverting trienode heal request", "peer", req.peer) + select { + case <-req.stale: + log.Trace("Trienode heal request already reverted", "peer", req.peer, "reqid", req.id) + return + default: + } + close(req.stale) + + // Remove the request from the tracked set + s.lock.Lock() + delete(s.trienodeHealReqs, req.id) + s.lock.Unlock() + + // If there's a timeout timer still running, abort it and mark the trie node + // retrievals as not-pending, ready for resheduling + req.timeout.Stop() + for i, hash := range req.hashes { + req.task.trieTasks[hash] = [][]byte(req.paths[i]) + } +} + +// revertBytecodeHealRequest cleans up an bytecode request and returns all failed +// retrieval tasks to the scheduler for reassignment. +func (s *Syncer) revertBytecodeHealRequest(req *bytecodeHealRequest) { + log.Trace("Reverting bytecode heal request", "peer", req.peer) + select { + case <-req.stale: + log.Trace("Bytecode heal request already reverted", "peer", req.peer, "reqid", req.id) + return + default: + } + close(req.stale) + + // Remove the request from the tracked set + s.lock.Lock() + delete(s.bytecodeHealReqs, req.id) + s.lock.Unlock() + + // If there's a timeout timer still running, abort it and mark the code + // retrievals as not-pending, ready for resheduling + req.timeout.Stop() + for _, hash := range req.hashes { + req.task.codeTasks[hash] = struct{}{} + } +} + +// processAccountResponse integrates an already validated account range response +// into the account tasks. +func (s *Syncer) processAccountResponse(res *accountResponse) { + // Switch the task from pending to filling + res.task.req = nil + res.task.res = res + + // Ensure that the response doesn't overflow into the subsequent task + last := res.task.Last.Big() + for i, hash := range res.hashes { + if hash.Big().Cmp(last) > 0 { + // Chunk overflown, cut off excess, but also update the boundary nodes + for j := i; j < len(res.hashes); j++ { + if err := res.trie.Prove(res.hashes[j][:], 0, res.overflow); err != nil { + panic(err) // Account range was already proven, what happened + } + } + res.hashes = res.hashes[:i] + res.accounts = res.accounts[:i] + res.cont = false // Mark range completed + break + } + } + // Itereate over all the accounts and assemble which ones need further sub- + // filling before the entire account range can be persisted. + res.task.needCode = make([]bool, len(res.accounts)) + res.task.needState = make([]bool, len(res.accounts)) + res.task.needHeal = make([]bool, len(res.accounts)) + + res.task.codeTasks = make(map[common.Hash]struct{}) + res.task.stateTasks = make(map[common.Hash]common.Hash) + + resumed := make(map[common.Hash]struct{}) + + res.task.pend = 0 + for i, account := range res.accounts { + // Check if the account is a contract with an unknown code + if !bytes.Equal(account.CodeHash, emptyCode[:]) { + if code := rawdb.ReadCodeWithPrefix(s.db, common.BytesToHash(account.CodeHash)); code == nil { + res.task.codeTasks[common.BytesToHash(account.CodeHash)] = struct{}{} + res.task.needCode[i] = true + res.task.pend++ + } + } + // Check if the account is a contract with an unknown storage trie + if account.Root != emptyRoot { + if node, err := s.db.Get(account.Root[:]); err != nil || node == nil { + // If there was a previous large state retrieval in progress, + // don't restart it from scratch. This happens if a sync cycle + // is interrupted and resumed later. However, *do* update the + // previous root hash. + if subtasks, ok := res.task.SubTasks[res.hashes[i]]; ok { + log.Error("Resuming large storage retrieval", "account", res.hashes[i], "root", account.Root) + for _, subtask := range subtasks { + subtask.root = account.Root + } + res.task.needHeal[i] = true + resumed[res.hashes[i]] = struct{}{} + } else { + res.task.stateTasks[res.hashes[i]] = account.Root + } + res.task.needState[i] = true + res.task.pend++ + } + } + } + // Delete any subtasks that have been aborted but not resumed. This may undo + // some progress if a newpeer gives us less accounts than an old one, but for + // now we have to live with that. + for hash := range res.task.SubTasks { + if _, ok := resumed[hash]; !ok { + log.Error("Aborting suspended storage retrieval", "account", hash) + delete(res.task.SubTasks, hash) + } + } + // If the account range contained no contracts, or all have been fully filled + // beforehand, short circuit storage filling and forward to the next task + if res.task.pend == 0 { + s.forwardAccountTask(res.task) + return + } + // Some accounts are incomplete, leave as is for the storage and contract + // task assigners to pick up and fill. +} + +// processBytecodeResponse integrates an already validated bytecode response +// into the account tasks. +func (s *Syncer) processBytecodeResponse(res *bytecodeResponse) { + batch := s.db.NewBatch() + + var ( + codes uint64 + bytes common.StorageSize + ) + for i, hash := range res.hashes { + code := res.codes[i] + + // If the bytecode was not delivered, reschedule it + if code == nil { + res.task.codeTasks[hash] = struct{}{} + continue + } + // Code was delivered, mark it not needed any more + for j, account := range res.task.res.accounts { + if res.task.needCode[j] && hash == common.BytesToHash(account.CodeHash) { + res.task.needCode[j] = false + res.task.pend-- + } + } + // Push the bytecode into a database batch + s.bytecodeSynced++ + s.bytecodeBytes += common.StorageSize(len(code)) + + codes++ + bytes += common.StorageSize(len(code)) + + rawdb.WriteCode(batch, hash, code) + s.bloom.Add(hash[:]) + } + if err := batch.Write(); err != nil { + log.Crit("Failed to persist bytecodes", "err", err) + } + log.Debug("Persisted set of bytecodes", "count", codes, "bytes", bytes) + + // If this delivery completed the last pending task, forward the account task + // to the next chunk + if res.task.pend == 0 { + s.forwardAccountTask(res.task) + return + } + // Some accounts are still incomplete, leave as is for the storage and contract + // task assigners to pick up and fill. +} + +// processStorageResponse integrates an already validated storage response +// into the account tasks. +func (s *Syncer) processStorageResponse(res *storageResponse) { + // Switch the suntask from pending to idle + if res.subTask != nil { + res.subTask.req = nil + } + batch := s.db.NewBatch() + + var ( + slots int + nodes int + skipped int + bytes common.StorageSize + ) + // Iterate over all the accounts and reconstruct their storage tries from the + // delivered slots + delivered := make(map[common.Hash]bool) + for i := 0; i < len(res.hashes); i++ { + delivered[res.roots[i]] = true + } + for i, account := range res.accounts { + // If the account was not delivered, reschedule it + if i >= len(res.hashes) { + if !delivered[res.roots[i]] { + res.mainTask.stateTasks[account] = res.roots[i] + } + continue + } + // State was delivered, if complete mark as not needed any more, otherwise + // mark the account as needing healing + for j, acc := range res.mainTask.res.accounts { + if res.roots[i] == acc.Root { + // If the packet contains multiple contract storage slots, all + // but the last are surely complete. The last contract may be + // chunked, so check it's continuation flag. + if res.subTask == nil && res.mainTask.needState[j] && (i < len(res.hashes)-1 || !res.cont) { + res.mainTask.needState[j] = false + res.mainTask.pend-- + } + // If the last contract was chunked, mark it as needing healing + // to avoid writing it out to disk prematurely. + if res.subTask == nil && !res.mainTask.needHeal[j] && i == len(res.hashes)-1 && res.cont { + res.mainTask.needHeal[j] = true + } + // If the last contract was chunked, we need to switch to large + // contract handling mode + if res.subTask == nil && i == len(res.hashes)-1 && res.cont { + // If we haven't yet started a large-contract retrieval, create + // the subtasks for it within the main account task + if tasks, ok := res.mainTask.SubTasks[account]; !ok { + var ( + next common.Hash + ) + step := new(big.Int).Sub( + new(big.Int).Div( + new(big.Int).Exp(common.Big2, common.Big256, nil), + big.NewInt(storageConcurrency), + ), common.Big1, + ) + for k := 0; k < storageConcurrency; k++ { + last := common.BigToHash(new(big.Int).Add(next.Big(), step)) + if k == storageConcurrency-1 { + // Make sure we don't overflow if the step is not a proper divisor + last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + } + tasks = append(tasks, &storageTask{ + Next: next, + Last: last, + root: acc.Root, + }) + log.Debug("Created storage sync task", "account", account, "root", acc.Root, "from", next, "last", last) + next = common.BigToHash(new(big.Int).Add(last.Big(), common.Big1)) + } + res.mainTask.SubTasks[account] = tasks + + // Since we've just created the sub-tasks, this response + // is surely for the first one (zero origin) + res.subTask = tasks[0] + } + } + // If we're in large contract delivery mode, forward the subtask + if res.subTask != nil { + // Ensure the response doesn't overflow into the subsequent task + last := res.subTask.Last.Big() + for k, hash := range res.hashes[i] { + if hash.Big().Cmp(last) > 0 { + // Chunk overflown, cut off excess, but also update the boundary + for l := k; l < len(res.hashes[i]); l++ { + if err := res.tries[i].Prove(res.hashes[i][l][:], 0, res.overflow); err != nil { + panic(err) // Account range was already proven, what happened + } + } + res.hashes[i] = res.hashes[i][:k] + res.slots[i] = res.slots[i][:k] + res.cont = false // Mark range completed + break + } + } + // Forward the relevant storage chunk (even if created just now) + if res.cont { + res.subTask.Next = common.BigToHash(new(big.Int).Add(res.hashes[i][len(res.hashes[i])-1].Big(), big.NewInt(1))) + } else { + res.subTask.done = true + } + } + } + } + // Iterate over all the reconstructed trie nodes and push them to disk + slots += len(res.hashes[i]) + + it := res.nodes[i].NewIterator(nil, nil) + for it.Next() { + // Boundary nodes are not written for the last result, since they are incomplete + if i == len(res.hashes)-1 { + if _, ok := res.bounds[common.BytesToHash(it.Key())]; ok { + skipped++ + continue + } + } + // Node is not a boundary, persist to disk + batch.Put(it.Key(), it.Value()) + s.bloom.Add(it.Key()) + + bytes += common.StorageSize(common.HashLength + len(it.Value())) + nodes++ + } + it.Release() + } + if err := batch.Write(); err != nil { + log.Crit("Failed to persist storage slots", "err", err) + } + s.storageSynced += uint64(slots) + s.storageBytes += bytes + + log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "nodes", nodes, "skipped", skipped, "bytes", bytes) + + // If this delivery completed the last pending task, forward the account task + // to the next chunk + if res.mainTask.pend == 0 { + s.forwardAccountTask(res.mainTask) + return + } + // Some accounts are still incomplete, leave as is for the storage and contract + // task assigners to pick up and fill. +} + +// processTrienodeHealResponse integrates an already validated trienode response +// into the healer tasks. +func (s *Syncer) processTrienodeHealResponse(res *trienodeHealResponse) { + for i, hash := range res.hashes { + node := res.nodes[i] + + // If the trie node was not delivered, reschedule it + if node == nil { + res.task.trieTasks[hash] = res.paths[i] + continue + } + // Push the trie node into the state syncer + s.trienodeHealSynced++ + s.trienodeHealBytes += common.StorageSize(len(node)) + + err := s.healer.scheduler.Process(trie.SyncResult{Hash: hash, Data: node}) + switch err { + case nil: + case trie.ErrAlreadyProcessed: + s.trienodeHealDups++ + case trie.ErrNotRequested: + s.trienodeHealNops++ + default: + log.Error("Invalid trienode processed", "hash", hash, "err", err) + } + } + batch := s.db.NewBatch() + if err := s.healer.scheduler.Commit(batch); err != nil { + log.Error("Failed to commit healing data", "err", err) + } + if err := batch.Write(); err != nil { + log.Crit("Failed to persist healing data", "err", err) + } + log.Debug("Persisted set of healing data", "bytes", common.StorageSize(batch.ValueSize())) +} + +// processBytecodeHealResponse integrates an already validated bytecode response +// into the healer tasks. +func (s *Syncer) processBytecodeHealResponse(res *bytecodeHealResponse) { + for i, hash := range res.hashes { + node := res.codes[i] + + // If the trie node was not delivered, reschedule it + if node == nil { + res.task.codeTasks[hash] = struct{}{} + continue + } + // Push the trie node into the state syncer + s.bytecodeHealSynced++ + s.bytecodeHealBytes += common.StorageSize(len(node)) + + err := s.healer.scheduler.Process(trie.SyncResult{Hash: hash, Data: node}) + switch err { + case nil: + case trie.ErrAlreadyProcessed: + s.bytecodeHealDups++ + case trie.ErrNotRequested: + s.bytecodeHealNops++ + default: + log.Error("Invalid bytecode processed", "hash", hash, "err", err) + } + } + batch := s.db.NewBatch() + if err := s.healer.scheduler.Commit(batch); err != nil { + log.Error("Failed to commit healing data", "err", err) + } + if err := batch.Write(); err != nil { + log.Crit("Failed to persist healing data", "err", err) + } + log.Debug("Persisted set of healing data", "bytes", common.StorageSize(batch.ValueSize())) +} + +// forwardAccountTask takes a filled account task and persists anything available +// into the database, after which it forwards the next account marker so that the +// task's next chunk may be filled. +func (s *Syncer) forwardAccountTask(task *accountTask) { + // Remove any pending delivery + res := task.res + if res == nil { + return // nothing to forward + } + task.res = nil + + // Iterate over all the accounts and gather all the incomplete trie nodes. A + // node is incomplete if we haven't yet filled it (sync was interrupted), or + // if we filled it in multiple chunks (storage trie), in which case the few + // nodes on the chunk boundaries are missing. + incompletes := light.NewNodeSet() + for i := range res.accounts { + // If the filling was interrupted, mark everything after as incomplete + if task.needCode[i] || task.needState[i] { + for j := i; j < len(res.accounts); j++ { + if err := res.trie.Prove(res.hashes[j][:], 0, incompletes); err != nil { + panic(err) // Account range was already proven, what happened + } + } + break + } + // Filling not interrupted until this point, mark incomplete if needs healing + if task.needHeal[i] { + if err := res.trie.Prove(res.hashes[i][:], 0, incompletes); err != nil { + panic(err) // Account range was already proven, what happened + } + } + } + // Persist every finalized trie node that's not on the boundary + batch := s.db.NewBatch() + + var ( + nodes int + skipped int + bytes common.StorageSize + ) + it := res.nodes.NewIterator(nil, nil) + for it.Next() { + // Boundary nodes are not written, since they are incomplete + if _, ok := res.bounds[common.BytesToHash(it.Key())]; ok { + skipped++ + continue + } + // Overflow nodes are not written, since they mess with another task + if _, err := res.overflow.Get(it.Key()); err == nil { + skipped++ + continue + } + // Accounts with split storage requests are incomplete + if _, err := incompletes.Get(it.Key()); err == nil { + skipped++ + continue + } + // Node is neither a boundary, not an incomplete account, persist to disk + batch.Put(it.Key(), it.Value()) + s.bloom.Add(it.Key()) + + bytes += common.StorageSize(common.HashLength + len(it.Value())) + nodes++ + } + it.Release() + + if err := batch.Write(); err != nil { + log.Crit("Failed to persist accounts", "err", err) + } + s.accountBytes += bytes + s.accountSynced += uint64(len(res.accounts)) + + log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "nodes", nodes, "skipped", skipped, "bytes", bytes) + + // Task filling persisted, push it the chunk marker forward to the first + // account still missing data. + for i, hash := range res.hashes { + if task.needCode[i] || task.needState[i] { + return + } + task.Next = common.BigToHash(new(big.Int).Add(hash.Big(), big.NewInt(1))) + } + // All accounts marked as complete, track if the entire task is done + task.done = !res.cont +} + +// OnAccounts is a callback method to invoke when a range of accounts are +// received from a remote peer. +func (s *Syncer) OnAccounts(peer *Peer, id uint64, hashes []common.Hash, accounts [][]byte, proof [][]byte) error { + size := common.StorageSize(len(hashes) * common.HashLength) + for _, account := range accounts { + size += common.StorageSize(len(account)) + } + for _, node := range proof { + size += common.StorageSize(len(node)) + } + logger := peer.logger.New("reqid", id) + logger.Trace("Delivering range of accounts", "hashes", len(hashes), "accounts", len(accounts), "proofs", len(proof), "bytes", size) + + // Whether or not the response is valid, we can mark the peer as idle and + // notify the scheduler to assign a new task. If the response is invalid, + // we'll drop the peer in a bit. + s.lock.Lock() + if _, ok := s.peers[peer.id]; ok { + s.accountIdlers[peer.id] = struct{}{} + } + select { + case s.update <- struct{}{}: + default: + } + // Ensure the response is for a valid request + req, ok := s.accountReqs[id] + if !ok { + // Request stale, perhaps the peer timed out but came through in the end + logger.Warn("Unexpected account range packet") + s.lock.Unlock() + return nil + } + delete(s.accountReqs, id) + + // Clean up the request timeout timer, we'll see how to proceed further based + // on the actual delivered content + req.timeout.Stop() + + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For account range queries that means the state being + // retrieved was either already pruned remotely, or the peer is not yet + // synced to our head. + if len(hashes) == 0 && len(accounts) == 0 && len(proof) == 0 { + logger.Debug("Peer rejected account range request", "root", s.root) + s.statelessPeers[peer.id] = struct{}{} + s.lock.Unlock() + return nil + } + root := s.root + s.lock.Unlock() + + // Reconstruct a partial trie from the response and verify it + keys := make([][]byte, len(hashes)) + for i, key := range hashes { + keys[i] = common.CopyBytes(key[:]) + } + nodes := make(light.NodeList, len(proof)) + for i, node := range proof { + nodes[i] = node + } + proofdb := nodes.NodeSet() + + var end []byte + if len(keys) > 0 { + end = keys[len(keys)-1] + } + db, tr, notary, cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb) + if err != nil { + logger.Warn("Account range failed proof", "err", err) + return err + } + // Partial trie reconstructed, send it to the scheduler for storage filling + bounds := make(map[common.Hash]struct{}) + + it := notary.Accessed().NewIterator(nil, nil) + for it.Next() { + bounds[common.BytesToHash(it.Key())] = struct{}{} + } + it.Release() + + accs := make([]*state.Account, len(accounts)) + for i, account := range accounts { + acc := new(state.Account) + if err := rlp.DecodeBytes(account, acc); err != nil { + panic(err) // We created these blobs, we must be able to decode them + } + accs[i] = acc + } + response := &accountResponse{ + task: req.task, + hashes: hashes, + accounts: accs, + nodes: db, + trie: tr, + bounds: bounds, + overflow: light.NewNodeSet(), + cont: cont, + } + select { + case s.accountResps <- response: + case <-req.cancel: + case <-req.stale: + } + return nil +} + +// OnByteCodes is a callback method to invoke when a batch of contract +// bytes codes are received from a remote peer. +func (s *Syncer) OnByteCodes(peer *Peer, id uint64, bytecodes [][]byte) error { + s.lock.RLock() + syncing := len(s.tasks) > 0 + s.lock.RUnlock() + + if syncing { + return s.onByteCodes(peer, id, bytecodes) + } + return s.onHealByteCodes(peer, id, bytecodes) +} + +// onByteCodes is a callback method to invoke when a batch of contract +// bytes codes are received from a remote peer in the syncing phase. +func (s *Syncer) onByteCodes(peer *Peer, id uint64, bytecodes [][]byte) error { + var size common.StorageSize + for _, code := range bytecodes { + size += common.StorageSize(len(code)) + } + logger := peer.logger.New("reqid", id) + logger.Trace("Delivering set of bytecodes", "bytecodes", len(bytecodes), "bytes", size) + + // Whether or not the response is valid, we can mark the peer as idle and + // notify the scheduler to assign a new task. If the response is invalid, + // we'll drop the peer in a bit. + s.lock.Lock() + if _, ok := s.peers[peer.id]; ok { + s.bytecodeIdlers[peer.id] = struct{}{} + } + select { + case s.update <- struct{}{}: + default: + } + // Ensure the response is for a valid request + req, ok := s.bytecodeReqs[id] + if !ok { + // Request stale, perhaps the peer timed out but came through in the end + logger.Warn("Unexpected bytecode packet") + s.lock.Unlock() + return nil + } + delete(s.bytecodeReqs, id) + + // Clean up the request timeout timer, we'll see how to proceed further based + // on the actual delivered content + req.timeout.Stop() + + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For bytecode range queries that means the peer is not + // yet synced. + if len(bytecodes) == 0 { + logger.Debug("Peer rejected bytecode request") + s.statelessPeers[peer.id] = struct{}{} + s.lock.Unlock() + return nil + } + s.lock.Unlock() + + // Cross reference the requested bytecodes with the response to find gaps + // that the serving node is missing + hasher := sha3.NewLegacyKeccak256() + + codes := make([][]byte, len(req.hashes)) + for i, j := 0, 0; i < len(bytecodes); i++ { + // Find the next hash that we've been served, leaving misses with nils + hasher.Reset() + hasher.Write(bytecodes[i]) + hash := hasher.Sum(nil) + + for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) { + j++ + } + if j < len(req.hashes) { + codes[j] = bytecodes[i] + j++ + continue + } + // We've either ran out of hashes, or got unrequested data + logger.Warn("Unexpected bytecodes", "count", len(bytecodes)-i) + return errors.New("unexpected bytecode") + } + // Response validated, send it to the scheduler for filling + response := &bytecodeResponse{ + task: req.task, + hashes: req.hashes, + codes: codes, + } + select { + case s.bytecodeResps <- response: + case <-req.cancel: + case <-req.stale: + } + return nil +} + +// OnStorage is a callback method to invoke when ranges of storage slots +// are received from a remote peer. +func (s *Syncer) OnStorage(peer *Peer, id uint64, hashes [][]common.Hash, slots [][][]byte, proof [][]byte) error { + // Gather some trace stats to aid in debugging issues + var ( + hashCount int + slotCount int + size common.StorageSize + ) + for _, hashset := range hashes { + size += common.StorageSize(common.HashLength * len(hashset)) + hashCount += len(hashset) + } + for _, slotset := range slots { + for _, slot := range slotset { + size += common.StorageSize(len(slot)) + } + slotCount += len(slotset) + } + for _, node := range proof { + size += common.StorageSize(len(node)) + } + logger := peer.logger.New("reqid", id) + logger.Trace("Delivering ranges of storage slots", "accounts", len(hashes), "hashes", hashCount, "slots", slotCount, "proofs", len(proof), "size", size) + + // Whether or not the response is valid, we can mark the peer as idle and + // notify the scheduler to assign a new task. If the response is invalid, + // we'll drop the peer in a bit. + s.lock.Lock() + if _, ok := s.peers[peer.id]; ok { + s.storageIdlers[peer.id] = struct{}{} + } + select { + case s.update <- struct{}{}: + default: + } + // Ensure the response is for a valid request + req, ok := s.storageReqs[id] + if !ok { + // Request stale, perhaps the peer timed out but came through in the end + logger.Warn("Unexpected storage ranges packet") + s.lock.Unlock() + return nil + } + delete(s.storageReqs, id) + + // Clean up the request timeout timer, we'll see how to proceed further based + // on the actual delivered content + req.timeout.Stop() + + // Reject the response if the hash sets and slot sets don't match, or if the + // peer sent more data than requested. + if len(hashes) != len(slots) { + s.lock.Unlock() + logger.Warn("Hash and slot set size mismatch", "hashset", len(hashes), "slotset", len(slots)) + return errors.New("hash and slot set size mismatch") + } + if len(hashes) > len(req.accounts) { + s.lock.Unlock() + logger.Warn("Hash set larger than requested", "hashset", len(hashes), "requested", len(req.accounts)) + return errors.New("hash set larger than requested") + } + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For storage range queries that means the state being + // retrieved was either already pruned remotely, or the peer is not yet + // synced to our head. + if len(hashes) == 0 { + logger.Debug("Peer rejected storage request") + s.statelessPeers[peer.id] = struct{}{} + s.lock.Unlock() + return nil + } + s.lock.Unlock() + + // Reconstruct the partial tries from the response and verify them + var ( + dbs = make([]ethdb.KeyValueStore, len(hashes)) + tries = make([]*trie.Trie, len(hashes)) + notary *trie.KeyValueNotary + cont bool + ) + for i := 0; i < len(hashes); i++ { + // Convert the keys and proofs into an internal format + keys := make([][]byte, len(hashes[i])) + for j, key := range hashes[i] { + keys[j] = common.CopyBytes(key[:]) + } + nodes := make(light.NodeList, 0, len(proof)) + if i == len(hashes)-1 { + for _, node := range proof { + nodes = append(nodes, node) + } + } + var err error + if len(nodes) == 0 { + // No proof has been attached, the response must cover the entire key + // space and hash to the origin root. + dbs[i], tries[i], _, _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil) + if err != nil { + logger.Warn("Storage slots failed proof", "err", err) + return err + } + } else { + // A proof was attached, the response is only partial, check that the + // returned data is indeed part of the storage trie + proofdb := nodes.NodeSet() + + var end []byte + if len(keys) > 0 { + end = keys[len(keys)-1] + } + dbs[i], tries[i], notary, cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb) + if err != nil { + logger.Warn("Storage range failed proof", "err", err) + return err + } + } + } + // Partial tries reconstructed, send them to the scheduler for storage filling + bounds := make(map[common.Hash]struct{}) + + if notary != nil { // if all contract storages are delivered in full, no notary will be created + it := notary.Accessed().NewIterator(nil, nil) + for it.Next() { + bounds[common.BytesToHash(it.Key())] = struct{}{} + } + it.Release() + } + response := &storageResponse{ + mainTask: req.mainTask, + subTask: req.subTask, + accounts: req.accounts, + roots: req.roots, + hashes: hashes, + slots: slots, + nodes: dbs, + tries: tries, + bounds: bounds, + overflow: light.NewNodeSet(), + cont: cont, + } + select { + case s.storageResps <- response: + case <-req.cancel: + case <-req.stale: + } + return nil +} + +// OnTrieNodes is a callback method to invoke when a batch of trie nodes +// are received from a remote peer. +func (s *Syncer) OnTrieNodes(peer *Peer, id uint64, trienodes [][]byte) error { + var size common.StorageSize + for _, node := range trienodes { + size += common.StorageSize(len(node)) + } + logger := peer.logger.New("reqid", id) + logger.Trace("Delivering set of healing trienodes", "trienodes", len(trienodes), "bytes", size) + + // Whether or not the response is valid, we can mark the peer as idle and + // notify the scheduler to assign a new task. If the response is invalid, + // we'll drop the peer in a bit. + s.lock.Lock() + if _, ok := s.peers[peer.id]; ok { + s.trienodeHealIdlers[peer.id] = struct{}{} + } + select { + case s.update <- struct{}{}: + default: + } + // Ensure the response is for a valid request + req, ok := s.trienodeHealReqs[id] + if !ok { + // Request stale, perhaps the peer timed out but came through in the end + logger.Warn("Unexpected trienode heal packet") + s.lock.Unlock() + return nil + } + delete(s.trienodeHealReqs, id) + + // Clean up the request timeout timer, we'll see how to proceed further based + // on the actual delivered content + req.timeout.Stop() + + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For bytecode range queries that means the peer is not + // yet synced. + if len(trienodes) == 0 { + logger.Debug("Peer rejected trienode heal request") + s.statelessPeers[peer.id] = struct{}{} + s.lock.Unlock() + return nil + } + s.lock.Unlock() + + // Cross reference the requested trienodes with the response to find gaps + // that the serving node is missing + hasher := sha3.NewLegacyKeccak256() + + nodes := make([][]byte, len(req.hashes)) + for i, j := 0, 0; i < len(trienodes); i++ { + // Find the next hash that we've been served, leaving misses with nils + hasher.Reset() + hasher.Write(trienodes[i]) + hash := hasher.Sum(nil) + + for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) { + j++ + } + if j < len(req.hashes) { + nodes[j] = trienodes[i] + j++ + continue + } + // We've either ran out of hashes, or got unrequested data + logger.Warn("Unexpected healing trienodes", "count", len(trienodes)-i) + return errors.New("unexpected healing trienode") + } + // Response validated, send it to the scheduler for filling + response := &trienodeHealResponse{ + task: req.task, + hashes: req.hashes, + paths: req.paths, + nodes: nodes, + } + select { + case s.trienodeHealResps <- response: + case <-req.cancel: + case <-req.stale: + } + return nil +} + +// onHealByteCodes is a callback method to invoke when a batch of contract +// bytes codes are received from a remote peer in the healing phase. +func (s *Syncer) onHealByteCodes(peer *Peer, id uint64, bytecodes [][]byte) error { + var size common.StorageSize + for _, code := range bytecodes { + size += common.StorageSize(len(code)) + } + logger := peer.logger.New("reqid", id) + logger.Trace("Delivering set of healing bytecodes", "bytecodes", len(bytecodes), "bytes", size) + + // Whether or not the response is valid, we can mark the peer as idle and + // notify the scheduler to assign a new task. If the response is invalid, + // we'll drop the peer in a bit. + s.lock.Lock() + if _, ok := s.peers[peer.id]; ok { + s.bytecodeHealIdlers[peer.id] = struct{}{} + } + select { + case s.update <- struct{}{}: + default: + } + // Ensure the response is for a valid request + req, ok := s.bytecodeHealReqs[id] + if !ok { + // Request stale, perhaps the peer timed out but came through in the end + logger.Warn("Unexpected bytecode heal packet") + s.lock.Unlock() + return nil + } + delete(s.bytecodeHealReqs, id) + + // Clean up the request timeout timer, we'll see how to proceed further based + // on the actual delivered content + req.timeout.Stop() + + // Response is valid, but check if peer is signalling that it does not have + // the requested data. For bytecode range queries that means the peer is not + // yet synced. + if len(bytecodes) == 0 { + logger.Debug("Peer rejected bytecode heal request") + s.statelessPeers[peer.id] = struct{}{} + s.lock.Unlock() + return nil + } + s.lock.Unlock() + + // Cross reference the requested bytecodes with the response to find gaps + // that the serving node is missing + hasher := sha3.NewLegacyKeccak256() + + codes := make([][]byte, len(req.hashes)) + for i, j := 0, 0; i < len(bytecodes); i++ { + // Find the next hash that we've been served, leaving misses with nils + hasher.Reset() + hasher.Write(bytecodes[i]) + hash := hasher.Sum(nil) + + for j < len(req.hashes) && !bytes.Equal(hash, req.hashes[j][:]) { + j++ + } + if j < len(req.hashes) { + codes[j] = bytecodes[i] + j++ + continue + } + // We've either ran out of hashes, or got unrequested data + logger.Warn("Unexpected healing bytecodes", "count", len(bytecodes)-i) + return errors.New("unexpected healing bytecode") + } + // Response validated, send it to the scheduler for filling + response := &bytecodeHealResponse{ + task: req.task, + hashes: req.hashes, + codes: codes, + } + select { + case s.bytecodeHealResps <- response: + case <-req.cancel: + case <-req.stale: + } + return nil +} + +// hashSpace is the total size of the 256 bit hash space for accounts. +var hashSpace = new(big.Int).Exp(common.Big2, common.Big256, nil) + +// report calculates various status reports and provides it to the user. +func (s *Syncer) report(force bool) { + if len(s.tasks) > 0 { + s.reportSyncProgress(force) + return + } + s.reportHealProgress(force) +} + +// reportSyncProgress calculates various status reports and provides it to the user. +func (s *Syncer) reportSyncProgress(force bool) { + // Don't report all the events, just occasionally + if !force && time.Since(s.logTime) < 3*time.Second { + return + } + // Don't report anything until we have a meaningful progress + synced := s.accountBytes + s.bytecodeBytes + s.storageBytes + if synced == 0 { + return + } + accountGaps := new(big.Int) + for _, task := range s.tasks { + accountGaps.Add(accountGaps, new(big.Int).Sub(task.Last.Big(), task.Next.Big())) + } + accountFills := new(big.Int).Sub(hashSpace, accountGaps) + if accountFills.BitLen() == 0 { + return + } + s.logTime = time.Now() + estBytes := float64(new(big.Int).Div( + new(big.Int).Mul(new(big.Int).SetUint64(uint64(synced)), hashSpace), + accountFills, + ).Uint64()) + + elapsed := time.Since(s.startTime) + estTime := elapsed / time.Duration(synced) * time.Duration(estBytes) + + // Create a mega progress report + var ( + progress = fmt.Sprintf("%.2f%%", float64(synced)*100/estBytes) + accounts = fmt.Sprintf("%d@%v", s.accountSynced, s.accountBytes.TerminalString()) + storage = fmt.Sprintf("%d@%v", s.storageSynced, s.storageBytes.TerminalString()) + bytecode = fmt.Sprintf("%d@%v", s.bytecodeSynced, s.bytecodeBytes.TerminalString()) + ) + log.Info("State sync in progress", "synced", progress, "state", synced, + "accounts", accounts, "slots", storage, "codes", bytecode, "eta", common.PrettyDuration(estTime-elapsed)) +} + +// reportHealProgress calculates various status reports and provides it to the user. +func (s *Syncer) reportHealProgress(force bool) { + // Don't report all the events, just occasionally + if !force && time.Since(s.logTime) < 3*time.Second { + return + } + s.logTime = time.Now() + + // Create a mega progress report + var ( + trienode = fmt.Sprintf("%d@%v", s.trienodeHealSynced, s.trienodeHealBytes.TerminalString()) + bytecode = fmt.Sprintf("%d@%v", s.bytecodeHealSynced, s.bytecodeHealBytes.TerminalString()) + ) + log.Info("State heal in progress", "nodes", trienode, "codes", bytecode, + "pending", s.healer.scheduler.Pending()) +} diff --git a/eth/sync.go b/eth/sync.go index 26badd1e21c2..03a5165245b2 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" ) @@ -40,12 +41,12 @@ const ( ) type txsync struct { - p *peer + p *eth.Peer txs []*types.Transaction } // syncTransactions starts sending all currently pending transactions to the given peer. -func (pm *ProtocolManager) syncTransactions(p *peer) { +func (h *handler) syncTransactions(p *eth.Peer) { // Assemble the set of transaction to broadcast or announce to the remote // peer. Fun fact, this is quite an expensive operation as it needs to sort // the transactions if the sorting is not cached yet. However, with a random @@ -53,7 +54,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) { // // TODO(karalabe): Figure out if we could get away with random order somehow var txs types.Transactions - pending, _ := pm.txpool.Pending() + pending, _ := h.txpool.Pending() for _, batch := range pending { txs = append(txs, batch...) } @@ -63,7 +64,7 @@ func (pm *ProtocolManager) syncTransactions(p *peer) { // The eth/65 protocol introduces proper transaction announcements, so instead // of dripping transactions across multiple peers, just send the entire list as // an announcement and let the remote side decide what they need (likely nothing). - if p.version >= eth65 { + if p.Version() >= eth.ETH65 { hashes := make([]common.Hash, len(txs)) for i, tx := range txs { hashes[i] = tx.Hash() @@ -73,8 +74,8 @@ func (pm *ProtocolManager) syncTransactions(p *peer) { } // Out of luck, peer is running legacy protocols, drop the txs over select { - case pm.txsyncCh <- &txsync{p: p, txs: txs}: - case <-pm.quitSync: + case h.txsyncCh <- &txsync{p: p, txs: txs}: + case <-h.quitSync: } } @@ -82,8 +83,8 @@ func (pm *ProtocolManager) syncTransactions(p *peer) { // connection. When a new peer appears, we relay all currently pending // transactions. In order to minimise egress bandwidth usage, we send // the transactions in small packs to one peer at a time. -func (pm *ProtocolManager) txsyncLoop64() { - defer pm.wg.Done() +func (h *handler) txsyncLoop64() { + defer h.wg.Done() var ( pending = make(map[enode.ID]*txsync) @@ -94,7 +95,7 @@ func (pm *ProtocolManager) txsyncLoop64() { // send starts a sending a pack of transactions from the sync. send := func(s *txsync) { - if s.p.version >= eth65 { + if s.p.Version() >= eth.ETH65 { panic("initial transaction syncer running on eth/65+") } // Fill pack with transactions up to the target size. @@ -108,14 +109,13 @@ func (pm *ProtocolManager) txsyncLoop64() { // Remove the transactions that will be sent. s.txs = s.txs[:copy(s.txs, s.txs[len(pack.txs):])] if len(s.txs) == 0 { - delete(pending, s.p.ID()) + delete(pending, s.p.Peer.ID()) } // Send the pack in the background. s.p.Log().Trace("Sending batch of transactions", "count", len(pack.txs), "bytes", size) sending = true - go func() { done <- pack.p.SendTransactions64(pack.txs) }() + go func() { done <- pack.p.SendTransactions(pack.txs) }() } - // pick chooses the next pending sync. pick := func() *txsync { if len(pending) == 0 { @@ -132,8 +132,8 @@ func (pm *ProtocolManager) txsyncLoop64() { for { select { - case s := <-pm.txsyncCh: - pending[s.p.ID()] = s + case s := <-h.txsyncCh: + pending[s.p.Peer.ID()] = s if !sending { send(s) } @@ -142,13 +142,13 @@ func (pm *ProtocolManager) txsyncLoop64() { // Stop tracking peers that cause send failures. if err != nil { pack.p.Log().Debug("Transaction send failed", "err", err) - delete(pending, pack.p.ID()) + delete(pending, pack.p.Peer.ID()) } // Schedule the next send. if s := pick(); s != nil { send(s) } - case <-pm.quitSync: + case <-h.quitSync: return } } @@ -156,7 +156,7 @@ func (pm *ProtocolManager) txsyncLoop64() { // chainSyncer coordinates blockchain sync components. type chainSyncer struct { - pm *ProtocolManager + handler *handler force *time.Timer forced bool // true when force timer fired peerEventCh chan struct{} @@ -166,15 +166,15 @@ type chainSyncer struct { // chainSyncOp is a scheduled sync operation. type chainSyncOp struct { mode downloader.SyncMode - peer *peer + peer *eth.Peer td *big.Int head common.Hash } // newChainSyncer creates a chainSyncer. -func newChainSyncer(pm *ProtocolManager) *chainSyncer { +func newChainSyncer(handler *handler) *chainSyncer { return &chainSyncer{ - pm: pm, + handler: handler, peerEventCh: make(chan struct{}), } } @@ -182,23 +182,24 @@ func newChainSyncer(pm *ProtocolManager) *chainSyncer { // handlePeerEvent notifies the syncer about a change in the peer set. // This is called for new peers and every time a peer announces a new // chain head. -func (cs *chainSyncer) handlePeerEvent(p *peer) bool { +func (cs *chainSyncer) handlePeerEvent(peer *eth.Peer) bool { select { case cs.peerEventCh <- struct{}{}: return true - case <-cs.pm.quitSync: + case <-cs.handler.quitSync: return false } } // loop runs in its own goroutine and launches the sync when necessary. func (cs *chainSyncer) loop() { - defer cs.pm.wg.Done() + defer cs.handler.wg.Done() - cs.pm.blockFetcher.Start() - cs.pm.txFetcher.Start() - defer cs.pm.blockFetcher.Stop() - defer cs.pm.txFetcher.Stop() + cs.handler.blockFetcher.Start() + cs.handler.txFetcher.Start() + defer cs.handler.blockFetcher.Stop() + defer cs.handler.txFetcher.Stop() + defer cs.handler.downloader.Terminate() // The force timer lowers the peer count threshold down to one when it fires. // This ensures we'll always start sync even if there aren't enough peers. @@ -209,7 +210,6 @@ func (cs *chainSyncer) loop() { if op := cs.nextSyncOp(); op != nil { cs.startSync(op) } - select { case <-cs.peerEventCh: // Peer information changed, recheck. @@ -220,14 +220,13 @@ func (cs *chainSyncer) loop() { case <-cs.force.C: cs.forced = true - case <-cs.pm.quitSync: + case <-cs.handler.quitSync: // Disable all insertion on the blockchain. This needs to happen before // terminating the downloader because the downloader waits for blockchain // inserts, and these can take a long time to finish. - cs.pm.blockchain.StopInsert() - cs.pm.downloader.Terminate() + cs.handler.chain.StopInsert() + cs.handler.downloader.Terminate() if cs.doneCh != nil { - // Wait for the current sync to end. <-cs.doneCh } return @@ -245,19 +244,22 @@ func (cs *chainSyncer) nextSyncOp() *chainSyncOp { minPeers := defaultMinSyncPeers if cs.forced { minPeers = 1 - } else if minPeers > cs.pm.maxPeers { - minPeers = cs.pm.maxPeers + } else if minPeers > cs.handler.maxPeers { + minPeers = cs.handler.maxPeers } - if cs.pm.peers.Len() < minPeers { + if cs.handler.peers.Len() < minPeers { return nil } - - // We have enough peers, check TD. - peer := cs.pm.peers.BestPeer() + // We have enough peers, check TD + peer := cs.handler.peers.ethPeerWithHighestTD() if peer == nil { return nil } mode, ourTD := cs.modeAndLocalHead() + if mode == downloader.FastSync && atomic.LoadUint32(&cs.handler.snapSync) == 1 { + // Fast sync via the snap protocol + mode = downloader.SnapSync + } op := peerToSyncOp(mode, peer) if op.td.Cmp(ourTD) <= 0 { return nil // We're in sync. @@ -265,42 +267,42 @@ func (cs *chainSyncer) nextSyncOp() *chainSyncOp { return op } -func peerToSyncOp(mode downloader.SyncMode, p *peer) *chainSyncOp { +func peerToSyncOp(mode downloader.SyncMode, p *eth.Peer) *chainSyncOp { peerHead, peerTD := p.Head() return &chainSyncOp{mode: mode, peer: p, td: peerTD, head: peerHead} } func (cs *chainSyncer) modeAndLocalHead() (downloader.SyncMode, *big.Int) { // If we're in fast sync mode, return that directly - if atomic.LoadUint32(&cs.pm.fastSync) == 1 { - block := cs.pm.blockchain.CurrentFastBlock() - td := cs.pm.blockchain.GetTdByHash(block.Hash()) + if atomic.LoadUint32(&cs.handler.fastSync) == 1 { + block := cs.handler.chain.CurrentFastBlock() + td := cs.handler.chain.GetTdByHash(block.Hash()) return downloader.FastSync, td } // We are probably in full sync, but we might have rewound to before the // fast sync pivot, check if we should reenable - if pivot := rawdb.ReadLastPivotNumber(cs.pm.chaindb); pivot != nil { - if head := cs.pm.blockchain.CurrentBlock(); head.NumberU64() < *pivot { - block := cs.pm.blockchain.CurrentFastBlock() - td := cs.pm.blockchain.GetTdByHash(block.Hash()) + if pivot := rawdb.ReadLastPivotNumber(cs.handler.database); pivot != nil { + if head := cs.handler.chain.CurrentBlock(); head.NumberU64() < *pivot { + block := cs.handler.chain.CurrentFastBlock() + td := cs.handler.chain.GetTdByHash(block.Hash()) return downloader.FastSync, td } } // Nope, we're really full syncing - head := cs.pm.blockchain.CurrentHeader() - td := cs.pm.blockchain.GetTd(head.Hash(), head.Number.Uint64()) + head := cs.handler.chain.CurrentHeader() + td := cs.handler.chain.GetTd(head.Hash(), head.Number.Uint64()) return downloader.FullSync, td } // startSync launches doSync in a new goroutine. func (cs *chainSyncer) startSync(op *chainSyncOp) { cs.doneCh = make(chan error, 1) - go func() { cs.doneCh <- cs.pm.doSync(op) }() + go func() { cs.doneCh <- cs.handler.doSync(op) }() } // doSync synchronizes the local blockchain with a remote peer. -func (pm *ProtocolManager) doSync(op *chainSyncOp) error { - if op.mode == downloader.FastSync { +func (h *handler) doSync(op *chainSyncOp) error { + if op.mode == downloader.FastSync || op.mode == downloader.SnapSync { // Before launch the fast sync, we have to ensure user uses the same // txlookup limit. // The main concern here is: during the fast sync Geth won't index the @@ -310,35 +312,33 @@ func (pm *ProtocolManager) doSync(op *chainSyncOp) error { // has been indexed. So here for the user-experience wise, it's non-optimal // that user can't change limit during the fast sync. If changed, Geth // will just blindly use the original one. - limit := pm.blockchain.TxLookupLimit() - if stored := rawdb.ReadFastTxLookupLimit(pm.chaindb); stored == nil { - rawdb.WriteFastTxLookupLimit(pm.chaindb, limit) + limit := h.chain.TxLookupLimit() + if stored := rawdb.ReadFastTxLookupLimit(h.database); stored == nil { + rawdb.WriteFastTxLookupLimit(h.database, limit) } else if *stored != limit { - pm.blockchain.SetTxLookupLimit(*stored) + h.chain.SetTxLookupLimit(*stored) log.Warn("Update txLookup limit", "provided", limit, "updated", *stored) } } // Run the sync cycle, and disable fast sync if we're past the pivot block - err := pm.downloader.Synchronise(op.peer.id, op.head, op.td, op.mode) + err := h.downloader.Synchronise(op.peer.ID(), op.head, op.td, op.mode) if err != nil { return err } - if atomic.LoadUint32(&pm.fastSync) == 1 { + if atomic.LoadUint32(&h.fastSync) == 1 { log.Info("Fast sync complete, auto disabling") - atomic.StoreUint32(&pm.fastSync, 0) + atomic.StoreUint32(&h.fastSync, 0) } - // If we've successfully finished a sync cycle and passed any required checkpoint, // enable accepting transactions from the network. - head := pm.blockchain.CurrentBlock() - if head.NumberU64() >= pm.checkpointNumber { + head := h.chain.CurrentBlock() + if head.NumberU64() >= h.checkpointNumber { // Checkpoint passed, sanity check the timestamp to have a fallback mechanism // for non-checkpointed (number = 0) private networks. if head.Time() >= uint64(time.Now().AddDate(0, -1, 0).Unix()) { - atomic.StoreUint32(&pm.acceptTxs, 1) + atomic.StoreUint32(&h.acceptTxs, 1) } } - if head.NumberU64() > 0 { // We've completed a sync cycle, notify all peers of new state. This path is // essential in star-topology networks where a gateway node needs to notify @@ -346,8 +346,7 @@ func (pm *ProtocolManager) doSync(op *chainSyncOp) error { // scenario will most often crop up in private and hackathon networks with // degenerate connectivity, but it should be healthy for the mainnet too to // more reliably update peers or the local TD state. - pm.BroadcastBlock(head, false) + h.BroadcastBlock(head, false) } - return nil } diff --git a/eth/sync_test.go b/eth/sync_test.go index ac1e5fad1bf1..473e19518fd0 100644 --- a/eth/sync_test.go +++ b/eth/sync_test.go @@ -22,43 +22,59 @@ import ( "time" "github.com/ethereum/go-ethereum/eth/downloader" + "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" ) -func TestFastSyncDisabling63(t *testing.T) { testFastSyncDisabling(t, 63) } +// Tests that fast sync is disabled after a successful sync cycle. func TestFastSyncDisabling64(t *testing.T) { testFastSyncDisabling(t, 64) } func TestFastSyncDisabling65(t *testing.T) { testFastSyncDisabling(t, 65) } // Tests that fast sync gets disabled as soon as a real block is successfully // imported into the blockchain. -func testFastSyncDisabling(t *testing.T, protocol int) { +func testFastSyncDisabling(t *testing.T, protocol uint) { t.Parallel() - // Create a pristine protocol manager, check that fast sync is left enabled - pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil) - if atomic.LoadUint32(&pmEmpty.fastSync) == 0 { + // Create an empty handler and ensure it's in fast sync mode + empty := newTestHandler() + if atomic.LoadUint32(&empty.handler.fastSync) == 0 { t.Fatalf("fast sync disabled on pristine blockchain") } - // Create a full protocol manager, check that fast sync gets disabled - pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil) - if atomic.LoadUint32(&pmFull.fastSync) == 1 { + defer empty.close() + + // Create a full handler and ensure fast sync ends up disabled + full := newTestHandlerWithBlocks(1024) + if atomic.LoadUint32(&full.handler.fastSync) == 1 { t.Fatalf("fast sync not disabled on non-empty blockchain") } + defer full.close() + + // Sync up the two handlers + emptyPipe, fullPipe := p2p.MsgPipe() + defer emptyPipe.Close() + defer fullPipe.Close() - // Sync up the two peers - io1, io2 := p2p.MsgPipe() - go pmFull.handle(pmFull.newPeer(protocol, p2p.NewPeer(enode.ID{}, "empty", nil), io2, pmFull.txpool.Get)) - go pmEmpty.handle(pmEmpty.newPeer(protocol, p2p.NewPeer(enode.ID{}, "full", nil), io1, pmEmpty.txpool.Get)) + emptyPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{1}, "", nil), emptyPipe, empty.txpool) + fullPeer := eth.NewPeer(protocol, p2p.NewPeer(enode.ID{2}, "", nil), fullPipe, full.txpool) + defer emptyPeer.Close() + defer fullPeer.Close() + go empty.handler.runEthPeer(emptyPeer, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(empty.handler), peer) + }) + go full.handler.runEthPeer(fullPeer, func(peer *eth.Peer) error { + return eth.Handle((*ethHandler)(full.handler), peer) + }) + // Wait a bit for the above handlers to start time.Sleep(250 * time.Millisecond) - op := peerToSyncOp(downloader.FastSync, pmEmpty.peers.BestPeer()) - if err := pmEmpty.doSync(op); err != nil { - t.Fatal("sync failed:", err) - } // Check that fast sync was disabled - if atomic.LoadUint32(&pmEmpty.fastSync) == 1 { + op := peerToSyncOp(downloader.FastSync, empty.handler.peers.ethPeerWithHighestTD()) + if err := empty.handler.doSync(op); err != nil { + t.Fatal("sync failed:", err) + } + if atomic.LoadUint32(&empty.handler.fastSync) == 1 { t.Fatalf("fast sync not disabled after successful synchronisation") } } diff --git a/eth/tracers/internal/tracers/assets.go b/eth/tracers/internal/tracers/assets.go index 432398ebb58e..8b89ad418203 100644 --- a/eth/tracers/internal/tracers/assets.go +++ b/eth/tracers/internal/tracers/assets.go @@ -6,7 +6,7 @@ // evmdis_tracer.js (4.195kB) // noop_tracer.js (1.271kB) // opcount_tracer.js (1.372kB) -// prestate_tracer.js (4.234kB) +// prestate_tracer.js (4.287kB) // trigram_tracer.js (1.788kB) // unigram_tracer.js (1.51kB) @@ -197,7 +197,7 @@ func opcount_tracerJs() (*asset, error) { return a, nil } -var _prestate_tracerJs = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x9c\x57\xdd\x6f\xdb\x38\x12\x7f\x96\xfe\x8a\x41\x5f\x6c\xa3\xae\xdc\x64\x81\x3d\xc0\xb9\x1c\xa0\xba\x6e\x1b\x20\x9b\x04\xb6\x7b\xb9\xdc\x62\x1f\x28\x72\x24\x73\x4d\x93\x02\x49\xd9\xf1\x15\xf9\xdf\x0f\x43\x7d\xf8\xa3\x49\xd3\xdd\x37\x9b\x1c\xfe\xe6\xfb\x37\xa3\xd1\x08\x26\xa6\xdc\x59\x59\x2c\x3d\x9c\xbf\x3f\xfb\x07\x2c\x96\x08\x85\x79\x87\x7e\x89\x16\xab\x35\xa4\x95\x5f\x1a\xeb\xe2\xd1\x08\x16\x4b\xe9\x20\x97\x0a\x41\x3a\x28\x99\xf5\x60\x72\xf0\x27\xf2\x4a\x66\x96\xd9\x5d\x12\x8f\x46\xf5\x9b\x67\xaf\x09\x21\xb7\x88\xe0\x4c\xee\xb7\xcc\xe2\x18\x76\xa6\x02\xce\x34\x58\x14\xd2\x79\x2b\xb3\xca\x23\x48\x0f\x4c\x8b\x91\xb1\xb0\x36\x42\xe6\x3b\x82\x94\x1e\x2a\x2d\xd0\x06\xd5\x1e\xed\xda\xb5\x76\x7c\xbe\xf9\x0a\xd7\xe8\x1c\x5a\xf8\x8c\x1a\x2d\x53\x70\x57\x65\x4a\x72\xb8\x96\x1c\xb5\x43\x60\x0e\x4a\x3a\x71\x4b\x14\x90\x05\x38\x7a\xf8\x89\x4c\x99\x37\xa6\xc0\x27\x53\x69\xc1\xbc\x34\x7a\x08\x28\xc9\x72\xd8\xa0\x75\xd2\x68\xf8\xa5\x55\xd5\x00\x0e\xc1\x58\x02\xe9\x33\x4f\x0e\x58\x30\x25\xbd\x1b\x00\xd3\x3b\x50\xcc\xef\x9f\xfe\x44\x40\xf6\x7e\x0b\x90\x3a\xa8\x59\x9a\x12\xc1\x2f\x99\x27\xaf\xb7\x52\x29\xc8\x10\x2a\x87\x79\xa5\x86\x84\x96\x55\x1e\xee\xaf\x16\x5f\x6e\xbf\x2e\x20\xbd\x79\x80\xfb\x74\x36\x4b\x6f\x16\x0f\x17\xb0\x95\x7e\x69\x2a\x0f\xb8\xc1\x1a\x4a\xae\x4b\x25\x51\xc0\x96\x59\xcb\xb4\xdf\x81\xc9\x09\xe1\xb7\xe9\x6c\xf2\x25\xbd\x59\xa4\x1f\xae\xae\xaf\x16\x0f\x60\x2c\x7c\xba\x5a\xdc\x4c\xe7\x73\xf8\x74\x3b\x83\x14\xee\xd2\xd9\xe2\x6a\xf2\xf5\x3a\x9d\xc1\xdd\xd7\xd9\xdd\xed\x7c\x9a\xc0\x1c\xc9\x2a\xa4\xf7\xaf\xc7\x3c\x0f\xd9\xb3\x08\x02\x3d\x93\xca\xb5\x91\x78\x30\x15\xb8\xa5\xa9\x94\x80\x25\xdb\x20\x58\xe4\x28\x37\x28\x80\x01\x37\xe5\xee\xa7\x93\x4a\x58\x4c\x19\x5d\x04\x9f\x5f\x2c\x48\xb8\xca\x41\x1b\x3f\x04\x87\x08\xff\x5c\x7a\x5f\x8e\x47\xa3\xed\x76\x9b\x14\xba\x4a\x8c\x2d\x46\xaa\x86\x73\xa3\x7f\x25\x31\x61\x96\x16\x9d\x67\x1e\x17\x96\x71\xb4\x60\x2a\x5f\x56\xde\x81\xab\xf2\x5c\x72\x89\xda\x83\xd4\xb9\xb1\xeb\x50\x29\xe0\x0d\x70\x8b\xcc\x23\x30\x50\x86\x33\x05\xf8\x88\xbc\x0a\x77\x75\xa4\x43\xb9\x5a\xa6\x1d\xe3\xe1\x34\xb7\x66\x4d\xbe\x56\xce\xd3\x0f\xe7\x70\x9d\x29\x14\x50\xa0\x46\x27\x1d\x64\xca\xf0\x55\x12\x7f\x8b\xa3\x03\x63\xa8\x4e\x82\x87\x8d\x50\xa8\x8d\x2d\xf6\x2c\x42\x56\x49\x25\xa4\x2e\x92\x38\x6a\xa5\xc7\xa0\x2b\xa5\x86\x71\x80\x50\xc6\xac\xaa\x32\xe5\xdc\x54\xc1\xf6\x3f\x91\xfb\x1a\xcc\x95\xc8\x65\x4e\xc5\xc1\xba\x5b\x6f\xc2\x55\xa7\xd7\x64\x24\x9f\xc4\xd1\x11\xcc\x18\xf2\x4a\x07\x77\xfa\x4c\x08\x3b\x04\x91\x0d\xbe\xc5\x51\xb4\x61\x96\xb0\xe0\x12\xbc\xf9\x82\x8f\xe1\x72\x70\x11\x47\x91\xcc\xa1\xef\x97\xd2\x25\x2d\xf0\xef\x8c\xf3\x3f\xe0\xf2\xf2\x32\x34\x75\x2e\x35\x8a\x01\x10\x44\xf4\x9c\x58\x7d\x13\x65\x4c\x31\xcd\x71\x0c\xbd\xf7\x8f\x3d\x78\x0b\x22\x4b\x0a\xf4\x1f\xea\xd3\x5a\x59\xe2\xcd\xdc\x5b\xa9\x8b\xfe\xd9\xaf\x83\x61\x78\xa5\x4d\x78\x03\x8d\xf8\x8d\xe9\x84\xeb\x7b\x6e\x44\xb8\x6e\x6c\xae\xa5\x26\x46\x34\x42\x8d\x94\xf3\xc6\xb2\x02\xc7\xf0\xed\x89\xfe\x3f\x91\x57\x4f\x71\xf4\x74\x14\xe5\x79\x2d\xf4\x42\x94\x1b\x08\x40\xed\x6d\x57\xe7\x85\xa4\x4e\x3d\x4c\x40\xc0\xfb\x51\x12\xe6\xad\x29\x27\x49\x58\xe1\xee\xf5\x4c\xd0\x85\x14\x8f\xdd\xc5\x0a\x77\x83\x8b\xf8\xc5\x14\x25\x8d\xd1\xbf\x4b\xf1\xf8\xb3\xf9\x3a\x79\x73\x14\xd7\x39\x49\xed\xed\x1d\x0c\x4e\xe2\x68\xd1\x55\xca\x53\xb9\x4b\xbd\x31\x2b\x22\xae\x25\xc5\x47\xa9\x10\x12\x53\x52\xb6\x5c\xcd\x1c\x19\xa2\x06\xe9\xd1\x32\xa2\x4e\xb3\x41\x4b\x53\x03\x2c\xfa\xca\x6a\xd7\x85\x31\x97\x9a\xa9\x16\xb8\x89\xba\xb7\x8c\xd7\x3d\x53\x9f\x1f\xc4\x92\xfb\xc7\x10\xc5\xe0\xdd\x68\x04\xa9\x07\x72\x11\x4a\x23\xb5\x1f\xc2\x16\x41\x23\x0a\x6a\x7c\x81\xa2\xe2\x3e\xe0\xf5\x36\x4c\x55\xd8\xab\x9b\x9b\x28\x32\x3c\x35\x15\x4d\x82\x83\xe6\x1f\x06\x03\xd7\x66\x13\x46\x5c\xc6\xf8\x0a\x9a\x86\x33\x56\x16\x52\xc7\x4d\x38\x8f\x9a\x8d\x2c\x4a\x08\x38\x98\x15\x72\x45\x49\xa4\x93\x0f\x4c\xc1\x25\x64\xb2\xb8\xd2\xfe\x24\x79\x75\xd0\xdb\xa7\x83\x3f\x92\xa6\x79\x12\x47\x84\xd7\x3f\x1f\x0c\xe1\xec\xd7\xae\x22\xbc\x21\x28\x78\x1d\xcc\x9b\x97\xa1\xe2\xd3\x62\x78\xfe\x59\x50\x43\x1d\xfc\x36\x68\x4d\x5c\x95\x51\x3a\x6a\x3f\x43\x1c\x8f\xbb\xf8\xe2\x07\xb8\xc7\xbe\xb5\xb8\x4d\x68\x12\x26\xc4\xcb\xa0\x75\x8a\x3e\x22\xb7\xb8\x26\x56\xa7\x2c\x70\xa6\x14\xda\x9e\x83\xc0\x19\xc3\xa6\x9c\x42\xbe\x70\x5d\xfa\x5d\xcb\xf5\x9e\xd9\x02\xbd\x7b\xdd\xb0\x80\xf3\xee\x5d\x4b\x81\x21\x14\xbb\x12\xe1\xf2\x12\x7a\x93\xd9\x34\x5d\x4c\x7b\x4d\x1b\x8d\x46\x70\x8f\x61\x13\xca\x94\xcc\x84\xda\x81\x40\x85\x1e\x6b\xbb\x8c\x0e\x21\xea\x28\x61\x48\x2b\x0d\x2d\x1b\xf8\x28\x9d\x97\xba\x80\x9a\x29\xb6\x34\x57\x1b\xb8\xd0\x23\x9c\x55\x8e\xaa\xf5\x64\x08\x79\x43\x1b\x85\x45\xe2\x15\xe2\xff\xd0\x6e\x4c\xc9\x6e\x03\xc9\xa5\x75\x1e\x4a\xc5\x38\x26\x84\xd7\x19\xf3\x72\x7e\x9b\x4e\x26\xd5\xb3\xd0\x82\x01\x68\x3f\xe0\x98\xa2\x01\x49\xea\x1d\xf4\x5b\x8c\x41\x1c\x45\xb6\x95\x3e\xc0\xbe\xd8\x53\x82\xf3\x58\x1e\x12\x02\x2d\x16\xb8\x41\xa2\xd0\xc0\x06\xf5\x30\x24\x5d\xff\xfe\xad\x99\xbe\xe8\x92\x38\xa2\x77\x07\x7d\xad\x4c\x71\xdc\xd7\xa2\x0e\x0b\xaf\xac\xa5\xfc\x77\x14\x9c\x53\x8f\xff\x59\x39\x4f\x31\xb5\x14\x9e\x86\x2d\x9e\x23\xc9\x40\x89\x34\x6d\x07\xdf\x93\x21\xcd\xad\x30\x27\x48\x5d\x33\xa5\xea\x6d\xae\x34\x1e\xb5\x97\x4c\xa9\x1d\xe5\x61\x6b\x69\x8d\xa1\xc5\x65\x08\x4e\x92\x54\x60\x9c\x20\x2a\x35\x57\x95\xa8\xcb\x20\xd4\x71\x83\xe7\x82\xcd\xc7\xfb\xcf\x1a\x9d\x63\x05\x26\x54\x49\xb9\x7c\x6c\x36\x48\x0d\xbd\x9a\xe4\xfa\x83\x5e\xd2\x19\x79\x4c\x31\xca\x14\x49\x5b\x64\x44\xd3\xa9\x10\x16\x9d\xeb\x0f\x1a\xce\xe9\x32\x7b\xbf\x44\x4d\xc1\x07\x8d\x5b\xe8\x56\x13\xc6\x39\xad\x6a\x62\x08\x4c\x08\xa2\xb6\x93\x35\x22\x8e\x22\xb7\x95\x9e\x2f\x21\x68\x32\xe5\xbe\x17\x07\x4d\xfd\x73\xe6\x10\xde\x4c\xff\xb3\x98\xdc\x7e\x9c\x4e\x6e\xef\x1e\xde\x8c\xe1\xe8\x6c\x7e\xf5\xdf\x69\x77\xf6\x21\xbd\x4e\x6f\x26\xd3\x37\xe3\x30\x9b\x9f\x71\xc8\x9b\xd6\x05\x52\xe8\x3c\xe3\xab\xa4\x44\x5c\xf5\xdf\x1f\xf3\xc0\xde\xc1\x28\xca\x2c\xb2\xd5\xc5\xde\x98\xba\x41\x1b\x1d\x2d\xe5\xc2\x25\xbc\x18\xac\x8b\x97\xad\x99\x34\xf2\xfd\x96\xc8\xf7\xab\x48\xa0\x8a\xd7\xed\x38\xff\xcb\x86\x84\xde\x61\x7c\x35\x06\xc7\x14\x6d\xc0\xf2\x7f\xf4\xe5\x92\xe7\x0e\xfd\x10\x50\x0b\xb3\x25\xe6\xeb\x50\xeb\x9b\x06\xf7\x20\x64\x67\x83\x9a\x41\x6f\xf3\xfe\xa0\x13\x26\xb0\xef\x45\xcf\x9f\x13\x45\x2d\xe0\xb2\x45\x7f\x1b\x5e\xbe\x1e\xa8\xf3\x26\x52\x27\x0a\x7e\x39\xd9\xf0\xc2\xfd\x1a\xd7\xc6\xee\x9a\x71\x74\xe0\xdf\x8f\xa3\x9a\x5e\x5f\x77\xf5\x44\x7f\xa8\xc8\xba\x83\x8f\xd3\xeb\xe9\xe7\x74\x31\x3d\x92\x9a\x2f\xd2\xc5\xd5\xa4\x3e\xfa\xcb\x85\x77\xf6\xd3\x85\xd7\x9b\xcf\x17\xb7\xb3\x69\x6f\xdc\xfc\xbb\xbe\x4d\x3f\xf6\xbe\x53\xd8\x6c\x81\x3f\x6a\x5d\x6f\xee\x8d\x15\x7f\xa7\x03\x0e\x36\xb2\x9c\x3d\xb7\x90\x05\x6a\xe7\xbe\x3a\xf9\xe0\x01\xa6\x5b\x56\xce\xeb\x8f\xbe\x28\xbc\x7f\x96\x87\x9f\xe2\xa7\xf8\xff\x01\x00\x00\xff\xff\xb1\x28\x85\x2a\x8a\x10\x00\x00") +var _prestate_tracerJs = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x9c\x57\xdd\x6f\xdb\x38\x12\x7f\xb6\xfe\x8a\x41\x5f\x6c\x5d\x5d\xb9\xcd\x02\x7b\x80\x73\x39\x40\x75\xdd\x36\x40\x36\x09\x6c\xe7\x72\xb9\xc5\x3e\x50\xe4\x48\xe6\x9a\x26\x05\x92\xb2\xe3\x2b\xf2\xbf\x1f\x86\xfa\xf0\x47\x93\xa6\x7b\x6f\x16\x39\xfc\xcd\xf7\x6f\xc6\xa3\x11\x4c\x4c\xb9\xb3\xb2\x58\x7a\x38\x7b\xff\xe1\xef\xb0\x58\x22\x14\xe6\x1d\xfa\x25\x5a\xac\xd6\x90\x56\x7e\x69\xac\x8b\x46\x23\x58\x2c\xa5\x83\x5c\x2a\x04\xe9\xa0\x64\xd6\x83\xc9\xc1\x9f\xc8\x2b\x99\x59\x66\x77\x49\x34\x1a\xd5\x6f\x9e\xbd\x26\x84\xdc\x22\x82\x33\xb9\xdf\x32\x8b\x63\xd8\x99\x0a\x38\xd3\x60\x51\x48\xe7\xad\xcc\x2a\x8f\x20\x3d\x30\x2d\x46\xc6\xc2\xda\x08\x99\xef\x08\x52\x7a\xa8\xb4\x40\x1b\x54\x7b\xb4\x6b\xd7\xda\xf1\xe5\xfa\x0e\xae\xd0\x39\xb4\xf0\x05\x35\x5a\xa6\xe0\xb6\xca\x94\xe4\x70\x25\x39\x6a\x87\xc0\x1c\x94\x74\xe2\x96\x28\x20\x0b\x70\xf4\xf0\x33\x99\x32\x6f\x4c\x81\xcf\xa6\xd2\x82\x79\x69\xf4\x10\x50\x92\xe5\xb0\x41\xeb\xa4\xd1\xf0\x4b\xab\xaa\x01\x1c\x82\xb1\x04\x32\x60\x9e\x1c\xb0\x60\x4a\x7a\x17\x03\xd3\x3b\x50\xcc\xef\x9f\xfe\x44\x40\xf6\x7e\x0b\x90\x3a\xa8\x59\x9a\x12\xc1\x2f\x99\x27\xaf\xb7\x52\x29\xc8\x10\x2a\x87\x79\xa5\x86\x84\x96\x55\x1e\xee\x2f\x17\x5f\x6f\xee\x16\x90\x5e\x3f\xc0\x7d\x3a\x9b\xa5\xd7\x8b\x87\x73\xd8\x4a\xbf\x34\x95\x07\xdc\x60\x0d\x25\xd7\xa5\x92\x28\x60\xcb\xac\x65\xda\xef\xc0\xe4\x84\xf0\xdb\x74\x36\xf9\x9a\x5e\x2f\xd2\x8f\x97\x57\x97\x8b\x07\x30\x16\x3e\x5f\x2e\xae\xa7\xf3\x39\x7c\xbe\x99\x41\x0a\xb7\xe9\x6c\x71\x39\xb9\xbb\x4a\x67\x70\x7b\x37\xbb\xbd\x99\x4f\x13\x98\x23\x59\x85\xf4\xfe\xf5\x98\xe7\x21\x7b\x16\x41\xa0\x67\x52\xb9\x36\x12\x0f\xa6\x02\xb7\x34\x95\x12\xb0\x64\x1b\x04\x8b\x1c\xe5\x06\x05\x30\xe0\xa6\xdc\xfd\x74\x52\x09\x8b\x29\xa3\x8b\xe0\xf3\x8b\x05\x09\x97\x39\x68\xe3\x87\xe0\x10\xe1\x1f\x4b\xef\xcb\xf1\x68\xb4\xdd\x6e\x93\x42\x57\x89\xb1\xc5\x48\xd5\x70\x6e\xf4\xcf\x24\x22\xcc\xd2\xa2\xf3\xcc\xe3\xc2\x32\x8e\x16\x4c\xe5\xcb\xca\x3b\x70\x55\x9e\x4b\x2e\x51\x7b\x90\x3a\x37\x76\x1d\x2a\x05\xbc\x01\x6e\x91\x79\x04\x06\xca\x70\xa6\x00\x1f\x91\x57\xe1\xae\x8e\x74\x28\x57\xcb\xb4\x63\x3c\x9c\xe6\xd6\xac\xc9\xd7\xca\x79\xfa\xe1\x1c\xae\x33\x85\x02\x0a\xd4\xe8\xa4\x83\x4c\x19\xbe\x4a\xa2\x6f\x51\xef\xc0\x18\xaa\x93\xe0\x61\x23\x14\x6a\x63\x8b\x7d\x8b\x90\x55\x52\x09\xa9\x8b\x24\xea\xb5\xd2\x63\xd0\x95\x52\xc3\x28\x40\x28\x63\x56\x55\x99\x72\x6e\xaa\x60\xfb\x9f\xc8\x7d\x0d\xe6\x4a\xe4\x32\xa7\xe2\x60\xdd\xad\x37\xe1\xaa\xd3\x6b\x32\x92\x4f\xa2\xde\x11\xcc\x18\xf2\x4a\x07\x77\x06\x4c\x08\x3b\x04\x91\xc5\xdf\xa2\x5e\x6f\xc3\x2c\x61\xc1\x05\x78\xf3\x15\x1f\xc3\x65\x7c\x1e\xf5\x7a\x32\x87\x81\x5f\x4a\x97\xb4\xc0\xbf\x33\xce\xff\x80\x8b\x8b\x8b\xd0\xd4\xb9\xd4\x28\x62\x20\x88\xde\x73\x62\xf5\x4d\x2f\x63\x8a\x69\x8e\x63\xe8\xbf\x7f\xec\xc3\x5b\x10\x59\x52\xa0\xff\x58\x9f\xd6\xca\x12\x6f\xe6\xde\x4a\x5d\x0c\x3e\xfc\x1a\x0f\xc3\x2b\x6d\xc2\x1b\x68\xc4\xaf\x4d\x27\x5c\xdf\x73\x23\xc2\x75\x63\x73\x2d\x35\x31\xa2\x11\x6a\xa4\x9c\x37\x96\x15\x38\x86\x6f\x4f\xf4\xfd\x44\x5e\x3d\x45\xbd\xa7\xa3\x28\xcf\x6b\xa1\x17\xa2\xdc\x40\x00\x6a\x6f\xbb\x3a\x2f\x24\x75\xea\x61\x02\x02\xde\x8f\x92\x30\x6f\x4d\x39\x49\xc2\x0a\x77\xaf\x67\x82\x2e\xa4\x78\xec\x2e\x56\xb8\x8b\xcf\xa3\x17\x53\x94\x34\x46\xff\x2e\xc5\xe3\xcf\xe6\xeb\xe4\xcd\x51\x5c\xe7\x24\xb5\xb7\x37\x8e\x4f\xe2\x68\xd1\x55\xca\x53\xb9\x4b\xbd\x31\x2b\x22\xae\x25\xc5\x47\xa9\x10\x12\x53\x52\xb6\x5c\xcd\x1c\x19\xa2\x06\xe9\xd1\x32\xa2\x4e\xb3\x41\x4b\x53\x03\x2c\xfa\xca\x6a\xd7\x85\x31\x97\x9a\xa9\x16\xb8\x89\xba\xb7\x8c\xd7\x3d\x53\x9f\x1f\xc4\x92\xfb\xc7\x10\xc5\xe0\xdd\x68\x04\xa9\x07\x72\x11\x4a\x23\xb5\x1f\xc2\x16\x41\x23\x0a\x6a\x7c\x81\xa2\xe2\x3e\xe0\xf5\x37\x4c\x55\xd8\xaf\x9b\x9b\x28\x32\x3c\x35\x15\x4d\x82\x83\xe6\x1f\x06\x03\xd7\x66\x13\x46\x5c\xc6\xf8\x0a\x9a\x86\x33\x56\x16\x52\x47\x4d\x38\x8f\x9a\x8d\x2c\x4a\x08\x38\x98\x15\x72\x45\x49\xa4\x93\x8f\x4c\xc1\x05\x64\xb2\xb8\xd4\xfe\x24\x79\x75\xd0\xdb\xa7\xf1\x1f\x49\xd3\x3c\x89\x23\xc2\x1b\x9c\xc5\x43\xf8\xf0\x6b\x57\x11\xde\x10\x14\xbc\x0e\xe6\xcd\xcb\x50\xd1\x69\x31\x3c\xff\x2c\xa8\xa1\x0e\x7e\x1b\xb4\x26\xae\xca\x28\x1d\xb5\x9f\x21\x8e\xc7\x5d\x7c\xfe\x03\xdc\x63\xdf\x5a\xdc\x26\x34\x09\x13\xe2\x10\x94\x3e\xc3\x77\xc1\xdc\x9d\x43\x01\x6f\x81\xbe\xa4\x26\x55\x4e\xf2\x2f\xcc\xc5\xf0\x37\x68\x24\x6e\xad\xe4\xdf\x59\x52\xe7\xf5\x13\x72\x8b\x6b\x1a\x05\x94\x3a\xce\x94\x42\xdb\x77\x10\x88\x66\xd8\xd4\x60\x48\x32\xae\x4b\xbf\x6b\x07\x84\x67\xb6\x40\xef\x5e\xf7\x26\xe0\xbc\x7b\xd7\xf2\x66\x88\xdf\xae\x44\xb8\xb8\x80\xfe\x64\x36\x4d\x17\xd3\x7e\xd3\x7b\xa3\x11\xdc\x63\x58\x9f\x32\x25\x33\xa1\x76\x20\x50\xa1\xc7\xda\x2e\xa3\x43\x5c\x3b\x1e\x19\xd2\x1e\x44\x1b\x0a\x3e\x4a\xe7\xa5\x2e\xa0\xa6\x97\x2d\x0d\xe3\x06\x2e\x34\x16\x67\x15\x85\xe7\x74\x72\x79\x43\x6b\x88\x45\x22\x23\x1a\x1a\xa1\x47\x99\x92\xdd\xda\x92\x4b\xeb\x3c\x94\x8a\x71\x4c\x08\xaf\x33\xe6\xe5\xa2\x68\xda\x9f\x54\xcf\x42\xdf\x06\xa0\xfd\x54\x64\x8a\xa6\x2a\xa9\x77\x30\x68\x31\xe2\xa8\xd7\xb3\xad\xf4\x01\xf6\xf9\x9e\x47\x9c\xc7\xf2\x90\x45\x68\x1b\xc1\x0d\x12\xef\x06\x0a\xa9\x27\x28\xe9\xfa\xd7\x6f\xcd\xc8\x46\x97\x44\x3d\x7a\x77\x40\x06\xca\x14\xc7\x64\x20\xea\xb0\xf0\xca\x5a\xca\x7f\xc7\xdb\x39\x11\xc3\x9f\x95\xf3\x14\x53\x4b\xe1\x69\x28\xe6\x39\x66\x0d\x3c\x4a\x23\x3a\xfe\x9e\x41\x69\xd8\x85\xe1\x42\xea\x9a\xd1\x56\xaf\x80\xa5\xf1\xa8\xbd\x64\x4a\xed\x28\x0f\x5b\x4b\xbb\x0f\x6d\x3b\x43\x70\x92\xa4\x02\x4d\x05\x51\xa9\xb9\xaa\x44\x5d\x06\xa1\xf8\x1b\x3c\x17\x6c\x3e\x5e\x9a\xd6\xe8\x1c\x2b\x30\xa1\x4a\xca\xe5\x63\xb3\x76\x6a\xe8\xd7\xcc\x38\x88\xfb\x49\x67\xe4\x31\x2f\x29\x53\x24\x6d\x91\x11\xb7\xa7\x42\x58\x74\x6e\x10\x37\x44\xd5\x65\xf6\x7e\x89\x9a\x82\x0f\x1a\xb7\xd0\xed\x33\x8c\x73\xda\xef\xc4\x10\x98\x10\xc4\x87\x27\xbb\x47\xd4\xeb\xb9\xad\xf4\x7c\x09\x41\x93\x29\xf7\xbd\x18\x37\xf5\xcf\x99\x43\x78\x33\xfd\xf7\x62\x72\xf3\x69\x3a\xb9\xb9\x7d\x78\x33\x86\xa3\xb3\xf9\xe5\x7f\xa6\xdd\xd9\xc7\xf4\x2a\xbd\x9e\x4c\xdf\x8c\xc3\x40\x7f\xc6\x21\x6f\x5a\x17\x48\xa1\xf3\x8c\xaf\x92\x12\x71\x35\x78\x7f\xcc\x03\x7b\x07\x7b\xbd\xcc\x22\x5b\x9d\xef\x8d\xa9\x1b\xb4\xd1\xd1\xf2\x34\x5c\xc0\x8b\xc1\x3a\x7f\xd9\x9a\x49\x23\x3f\x68\xd9\x7f\xbf\xbf\x04\xaa\x78\xdd\x8e\xb3\xbf\x6c\x48\xe8\x1d\xc6\x57\x63\x70\x4c\xd1\xda\x2c\xff\x4b\x7f\x77\xf2\xdc\xa1\x1f\x02\x6a\x61\xb6\xc4\x7c\x1d\x6a\x7d\xd3\xe0\x1e\x84\xec\x43\x5c\xd3\xee\x4d\x3e\x88\x3b\x61\x02\xfb\x5e\xf4\xec\x39\x51\xd4\x02\x2e\x5a\xf4\xb7\xe1\xe5\xeb\x81\x3a\x6b\x22\x75\xa2\xe0\x97\x93\xb5\x30\xdc\xaf\x71\x6d\xec\xae\x99\x61\x07\xfe\xfd\x38\xaa\xe9\xd5\x55\x57\x4f\xf4\x41\x45\xd6\x1d\x7c\x9a\x5e\x4d\xbf\xa4\x8b\xe9\x91\xd4\x7c\x91\x2e\x2e\x27\xf5\xd1\x5f\x2e\xbc\x0f\x3f\x5d\x78\xfd\xf9\x7c\x71\x33\x9b\xf6\xc7\xcd\xd7\xd5\x4d\xfa\xa9\xff\x9d\xc2\x66\x75\xfc\x51\xeb\x7a\x73\x6f\xac\xf8\x7f\x3a\xe0\x60\x8d\xcb\xd9\x73\x5b\x5c\xa0\x76\xee\xab\x93\x7f\x49\xc0\x74\xcb\xca\x79\xfd\x4f\xb1\x17\xde\x3f\xcb\xc3\x4f\xd1\x53\xf4\xbf\x00\x00\x00\xff\xff\x3a\xb7\x37\x41\xbf\x10\x00\x00") func prestate_tracerJsBytes() ([]byte, error) { return bindataRead( @@ -213,7 +213,7 @@ func prestate_tracerJs() (*asset, error) { } info := bindataFileInfo{name: "prestate_tracer.js", size: 0, mode: os.FileMode(0), modTime: time.Unix(0, 0)} - a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xe9, 0x79, 0x70, 0x4f, 0xc5, 0x78, 0x57, 0x63, 0x6f, 0x5, 0x31, 0xce, 0x3e, 0x5d, 0xbd, 0x71, 0x4, 0x46, 0x78, 0xcd, 0x1d, 0xcd, 0xb9, 0xd8, 0x10, 0xff, 0xe6, 0xc5, 0x59, 0xb9, 0x25, 0x6e}} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xd4, 0x9, 0xf9, 0x44, 0x13, 0x31, 0x89, 0xf7, 0x35, 0x9a, 0xc6, 0xf0, 0x86, 0x9d, 0xb2, 0xe3, 0x57, 0xe2, 0xc0, 0xde, 0xc9, 0x3a, 0x4c, 0x4a, 0x94, 0x90, 0xa5, 0x92, 0x2f, 0xbf, 0xc0, 0xb8}} return a, nil } diff --git a/eth/tracers/internal/tracers/prestate_tracer.js b/eth/tracers/internal/tracers/prestate_tracer.js index e0a22bf157d3..084c04ec46b8 100644 --- a/eth/tracers/internal/tracers/prestate_tracer.js +++ b/eth/tracers/internal/tracers/prestate_tracer.js @@ -55,7 +55,7 @@ var toBal = bigInt(this.prestate[toHex(ctx.to)].balance.slice(2), 16); this.prestate[toHex(ctx.to)].balance = '0x'+toBal.subtract(ctx.value).toString(16); - this.prestate[toHex(ctx.from)].balance = '0x'+fromBal.add(ctx.value).toString(16); + this.prestate[toHex(ctx.from)].balance = '0x'+fromBal.add(ctx.value).add((ctx.gasUsed + ctx.intrinsicGas) * ctx.gasPrice).toString(16); // Decrement the caller's nonce, and remove empty create targets this.prestate[toHex(ctx.from)].nonce--; diff --git a/eth/tracers/tracer.go b/eth/tracers/tracer.go index 050fb05159f1..c9f00d7371ed 100644 --- a/eth/tracers/tracer.go +++ b/eth/tracers/tracer.go @@ -27,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" @@ -316,7 +317,7 @@ type Tracer struct { // New instantiates a new tracer instance. code specifies a Javascript snippet, // which must evaluate to an expression returning an object with 'step', 'fault' // and 'result' functions. -func New(code string) (*Tracer, error) { +func New(code string, txCtx vm.TxContext) (*Tracer, error) { // Resolve any tracers by name and assemble the tracer object if tracer, ok := tracer(code); ok { code = tracer @@ -335,6 +336,8 @@ func New(code string) (*Tracer, error) { depthValue: new(uint), refundValue: new(uint), } + tracer.ctx["gasPrice"] = txCtx.GasPrice + // Set up builtins for this environment tracer.vm.PushGlobalGoFunction("toHex", func(ctx *duktape.Context) int { ctx.PushString(hexutil.Encode(popSlice(ctx))) @@ -545,7 +548,19 @@ func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost if jst.err == nil { // Initialize the context if it wasn't done yet if !jst.inited { - jst.ctx["block"] = env.BlockNumber.Uint64() + jst.ctx["block"] = env.Context.BlockNumber.Uint64() + // Compute intrinsic gas + isHomestead := env.ChainConfig().IsHomestead(env.Context.BlockNumber) + isIstanbul := env.ChainConfig().IsIstanbul(env.Context.BlockNumber) + var input []byte + if data, ok := jst.ctx["input"].([]byte); ok { + input = data + } + intrinsicGas, err := core.IntrinsicGas(input, jst.ctx["type"] == "CREATE", isHomestead, isIstanbul) + if err != nil { + return err + } + jst.ctx["intrinsicGas"] = intrinsicGas jst.inited = true } // If tracing was interrupted, set the error and stop @@ -597,8 +612,8 @@ func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost // CaptureEnd is called after the call finishes to finalize the tracing. func (jst *Tracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { jst.ctx["output"] = output - jst.ctx["gasUsed"] = gasUsed jst.ctx["time"] = t.String() + jst.ctx["gasUsed"] = gasUsed if err != nil { jst.ctx["error"] = err.Error() diff --git a/eth/tracers/tracer_test.go b/eth/tracers/tracer_test.go index b4de99865153..f28e14864bd1 100644 --- a/eth/tracers/tracer_test.go +++ b/eth/tracers/tracer_test.go @@ -17,7 +17,6 @@ package tracers import ( - "bytes" "encoding/json" "errors" "math/big" @@ -50,94 +49,77 @@ type dummyStatedb struct { func (*dummyStatedb) GetRefund() uint64 { return 1337 } -func runTrace(tracer *Tracer) (json.RawMessage, error) { - env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) - - contract := vm.NewContract(account{}, account{}, big.NewInt(0), 10000) - contract.Code = []byte{byte(vm.PUSH1), 0x1, byte(vm.PUSH1), 0x1, 0x0} - - _, err := env.Interpreter().Run(contract, []byte{}, false) - if err != nil { - return nil, err - } - return tracer.GetResult() -} - -// TestRegressionPanicSlice tests that we don't panic on bad arguments to memory access -func TestRegressionPanicSlice(t *testing.T) { - tracer, err := New("{depths: [], step: function(log) { this.depths.push(log.memory.slice(-1,-2)); }, fault: function() {}, result: function() { return this.depths; }}") - if err != nil { - t.Fatal(err) - } - if _, err = runTrace(tracer); err != nil { - t.Fatal(err) - } -} - -// TestRegressionPanicSlice tests that we don't panic on bad arguments to stack peeks -func TestRegressionPanicPeek(t *testing.T) { - tracer, err := New("{depths: [], step: function(log) { this.depths.push(log.stack.peek(-1)); }, fault: function() {}, result: function() { return this.depths; }}") - if err != nil { - t.Fatal(err) - } - if _, err = runTrace(tracer); err != nil { - t.Fatal(err) - } -} - -// TestRegressionPanicSlice tests that we don't panic on bad arguments to memory getUint -func TestRegressionPanicGetUint(t *testing.T) { - tracer, err := New("{ depths: [], step: function(log, db) { this.depths.push(log.memory.getUint(-64));}, fault: function() {}, result: function() { return this.depths; }}") - if err != nil { - t.Fatal(err) - } - if _, err = runTrace(tracer); err != nil { - t.Fatal(err) - } +type vmContext struct { + blockCtx vm.BlockContext + txCtx vm.TxContext } -func TestTracing(t *testing.T) { - tracer, err := New("{count: 0, step: function() { this.count += 1; }, fault: function() {}, result: function() { return this.count; }}") - if err != nil { - t.Fatal(err) - } - - ret, err := runTrace(tracer) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(ret, []byte("3")) { - t.Errorf("Expected return value to be 3, got %s", string(ret)) - } +func testCtx() *vmContext { + return &vmContext{blockCtx: vm.BlockContext{BlockNumber: big.NewInt(1)}, txCtx: vm.TxContext{GasPrice: big.NewInt(100000)}} } -func TestStack(t *testing.T) { - tracer, err := New("{depths: [], step: function(log) { this.depths.push(log.stack.length()); }, fault: function() {}, result: function() { return this.depths; }}") - if err != nil { - t.Fatal(err) - } +func runTrace(tracer *Tracer, vmctx *vmContext) (json.RawMessage, error) { + env := vm.NewEVM(vmctx.blockCtx, vmctx.txCtx, &dummyStatedb{}, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) + var ( + startGas uint64 = 10000 + value = big.NewInt(0) + ) + contract := vm.NewContract(account{}, account{}, value, startGas) + contract.Code = []byte{byte(vm.PUSH1), 0x1, byte(vm.PUSH1), 0x1, 0x0} - ret, err := runTrace(tracer) + tracer.CaptureStart(contract.Caller(), contract.Address(), false, []byte{}, startGas, value) + ret, err := env.Interpreter().Run(contract, []byte{}, false) + tracer.CaptureEnd(ret, startGas-contract.Gas, 1, err) if err != nil { - t.Fatal(err) - } - if !bytes.Equal(ret, []byte("[0,1,2]")) { - t.Errorf("Expected return value to be [0,1,2], got %s", string(ret)) + return nil, err } + return tracer.GetResult() } -func TestOpcodes(t *testing.T) { - tracer, err := New("{opcodes: [], step: function(log) { this.opcodes.push(log.op.toString()); }, fault: function() {}, result: function() { return this.opcodes; }}") - if err != nil { - t.Fatal(err) - } - - ret, err := runTrace(tracer) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(ret, []byte("[\"PUSH1\",\"PUSH1\",\"STOP\"]")) { - t.Errorf("Expected return value to be [\"PUSH1\",\"PUSH1\",\"STOP\"], got %s", string(ret)) +func TestTracer(t *testing.T) { + execTracer := func(code string) []byte { + t.Helper() + ctx := &vmContext{blockCtx: vm.BlockContext{BlockNumber: big.NewInt(1)}, txCtx: vm.TxContext{GasPrice: big.NewInt(100000)}} + tracer, err := New(code, ctx.txCtx) + if err != nil { + t.Fatal(err) + } + ret, err := runTrace(tracer, ctx) + if err != nil { + t.Fatal(err) + } + return ret + } + for i, tt := range []struct { + code string + want string + }{ + { // tests that we don't panic on bad arguments to memory access + code: "{depths: [], step: function(log) { this.depths.push(log.memory.slice(-1,-2)); }, fault: function() {}, result: function() { return this.depths; }}", + want: `[{},{},{}]`, + }, { // tests that we don't panic on bad arguments to stack peeks + code: "{depths: [], step: function(log) { this.depths.push(log.stack.peek(-1)); }, fault: function() {}, result: function() { return this.depths; }}", + want: `["0","0","0"]`, + }, { // tests that we don't panic on bad arguments to memory getUint + code: "{ depths: [], step: function(log, db) { this.depths.push(log.memory.getUint(-64));}, fault: function() {}, result: function() { return this.depths; }}", + want: `["0","0","0"]`, + }, { // tests some general counting + code: "{count: 0, step: function() { this.count += 1; }, fault: function() {}, result: function() { return this.count; }}", + want: `3`, + }, { // tests that depth is reported correctly + code: "{depths: [], step: function(log) { this.depths.push(log.stack.length()); }, fault: function() {}, result: function() { return this.depths; }}", + want: `[0,1,2]`, + }, { // tests to-string of opcodes + code: "{opcodes: [], step: function(log) { this.opcodes.push(log.op.toString()); }, fault: function() {}, result: function() { return this.opcodes; }}", + want: `["PUSH1","PUSH1","STOP"]`, + }, { // tests intrinsic gas + code: "{depths: [], step: function() {}, fault: function() {}, result: function(ctx) { return ctx.gasPrice+'.'+ctx.gasUsed+'.'+ctx.intrinsicGas; }}", + want: `"100000.6.21000"`, + }, + } { + if have := execTracer(tt.code); tt.want != string(have) { + t.Errorf("testcase %d: expected return value to be %s got %s\n\tcode: %v", i, tt.want, string(have), tt.code) + } } } @@ -145,7 +127,8 @@ func TestHalt(t *testing.T) { t.Skip("duktape doesn't support abortion") timeout := errors.New("stahp") - tracer, err := New("{step: function() { while(1); }, result: function() { return null; }}") + vmctx := testCtx() + tracer, err := New("{step: function() { while(1); }, result: function() { return null; }}", vmctx.txCtx) if err != nil { t.Fatal(err) } @@ -155,18 +138,18 @@ func TestHalt(t *testing.T) { tracer.Stop(timeout) }() - if _, err = runTrace(tracer); err.Error() != "stahp in server-side tracer function 'step'" { + if _, err = runTrace(tracer, vmctx); err.Error() != "stahp in server-side tracer function 'step'" { t.Errorf("Expected timeout error, got %v", err) } } func TestHaltBetweenSteps(t *testing.T) { - tracer, err := New("{step: function() {}, fault: function() {}, result: function() { return null; }}") + vmctx := testCtx() + tracer, err := New("{step: function() {}, fault: function() {}, result: function() { return null; }}", vmctx.txCtx) if err != nil { t.Fatal(err) } - - env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) + env := vm.NewEVM(vm.BlockContext{BlockNumber: big.NewInt(1)}, vm.TxContext{}, &dummyStatedb{}, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) contract := vm.NewContract(&account{}, &account{}, big.NewInt(0), 0) tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, nil, nil, contract, 0, nil) diff --git a/eth/tracers/tracers_test.go b/eth/tracers/tracers_test.go index 18f8eb12aa55..9dc4c696311c 100644 --- a/eth/tracers/tracers_test.go +++ b/eth/tracers/tracers_test.go @@ -143,16 +143,18 @@ func TestPrestateTracerCreate2(t *testing.T) { result: 0x60f3f640a8508fC6a86d45DF051962668E1e8AC7 */ origin, _ := signer.Sender(tx) - context := vm.Context{ + txContext := vm.TxContext{ + Origin: origin, + GasPrice: big.NewInt(1), + } + context := vm.BlockContext{ CanTransfer: core.CanTransfer, Transfer: core.Transfer, - Origin: origin, Coinbase: common.Address{}, BlockNumber: new(big.Int).SetUint64(8000000), Time: new(big.Int).SetUint64(5), Difficulty: big.NewInt(0x30000), GasLimit: uint64(6000000), - GasPrice: big.NewInt(1), } alloc := core.GenesisAlloc{} @@ -171,11 +173,11 @@ func TestPrestateTracerCreate2(t *testing.T) { _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), alloc, false) // Create the tracer, the EVM environment and run it - tracer, err := New("prestateTracer") + tracer, err := New("prestateTracer", txContext) if err != nil { t.Fatalf("failed to create call tracer: %v", err) } - evm := vm.NewEVM(context, statedb, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer}) + evm := vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vm.Config{Debug: true, Tracer: tracer}) msg, err := tx.AsMessage(signer) if err != nil { @@ -230,26 +232,27 @@ func TestCallTracer(t *testing.T) { } signer := types.MakeSigner(test.Genesis.Config, new(big.Int).SetUint64(uint64(test.Context.Number))) origin, _ := signer.Sender(tx) - - context := vm.Context{ + txContext := vm.TxContext{ + Origin: origin, + GasPrice: tx.GasPrice(), + } + context := vm.BlockContext{ CanTransfer: core.CanTransfer, Transfer: core.Transfer, - Origin: origin, Coinbase: test.Context.Miner, BlockNumber: new(big.Int).SetUint64(uint64(test.Context.Number)), Time: new(big.Int).SetUint64(uint64(test.Context.Time)), Difficulty: (*big.Int)(test.Context.Difficulty), GasLimit: uint64(test.Context.GasLimit), - GasPrice: tx.GasPrice(), } _, statedb := tests.MakePreState(rawdb.NewMemoryDatabase(), test.Genesis.Alloc, false) // Create the tracer, the EVM environment and run it - tracer, err := New("callTracer") + tracer, err := New("callTracer", txContext) if err != nil { t.Fatalf("failed to create call tracer: %v", err) } - evm := vm.NewEVM(context, statedb, test.Genesis.Config, vm.Config{Debug: true, Tracer: tracer}) + evm := vm.NewEVM(context, txContext, statedb, test.Genesis.Config, vm.Config{Debug: true, Tracer: tracer}) msg, err := tx.AsMessage(signer) if err != nil { diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index 1828ad70fbc5..e0f4f95ff3b2 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -36,8 +36,8 @@ import ( "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/eth/downloader" + ethproto "github.com/ethereum/go-ethereum/eth/protocols/eth" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/les" "github.com/ethereum/go-ethereum/log" @@ -444,13 +444,15 @@ func (s *Service) login(conn *connWrapper) error { // Construct and send the login authentication infos := s.server.NodeInfo() - var network, protocol string + var protocols []string + for _, proto := range s.server.Protocols { + protocols = append(protocols, fmt.Sprintf("%s/%d", proto.Name, proto.Version)) + } + var network string if info := infos.Protocols["eth"]; info != nil { - network = fmt.Sprintf("%d", info.(*eth.NodeInfo).Network) - protocol = fmt.Sprintf("eth/%d", eth.ProtocolVersions[0]) + network = fmt.Sprintf("%d", info.(*ethproto.NodeInfo).Network) } else { network = fmt.Sprintf("%d", infos.Protocols["les"].(*les.NodeInfo).Network) - protocol = fmt.Sprintf("les/%d", les.ClientProtocolVersions[0]) } auth := &authMsg{ ID: s.node, @@ -459,7 +461,7 @@ func (s *Service) login(conn *connWrapper) error { Node: infos.Name, Port: infos.Ports.Listener, Network: network, - Protocol: protocol, + Protocol: strings.Join(protocols, ", "), API: "No", Os: runtime.GOOS, OsVer: runtime.GOARCH, diff --git a/event/event_test.go b/event/event_test.go index cc9fa5d7c828..bdad11f13d6c 100644 --- a/event/event_test.go +++ b/event/event_test.go @@ -29,7 +29,7 @@ func TestSubCloseUnsub(t *testing.T) { // the point of this test is **not** to panic var mux TypeMux mux.Stop() - sub := mux.Subscribe(int(0)) + sub := mux.Subscribe(0) sub.Unsubscribe() } diff --git a/event/feed_test.go b/event/feed_test.go index 93badaaee6d8..cdf29fdd73ca 100644 --- a/event/feed_test.go +++ b/event/feed_test.go @@ -27,8 +27,8 @@ import ( func TestFeedPanics(t *testing.T) { { var f Feed - f.Send(int(2)) - want := feedTypeError{op: "Send", got: reflect.TypeOf(uint64(0)), want: reflect.TypeOf(int(0))} + f.Send(2) + want := feedTypeError{op: "Send", got: reflect.TypeOf(uint64(0)), want: reflect.TypeOf(0)} if err := checkPanic(want, func() { f.Send(uint64(2)) }); err != nil { t.Error(err) } @@ -37,14 +37,14 @@ func TestFeedPanics(t *testing.T) { var f Feed ch := make(chan int) f.Subscribe(ch) - want := feedTypeError{op: "Send", got: reflect.TypeOf(uint64(0)), want: reflect.TypeOf(int(0))} + want := feedTypeError{op: "Send", got: reflect.TypeOf(uint64(0)), want: reflect.TypeOf(0)} if err := checkPanic(want, func() { f.Send(uint64(2)) }); err != nil { t.Error(err) } } { var f Feed - f.Send(int(2)) + f.Send(2) want := feedTypeError{op: "Subscribe", got: reflect.TypeOf(make(chan uint64)), want: reflect.TypeOf(make(chan<- int))} if err := checkPanic(want, func() { f.Subscribe(make(chan uint64)) }); err != nil { t.Error(err) @@ -58,7 +58,7 @@ func TestFeedPanics(t *testing.T) { } { var f Feed - if err := checkPanic(errBadChannel, func() { f.Subscribe(int(0)) }); err != nil { + if err := checkPanic(errBadChannel, func() { f.Subscribe(0) }); err != nil { t.Error(err) } } diff --git a/fuzzbuzz.yaml b/fuzzbuzz.yaml deleted file mode 100644 index 2a4f0c296fe7..000000000000 --- a/fuzzbuzz.yaml +++ /dev/null @@ -1,44 +0,0 @@ -# bmt keystore rlp trie whisperv6 - -base: ubuntu:16.04 -targets: - - name: rlp - language: go - version: "1.13" - corpus: ./fuzzers/rlp/corpus - harness: - function: Fuzz - package: github.com/ethereum/go-ethereum/tests/fuzzers/rlp - checkout: github.com/ethereum/go-ethereum/ - - name: keystore - language: go - version: "1.13" - corpus: ./fuzzers/keystore/corpus - harness: - function: Fuzz - package: github.com/ethereum/go-ethereum/tests/fuzzers/keystore - checkout: github.com/ethereum/go-ethereum/ - - name: trie - language: go - version: "1.13" - corpus: ./fuzzers/trie/corpus - harness: - function: Fuzz - package: github.com/ethereum/go-ethereum/tests/fuzzers/trie - checkout: github.com/ethereum/go-ethereum/ - - name: txfetcher - language: go - version: "1.13" - corpus: ./fuzzers/txfetcher/corpus - harness: - function: Fuzz - package: github.com/ethereum/go-ethereum/tests/fuzzers/txfetcher - checkout: github.com/ethereum/go-ethereum/ - - name: whisperv6 - language: go - version: "1.13" - corpus: ./fuzzers/whisperv6/corpus - harness: - function: Fuzz - package: github.com/ethereum/go-ethereum/tests/fuzzers/whisperv6 - checkout: github.com/ethereum/go-ethereum/ diff --git a/go.mod b/go.mod old mode 100755 new mode 100644 index 77e1dda6887f..0f47809ef6bd --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/go-sourcemap/sourcemap v2.1.2+incompatible // indirect github.com/go-stack/stack v1.8.0 github.com/golang/protobuf v1.4.2 - github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26 + github.com/golang/snappy v0.0.3-0.20201103224600-674baa8c7fc3 github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa github.com/gorilla/websocket v1.4.1-0.20190629185528-ae1634f6a989 github.com/graph-gophers/graphql-go v0.0.0-20191115155744-f33e81362277 @@ -36,6 +36,7 @@ require ( github.com/huin/goupnp v1.0.0 github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883 github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 + github.com/jedisct1/go-minisign v0.0.0-20190909160543-45766022959e github.com/julienschmidt/httprouter v1.1.1-0.20170430222011-975b5c4c7c21 github.com/karalabe/usb v0.0.0-20190919080040-51dc0efba356 github.com/kr/pretty v0.1.0 // indirect @@ -60,6 +61,7 @@ require ( github.com/tyler-smith/go-bip39 v1.0.1-0.20181017060643-dbb3b84ba2ef github.com/wsddn/go-ecdh v0.0.0-20161211032359-48726bab9208 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/mobile v0.0.0-20200801112145-973feb4309de // indirect golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 golang.org/x/text v0.3.3 @@ -67,5 +69,5 @@ require ( gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce gopkg.in/olebedev/go-duktape.v3 v3.0.0-20200619000410-60c24ae608a6 gopkg.in/urfave/cli.v1 v1.20.0 - gotest.tools v2.2.0+incompatible // indirect + gotest.tools v2.2.0+incompatible ) diff --git a/go.sum b/go.sum old mode 100755 new mode 100644 index 949f7a21b4f4..d1bd6085856a --- a/go.sum +++ b/go.sum @@ -20,6 +20,7 @@ github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6L github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k= github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 h1:fLjPD/aNc3UIOA6tDi6QXUemppXK3P9BI7mr2hd6gx8= github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= @@ -96,6 +97,8 @@ github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26 h1:lMm2hD9Fy0ynom5+85/pbdkiYcBqM1JWmhpAXLmy0fw= github.com/golang/snappy v0.0.2-0.20200707131729-196ae77b8a26/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3-0.20201103224600-674baa8c7fc3 h1:ur2rms48b3Ep1dxh7aUV2FZEQ8jEVO2F6ILKx8ofkAg= +github.com/golang/snappy v0.0.3-0.20201103224600-674baa8c7fc3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -122,6 +125,8 @@ github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883 h1:FSeK4fZCo github.com/influxdata/influxdb v1.2.3-0.20180221223340-01288bdb0883/go.mod h1:qZna6X/4elxqT3yI9iZYdZrWWdeFOOprn86kgg4+IzY= github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458 h1:6OvNmYgJyexcZ3pYbTI9jWx5tHo1Dee/tWbLMfPe2TA= github.com/jackpal/go-nat-pmp v1.0.2-0.20160603034137-1fa385a6f458/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= +github.com/jedisct1/go-minisign v0.0.0-20190909160543-45766022959e h1:UvSe12bq+Uj2hWd8aOlwPmoZ+CITRFrdit+sDGfAg8U= +github.com/jedisct1/go-minisign v0.0.0-20190909160543-45766022959e/go.mod h1:G1CVv03EnqU1wYL2dFwXxW2An0az9JTl/ZsqXQeBlkU= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/julienschmidt/httprouter v1.1.1-0.20170430222011-975b5c4c7c21 h1:F/iKcka0K2LgnKy/fgSBf235AETtm1n1TvBzqu40LE0= @@ -212,11 +217,26 @@ github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtX github.com/wsddn/go-ecdh v0.0.0-20161211032359-48726bab9208 h1:1cngl9mPEoITZG8s8cVcUy5CeIBYhEESkOB7m6Gmkrk= github.com/wsddn/go-ecdh v0.0.0-20161211032359-48726bab9208/go.mod h1:IotVbo4F+mw0EzQ08zFqg7pK3FebNXpaMsRy2RT+Ees= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190909091759-094676da4a83/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20200801112145-973feb4309de h1:OVJ6QQUBAesB8CZijKDSsXX7xYVtUhrkY0gwMfbi4p4= +golang.org/x/mobile v0.0.0-20200801112145-973feb4309de/go.mod h1:skQtrUTUwhdJvXM/2KKJzY8pDgNr9I/FOMqDVRPBUS4= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.1.1-0.20191209134235-331c550502dd h1:ePuNC7PZ6O5BzgPn9bZayERXBdfZjUYoXEf5BTfDfh8= +golang.org/x/mod v0.1.1-0.20191209134235-331c550502dd/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181011144130-49bb7cea24b1/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= @@ -224,6 +244,7 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -245,6 +266,12 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69 h1:yBHHx+XZqXJBm6Exke3N7V9gnlsyXxoCPEb1yVenjfk= +golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= diff --git a/graphql/graphql.go b/graphql/graphql.go index 559da8aaaa7a..22cfcf6637d2 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -77,7 +77,7 @@ func (a *Account) Code(ctx context.Context) (hexutil.Bytes, error) { if err != nil { return hexutil.Bytes{}, err } - return hexutil.Bytes(state.GetCode(a.address)), nil + return state.GetCode(a.address), nil } func (a *Account) Storage(ctx context.Context, args struct{ Slot common.Hash }) (common.Hash, error) { @@ -116,7 +116,7 @@ func (l *Log) Topics(ctx context.Context) []common.Hash { } func (l *Log) Data(ctx context.Context) hexutil.Bytes { - return hexutil.Bytes(l.log.Data) + return l.log.Data } // Transaction represents an Ethereum transaction. @@ -157,7 +157,7 @@ func (t *Transaction) InputData(ctx context.Context) (hexutil.Bytes, error) { if err != nil || tx == nil { return hexutil.Bytes{}, err } - return hexutil.Bytes(tx.Data()), nil + return tx.Data(), nil } func (t *Transaction) Gas(ctx context.Context) (hexutil.Uint64, error) { @@ -410,7 +410,7 @@ func (b *Block) resolveReceipts(ctx context.Context) ([]*types.Receipt, error) { if err != nil { return nil, err } - b.receipts = []*types.Receipt(receipts) + b.receipts = receipts } return b.receipts, nil } @@ -490,7 +490,7 @@ func (b *Block) Nonce(ctx context.Context) (hexutil.Bytes, error) { if err != nil { return hexutil.Bytes{}, err } - return hexutil.Bytes(header.Nonce[:]), nil + return header.Nonce[:], nil } func (b *Block) MixHash(ctx context.Context) (common.Hash, error) { @@ -564,7 +564,7 @@ func (b *Block) ExtraData(ctx context.Context) (hexutil.Bytes, error) { if err != nil { return hexutil.Bytes{}, err } - return hexutil.Bytes(header.Extra), nil + return header.Extra, nil } func (b *Block) LogsBloom(ctx context.Context) (hexutil.Bytes, error) { @@ -572,7 +572,7 @@ func (b *Block) LogsBloom(ctx context.Context) (hexutil.Bytes, error) { if err != nil { return hexutil.Bytes{}, err } - return hexutil.Bytes(header.Bloom.Bytes()), nil + return header.Bloom.Bytes(), nil } func (b *Block) TotalDifficulty(ctx context.Context) (hexutil.Big, error) { @@ -907,7 +907,7 @@ func (r *Resolver) Block(ctx context.Context, args struct { }) (*Block, error) { var block *Block if args.Number != nil { - number := rpc.BlockNumber(uint64(*args.Number)) + number := rpc.BlockNumber(*args.Number) numberOrHash := rpc.BlockNumberOrHashWithNumber(number) block = &Block{ backend: r.backend, @@ -1040,10 +1040,6 @@ func (r *Resolver) GasPrice(ctx context.Context) (hexutil.Big, error) { return hexutil.Big(*price), err } -func (r *Resolver) ProtocolVersion(ctx context.Context) (int32, error) { - return int32(r.backend.ProtocolVersion()), nil -} - func (r *Resolver) ChainID(ctx context.Context) (hexutil.Big, error) { return hexutil.Big(*r.backend.ChainConfig().ChainID), nil } diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 5ba9c955375d..98c8622ee624 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -22,8 +22,13 @@ import ( "net/http" "strings" "testing" + "time" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/miner" "github.com/ethereum/go-ethereum/node" "github.com/stretchr/testify/assert" ) @@ -61,6 +66,7 @@ func TestGraphQLHTTPOnSamePort_GQLRequest_Successful(t *testing.T) { t.Fatalf("could not read from response body: %v", err) } expected := "{\"data\":{\"block\":{\"number\":\"0x0\"}}}" + assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, expected, string(bodyBytes)) } @@ -90,6 +96,32 @@ func TestGraphQLHTTPOnSamePort_GQLRequest_Unsuccessful(t *testing.T) { assert.Equal(t, "404 page not found\n", string(bodyBytes)) } +// Tests that 400 is returned when an invalid RPC request is made. +func TestGraphQL_BadRequest(t *testing.T) { + stack := createNode(t, true) + defer stack.Close() + // start node + if err := stack.Start(); err != nil { + t.Fatalf("could not start node: %v", err) + } + // create http request + body := strings.NewReader("{\"query\": \"{bleh{number}}\",\"variables\": null}") + gqlReq, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s/graphql", "127.0.0.1:9393"), body) + if err != nil { + t.Error("could not issue new http request ", err) + } + gqlReq.Header.Set("Content-Type", "application/json") + // read from response + resp := doHTTPRequest(t, gqlReq) + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("could not read from response body: %v", err) + } + expected := "{\"errors\":[{\"message\":\"Cannot query field \\\"bleh\\\" on type \\\"Query\\\".\",\"locations\":[{\"line\":1,\"column\":2}]}]}" + assert.Equal(t, expected, string(bodyBytes)) + assert.Equal(t, 400, resp.StatusCode) +} + func createNode(t *testing.T, gqlEnabled bool) *node.Node { stack, err := node.New(&node.Config{ HTTPHost: "127.0.0.1", @@ -110,8 +142,24 @@ func createNode(t *testing.T, gqlEnabled bool) *node.Node { } func createGQLService(t *testing.T, stack *node.Node, endpoint string) { - // create backend - ethBackend, err := eth.New(stack, ð.DefaultConfig) + // create backend (use a config which is light on mem consumption) + ethConf := ð.Config{ + Genesis: core.DeveloperGenesisBlock(15, common.Address{}), + Miner: miner.Config{ + Etherbase: common.HexToAddress("0xaabb"), + }, + Ethash: ethash.Config{ + PowMode: ethash.ModeTest, + }, + NetworkId: 1337, + TrieCleanCache: 5, + TrieCleanCacheJournal: "triecache", + TrieCleanCacheRejournal: 60 * time.Minute, + TrieDirtyCache: 5, + TrieTimeout: 60 * time.Minute, + SnapshotCache: 5, + } + ethBackend, err := eth.New(stack, ethConf) if err != nil { t.Fatalf("could not create eth backend: %v", err) } diff --git a/graphql/schema.go b/graphql/schema.go index d7b253f22703..1fdc370040b4 100644 --- a/graphql/schema.go +++ b/graphql/schema.go @@ -310,8 +310,6 @@ const schema string = ` # GasPrice returns the node's estimate of a gas price sufficient to # ensure a transaction is mined in a timely fashion. gasPrice: BigInt! - # ProtocolVersion returns the current wire protocol version number. - protocolVersion: Int! # Syncing returns information on the current synchronisation state. syncing: SyncState # ChainID returns the current chain ID for transaction replay protection. diff --git a/graphql/service.go b/graphql/service.go index ae962e5b365a..bcb0a4990d64 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -17,12 +17,44 @@ package graphql import ( + "encoding/json" + "net/http" + "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/node" "github.com/graph-gophers/graphql-go" - "github.com/graph-gophers/graphql-go/relay" ) +type handler struct { + Schema *graphql.Schema +} + +func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var params struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + Variables map[string]interface{} `json:"variables"` + } + if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + response := h.Schema.Exec(r.Context(), params.Query, params.OperationName, params.Variables) + responseJSON, err := json.Marshal(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(response.Errors) > 0 { + w.WriteHeader(http.StatusBadRequest) + } + + w.Header().Set("Content-Type", "application/json") + w.Write(responseJSON) + +} + // New constructs a new GraphQL service instance. func New(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) error { if backend == nil { @@ -41,7 +73,7 @@ func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) if err != nil { return err } - h := &relay.Handler{Schema: s} + h := handler{Schema: s} handler := node.NewHTTPHandlerStack(h, cors, vhosts) stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{}) diff --git a/internal/build/archive.go b/internal/build/archive.go index a00258d99903..8b3ac23d1d89 100644 --- a/internal/build/archive.go +++ b/internal/build/archive.go @@ -184,24 +184,35 @@ func (a *TarballArchive) Close() error { return a.file.Close() } -func ExtractTarballArchive(archive string, dest string) error { - // We're only interested in gzipped archives, wrap the reader now +// ExtractArchive unpacks a .zip or .tar.gz archive to the destination directory. +func ExtractArchive(archive string, dest string) error { ar, err := os.Open(archive) if err != nil { return err } defer ar.Close() + switch { + case strings.HasSuffix(archive, ".tar.gz"): + return extractTarball(ar, dest) + case strings.HasSuffix(archive, ".zip"): + return extractZip(ar, dest) + default: + return fmt.Errorf("unhandled archive type %s", archive) + } +} + +// extractTarball unpacks a .tar.gz file. +func extractTarball(ar io.Reader, dest string) error { gzr, err := gzip.NewReader(ar) if err != nil { return err } defer gzr.Close() - // Iterate over all the files in the tarball tr := tar.NewReader(gzr) for { - // Fetch the next tarball header and abort if needed + // Move to the next file header. header, err := tr.Next() if err != nil { if err == io.EOF { @@ -209,22 +220,69 @@ func ExtractTarballArchive(archive string, dest string) error { } return err } - // Figure out the target and create it - target := filepath.Join(dest, header.Name) - - switch header.Typeflag { - case tar.TypeReg: - if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { - return err - } - file, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) + // We only care about regular files, directory modes + // and special file types are not supported. + if header.Typeflag == tar.TypeReg { + armode := header.FileInfo().Mode() + err := extractFile(header.Name, armode, tr, dest) if err != nil { - return err - } - if _, err := io.Copy(file, tr); err != nil { - return err + return fmt.Errorf("extract %s: %v", header.Name, err) } - file.Close() } } } + +// extractZip unpacks the given .zip file. +func extractZip(ar *os.File, dest string) error { + info, err := ar.Stat() + if err != nil { + return err + } + zr, err := zip.NewReader(ar, info.Size()) + if err != nil { + return err + } + + for _, zf := range zr.File { + if !zf.Mode().IsRegular() { + continue + } + + data, err := zf.Open() + if err != nil { + return err + } + err = extractFile(zf.Name, zf.Mode(), data, dest) + data.Close() + if err != nil { + return fmt.Errorf("extract %s: %v", zf.Name, err) + } + } + return nil +} + +// extractFile extracts a single file from an archive. +func extractFile(arpath string, armode os.FileMode, data io.Reader, dest string) error { + // Check that path is inside destination directory. + target := filepath.Join(dest, filepath.FromSlash(arpath)) + if !strings.HasPrefix(target, filepath.Clean(dest)+string(os.PathSeparator)) { + return fmt.Errorf("path %q escapes archive destination", target) + } + + // Ensure the destination directory exists. + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return err + } + + // Copy file data. + file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, armode) + if err != nil { + return err + } + if _, err := io.Copy(file, data); err != nil { + file.Close() + os.Remove(target) + return err + } + return file.Close() +} diff --git a/internal/build/util.go b/internal/build/util.go index fc559760b26c..91149926f790 100644 --- a/internal/build/util.go +++ b/internal/build/util.go @@ -20,6 +20,8 @@ import ( "bytes" "flag" "fmt" + "go/parser" + "go/token" "io" "io/ioutil" "log" @@ -152,3 +154,28 @@ func UploadSFTP(identityFile, host, dir string, files []string) error { stdin.Close() return sftp.Wait() } + +// FindMainPackages finds all 'main' packages in the given directory and returns their +// package paths. +func FindMainPackages(dir string) []string { + var commands []string + cmds, err := ioutil.ReadDir(dir) + if err != nil { + log.Fatal(err) + } + for _, cmd := range cmds { + pkgdir := filepath.Join(dir, cmd.Name()) + pkgs, err := parser.ParseDir(token.NewFileSet(), pkgdir, nil, parser.PackageClauseOnly) + if err != nil { + log.Fatal(err) + } + for name := range pkgs { + if name == "main" { + path := "./" + filepath.ToSlash(pkgdir) + commands = append(commands, path) + break + } + } + } + return commands +} diff --git a/internal/cmdtest/test_cmd.go b/internal/cmdtest/test_cmd.go index 0edfccec5a16..82ad9c15b611 100644 --- a/internal/cmdtest/test_cmd.go +++ b/internal/cmdtest/test_cmd.go @@ -27,6 +27,7 @@ import ( "regexp" "strings" "sync" + "sync/atomic" "syscall" "testing" "text/template" @@ -55,10 +56,13 @@ type TestCmd struct { Err error } +var id int32 + // Run exec's the current binary using name as argv[0] which will trigger the // reexec init function for that name (e.g. "geth-test" in cmd/geth/run_test.go) func (tt *TestCmd) Run(name string, args ...string) { - tt.stderr = &testlogger{t: tt.T} + id := atomic.AddInt32(&id, 1) + tt.stderr = &testlogger{t: tt.T, name: fmt.Sprintf("%d", id)} tt.cmd = &exec.Cmd{ Path: reexec.Self(), Args: append([]string{name}, args...), @@ -238,16 +242,17 @@ func (tt *TestCmd) withKillTimeout(fn func()) { // testlogger logs all written lines via t.Log and also // collects them for later inspection. type testlogger struct { - t *testing.T - mu sync.Mutex - buf bytes.Buffer + t *testing.T + mu sync.Mutex + buf bytes.Buffer + name string } func (tl *testlogger) Write(b []byte) (n int, err error) { lines := bytes.Split(b, []byte("\n")) for _, line := range lines { if len(line) > 0 { - tl.t.Logf("(stderr) %s", line) + tl.t.Logf("(stderr:%v) %s", tl.name, line) } } tl.mu.Lock() diff --git a/internal/debug/api.go b/internal/debug/api.go index 86a4218f6a57..efd8626776a3 100644 --- a/internal/debug/api.go +++ b/internal/debug/api.go @@ -196,7 +196,7 @@ func (*HandlerT) Stacks() string { return buf.String() } -// FreeOSMemory returns unused memory to the OS. +// FreeOSMemory forces a garbage collection. func (*HandlerT) FreeOSMemory() { debug.FreeOSMemory() } diff --git a/internal/debug/flags.go b/internal/debug/flags.go index 3b077b6e0884..e3a01378ca39 100644 --- a/internal/debug/flags.go +++ b/internal/debug/flags.go @@ -28,7 +28,7 @@ import ( "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics/exp" "github.com/fjl/memsize/memsizeui" - colorable "github.com/mattn/go-colorable" + "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "gopkg.in/urfave/cli.v1" ) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index c035b7a9edf1..9ff1781e4a3d 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -64,11 +64,6 @@ func (s *PublicEthereumAPI) GasPrice(ctx context.Context) (*hexutil.Big, error) return (*hexutil.Big)(price), err } -// ProtocolVersion returns the current Ethereum protocol version this node supports -func (s *PublicEthereumAPI) ProtocolVersion() hexutil.Uint { - return hexutil.Uint(s.b.ProtocolVersion()) -} - // Syncing returns false in case the node is currently not syncing with the network. It can be up to date or has not // yet received the latest block headers from its pears. In case it is synchronizing: // - startingBlock: block number this node started to synchronise from @@ -599,7 +594,7 @@ func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Addre if storageError != nil { return nil, storageError } - storageProof[i] = StorageResult{key, (*hexutil.Big)(state.GetState(address, common.HexToHash(key)).Big()), common.ToHexArray(proof)} + storageProof[i] = StorageResult{key, (*hexutil.Big)(state.GetState(address, common.HexToHash(key)).Big()), toHexSlice(proof)} } else { storageProof[i] = StorageResult{key, &hexutil.Big{}, []string{}} } @@ -613,7 +608,7 @@ func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Addre return &AccountResult{ Address: address, - AccountProof: common.ToHexArray(accountProof), + AccountProof: toHexSlice(accountProof), Balance: (*hexutil.Big)(state.GetBalance(address)), CodeHash: codeHash, Nonce: hexutil.Uint64(state.GetNonce(address)), @@ -793,7 +788,7 @@ func (args *CallArgs) ToMessage(globalGasCap uint64) types.Message { var data []byte if args.Data != nil { - data = []byte(*args.Data) + data = *args.Data } msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, data, false) @@ -963,6 +958,9 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash if err != nil { return 0, err } + if block == nil { + return 0, errors.New("block not found") + } hi = block.GasLimit() } // Recap the highest gas limit with account's available balance. @@ -1051,9 +1049,12 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash // EstimateGas returns an estimate of the amount of gas needed to execute the // given transaction against the current pending block. -func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs) (hexutil.Uint64, error) { - blockNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) - return DoEstimateGas(ctx, s.b, args, blockNrOrHash, s.b.RPCGasCap()) +func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash) (hexutil.Uint64, error) { + bNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + if blockNrOrHash != nil { + bNrOrHash = *blockNrOrHash + } + return DoEstimateGas(ctx, s.b, args, bNrOrHash, s.b.RPCGasCap()) } // ExecutionResult groups all structured logs emitted by the EVM @@ -1899,13 +1900,12 @@ func (api *PrivateDebugAPI) SetHead(number hexutil.Uint64) { // PublicNetAPI offers network related RPC methods type PublicNetAPI struct { - net *p2p.Server - networkVersion uint64 + net *p2p.Server } // NewPublicNetAPI creates a new net API instance. -func NewPublicNetAPI(net *p2p.Server, networkVersion uint64) *PublicNetAPI { - return &PublicNetAPI{net, networkVersion} +func NewPublicNetAPI(net *p2p.Server) *PublicNetAPI { + return &PublicNetAPI{net} } // Listening returns an indication if the node is listening for network connections. @@ -1918,11 +1918,6 @@ func (s *PublicNetAPI) PeerCount() hexutil.Uint { return hexutil.Uint(s.net.PeerCount()) } -// Version returns the current ethereum protocol version. -func (s *PublicNetAPI) Version() string { - return fmt.Sprintf("%d", s.networkVersion) -} - // checkTxFee is an internal function used to check whether the fee of // the given transaction is _reasonable_(under the cap). func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error { @@ -1937,3 +1932,12 @@ func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error { } return nil } + +// toHexSlice creates a slice of hex-strings based on []byte. +func toHexSlice(b [][]byte) []string { + r := make([]string, len(b)) + for i := range b { + r[i] = hexutil.Encode(b[i]) + } + return r +} diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 10e716bf200d..f0a4c0493c86 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -41,7 +41,6 @@ import ( type Backend interface { // General Ethereum API Downloader() *downloader.Downloader - ProtocolVersion() int SuggestPrice(ctx context.Context) (*big.Int, error) ChainDb() ethdb.Database AccountManager() *accounts.Manager diff --git a/internal/flags/helpers.go b/internal/flags/helpers.go index 900fec454c32..f61ed4b68987 100644 --- a/internal/flags/helpers.go +++ b/internal/flags/helpers.go @@ -21,7 +21,7 @@ import ( "path/filepath" "github.com/ethereum/go-ethereum/params" - cli "gopkg.in/urfave/cli.v1" + "gopkg.in/urfave/cli.v1" ) var ( diff --git a/internal/utesting/utesting.go b/internal/utesting/utesting.go index 23c748cae9b4..ef05a90e4c57 100644 --- a/internal/utesting/utesting.go +++ b/internal/utesting/utesting.go @@ -25,6 +25,7 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "regexp" "runtime" "sync" @@ -63,26 +64,165 @@ func MatchTests(tests []Test, expr string) []Test { // RunTests executes all given tests in order and returns their results. // If the report writer is non-nil, a test report is written to it in real time. func RunTests(tests []Test, report io.Writer) []Result { - results := make([]Result, len(tests)) + if report == nil { + report = ioutil.Discard + } + results := run(tests, newConsoleOutput(report)) + fails := CountFailures(results) + fmt.Fprintf(report, "%v/%v tests passed.\n", len(tests)-fails, len(tests)) + return results +} + +// RunTAP runs the given tests and writes Test Anything Protocol output +// to the report writer. +func RunTAP(tests []Test, report io.Writer) []Result { + return run(tests, newTAP(report, len(tests))) +} + +func run(tests []Test, output testOutput) []Result { + var results = make([]Result, len(tests)) for i, test := range tests { + buffer := new(bytes.Buffer) + logOutput := io.MultiWriter(buffer, output) + + output.testStart(test.Name) start := time.Now() results[i].Name = test.Name - results[i].Failed, results[i].Output = Run(test) + results[i].Failed = runTest(test, logOutput) results[i].Duration = time.Since(start) - if report != nil { - printResult(results[i], report) - } + results[i].Output = buffer.String() + output.testResult(results[i]) } return results } -func printResult(r Result, w io.Writer) { +// testOutput is implemented by output formats. +type testOutput interface { + testStart(name string) + Write([]byte) (int, error) + testResult(Result) +} + +// consoleOutput prints test results similarly to go test. +type consoleOutput struct { + out io.Writer + indented *indentWriter + curTest string + wroteHeader bool +} + +func newConsoleOutput(w io.Writer) *consoleOutput { + return &consoleOutput{ + out: w, + indented: newIndentWriter(" ", w), + } +} + +// testStart signals the start of a new test. +func (c *consoleOutput) testStart(name string) { + c.curTest = name + c.wroteHeader = false +} + +// Write handles test log output. +func (c *consoleOutput) Write(b []byte) (int, error) { + if !c.wroteHeader { + // This is the first output line from the test. Print a "-- RUN" header. + fmt.Fprintln(c.out, "-- RUN", c.curTest) + c.wroteHeader = true + } + return c.indented.Write(b) +} + +// testResult prints the final test result line. +func (c *consoleOutput) testResult(r Result) { + c.indented.flush() pd := r.Duration.Truncate(100 * time.Microsecond) if r.Failed { - fmt.Fprintf(w, "-- FAIL %s (%v)\n", r.Name, pd) - fmt.Fprintln(w, r.Output) + fmt.Fprintf(c.out, "-- FAIL %s (%v)\n", r.Name, pd) } else { - fmt.Fprintf(w, "-- OK %s (%v)\n", r.Name, pd) + fmt.Fprintf(c.out, "-- OK %s (%v)\n", r.Name, pd) + } +} + +// tapOutput produces Test Anything Protocol v13 output. +type tapOutput struct { + out io.Writer + indented *indentWriter + counter int +} + +func newTAP(out io.Writer, numTests int) *tapOutput { + fmt.Fprintf(out, "1..%d\n", numTests) + return &tapOutput{ + out: out, + indented: newIndentWriter("# ", out), + } +} + +func (t *tapOutput) testStart(name string) { + t.counter++ +} + +// Write does nothing for TAP because there is no real-time output of test logs. +func (t *tapOutput) Write(b []byte) (int, error) { + return len(b), nil +} + +func (t *tapOutput) testResult(r Result) { + status := "ok" + if r.Failed { + status = "not ok" + } + fmt.Fprintln(t.out, status, t.counter, r.Name) + t.indented.Write([]byte(r.Output)) + t.indented.flush() +} + +// indentWriter indents all written text. +type indentWriter struct { + out io.Writer + indent string + inLine bool +} + +func newIndentWriter(indent string, out io.Writer) *indentWriter { + return &indentWriter{out: out, indent: indent} +} + +func (w *indentWriter) Write(b []byte) (n int, err error) { + for len(b) > 0 { + if !w.inLine { + if _, err = io.WriteString(w.out, w.indent); err != nil { + return n, err + } + w.inLine = true + } + + end := bytes.IndexByte(b, '\n') + if end == -1 { + nn, err := w.out.Write(b) + n += nn + return n, err + } + + line := b[:end+1] + nn, err := w.out.Write(line) + n += nn + if err != nil { + return n, err + } + b = b[end+1:] + w.inLine = false + } + return n, err +} + +// flush ensures the current line is terminated. +func (w *indentWriter) flush() { + if w.inLine { + fmt.Println(w.out) + w.inLine = false } } @@ -99,7 +239,13 @@ func CountFailures(rr []Result) int { // Run executes a single test. func Run(test Test) (bool, string) { - t := new(T) + output := new(bytes.Buffer) + failed := runTest(test, output) + return failed, output.String() +} + +func runTest(test Test, output io.Writer) bool { + t := &T{output: output} done := make(chan struct{}) go func() { defer close(done) @@ -114,7 +260,7 @@ func Run(test Test) (bool, string) { test.Fn(t) }() <-done - return t.failed, t.output.String() + return t.failed } // T is the value given to the test function. The test can signal failures @@ -122,9 +268,12 @@ func Run(test Test) (bool, string) { type T struct { mu sync.Mutex failed bool - output bytes.Buffer + output io.Writer } +// Helper exists for compatibility with testing.T. +func (t *T) Helper() {} + // FailNow marks the test as having failed and stops its execution by calling // runtime.Goexit (which then runs all deferred calls in the current goroutine). func (t *T) FailNow() { @@ -151,7 +300,7 @@ func (t *T) Failed() bool { func (t *T) Log(vs ...interface{}) { t.mu.Lock() defer t.mu.Unlock() - fmt.Fprintln(&t.output, vs...) + fmt.Fprintln(t.output, vs...) } // Logf formats its arguments according to the format, analogous to Printf, and records @@ -162,7 +311,7 @@ func (t *T) Logf(format string, vs ...interface{}) { if len(format) == 0 || format[len(format)-1] != '\n' { format += "\n" } - fmt.Fprintf(&t.output, format, vs...) + fmt.Fprintf(t.output, format, vs...) } // Error is equivalent to Log followed by Fail. diff --git a/internal/utesting/utesting_test.go b/internal/utesting/utesting_test.go index 1403a5c8f735..31c7911c52f0 100644 --- a/internal/utesting/utesting_test.go +++ b/internal/utesting/utesting_test.go @@ -17,6 +17,8 @@ package utesting import ( + "bytes" + "regexp" "strings" "testing" ) @@ -53,3 +55,85 @@ func TestTest(t *testing.T) { t.Fatalf("wrong result for panicking test: %#v", results[2]) } } + +var outputTests = []Test{ + { + Name: "TestWithLogs", + Fn: func(t *T) { + t.Log("output line 1") + t.Log("output line 2\noutput line 3") + }, + }, + { + Name: "TestNoLogs", + Fn: func(t *T) {}, + }, + { + Name: "FailWithLogs", + Fn: func(t *T) { + t.Log("output line 1") + t.Error("failed 1") + }, + }, + { + Name: "FailMessage", + Fn: func(t *T) { + t.Error("failed 2") + }, + }, + { + Name: "FailNoOutput", + Fn: func(t *T) { + t.Fail() + }, + }, +} + +func TestOutput(t *testing.T) { + var buf bytes.Buffer + RunTests(outputTests, &buf) + + want := regexp.MustCompile(` +^-- RUN TestWithLogs + output line 1 + output line 2 + output line 3 +-- OK TestWithLogs \([^)]+\) +-- OK TestNoLogs \([^)]+\) +-- RUN FailWithLogs + output line 1 + failed 1 +-- FAIL FailWithLogs \([^)]+\) +-- RUN FailMessage + failed 2 +-- FAIL FailMessage \([^)]+\) +-- FAIL FailNoOutput \([^)]+\) +2/5 tests passed. +$`[1:]) + if !want.MatchString(buf.String()) { + t.Fatalf("output does not match: %q", buf.String()) + } +} + +func TestOutputTAP(t *testing.T) { + var buf bytes.Buffer + RunTAP(outputTests, &buf) + + want := ` +1..5 +ok 1 TestWithLogs +# output line 1 +# output line 2 +# output line 3 +ok 2 TestNoLogs +not ok 3 FailWithLogs +# output line 1 +# failed 1 +not ok 4 FailMessage +# failed 2 +not ok 5 FailNoOutput +` + if buf.String() != want[1:] { + t.Fatalf("output does not match: %q", buf.String()) + } +} diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 300c2b054f38..77954bbbf001 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -237,7 +237,8 @@ web3._extend({ new web3._extend.Method({ name: 'printBlock', call: 'debug_printBlock', - params: 1 + params: 1, + outputFormatter: console.log }), new web3._extend.Method({ name: 'getBlockRlp', @@ -248,7 +249,7 @@ web3._extend({ name: 'testSignCliqueBlock', call: 'debug_testSignCliqueBlock', params: 2, - inputFormatters: [web3._extend.formatters.inputAddressFormatter, null], + inputFormatter: [web3._extend.formatters.inputAddressFormatter, null], }), new web3._extend.Method({ name: 'setHead', @@ -263,7 +264,8 @@ web3._extend({ new web3._extend.Method({ name: 'dumpBlock', call: 'debug_dumpBlock', - params: 1 + params: 1, + inputFormatter: [web3._extend.formatters.inputBlockNumberFormatter] }), new web3._extend.Method({ name: 'chaindbProperty', @@ -415,7 +417,7 @@ web3._extend({ name: 'traceBlockByNumber', call: 'debug_traceBlockByNumber', params: 2, - inputFormatter: [null, null] + inputFormatter: [web3._extend.formatters.inputBlockNumberFormatter, null] }), new web3._extend.Method({ name: 'traceBlockByHash', @@ -500,6 +502,13 @@ web3._extend({ params: 1, inputFormatter: [web3._extend.formatters.inputTransactionFormatter] }), + new web3._extend.Method({ + name: 'estimateGas', + call: 'eth_estimateGas', + params: 2, + inputFormatter: [web3._extend.formatters.inputCallFormatter, web3._extend.formatters.inputBlockNumberFormatter], + outputFormatter: web3._extend.utils.toDecimal + }), new web3._extend.Method({ name: 'submitTransaction', call: 'eth_submitTransaction', @@ -515,7 +524,8 @@ web3._extend({ new web3._extend.Method({ name: 'getHeaderByNumber', call: 'eth_getHeaderByNumber', - params: 1 + params: 1, + inputFormatter: [web3._extend.formatters.inputBlockNumberFormatter] }), new web3._extend.Method({ name: 'getHeaderByHash', @@ -525,12 +535,14 @@ web3._extend({ new web3._extend.Method({ name: 'getBlockByNumber', call: 'eth_getBlockByNumber', - params: 2 + params: 2, + inputFormatter: [web3._extend.formatters.inputBlockNumberFormatter, function (val) { return !!val; }] }), new web3._extend.Method({ name: 'getBlockByHash', call: 'eth_getBlockByHash', - params: 2 + params: 2, + inputFormatter: [null, function (val) { return !!val; }] }), new web3._extend.Method({ name: 'getRawTransaction', diff --git a/les/api_backend.go b/les/api_backend.go index 75bea56da67a..9fbc21f4594e 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -171,8 +171,9 @@ func (b *LesApiBackend) GetTd(ctx context.Context, hash common.Hash) *big.Int { } func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header) (*vm.EVM, func() error, error) { - context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) - return vm.NewEVM(context, state, b.eth.chainConfig, vm.Config{}), state.Error, nil + txContext := core.NewEVMTxContext(msg) + context := core.NewEVMBlockContext(header, b.eth.blockchain, nil) + return vm.NewEVM(context, txContext, state, b.eth.chainConfig, vm.Config{}), state.Error, nil } func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { diff --git a/les/benchmark.go b/les/benchmark.go index a146de2fed0b..312d4533dfd9 100644 --- a/les/benchmark.go +++ b/les/benchmark.go @@ -75,9 +75,8 @@ func (b *benchmarkBlockHeaders) init(h *serverHandler, count int) error { func (b *benchmarkBlockHeaders) request(peer *serverPeer, index int) error { if b.byHash { return peer.requestHeadersByHash(0, b.hashes[index], b.amount, b.skip, b.reverse) - } else { - return peer.requestHeadersByNumber(0, uint64(b.offset+rand.Int63n(b.randMax)), b.amount, b.skip, b.reverse) } + return peer.requestHeadersByNumber(0, uint64(b.offset+rand.Int63n(b.randMax)), b.amount, b.skip, b.reverse) } // benchmarkBodiesOrReceipts implements requestBenchmark @@ -98,9 +97,8 @@ func (b *benchmarkBodiesOrReceipts) init(h *serverHandler, count int) error { func (b *benchmarkBodiesOrReceipts) request(peer *serverPeer, index int) error { if b.receipts { return peer.requestReceipts(0, []common.Hash{b.hashes[index]}) - } else { - return peer.requestBodies(0, []common.Hash{b.hashes[index]}) } + return peer.requestBodies(0, []common.Hash{b.hashes[index]}) } // benchmarkProofsOrCode implements requestBenchmark @@ -119,9 +117,8 @@ func (b *benchmarkProofsOrCode) request(peer *serverPeer, index int) error { rand.Read(key) if b.code { return peer.requestCode(0, []CodeReq{{BHash: b.headHash, AccKey: key}}) - } else { - return peer.requestProofs(0, []ProofReq{{BHash: b.headHash, Key: key}}) } + return peer.requestProofs(0, []ProofReq{{BHash: b.headHash, Key: key}}) } // benchmarkHelperTrie implements requestBenchmark diff --git a/les/client.go b/les/client.go index a2f7c56dfd57..198255dc54f1 100644 --- a/les/client.go +++ b/les/client.go @@ -112,7 +112,7 @@ func New(stack *node.Node, config *eth.Config) (*LightEthereum, error) { } peers.subscribe((*vtSubscription)(leth.valueTracker)) - dnsdisc, err := leth.setupDiscovery(&stack.Config().P2P) + dnsdisc, err := leth.setupDiscovery() if err != nil { return nil, err } @@ -171,13 +171,26 @@ func New(stack *node.Node, config *eth.Config) (*LightEthereum, error) { leth.blockchain.DisableCheckFreq() } - leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer, leth.config.NetworkId) + leth.netRPCService = ethapi.NewPublicNetAPI(leth.p2pServer) // Register the backend on the node stack.RegisterAPIs(leth.APIs()) stack.RegisterProtocols(leth.Protocols()) stack.RegisterLifecycle(leth) + // Check for unclean shutdown + if uncleanShutdowns, discards, err := rawdb.PushUncleanShutdownMarker(chainDb); err != nil { + log.Error("Could not update unclean-shutdown-marker list", "error", err) + } else { + if discards > 0 { + log.Warn("Old unclean shutdowns found", "count", discards) + } + for _, tstamp := range uncleanShutdowns { + t := time.Unix(int64(tstamp), 0) + log.Warn("Unclean shutdown detected", "booted", t, + "age", common.PrettyAge(t)) + } + } return leth, nil } @@ -313,6 +326,7 @@ func (s *LightEthereum) Stop() error { s.engine.Close() s.pruner.close() s.eventMux.Stop() + rawdb.PopUncleanShutdownMarker(s.chainDb) s.chainDb.Close() s.wg.Wait() log.Info("Light ethereum stopped") diff --git a/les/client_handler.go b/les/client_handler.go index cfeec7a03cdb..6de57669615d 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -17,6 +17,7 @@ package les import ( + "context" "math/big" "sync" "sync/atomic" @@ -24,6 +25,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/light" @@ -36,6 +38,7 @@ import ( // responses. type clientHandler struct { ulc *ulc + forkFilter forkid.Filter checkpoint *params.TrustedCheckpoint fetcher *lightFetcher downloader *downloader.Downloader @@ -48,6 +51,7 @@ type clientHandler struct { func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.TrustedCheckpoint, backend *LightEthereum) *clientHandler { handler := &clientHandler{ + forkFilter: forkid.NewFilter(backend.blockchain), checkpoint: checkpoint, backend: backend, closeCh: make(chan struct{}), @@ -102,13 +106,8 @@ func (h *clientHandler) handle(p *serverPeer) error { p.Log().Debug("Light Ethereum peer connected", "name", p.Name()) // Execute the LES handshake - var ( - head = h.backend.blockchain.CurrentHeader() - hash = head.Hash() - number = head.Number.Uint64() - td = h.backend.blockchain.GetTd(hash, number) - ) - if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil { + forkid := forkid.NewID(h.backend.blockchain.Config(), h.backend.genesis, h.backend.blockchain.CurrentHeader().Number.Uint64()) + if err := p.Handshake(h.backend.blockchain.Genesis().Hash(), forkid, h.forkFilter); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } @@ -159,8 +158,8 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { var deliverMsg *Msg // Handle the message depending on its contents - switch msg.Code { - case AnnounceMsg: + switch { + case msg.Code == AnnounceMsg: p.Log().Trace("Received announce message") var req announceData if err := msg.Decode(&req); err != nil { @@ -193,7 +192,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { p.updateHead(req.Hash, req.Number, req.Td) h.fetcher.announce(p, &req) } - case BlockHeadersMsg: + case msg.Code == BlockHeadersMsg: p.Log().Trace("Received block header response message") var resp struct { ReqID, BV uint64 @@ -206,17 +205,26 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { p.fcServer.ReceivedReply(resp.ReqID, resp.BV) p.answeredRequest(resp.ReqID) - // Filter out any explicitly requested headers, deliver the rest to the downloader - filter := len(headers) == 1 - if filter { - headers = h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) - } - if len(headers) != 0 || !filter { - if err := h.downloader.DeliverHeaders(p.id, headers); err != nil { - log.Debug("Failed to deliver headers", "err", err) + // Filter out the explicitly requested header by the retriever + if h.backend.retriever.requested(resp.ReqID) { + deliverMsg = &Msg{ + MsgType: MsgBlockHeaders, + ReqID: resp.ReqID, + Obj: resp.Headers, + } + } else { + // Filter out any explicitly requested headers, deliver the rest to the downloader + filter := len(headers) == 1 + if filter { + headers = h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) + } + if len(headers) != 0 || !filter { + if err := h.downloader.DeliverHeaders(p.id, headers); err != nil { + log.Debug("Failed to deliver headers", "err", err) + } } } - case BlockBodiesMsg: + case msg.Code == BlockBodiesMsg: p.Log().Trace("Received block bodies response") var resp struct { ReqID, BV uint64 @@ -232,7 +240,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case CodeMsg: + case msg.Code == CodeMsg: p.Log().Trace("Received code response") var resp struct { ReqID, BV uint64 @@ -248,7 +256,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case ReceiptsMsg: + case msg.Code == ReceiptsMsg: p.Log().Trace("Received receipts response") var resp struct { ReqID, BV uint64 @@ -264,7 +272,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Receipts, } - case ProofsV2Msg: + case msg.Code == ProofsV2Msg: p.Log().Trace("Received les/2 proofs response") var resp struct { ReqID, BV uint64 @@ -280,7 +288,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case HelperTrieProofsMsg: + case msg.Code == HelperTrieProofsMsg: p.Log().Trace("Received helper trie proof response") var resp struct { ReqID, BV uint64 @@ -296,7 +304,7 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Data, } - case TxStatusMsg: + case msg.Code == TxStatusMsg: p.Log().Trace("Received tx status response") var resp struct { ReqID, BV uint64 @@ -312,11 +320,11 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { ReqID: resp.ReqID, Obj: resp.Status, } - case StopMsg: + case msg.Code == StopMsg && p.version >= lpv3: p.freeze() h.backend.retriever.frozen(p) p.Log().Debug("Service stopped") - case ResumeMsg: + case msg.Code == ResumeMsg && p.version >= lpv3: var bv uint64 if err := msg.Decode(&bv); err != nil { return errResp(ErrDecode, "msg %v: %v", msg, err) @@ -400,6 +408,42 @@ func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip return nil } +// RetrieveSingleHeaderByNumber requests a single header by the specified block +// number. This function will wait the response until it's timeout or delivered. +func (pc *peerConnection) RetrieveSingleHeaderByNumber(context context.Context, number uint64) (*types.Header, error) { + reqID := genReqID() + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*serverPeer) + return peer.getRequestCost(GetBlockHeadersMsg, 1) + }, + canSend: func(dp distPeer) bool { + return dp.(*serverPeer) == pc.peer + }, + request: func(dp distPeer) func() { + peer := dp.(*serverPeer) + cost := peer.getRequestCost(GetBlockHeadersMsg, 1) + peer.fcServer.QueuedRequest(reqID, cost) + return func() { peer.requestHeadersByNumber(reqID, number, 1, 0, false) } + }, + } + var header *types.Header + if err := pc.handler.backend.retriever.retrieve(context, reqID, rq, func(peer distPeer, msg *Msg) error { + if msg.MsgType != MsgBlockHeaders { + return errInvalidMessageType + } + headers := msg.Obj.([]*types.Header) + if len(headers) != 1 { + return errInvalidEntryCount + } + header = headers[0] + return nil + }, nil); err != nil { + return nil, err + } + return header, nil +} + // downloaderPeerNotify implements peerSetNotify type downloaderPeerNotify clientHandler diff --git a/les/clientpool.go b/les/clientpool.go index 4f6e3fafe01d..da0db6e62295 100644 --- a/les/clientpool.go +++ b/les/clientpool.go @@ -18,7 +18,6 @@ package les import ( "fmt" - "reflect" "sync" "time" @@ -46,19 +45,6 @@ const ( inactiveTimeout = time.Second * 10 ) -var ( - clientPoolSetup = &nodestate.Setup{} - clientField = clientPoolSetup.NewField("clientInfo", reflect.TypeOf(&clientInfo{})) - connAddressField = clientPoolSetup.NewField("connAddr", reflect.TypeOf("")) - balanceTrackerSetup = lps.NewBalanceTrackerSetup(clientPoolSetup) - priorityPoolSetup = lps.NewPriorityPoolSetup(clientPoolSetup) -) - -func init() { - balanceTrackerSetup.Connect(connAddressField, priorityPoolSetup.CapacityField) - priorityPoolSetup.Connect(balanceTrackerSetup.BalanceField, balanceTrackerSetup.UpdateFlag) // NodeBalance implements nodePriority -} - // clientPool implements a client database that assigns a priority to each client // based on a positive and negative balance. Positive balance is externally assigned // to prioritized clients and is decreased with connection time and processed @@ -119,8 +105,7 @@ type clientInfo struct { } // newClientPool creates a new client pool -func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { - ns := nodestate.NewNodeStateMachine(nil, nil, clock, clientPoolSetup) +func newClientPool(ns *nodestate.NodeStateMachine, lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool { pool := &clientPool{ ns: ns, BalanceTrackerSetup: balanceTrackerSetup, @@ -147,7 +132,7 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du }) ns.SubscribeState(pool.ActiveFlag.Or(pool.PriorityFlag), func(node *enode.Node, oldState, newState nodestate.Flags) { - c, _ := ns.GetField(node, clientField).(*clientInfo) + c, _ := ns.GetField(node, clientInfoField).(*clientInfo) if c == nil { return } @@ -172,7 +157,7 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du if oldState.Equals(pool.ActiveFlag) && newState.Equals(pool.InactiveFlag) { clientDeactivatedMeter.Mark(1) log.Debug("Client deactivated", "id", node.ID()) - c, _ := ns.GetField(node, clientField).(*clientInfo) + c, _ := ns.GetField(node, clientInfoField).(*clientInfo) if c == nil || !c.peer.allowInactive() { pool.removePeer(node.ID()) } @@ -190,13 +175,11 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du newCap, _ := newValue.(uint64) totalConnected += newCap - oldCap totalConnectedGauge.Update(int64(totalConnected)) - c, _ := ns.GetField(node, clientField).(*clientInfo) + c, _ := ns.GetField(node, clientInfoField).(*clientInfo) if c != nil { c.peer.updateCapacity(newCap) } }) - - ns.Start() return pool } @@ -210,7 +193,6 @@ func (f *clientPool) stop() { f.disconnectNode(node) }) f.bt.Stop() - f.ns.Stop() } // connect should be called after a successful handshake. If the connection was @@ -225,7 +207,7 @@ func (f *clientPool) connect(peer clientPoolPeer) (uint64, error) { } // Dedup connected peers. node, freeID := peer.Node(), peer.freeClientId() - if f.ns.GetField(node, clientField) != nil { + if f.ns.GetField(node, clientInfoField) != nil { log.Debug("Client already connected", "address", freeID, "id", node.ID().String()) return 0, fmt.Errorf("Client already connected address=%s id=%s", freeID, node.ID().String()) } @@ -237,7 +219,7 @@ func (f *clientPool) connect(peer clientPoolPeer) (uint64, error) { connected: true, connectedAt: now, } - f.ns.SetField(node, clientField, c) + f.ns.SetField(node, clientInfoField, c) f.ns.SetField(node, connAddressField, freeID) if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance == nil { f.disconnect(peer) @@ -280,7 +262,7 @@ func (f *clientPool) disconnect(p clientPoolPeer) { // disconnectNode removes node fields and flags related to connected status func (f *clientPool) disconnectNode(node *enode.Node) { f.ns.SetField(node, connAddressField, nil) - f.ns.SetField(node, clientField, nil) + f.ns.SetField(node, clientInfoField, nil) } // setDefaultFactors sets the default price factors applied to subsequently connected clients @@ -299,7 +281,8 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) { defer f.lock.Unlock() // total priority active cap will be supported when the token issuer module is added - return f.capLimit, f.pp.ActiveCapacity(), 0 + _, activeCap := f.pp.Active() + return f.capLimit, activeCap, 0 } // setLimits sets the maximum number and total capacity of connected clients, @@ -314,13 +297,13 @@ func (f *clientPool) setLimits(totalConn int, totalCap uint64) { // setCapacity sets the assigned capacity of a connected client func (f *clientPool) setCapacity(node *enode.Node, freeID string, capacity uint64, bias time.Duration, setCap bool) (uint64, error) { - c, _ := f.ns.GetField(node, clientField).(*clientInfo) + c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo) if c == nil { if setCap { return 0, fmt.Errorf("client %064x is not connected", node.ID()) } c = &clientInfo{node: node} - f.ns.SetField(node, clientField, c) + f.ns.SetField(node, clientInfoField, c) f.ns.SetField(node, connAddressField, freeID) if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance == nil { log.Error("BalanceField is missing", "node", node.ID()) @@ -328,7 +311,7 @@ func (f *clientPool) setCapacity(node *enode.Node, freeID string, capacity uint6 } defer func() { f.ns.SetField(node, connAddressField, nil) - f.ns.SetField(node, clientField, nil) + f.ns.SetField(node, clientInfoField, nil) }() } var ( @@ -370,7 +353,7 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) { if len(ids) == 0 { f.ns.ForEach(nodestate.Flags{}, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { - c, _ := f.ns.GetField(node, clientField).(*clientInfo) + c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo) if c != nil { cb(c) } @@ -381,12 +364,12 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) { if node == nil { node = enode.SignNull(&enr.Record{}, id) } - c, _ := f.ns.GetField(node, clientField).(*clientInfo) + c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo) if c != nil { cb(c) } else { c = &clientInfo{node: node} - f.ns.SetField(node, clientField, c) + f.ns.SetField(node, clientInfoField, c) f.ns.SetField(node, connAddressField, "") if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance != nil { cb(c) @@ -394,7 +377,7 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) { log.Error("BalanceField is missing") } f.ns.SetField(node, connAddressField, nil) - f.ns.SetField(node, clientField, nil) + f.ns.SetField(node, clientInfoField, nil) } } } diff --git a/les/clientpool_test.go b/les/clientpool_test.go index cfd1486b437e..b1c38d374c87 100644 --- a/les/clientpool_test.go +++ b/les/clientpool_test.go @@ -64,6 +64,11 @@ type poolTestPeer struct { inactiveAllowed bool } +func testStateMachine() *nodestate.NodeStateMachine { + return nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) + +} + func newPoolTestPeer(i int, disconnCh chan int) *poolTestPeer { return &poolTestPeer{ index: i, @@ -91,7 +96,7 @@ func (i *poolTestPeer) allowInactive() bool { } func getBalance(pool *clientPool, p *poolTestPeer) (pos, neg uint64) { - temp := pool.ns.GetField(p.node, clientField) == nil + temp := pool.ns.GetField(p.node, clientInfoField) == nil if temp { pool.ns.SetField(p.node, connAddressField, p.freeClientId()) } @@ -128,8 +133,9 @@ func testClientPool(t *testing.T, activeLimit, clientCount, paidCount int, rando disconnFn = func(id enode.ID) { disconnCh <- int(id[0]) + int(id[1])<<8 } - pool = newClientPool(db, 1, 0, &clock, disconnFn) + pool = newClientPool(testStateMachine(), db, 1, 0, &clock, disconnFn) ) + pool.ns.Start() pool.setLimits(activeLimit, uint64(activeLimit)) pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -233,7 +239,8 @@ func TestConnectPaidClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -248,7 +255,8 @@ func TestConnectPaidClientToSmallPool(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -266,7 +274,8 @@ func TestConnectPaidClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -295,7 +304,8 @@ func TestPaidClientKickedOut(t *testing.T) { removeFn := func(id enode.ID) { kickedCh <- int(id[0]) } - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() pool.bt.SetExpirationTCs(0, 0) defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 @@ -325,7 +335,8 @@ func TestConnectFreeClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -341,7 +352,8 @@ func TestConnectFreeClientToFullPool(t *testing.T) { db = rawdb.NewMemoryDatabase() ) removeFn := func(enode.ID) {} // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -370,7 +382,8 @@ func TestFreeClientKickedOut(t *testing.T) { kicked = make(chan int, 100) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -411,7 +424,8 @@ func TestPositiveBalanceCalculation(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -434,7 +448,8 @@ func TestDowngradePriorityClient(t *testing.T) { kicked = make(chan int, 10) ) removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop - pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}) @@ -468,7 +483,8 @@ func TestNegativeBalanceCalculation(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(10, uint64(10)) // Total capacity limit is 10 pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1e-3, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1e-3, CapacityFactor: 0, RequestFactor: 1}) @@ -503,7 +519,8 @@ func TestInactiveClient(t *testing.T) { clock mclock.Simulated db = rawdb.NewMemoryDatabase() ) - pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {}) + pool.ns.Start() defer pool.stop() pool.setLimits(2, uint64(2)) diff --git a/les/enr_entry.go b/les/enr_entry.go index 65d0d1fdb412..a357f689df6e 100644 --- a/les/enr_entry.go +++ b/les/enr_entry.go @@ -17,7 +17,6 @@ package les import ( - "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/dnsdisc" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rlp" @@ -35,10 +34,10 @@ func (e lesEntry) ENRKey() string { } // setupDiscovery creates the node discovery source for the eth protocol. -func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) { - if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 { +func (eth *LightEthereum) setupDiscovery() (enode.Iterator, error) { + if len(eth.config.EthDiscoveryURLs) == 0 { return nil, nil } client := dnsdisc.NewClient(dnsdisc.Config{}) - return client.NewIterator(eth.config.DiscoveryURLs...) + return client.NewIterator(eth.config.EthDiscoveryURLs...) } diff --git a/les/flowcontrol/control.go b/les/flowcontrol/control.go index 490013677c63..4f0de8231835 100644 --- a/les/flowcontrol/control.go +++ b/les/flowcontrol/control.go @@ -19,6 +19,7 @@ package flowcontrol import ( "fmt" + "math" "sync" "time" @@ -316,6 +317,9 @@ func (node *ServerNode) CanSend(maxCost uint64) (time.Duration, float64) { node.lock.RLock() defer node.lock.RUnlock() + if node.params.BufLimit == 0 { + return time.Duration(math.MaxInt64), 0 + } now := node.clock.Now() node.recalcBLE(now) maxCost += uint64(safetyMargin) * node.params.MinRecharge / uint64(fcTimeConst) diff --git a/les/handler_test.go b/les/handler_test.go index 1612caf42769..04277f661b55 100644 --- a/les/handler_test.go +++ b/les/handler_test.go @@ -51,7 +51,7 @@ func TestGetBlockHeadersLes2(t *testing.T) { testGetBlockHeaders(t, 2) } func TestGetBlockHeadersLes3(t *testing.T) { testGetBlockHeaders(t, 3) } func testGetBlockHeaders(t *testing.T, protocol int) { - server, tearDown := newServerEnv(t, downloader.MaxHashFetch+15, protocol, nil, false, true, 0) + server, tearDown := newServerEnv(t, downloader.MaxHeaderFetch+15, protocol, nil, false, true, 0) defer tearDown() bc := server.handler.blockchain diff --git a/les/lespay/server/balance.go b/les/lespay/server/balance.go index f820a4ad050d..f5073d0db1a6 100644 --- a/les/lespay/server/balance.go +++ b/les/lespay/server/balance.go @@ -588,9 +588,8 @@ func (n *NodeBalance) timeUntil(priority int64) (time.Duration, bool) { } dt = float64(posBalance-newBalance) / timePrice return time.Duration(dt), true - } else { - dt = float64(posBalance) / timePrice } + dt = float64(posBalance) / timePrice } else { if priority > 0 { return 0, false diff --git a/les/lespay/server/balance_tracker.go b/les/lespay/server/balance_tracker.go index c1ea3c649671..edd4b288d938 100644 --- a/les/lespay/server/balance_tracker.go +++ b/les/lespay/server/balance_tracker.go @@ -261,9 +261,8 @@ func (bt *BalanceTracker) storeBalance(id []byte, neg bool, value utils.ExpiredV func (bt *BalanceTracker) canDropBalance(now mclock.AbsTime, neg bool, b utils.ExpiredValue) bool { if neg { return b.Value(bt.negExp.LogOffset(now)) <= negThreshold - } else { - return b.Value(bt.posExp.LogOffset(now)) <= posThreshold } + return b.Value(bt.posExp.LogOffset(now)) <= posThreshold } // updateTotalBalance adjusts the total balance after executing given callback. diff --git a/les/lespay/server/prioritypool.go b/les/lespay/server/prioritypool.go index 52224e093e93..e3327aba7524 100644 --- a/les/lespay/server/prioritypool.go +++ b/les/lespay/server/prioritypool.go @@ -253,12 +253,12 @@ func (pp *PriorityPool) SetActiveBias(bias time.Duration) { pp.tryActivate() } -// ActiveCapacity returns the total capacity of currently active nodes -func (pp *PriorityPool) ActiveCapacity() uint64 { +// Active returns the number and total capacity of currently active nodes +func (pp *PriorityPool) Active() (uint64, uint64) { pp.lock.Lock() defer pp.lock.Unlock() - return pp.activeCap + return pp.activeCount, pp.activeCap } // inactiveSetIndex callback updates ppNodeInfo item index in inactiveQueue @@ -288,9 +288,8 @@ func activePriority(a interface{}, now mclock.AbsTime) int64 { } if c.bias == 0 { return invertPriority(c.nodePriority.Priority(now, c.capacity)) - } else { - return invertPriority(c.nodePriority.EstMinPriority(now+mclock.AbsTime(c.bias), c.capacity, true)) } + return invertPriority(c.nodePriority.EstMinPriority(now+mclock.AbsTime(c.bias), c.capacity, true)) } // activeMaxPriority callback returns estimated maximum priority of ppNodeInfo item in activeQueue diff --git a/les/odr.go b/les/odr.go index f8469cc103b8..2c36f512de7c 100644 --- a/les/odr.go +++ b/les/odr.go @@ -24,7 +24,6 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" - "github.com/ethereum/go-ethereum/log" ) // LesOdr implements light.OdrBackend @@ -83,7 +82,8 @@ func (odr *LesOdr) IndexerConfig() *light.IndexerConfig { } const ( - MsgBlockBodies = iota + MsgBlockHeaders = iota + MsgBlockBodies MsgCode MsgReceipts MsgProofsV2 @@ -122,13 +122,17 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro return func() { lreq.Request(reqID, p) } }, } - sent := mclock.Now() - if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { - // retrieved from network, store in db - req.StoreResult(odr.db) + + defer func(sent mclock.AbsTime) { + if err != nil { + return + } requestRTT.Update(time.Duration(mclock.Now() - sent)) - } else { - log.Debug("Failed to retrieve data from network", "err", err) + }(mclock.Now()) + + if err := odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err != nil { + return err } - return + req.StoreResult(odr.db) + return nil } diff --git a/les/odr_requests.go b/les/odr_requests.go index 3cc55c98d8f1..a8cf8f50a97a 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -327,11 +327,7 @@ func (r *ChtRequest) CanSend(peer *serverPeer) bool { peer.lock.RLock() defer peer.lock.RUnlock() - if r.Untrusted { - return peer.headInfo.Number >= r.BlockNum && peer.id == r.PeerId - } else { - return peer.headInfo.Number >= r.Config.ChtConfirms && r.ChtNum <= (peer.headInfo.Number-r.Config.ChtConfirms)/r.Config.ChtSize - } + return peer.headInfo.Number >= r.Config.ChtConfirms && r.ChtNum <= (peer.headInfo.Number-r.Config.ChtConfirms)/r.Config.ChtSize } // Request sends an ODR request to the LES network (implementation of LesOdrRequest) @@ -370,39 +366,34 @@ func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error { if err := rlp.DecodeBytes(headerEnc, header); err != nil { return errHeaderUnavailable } - // Verify the CHT - // Note: For untrusted CHT request, there is no proof response but - // header data. - var node light.ChtNode - if !r.Untrusted { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) - - reads := &readTraceDB{db: nodeSet} - value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) - if err != nil { - return fmt.Errorf("merkle proof verification failed: %v", err) - } - if len(reads.reads) != nodeSet.KeyCount() { - return errUselessNodes - } + var ( + node light.ChtNode + encNumber [8]byte + ) + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) - if err := rlp.DecodeBytes(value, &node); err != nil { - return err - } - if node.Hash != header.Hash() { - return errCHTHashMismatch - } - if r.BlockNum != header.Number.Uint64() { - return errCHTNumberMismatch - } + reads := &readTraceDB{db: nodeSet} + value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) + if err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != header.Hash() { + return errCHTHashMismatch + } + if r.BlockNum != header.Number.Uint64() { + return errCHTNumberMismatch } // Verifications passed, store and return r.Header = header r.Proof = nodeSet - r.Td = node.Td // For untrusted request, td here is nil, todo improve the les/2 protocol - + r.Td = node.Td return nil } @@ -497,7 +488,7 @@ func (r *TxStatusRequest) GetCost(peer *serverPeer) uint64 { // CanSend tells if a certain peer is suitable for serving the given request func (r *TxStatusRequest) CanSend(peer *serverPeer) bool { - return peer.version >= lpv2 + return peer.serveTxLookup } // Request sends an ODR request to the LES network (implementation of LesOdrRequest) diff --git a/les/odr_test.go b/les/odr_test.go index ccd220d692ef..c157507dd2b7 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -130,8 +130,9 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, bc, nil) - vmenv := vm.NewEVM(context, statedb, config, vm.Config{}) + context := core.NewEVMBlockContext(header, bc, nil) + txContext := core.NewEVMTxContext(msg) + vmenv := vm.NewEVM(context, txContext, statedb, config, vm.Config{}) //vmenv := core.NewEnv(statedb, config, bc, msg, header, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) @@ -143,8 +144,9 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai state := light.NewState(ctx, header, lc.Odr()) state.SetBalance(bankAddr, math.MaxBig256) msg := callmsg{types.NewMessage(bankAddr, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, lc, nil) - vmenv := vm.NewEVM(context, state, config, vm.Config{}) + context := core.NewEVMBlockContext(header, lc, nil) + txContext := core.NewEVMTxContext(msg) + vmenv := vm.NewEVM(context, txContext, state, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) result, _ := core.ApplyMessage(vmenv, msg, gp) if state.Error() == nil { diff --git a/les/peer.go b/les/peer.go index 0549daf9a648..0e2ed52c124b 100644 --- a/les/peer.go +++ b/les/peer.go @@ -29,8 +29,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/les/flowcontrol" lpc "github.com/ethereum/go-ethereum/les/lespay/client" lps "github.com/ethereum/go-ethereum/les/lespay/server" @@ -126,7 +126,7 @@ type peerCommons struct { frozen uint32 // Flag whether the peer is frozen. announceType uint64 // New block announcement type. serving uint32 // The status indicates the peer is served. - headInfo blockInfo // Latest block information. + headInfo blockInfo // Last announced block information. // Background task queue for caching peer tasks and executing in order. sendQueue *utils.ExecQueue @@ -161,9 +161,17 @@ func (p *peerCommons) String() string { return fmt.Sprintf("Peer %s [%s]", p.id, fmt.Sprintf("les/%d", p.version)) } +// PeerInfo represents a short summary of the `eth` sub-protocol metadata known +// about a connected peer. +type PeerInfo struct { + Version int `json:"version"` // Ethereum protocol version negotiated + Difficulty *big.Int `json:"difficulty"` // Total difficulty of the peer's blockchain + Head string `json:"head"` // SHA3 hash of the peer's best owned block +} + // Info gathers and returns a collection of metadata known about a peer. -func (p *peerCommons) Info() *eth.PeerInfo { - return ð.PeerInfo{ +func (p *peerCommons) Info() *PeerInfo { + return &PeerInfo{ Version: p.version, Difficulty: p.Td(), Head: fmt.Sprintf("%x", p.Head()), @@ -246,7 +254,7 @@ func (p *peerCommons) sendReceiveHandshake(sendList keyValueList) (keyValueList, // network IDs, difficulties, head and genesis blocks. Besides the basic handshake // fields, server and client can exchange and resolve some specified fields through // two callback functions. -func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, sendCallback func(*keyValueList), recvCallback func(keyValueMap) error) error { +func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter, sendCallback func(*keyValueList), recvCallback func(keyValueMap) error) error { p.lock.Lock() defer p.lock.Unlock() @@ -255,11 +263,19 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g // Add some basic handshake fields send = send.add("protocolVersion", uint64(p.version)) send = send.add("networkId", p.network) + // Note: the head info announced at handshake is only used in case of server peers + // but dummy values are still announced by clients for compatibility with older servers send = send.add("headTd", td) send = send.add("headHash", head) send = send.add("headNum", headNum) send = send.add("genesisHash", genesis) + // If the protocol version is beyond les4, then pass the forkID + // as well. Check http://eips.ethereum.org/EIPS/eip-2124 for more + // spec detail. + if p.version >= lpv4 { + send = send.add("forkID", forkID) + } // Add client-specified or server-specified fields if sendCallback != nil { sendCallback(&send) @@ -273,24 +289,14 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g if size > allowedUpdateBytes { return errResp(ErrRequestRejected, "") } - var rGenesis, rHash common.Hash - var rVersion, rNetwork, rNum uint64 - var rTd *big.Int + var rGenesis common.Hash + var rVersion, rNetwork uint64 if err := recv.get("protocolVersion", &rVersion); err != nil { return err } if err := recv.get("networkId", &rNetwork); err != nil { return err } - if err := recv.get("headTd", &rTd); err != nil { - return err - } - if err := recv.get("headHash", &rHash); err != nil { - return err - } - if err := recv.get("headNum", &rNum); err != nil { - return err - } if err := recv.get("genesisHash", &rGenesis); err != nil { return err } @@ -303,7 +309,16 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g if int(rVersion) != p.version { return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version) } - p.headInfo = blockInfo{Hash: rHash, Number: rNum, Td: rTd} + // Check forkID if the protocol version is beyond the les4 + if p.version >= lpv4 { + var forkID forkid.ID + if err := recv.get("forkID", &forkID); err != nil { + return err + } + if err := forkFilter(forkID); err != nil { + return errResp(ErrForkIDRejected, "%v", err) + } + } if recvCallback != nil { return recvCallback(recv) } @@ -326,6 +341,7 @@ type serverPeer struct { onlyAnnounce bool // The flag whether the server sends announcement only. chainSince, chainRecent uint64 // The range of chain server peer can serve. stateSince, stateRecent uint64 // The range of state server peer can serve. + serveTxLookup bool // The server peer can serve tx lookups. // Advertised checkpoint fields checkpointNumber uint64 // The block height which the checkpoint is registered. @@ -569,9 +585,11 @@ func (p *serverPeer) updateHead(hash common.Hash, number uint64, td *big.Int) { } // Handshake executes the les protocol handshake, negotiating version number, -// network IDs, difficulties, head and genesis blocks. -func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { - return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) { +// network IDs and genesis blocks. +func (p *serverPeer) Handshake(genesis common.Hash, forkid forkid.ID, forkFilter forkid.Filter) error { + // Note: there is no need to share local head with a server but older servers still + // require these fields so we announce zero values. + return p.handshake(common.Big0, common.Hash{}, 0, genesis, forkid, forkFilter, func(lists *keyValueList) { // Add some client-specific handshake fields // // Enable signed announcement randomly even the server is not trusted. @@ -581,6 +599,21 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge } *lists = (*lists).add("announceType", p.announceType) }, func(recv keyValueMap) error { + var ( + rHash common.Hash + rNum uint64 + rTd *big.Int + ) + if err := recv.get("headTd", &rTd); err != nil { + return err + } + if err := recv.get("headHash", &rHash); err != nil { + return err + } + if err := recv.get("headNum", &rNum); err != nil { + return err + } + p.headInfo = blockInfo{Hash: rHash, Number: rNum, Td: rTd} if recv.get("serveChainSince", &p.chainSince) != nil { p.onlyAnnounce = true } @@ -596,6 +629,18 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge if recv.get("txRelay", nil) != nil { p.onlyAnnounce = true } + if p.version >= lpv4 { + var recentTx uint + if err := recv.get("recentTxLookup", &recentTx); err != nil { + return err + } + // Note: in the current version we only consider the tx index service useful + // if it is unlimited. This can be made configurable in the future. + p.serveTxLookup = recentTx == txIndexUnlimited + } else { + p.serveTxLookup = true + } + if p.onlyAnnounce && !p.trusted { return errResp(ErrUselessPeer, "peer cannot serve requests") } @@ -936,8 +981,25 @@ func (p *clientPeer) freezeClient() { // Handshake executes the les protocol handshake, negotiating version number, // network IDs, difficulties, head and genesis blocks. -func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error { - return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) { +func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, forkID forkid.ID, forkFilter forkid.Filter, server *LesServer) error { + recentTx := server.handler.blockchain.TxLookupLimit() + if recentTx != txIndexUnlimited { + if recentTx < blockSafetyMargin { + recentTx = txIndexDisabled + } else { + recentTx -= blockSafetyMargin - txIndexRecentOffset + } + } + if server.config.UltraLightOnlyAnnounce { + recentTx = txIndexDisabled + } + if recentTx != txIndexUnlimited && p.version < lpv4 { + return errors.New("Cannot serve old clients without a complete tx index") + } + // Note: clientPeer.headInfo should contain the last head announced to the client by us. + // The values announced in the handshake are dummy values for compatibility reasons and should be ignored. + p.headInfo = blockInfo{Hash: head, Number: headNum, Td: td} + return p.handshake(td, head, headNum, genesis, forkID, forkFilter, func(lists *keyValueList) { // Add some information which services server can offer. if !server.config.UltraLightOnlyAnnounce { *lists = (*lists).add("serveHeaders", nil) @@ -946,13 +1008,16 @@ func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge // If local ethereum node is running in archive mode, advertise ourselves we have // all version state data. Otherwise only recent state is available. - stateRecent := uint64(core.TriesInMemory - 4) + stateRecent := uint64(core.TriesInMemory - blockSafetyMargin) if server.archiveMode { stateRecent = 0 } *lists = (*lists).add("serveRecentState", stateRecent) *lists = (*lists).add("txRelay", nil) } + if p.version >= lpv4 { + *lists = (*lists).add("recentTxLookup", recentTx) + } *lists = (*lists).add("flowControl/BL", server.defParams.BufLimit) *lists = (*lists).add("flowControl/MRR", server.defParams.MinRecharge) @@ -1009,145 +1074,6 @@ type serverPeerSubscriber interface { unregisterPeer(*serverPeer) } -// clientPeerSubscriber is an interface to notify services about added or -// removed client peers -type clientPeerSubscriber interface { - registerPeer(*clientPeer) - unregisterPeer(*clientPeer) -} - -// clientPeerSet represents the set of active client peers currently -// participating in the Light Ethereum sub-protocol. -type clientPeerSet struct { - peers map[string]*clientPeer - // subscribers is a batch of subscribers and peerset will notify - // these subscribers when the peerset changes(new client peer is - // added or removed) - subscribers []clientPeerSubscriber - closed bool - lock sync.RWMutex -} - -// newClientPeerSet creates a new peer set to track the client peers. -func newClientPeerSet() *clientPeerSet { - return &clientPeerSet{peers: make(map[string]*clientPeer)} -} - -// subscribe adds a service to be notified about added or removed -// peers and also register all active peers into the given service. -func (ps *clientPeerSet) subscribe(sub clientPeerSubscriber) { - ps.lock.Lock() - defer ps.lock.Unlock() - - ps.subscribers = append(ps.subscribers, sub) - for _, p := range ps.peers { - sub.registerPeer(p) - } -} - -// unSubscribe removes the specified service from the subscriber pool. -func (ps *clientPeerSet) unSubscribe(sub clientPeerSubscriber) { - ps.lock.Lock() - defer ps.lock.Unlock() - - for i, s := range ps.subscribers { - if s == sub { - ps.subscribers = append(ps.subscribers[:i], ps.subscribers[i+1:]...) - return - } - } -} - -// register adds a new peer into the peer set, or returns an error if the -// peer is already known. -func (ps *clientPeerSet) register(peer *clientPeer) error { - ps.lock.Lock() - defer ps.lock.Unlock() - - if ps.closed { - return errClosed - } - if _, exist := ps.peers[peer.id]; exist { - return errAlreadyRegistered - } - ps.peers[peer.id] = peer - for _, sub := range ps.subscribers { - sub.registerPeer(peer) - } - return nil -} - -// unregister removes a remote peer from the peer set, disabling any further -// actions to/from that particular entity. It also initiates disconnection -// at the networking layer. -func (ps *clientPeerSet) unregister(id string) error { - ps.lock.Lock() - defer ps.lock.Unlock() - - p, ok := ps.peers[id] - if !ok { - return errNotRegistered - } - delete(ps.peers, id) - for _, sub := range ps.subscribers { - sub.unregisterPeer(p) - } - p.Peer.Disconnect(p2p.DiscRequested) - return nil -} - -// ids returns a list of all registered peer IDs -func (ps *clientPeerSet) ids() []string { - ps.lock.RLock() - defer ps.lock.RUnlock() - - var ids []string - for id := range ps.peers { - ids = append(ids, id) - } - return ids -} - -// peer retrieves the registered peer with the given id. -func (ps *clientPeerSet) peer(id string) *clientPeer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - return ps.peers[id] -} - -// len returns if the current number of peers in the set. -func (ps *clientPeerSet) len() int { - ps.lock.RLock() - defer ps.lock.RUnlock() - - return len(ps.peers) -} - -// allClientPeers returns all client peers in a list. -func (ps *clientPeerSet) allPeers() []*clientPeer { - ps.lock.RLock() - defer ps.lock.RUnlock() - - list := make([]*clientPeer, 0, len(ps.peers)) - for _, p := range ps.peers { - list = append(list, p) - } - return list -} - -// close disconnects all peers. No new peers can be registered -// after close has returned. -func (ps *clientPeerSet) close() { - ps.lock.Lock() - defer ps.lock.Unlock() - - for _, p := range ps.peers { - p.Disconnect(p2p.DiscQuitting) - } - ps.closed = true -} - // serverPeerSet represents the set of active server peers currently // participating in the Light Ethereum sub-protocol. type serverPeerSet struct { @@ -1298,42 +1224,3 @@ func (ps *serverPeerSet) close() { } ps.closed = true } - -// serverSet is a special set which contains all connected les servers. -// Les servers will also be discovered by discovery protocol because they -// also run the LES protocol. We can't drop them although they are useless -// for us(server) but for other protocols(e.g. ETH) upon the devp2p they -// may be useful. -type serverSet struct { - lock sync.Mutex - set map[string]*clientPeer - closed bool -} - -func newServerSet() *serverSet { - return &serverSet{set: make(map[string]*clientPeer)} -} - -func (s *serverSet) register(peer *clientPeer) error { - s.lock.Lock() - defer s.lock.Unlock() - - if s.closed { - return errClosed - } - if _, exist := s.set[peer.id]; exist { - return errAlreadyRegistered - } - s.set[peer.id] = peer - return nil -} - -func (s *serverSet) close() { - s.lock.Lock() - defer s.lock.Unlock() - - for _, p := range s.set { - p.Disconnect(p2p.DiscQuitting) - } - s.closed = true -} diff --git a/les/peer_test.go b/les/peer_test.go index 6d3c7f97555e..d6551ce6b639 100644 --- a/les/peer_test.go +++ b/les/peer_test.go @@ -26,8 +26,13 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/params" ) type testServerPeerSub struct { @@ -91,6 +96,14 @@ func TestPeerSubscription(t *testing.T) { checkPeers(sub.unregCh) } +type fakeChain struct{} + +func (f *fakeChain) Config() *params.ChainConfig { return params.MainnetChainConfig } +func (f *fakeChain) Genesis() *types.Block { + return core.DefaultGenesisBlock().ToBlock(rawdb.NewMemoryDatabase()) +} +func (f *fakeChain) CurrentHeader() *types.Header { return &types.Header{Number: big.NewInt(10000000)} } + func TestHandshake(t *testing.T) { // Create a message pipe to communicate through app, net := p2p.MsgPipe() @@ -110,15 +123,21 @@ func TestHandshake(t *testing.T) { head = common.HexToHash("deadbeef") headNum = uint64(10) genesis = common.HexToHash("cafebabe") + + chain1, chain2 = &fakeChain{}, &fakeChain{} + forkID1 = forkid.NewID(chain1.Config(), chain1.Genesis().Hash(), chain1.CurrentHeader().Number.Uint64()) + forkID2 = forkid.NewID(chain2.Config(), chain2.Genesis().Hash(), chain2.CurrentHeader().Number.Uint64()) + filter1, filter2 = forkid.NewFilter(chain1), forkid.NewFilter(chain2) ) + go func() { - errCh1 <- peer1.handshake(td, head, headNum, genesis, func(list *keyValueList) { + errCh1 <- peer1.handshake(td, head, headNum, genesis, forkID1, filter1, func(list *keyValueList) { var announceType uint64 = announceTypeSigned *list = (*list).add("announceType", announceType) }, nil) }() go func() { - errCh2 <- peer2.handshake(td, head, headNum, genesis, nil, func(recv keyValueMap) error { + errCh2 <- peer2.handshake(td, head, headNum, genesis, forkID2, filter2, nil, func(recv keyValueMap) error { var reqType uint64 err := recv.get("announceType", &reqType) if err != nil { diff --git a/les/protocol.go b/les/protocol.go index 4fd19f9beca1..39d9f5152fd8 100644 --- a/les/protocol.go +++ b/les/protocol.go @@ -34,6 +34,7 @@ import ( const ( lpv2 = 2 lpv3 = 3 + lpv4 = 4 ) // Supported versions of the les protocol (first is primary) @@ -44,11 +45,16 @@ var ( ) // Number of implemented message corresponding to different protocol versions. -var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24} +var ProtocolLengths = map[uint]uint64{lpv2: 22, lpv3: 24, lpv4: 24} const ( NetworkId = 1 ProtocolMaxMsgSize = 10 * 1024 * 1024 // Maximum cap on the size of a protocol message + blockSafetyMargin = 4 // safety margin applied to block ranges specified relative to head block + + txIndexUnlimited = 0 // this value in the "recentTxLookup" handshake field means the entire tx index history is served + txIndexDisabled = 1 // this value means tx index is not served at all + txIndexRecentOffset = 1 // txIndexRecentOffset + N in the handshake field means then tx index of the last N blocks is supported ) // les protocol message codes @@ -150,6 +156,7 @@ const ( ErrInvalidResponse ErrTooManyTimeouts ErrMissingKey + ErrForkIDRejected ) func (e errCode) String() string { @@ -172,12 +179,7 @@ var errorToString = map[int]string{ ErrInvalidResponse: "Invalid response", ErrTooManyTimeouts: "Too many request timeouts", ErrMissingKey: "Key missing from list", -} - -type announceBlock struct { - Hash common.Hash // Hash of one particular block being announced - Number uint64 // Number of one particular block being announced - Td *big.Int // Total difficulty of one particular block being announced + ErrForkIDRejected: "ForkID rejected", } // announceData is the network packet for the block announcements. @@ -199,7 +201,7 @@ func (a *announceData) sanityCheck() error { // sign adds a signature to the block announcement by the given privKey func (a *announceData) sign(privKey *ecdsa.PrivateKey) { - rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + rlp, _ := rlp.EncodeToBytes(blockInfo{a.Hash, a.Number, a.Td}) sig, _ := crypto.Sign(crypto.Keccak256(rlp), privKey) a.Update = a.Update.add("sign", sig) } @@ -210,7 +212,7 @@ func (a *announceData) checkSignature(id enode.ID, update keyValueMap) error { if err := update.get("sign", &sig); err != nil { return err } - rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td}) + rlp, _ := rlp.EncodeToBytes(blockInfo{a.Hash, a.Number, a.Td}) recPubkey, err := crypto.SigToPub(crypto.Keccak256(rlp), sig) if err != nil { return err diff --git a/les/retrieve.go b/les/retrieve.go index 4f77004f20cd..ca4f867ea80f 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -155,6 +155,15 @@ func (rm *retrieveManager) sendReq(reqID uint64, req *distReq, val validatorFunc return r } +// requested reports whether the request with given reqid is sent by the retriever. +func (rm *retrieveManager) requested(reqId uint64) bool { + rm.lock.RLock() + defer rm.lock.RUnlock() + + _, ok := rm.sentReqs[reqId] + return ok +} + // deliver is called by the LES protocol manager to deliver reply messages to waiting requests func (rm *retrieveManager) deliver(peer distPeer, msg *Msg) error { rm.lock.RLock() diff --git a/les/server.go b/les/server.go index 225a7ad1f033..cbedce136c35 100644 --- a/les/server.go +++ b/les/server.go @@ -18,6 +18,7 @@ package les import ( "crypto/ecdsa" + "reflect" "time" "github.com/ethereum/go-ethereum/common/mclock" @@ -31,17 +32,32 @@ import ( "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" ) +var ( + serverSetup = &nodestate.Setup{} + clientPeerField = serverSetup.NewField("clientPeer", reflect.TypeOf(&clientPeer{})) + clientInfoField = serverSetup.NewField("clientInfo", reflect.TypeOf(&clientInfo{})) + connAddressField = serverSetup.NewField("connAddr", reflect.TypeOf("")) + balanceTrackerSetup = lps.NewBalanceTrackerSetup(serverSetup) + priorityPoolSetup = lps.NewPriorityPoolSetup(serverSetup) +) + +func init() { + balanceTrackerSetup.Connect(connAddressField, priorityPoolSetup.CapacityField) + priorityPoolSetup.Connect(balanceTrackerSetup.BalanceField, balanceTrackerSetup.UpdateFlag) // NodeBalance implements nodePriority +} + type LesServer struct { lesCommons + ns *nodestate.NodeStateMachine archiveMode bool // Flag whether the ethereum node runs in archive mode. - peers *clientPeerSet - serverset *serverSet handler *serverHandler + broadcaster *broadcaster lesTopics []discv5.Topic privateKey *ecdsa.PrivateKey @@ -60,6 +76,7 @@ type LesServer struct { } func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesServer, error) { + ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) // Collect les protocol version information supported by local node. lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions)) for i, pv := range AdvertiseProtocolVersions { @@ -83,9 +100,9 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency, true), closeCh: make(chan struct{}), }, + ns: ns, archiveMode: e.ArchiveMode(), - peers: newClientPeerSet(), - serverset: newServerSet(), + broadcaster: newBroadcaster(ns), lesTopics: lesTopics, fcManager: flowcontrol.NewClientManager(nil, &mclock.System{}), servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100), @@ -116,7 +133,7 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer srv.maxCapacity = totalRecharge } srv.fcManager.SetCapacityLimits(srv.minCapacity, srv.maxCapacity, srv.minCapacity*2) - srv.clientPool = newClientPool(srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, func(id enode.ID) { go srv.peers.unregister(id.String()) }) + srv.clientPool = newClientPool(ns, srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, srv.dropClient) srv.clientPool.setDefaultFactors(lps.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}) checkpoint := srv.latestLocalCheckpoint() @@ -130,6 +147,13 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer node.RegisterAPIs(srv.APIs()) node.RegisterLifecycle(srv) + // disconnect all peers at nsm shutdown + ns.SubscribeField(clientPeerField, func(node *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) { + if state.Equals(serverSetup.OfflineFlag()) && oldValue != nil { + oldValue.(*clientPeer).Peer.Disconnect(p2p.DiscRequested) + } + }) + ns.Start() return srv, nil } @@ -158,7 +182,7 @@ func (s *LesServer) APIs() []rpc.API { func (s *LesServer) Protocols() []p2p.Protocol { ps := s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} { - if p := s.peers.peer(id.String()); p != nil { + if p := s.getClient(id); p != nil { return p.Info() } return nil @@ -173,6 +197,7 @@ func (s *LesServer) Protocols() []p2p.Protocol { // Start starts the LES server func (s *LesServer) Start() error { s.privateKey = s.p2pSrv.PrivateKey + s.broadcaster.setSignerKey(s.privateKey) s.handler.start() s.wg.Add(1) @@ -198,19 +223,11 @@ func (s *LesServer) Start() error { func (s *LesServer) Stop() error { close(s.closeCh) - // Disconnect existing connections with other LES servers. - s.serverset.close() - - // Disconnect existing sessions. - // This also closes the gate for any new registrations on the peer set. - // sessions which are already established but not added to pm.peers yet - // will exit when they try to register. - s.peers.close() - + s.clientPool.stop() + s.ns.Stop() s.fcManager.Stop() s.costTracker.stop() s.handler.stop() - s.clientPool.stop() // client pool should be closed after handler. s.servingQueue.stop() // Note, bloom trie indexer is closed by parent bloombits indexer. @@ -279,3 +296,18 @@ func (s *LesServer) capacityManagement() { } } } + +func (s *LesServer) getClient(id enode.ID) *clientPeer { + if node := s.ns.GetNode(id); node != nil { + if p, ok := s.ns.GetField(node, clientPeerField).(*clientPeer); ok { + return p + } + } + return nil +} + +func (s *LesServer) dropClient(id enode.ID) { + if p := s.getClient(id); p != nil { + p.Peer.Disconnect(p2p.DiscRequested) + } +} diff --git a/les/server_handler.go b/les/server_handler.go index 583df960080c..2316c9c5a4c0 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -17,6 +17,7 @@ package les import ( + "crypto/ecdsa" "encoding/binary" "encoding/json" "errors" @@ -27,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/core/forkid" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -36,6 +38,8 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" ) @@ -43,7 +47,7 @@ import ( const ( softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header - ethVersion = 63 // equivalent eth version for the downloader + ethVersion = 64 // equivalent eth version for the downloader MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request MaxBodyFetch = 32 // Amount of block bodies to be fetched per retrieval request @@ -63,6 +67,7 @@ var ( // serverHandler is responsible for serving light client and process // all incoming light requests. type serverHandler struct { + forkFilter forkid.Filter blockchain *core.BlockChain chainDb ethdb.Database txpool *core.TxPool @@ -78,6 +83,7 @@ type serverHandler struct { func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb ethdb.Database, txpool *core.TxPool, synced func() bool) *serverHandler { handler := &serverHandler{ + forkFilter: forkid.NewFilter(blockchain), server: server, blockchain: blockchain, chainDb: chainDb, @@ -91,7 +97,7 @@ func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb et // start starts the server handler. func (h *serverHandler) start() { h.wg.Add(1) - go h.broadcastHeaders() + go h.broadcastLoop() } // stop stops the server handler. @@ -118,52 +124,67 @@ func (h *serverHandler) handle(p *clientPeer) error { hash = head.Hash() number = head.Number.Uint64() td = h.blockchain.GetTd(hash, number) + forkID = forkid.NewID(h.blockchain.Config(), h.blockchain.Genesis().Hash(), h.blockchain.CurrentBlock().NumberU64()) ) - if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), h.server); err != nil { + if err := p.Handshake(td, hash, number, h.blockchain.Genesis().Hash(), forkID, h.forkFilter, h.server); err != nil { p.Log().Debug("Light Ethereum handshake failed", "err", err) return err } - if p.server { - if err := h.server.serverset.register(p); err != nil { - return err + // Reject the duplicated peer, otherwise register it to peerset. + var registered bool + if err := h.server.ns.Operation(func() { + if h.server.ns.GetField(p.Node(), clientPeerField) != nil { + registered = true + } else { + h.server.ns.SetFieldSub(p.Node(), clientPeerField, p) } + }); err != nil { + return err + } + if registered { + return errAlreadyRegistered + } + + defer func() { + h.server.ns.SetField(p.Node(), clientPeerField, nil) + if p.fcClient != nil { // is nil when connecting another server + p.fcClient.Disconnect() + } + }() + if p.server { // connected to another server, no messages expected, just wait for disconnection _, err := p.rw.ReadMsg() return err } // Reject light clients if server is not synced. + // + // Put this checking here, so that "non-synced" les-server peers are still allowed + // to keep the connection. if !h.synced() { p.Log().Debug("Light server not synced, rejecting peer") return p2p.DiscRequested } - defer p.fcClient.Disconnect() - // Disconnect the inbound peer if it's rejected by clientPool if cap, err := h.server.clientPool.connect(p); cap != p.fcParams.MinRecharge || err != nil { p.Log().Debug("Light Ethereum peer rejected", "err", errFullClientPool) return errFullClientPool } - p.balance, _ = h.server.clientPool.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance) + p.balance, _ = h.server.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance) if p.balance == nil { return p2p.DiscRequested } - // Register the peer locally - if err := h.server.peers.register(p); err != nil { - h.server.clientPool.disconnect(p) - p.Log().Error("Light Ethereum peer registration failed", "err", err) - return err - } - clientConnectionGauge.Update(int64(h.server.peers.len())) + activeCount, _ := h.server.clientPool.pp.Active() + clientConnectionGauge.Update(int64(activeCount)) var wg sync.WaitGroup // Wait group used to track all in-flight task routines. connectedAt := mclock.Now() defer func() { wg.Wait() // Ensure all background task routines have exited. - h.server.peers.unregister(p.id) h.server.clientPool.disconnect(p) p.balance = nil - clientConnectionGauge.Update(int64(h.server.peers.len())) + activeCount, _ := h.server.clientPool.pp.Active() + clientConnectionGauge.Update(int64(activeCount)) connectionTimer.Update(time.Duration(mclock.Now() - connectedAt)) }() // Mark the peer starts to be served. @@ -334,7 +355,6 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { origin = h.blockchain.GetHeaderByNumber(query.Origin.Number) } if origin == nil { - p.bumpInvalid() break } headers = append(headers, origin) @@ -594,6 +614,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { var ( lastBHash common.Hash root common.Hash + header *types.Header ) reqCnt := len(req.Reqs) if accept(req.ReqID, uint64(reqCnt), MaxProofsFetch) { @@ -608,10 +629,6 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { return } // Look up the root hash belonging to the request - var ( - header *types.Header - trie state.Trie - ) if request.BHash != lastBHash { root, lastBHash = common.Hash{}, request.BHash @@ -638,6 +655,7 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { // Open the account or storage trie for the request statedb := h.blockchain.StateCache() + var trie state.Trie switch len(request.AccKey) { case 0: // No account key specified, open an account trie @@ -911,11 +929,11 @@ func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus { return stat } -// broadcastHeaders broadcasts new block information to all connected light +// broadcastLoop broadcasts new block information to all connected light // clients. According to the agreement between client and server, server should // only broadcast new announcement if the total difficulty is higher than the // last one. Besides server will add the signature if client requires. -func (h *serverHandler) broadcastHeaders() { +func (h *serverHandler) broadcastLoop() { defer h.wg.Done() headCh := make(chan core.ChainHeadEvent, 10) @@ -929,10 +947,6 @@ func (h *serverHandler) broadcastHeaders() { for { select { case ev := <-headCh: - peers := h.server.peers.allPeers() - if len(peers) == 0 { - continue - } header := ev.Block.Header() hash, number := header.Hash(), header.Number.Uint64() td := h.blockchain.GetTd(hash, number) @@ -944,33 +958,79 @@ func (h *serverHandler) broadcastHeaders() { reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64() } lastHead, lastTd = header, td - log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg) - var ( - signed bool - signedAnnounce announceData - ) - announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg} - for _, p := range peers { - p := p - switch p.announceType { - case announceTypeSimple: - if !p.queueSend(func() { p.sendAnnounce(announce) }) { - log.Debug("Drop announcement because queue is full", "number", number, "hash", hash) - } - case announceTypeSigned: - if !signed { - signedAnnounce = announce - signedAnnounce.sign(h.server.privateKey) - signed = true - } - if !p.queueSend(func() { p.sendAnnounce(signedAnnounce) }) { - log.Debug("Drop announcement because queue is full", "number", number, "hash", hash) - } - } - } + h.server.broadcaster.broadcast(announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}) case <-h.closeCh: return } } } + +// broadcaster sends new header announcements to active client peers +type broadcaster struct { + ns *nodestate.NodeStateMachine + privateKey *ecdsa.PrivateKey + lastAnnounce, signedAnnounce announceData +} + +// newBroadcaster creates a new broadcaster +func newBroadcaster(ns *nodestate.NodeStateMachine) *broadcaster { + b := &broadcaster{ns: ns} + ns.SubscribeState(priorityPoolSetup.ActiveFlag, func(node *enode.Node, oldState, newState nodestate.Flags) { + if newState.Equals(priorityPoolSetup.ActiveFlag) { + // send last announcement to activated peers + b.sendTo(node) + } + }) + return b +} + +// setSignerKey sets the signer key for signed announcements. Should be called before +// starting the protocol handler. +func (b *broadcaster) setSignerKey(privateKey *ecdsa.PrivateKey) { + b.privateKey = privateKey +} + +// broadcast sends the given announcements to all active peers +func (b *broadcaster) broadcast(announce announceData) { + b.ns.Operation(func() { + // iterate in an Operation to ensure that the active set does not change while iterating + b.lastAnnounce = announce + b.ns.ForEach(priorityPoolSetup.ActiveFlag, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) { + b.sendTo(node) + }) + }) +} + +// sendTo sends the most recent announcement to the given node unless the same or higher Td +// announcement has already been sent. +func (b *broadcaster) sendTo(node *enode.Node) { + if b.lastAnnounce.Td == nil { + return + } + if p, _ := b.ns.GetField(node, clientPeerField).(*clientPeer); p != nil { + if p.headInfo.Td == nil || b.lastAnnounce.Td.Cmp(p.headInfo.Td) > 0 { + announce := b.lastAnnounce + switch p.announceType { + case announceTypeSimple: + if !p.queueSend(func() { p.sendAnnounce(announce) }) { + log.Debug("Drop announcement because queue is full", "number", announce.Number, "hash", announce.Hash) + } else { + log.Debug("Sent announcement", "number", announce.Number, "hash", announce.Hash) + } + case announceTypeSigned: + if b.signedAnnounce.Hash != b.lastAnnounce.Hash { + b.signedAnnounce = b.lastAnnounce + b.signedAnnounce.sign(b.privateKey) + } + announce := b.signedAnnounce + if !p.queueSend(func() { p.sendAnnounce(announce) }) { + log.Debug("Drop announcement because queue is full", "number", announce.Number, "hash", announce.Hash) + } else { + log.Debug("Sent announcement", "number", announce.Number, "hash", announce.Hash) + } + } + p.headInfo = blockInfo{b.lastAnnounce.Hash, b.lastAnnounce.Number, b.lastAnnounce.Td} + } + } +} diff --git a/les/serverpool.go b/les/serverpool.go index 9bfa0bd7259d..6cf4affff097 100644 --- a/les/serverpool.go +++ b/les/serverpool.go @@ -112,9 +112,8 @@ var ( } enc, err := rlp.EncodeToBytes(&ne) return enc, err - } else { - return nil, errors.New("invalid field type") } + return nil, errors.New("invalid field type") }, func(enc []byte) (interface{}, error) { var ne nodeHistoryEnc diff --git a/les/serverpool_test.go b/les/serverpool_test.go index 3d0487d102e9..70a41b74c6f8 100644 --- a/les/serverpool_test.go +++ b/les/serverpool_test.go @@ -119,31 +119,28 @@ func (s *serverPoolTest) start() { s.clock.Sleep(time.Second * 5) s.endWait() return -1 - } else { - switch idx % 3 { - case 0: - // pre-neg returns true only if connection is possible - if canConnect { - return 1 - } else { - return 0 - } - case 1: - // pre-neg returns true but connection might still fail + } + switch idx % 3 { + case 0: + // pre-neg returns true only if connection is possible + if canConnect { + return 1 + } + return 0 + case 1: + // pre-neg returns true but connection might still fail + return 1 + case 2: + // pre-neg returns true if connection is possible, otherwise timeout (node unresponsive) + if canConnect { return 1 - case 2: - // pre-neg returns true if connection is possible, otherwise timeout (node unresponsive) - if canConnect { - return 1 - } else { - s.beginWait() - s.clock.Sleep(time.Second * 5) - s.endWait() - return -1 - } } + s.beginWait() + s.clock.Sleep(time.Second * 5) + s.endWait() return -1 } + return -1 } } diff --git a/les/sync.go b/les/sync.go index d2568d45bcb5..ad3a0e0f3cdc 100644 --- a/les/sync.go +++ b/les/sync.go @@ -56,8 +56,8 @@ func (h *clientHandler) validateCheckpoint(peer *serverPeer) error { defer cancel() // Fetch the block header corresponding to the checkpoint registration. - cp := peer.checkpoint - header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id) + wrapPeer := &peerConnection{handler: h, peer: peer} + header, err := wrapPeer.RetrieveSingleHeaderByNumber(ctx, peer.checkpointNumber) if err != nil { return err } @@ -66,7 +66,7 @@ func (h *clientHandler) validateCheckpoint(peer *serverPeer) error { if err != nil { return err } - events := h.backend.oracle.Contract().LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash()) + events := h.backend.oracle.Contract().LookupCheckpointEvents(logs, peer.checkpoint.SectionIndex, peer.checkpoint.Hash()) if len(events) == 0 { return errInvalidCheckpoint } diff --git a/les/sync_test.go b/les/sync_test.go index ffce4d8df290..2eb0f88bf97a 100644 --- a/les/sync_test.go +++ b/les/sync_test.go @@ -53,7 +53,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { time.Sleep(10 * time.Millisecond) } } - // Generate 512+4 blocks (totally 1 CHT sections) + // Generate 128+1 blocks (totally 1 CHT sections) server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false, true) defer tearDown() @@ -80,7 +80,8 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...) sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey) sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper - if _, err := server.handler.server.oracle.Contract().RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { + auth, _ := bind.NewKeyedTransactorWithChainID(signerKey, big.NewInt(1337)) + if _, err := server.handler.server.oracle.Contract().RegisterCheckpoint(auth, cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { t.Error("register checkpoint failed", err) } server.backend.Commit() @@ -162,7 +163,8 @@ func testMissOracleBackend(t *testing.T, hasCheckpoint bool) { data := append([]byte{0x19, 0x00}, append(registrarAddr.Bytes(), append([]byte{0, 0, 0, 0, 0, 0, 0, 0}, cp.Hash().Bytes()...)...)...) sig, _ := crypto.Sign(crypto.Keccak256(data), signerKey) sig[64] += 27 // Transform V from 0/1 to 27/28 according to the yellow paper - if _, err := server.handler.server.oracle.Contract().RegisterCheckpoint(bind.NewKeyedTransactor(signerKey), cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { + auth, _ := bind.NewKeyedTransactorWithChainID(signerKey, big.NewInt(1337)) + if _, err := server.handler.server.oracle.Contract().RegisterCheckpoint(auth, cp.SectionIndex, cp.Hash().Bytes(), new(big.Int).Sub(header.Number, big.NewInt(1)), header.ParentHash, [][]byte{sig}); err != nil { t.Error("register checkpoint failed", err) } server.backend.Commit() diff --git a/les/test_helper.go b/les/test_helper.go index 9f9b28721e44..d108a8dacefb 100644 --- a/les/test_helper.go +++ b/les/test_helper.go @@ -46,6 +46,7 @@ import ( "github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/nodestate" "github.com/ethereum/go-ethereum/params" ) @@ -111,7 +112,8 @@ func prepare(n int, backend *backends.SimulatedBackend) { switch i { case 0: // deploy checkpoint contract - registrarAddr, _, _, _ = contract.DeployCheckpointOracle(bind.NewKeyedTransactor(bankKey), backend, []common.Address{signerAddr}, sectionSize, processConfirms, big.NewInt(1)) + auth, _ := bind.NewKeyedTransactorWithChainID(bankKey, big.NewInt(1337)) + registrarAddr, _, _, _ = contract.DeployCheckpointOracle(auth, backend, []common.Address{signerAddr}, sectionSize, processConfirms, big.NewInt(1)) // bankUser transfers some ether to user1 nonce, _ := backend.PendingNonceAt(ctx, bankAddr) tx, _ := types.SignTx(types.NewTransaction(nonce, userAddr1, big.NewInt(10000), params.TxGas, nil, nil), signer, bankKey) @@ -227,7 +229,7 @@ func newTestClientHandler(backend *backends.SimulatedBackend, odr *LesOdr, index return client.handler } -func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, peers *clientPeerSet, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) { +func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) { var ( gspec = core.Genesis{ Config: params.AllEthashProtocolChanges, @@ -263,6 +265,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } oracle = checkpointoracle.New(checkpointConfig, getLocal) } + ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup) server := &LesServer{ lesCommons: lesCommons{ genesis: genesis.Hash(), @@ -274,7 +277,8 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da oracle: oracle, closeCh: make(chan struct{}), }, - peers: peers, + ns: ns, + broadcaster: newBroadcaster(ns), servingQueue: newServingQueue(int64(time.Millisecond*10), 1), defParams: flowcontrol.ServerParams{ BufLimit: testBufLimit, @@ -284,13 +288,14 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da } server.costTracker, server.minCapacity = newCostTracker(db, server.config) server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism. - server.clientPool = newClientPool(db, testBufRecharge, defaultConnectedBias, clock, func(id enode.ID) {}) + server.clientPool = newClientPool(ns, db, testBufRecharge, defaultConnectedBias, clock, func(id enode.ID) {}) server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true }) if server.oracle != nil { server.oracle.Start(simulation) } server.servingQueue.setThreads(4) + ns.Start() server.handler.start() return server.handler, simulation } @@ -463,7 +468,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba if simClock { clock = &mclock.Simulated{} } - handler, b := newTestServerHandler(blocks, indexers, db, newClientPeerSet(), clock) + handler, b := newTestServerHandler(blocks, indexers, db, clock) var peer *testPeer if newPeer { @@ -502,7 +507,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, ulcServers []string, ulcFraction int, simClock bool, connect bool, disablePruning bool) (*testServer, *testClient, func()) { sdb, cdb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase() - speers, cpeers := newServerPeerSet(), newClientPeerSet() + speers := newServerPeerSet() var clock mclock.Clock = &mclock.System{} if simClock { @@ -519,7 +524,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer ccIndexer, cbIndexer, cbtIndexer := cIndexers[0], cIndexers[1], cIndexers[2] odr.SetIndexers(ccIndexer, cbIndexer, cbtIndexer) - server, b := newTestServerHandler(blocks, sindexers, sdb, cpeers, clock) + server, b := newTestServerHandler(blocks, sindexers, sdb, clock) client := newTestClientHandler(b, odr, cIndexers, cdb, speers, ulcServers, ulcFraction) scIndexer.Start(server.blockchain) diff --git a/les/txrelay.go b/les/txrelay.go index 4f6c15025e78..57f2412eba5d 100644 --- a/les/txrelay.go +++ b/les/txrelay.go @@ -35,7 +35,7 @@ type lesTxRelay struct { txPending map[common.Hash]struct{} peerList []*serverPeer peerStartPos int - lock sync.RWMutex + lock sync.Mutex stop chan struct{} retriever *retrieveManager diff --git a/les/utils/expiredvalue.go b/les/utils/expiredvalue.go index 55e82cee4819..3fd52616fac5 100644 --- a/les/utils/expiredvalue.go +++ b/les/utils/expiredvalue.go @@ -86,10 +86,15 @@ func (e *ExpiredValue) Add(amount int64, logOffset Fixed64) int64 { e.Exp = integer } if base >= 0 || uint64(-base) <= e.Base { - // This is a temporary fix to circumvent a golang - // uint conversion issue on arm64, which needs to - // be investigated further. FIXME - e.Base = uint64(int64(e.Base) + int64(base)) + // The conversion from negative float64 to + // uint64 is undefined in golang, and doesn't + // work with ARMv8. More details at: + // https://github.com/golang/go/issues/43047 + if base >= 0 { + e.Base += uint64(base) + } else { + e.Base -= uint64(-base) + } return amount } net := int64(-float64(e.Base) / factor) diff --git a/les/utils/weighted_select.go b/les/utils/weighted_select.go index d6db3c0e657c..9b4413afb564 100644 --- a/les/utils/weighted_select.go +++ b/les/utils/weighted_select.go @@ -17,7 +17,10 @@ package utils import ( + "math" "math/rand" + + "github.com/ethereum/go-ethereum/log" ) type ( @@ -54,6 +57,14 @@ func (w *WeightedRandomSelect) IsEmpty() bool { // setWeight sets an item's weight to a specific value (removes it if zero) func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { + if weight > math.MaxInt64-w.root.sumWeight { + // old weight is still included in sumWeight, remove and check again + w.setWeight(item, 0) + if weight > math.MaxInt64-w.root.sumWeight { + log.Error("WeightedRandomSelect overflow", "sumWeight", w.root.sumWeight, "new weight", weight) + weight = math.MaxInt64 - w.root.sumWeight + } + } idx, ok := w.idx[item] if ok { w.root.setWeight(idx, weight) diff --git a/light/lightchain.go b/light/lightchain.go index 6fc321ae0b57..ca6fbfac49de 100644 --- a/light/lightchain.go +++ b/light/lightchain.go @@ -396,24 +396,26 @@ func (lc *LightChain) InsertHeaderChain(chain []*types.Header, checkFreq int) (i lc.wg.Add(1) defer lc.wg.Done() - var events []interface{} - whFunc := func(header *types.Header) error { - status, err := lc.hc.WriteHeader(header) - - switch status { - case core.CanonStatTy: - log.Debug("Inserted new header", "number", header.Number, "hash", header.Hash()) - events = append(events, core.ChainEvent{Block: types.NewBlockWithHeader(header), Hash: header.Hash()}) - - case core.SideStatTy: - log.Debug("Inserted forked header", "number", header.Number, "hash", header.Hash()) - events = append(events, core.ChainSideEvent{Block: types.NewBlockWithHeader(header)}) - } - return err + status, err := lc.hc.InsertHeaderChain(chain, start) + if err != nil || len(chain) == 0 { + return 0, err + } + + // Create chain event for the new head block of this insertion. + var ( + events = make([]interface{}, 0, 1) + lastHeader = chain[len(chain)-1] + block = types.NewBlockWithHeader(lastHeader) + ) + switch status { + case core.CanonStatTy: + events = append(events, core.ChainEvent{Block: block, Hash: block.Hash()}) + case core.SideStatTy: + events = append(events, core.ChainSideEvent{Block: block}) } - i, err := lc.hc.InsertHeaderChain(chain, whFunc, start) lc.postChainEvents(events) - return i, err + + return 0, err } // CurrentHeader retrieves the current head header of the canonical chain. The diff --git a/light/lightchain_test.go b/light/lightchain_test.go index 70d2e70c189d..2aed08d74edf 100644 --- a/light/lightchain_test.go +++ b/light/lightchain_test.go @@ -18,6 +18,7 @@ package light import ( "context" + "errors" "math/big" "testing" @@ -321,7 +322,7 @@ func TestBadHeaderHashes(t *testing.T) { var err error headers := makeHeaderChainWithDiff(bc.genesisBlock, []int{1, 2, 4}, 10) core.BadHashes[headers[2].Hash()] = true - if _, err = bc.InsertHeaderChain(headers, 1); err != core.ErrBlacklistedHash { + if _, err = bc.InsertHeaderChain(headers, 1); !errors.Is(err, core.ErrBlacklistedHash) { t.Errorf("error mismatch: have: %v, want %v", err, core.ErrBlacklistedHash) } } diff --git a/light/odr.go b/light/odr.go index 0b854b0b6c9e..bb243f915215 100644 --- a/light/odr.go +++ b/light/odr.go @@ -133,10 +133,8 @@ func (req *ReceiptsRequest) StoreResult(db ethdb.Database) { } } -// ChtRequest is the ODR request type for state/storage trie entries +// ChtRequest is the ODR request type for retrieving header by Canonical Hash Trie type ChtRequest struct { - Untrusted bool // Indicator whether the result retrieved is trusted or not - PeerId string // The specified peer id from which to retrieve data. Config *IndexerConfig ChtNum, BlockNum uint64 ChtRoot common.Hash @@ -148,12 +146,9 @@ type ChtRequest struct { // StoreResult stores the retrieved data in local database func (req *ChtRequest) StoreResult(db ethdb.Database) { hash, num := req.Header.Hash(), req.Header.Number.Uint64() - - if !req.Untrusted { - rawdb.WriteHeader(db, req.Header) - rawdb.WriteTd(db, hash, num, req.Td) - rawdb.WriteCanonicalHash(db, hash, num) - } + rawdb.WriteHeader(db, req.Header) + rawdb.WriteTd(db, hash, num, req.Td) + rawdb.WriteCanonicalHash(db, hash, num) } // BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure diff --git a/light/odr_test.go b/light/odr_test.go index 5f7f4d96cb1e..cb22334fdc9f 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -195,8 +195,9 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain // Perform read-only call. st.SetBalance(testBankAddress, math.MaxBig256) msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, false)} - context := core.NewEVMContext(msg, header, chain, nil) - vmenv := vm.NewEVM(context, st, config, vm.Config{}) + txContext := core.NewEVMTxContext(msg) + context := core.NewEVMBlockContext(header, chain, nil) + vmenv := vm.NewEVM(context, txContext, st, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) result, _ := core.ApplyMessage(vmenv, msg, gp) res = append(res, result.Return()...) diff --git a/light/odr_util.go b/light/odr_util.go index aec0c7b69f7d..ec2d1e64913d 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -19,20 +19,23 @@ package light import ( "bytes" "context" + "errors" "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) -var sha3Nil = crypto.Keccak256Hash(nil) +// errNonCanonicalHash is returned if the requested chain data doesn't belong +// to the canonical chain. ODR can only retrieve the canonical chain data covered +// by the CHT or Bloom trie for verification. +var errNonCanonicalHash = errors.New("hash is not currently canonical") // GetHeaderByNumber retrieves the canonical block header corresponding to the -// given number. +// given number. The returned header is proven by local CHT. func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) { // Try to find it in the local database first. db := odr.Database() @@ -63,25 +66,6 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ return r.Header, nil } -// GetUntrustedHeaderByNumber retrieves specified block header without -// correctness checking. Note this function should only be used in light -// client checkpoint syncing. -func GetUntrustedHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64, peerId string) (*types.Header, error) { - // todo(rjl493456442) it's a hack to retrieve headers which is not covered - // by CHT. Fix it in LES4 - r := &ChtRequest{ - BlockNum: number, - ChtNum: number / odr.IndexerConfig().ChtSize, - Untrusted: true, - PeerId: peerId, - Config: odr.IndexerConfig(), - } - if err := odr.Retrieve(ctx, r); err != nil { - return nil, err - } - return r.Header, nil -} - // GetCanonicalHash retrieves the canonical block hash corresponding to the number. func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (common.Hash, error) { hash := rawdb.ReadCanonicalHash(odr.Database(), number) @@ -102,10 +86,13 @@ func GetTd(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) if td != nil { return td, nil } - _, err := GetHeaderByNumber(ctx, odr, number) + header, err := GetHeaderByNumber(ctx, odr, number) if err != nil { return nil, err } + if header.Hash() != hash { + return nil, errNonCanonicalHash + } // -> td mapping already be stored in db, get it. return rawdb.ReadTd(odr.Database(), hash, number), nil } @@ -120,6 +107,9 @@ func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number ui if err != nil { return nil, errNoHeader } + if header.Hash() != hash { + return nil, errNonCanonicalHash + } r := &BlockRequest{Hash: hash, Number: number, Header: header} if err := odr.Retrieve(ctx, r); err != nil { return nil, err @@ -167,6 +157,9 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num if err != nil { return nil, errNoHeader } + if header.Hash() != hash { + return nil, errNonCanonicalHash + } r := &ReceiptsRequest{Hash: hash, Number: number, Header: header} if err := odr.Retrieve(ctx, r); err != nil { return nil, err diff --git a/light/postprocess.go b/light/postprocess.go index de207ad4a3bf..891c8a5869db 100644 --- a/light/postprocess.go +++ b/light/postprocess.go @@ -147,7 +147,7 @@ func NewChtIndexer(db ethdb.Database, odr OdrBackend, size, confirms uint64, dis diskdb: db, odr: odr, trieTable: trieTable, - triedb: trie.NewDatabaseWithCache(trieTable, 1, ""), // Use a tiny cache only to keep memory down + triedb: trie.NewDatabaseWithConfig(trieTable, &trie.Config{Cache: 1}), // Use a tiny cache only to keep memory down trieset: mapset.NewSet(), sectionSize: size, disablePruning: disablePruning, @@ -340,7 +340,7 @@ func NewBloomTrieIndexer(db ethdb.Database, odr OdrBackend, parentSize, size uin diskdb: db, odr: odr, trieTable: trieTable, - triedb: trie.NewDatabaseWithCache(trieTable, 1, ""), // Use a tiny cache only to keep memory down + triedb: trie.NewDatabaseWithConfig(trieTable, &trie.Config{Cache: 1}), // Use a tiny cache only to keep memory down trieset: mapset.NewSet(), parentSize: parentSize, size: size, diff --git a/light/trie.go b/light/trie.go index 3eb05f4a3fef..0516b944860d 100644 --- a/light/trie.go +++ b/light/trie.go @@ -30,6 +30,10 @@ import ( "github.com/ethereum/go-ethereum/trie" ) +var ( + sha3Nil = crypto.Keccak256Hash(nil) +) + func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil) return state diff --git a/metrics/cpu_enabled.go b/metrics/cpu_enabled.go index 52a3c2e96676..02192928b70c 100644 --- a/metrics/cpu_enabled.go +++ b/metrics/cpu_enabled.go @@ -31,6 +31,10 @@ func ReadCPUStats(stats *CPUStats) { log.Error("Could not read cpu stats", "err", err) return } + if len(timeStats) == 0 { + log.Error("Empty cpu stats") + return + } // requesting all cpu times will always return an array with only one time stats entry timeStat := timeStats[0] stats.GlobalTime = int64((timeStat.User + timeStat.Nice + timeStat.System) * cpu.ClocksPerSec) diff --git a/miner/miner.go b/miner/miner.go index 8cbd70b424b5..20169f5007e2 100644 --- a/miner/miner.go +++ b/miner/miner.go @@ -85,15 +85,22 @@ func New(eth Backend, config *Config, chainConfig *params.ChainConfig, mux *even // and halt your mining operation for as long as the DOS continues. func (miner *Miner) update() { events := miner.mux.Subscribe(downloader.StartEvent{}, downloader.DoneEvent{}, downloader.FailedEvent{}) - defer events.Unsubscribe() + defer func() { + if !events.Closed() { + events.Unsubscribe() + } + }() shouldStart := false canStart := true + dlEventCh := events.Chan() for { select { - case ev := <-events.Chan(): + case ev := <-dlEventCh: if ev == nil { - return + // Unsubscription done, stop listening + dlEventCh = nil + continue } switch ev.Data.(type) { case downloader.StartEvent: @@ -105,16 +112,24 @@ func (miner *Miner) update() { shouldStart = true log.Info("Mining aborted due to sync") } - case downloader.DoneEvent, downloader.FailedEvent: + case downloader.FailedEvent: + canStart = true + if shouldStart { + miner.SetEtherbase(miner.coinbase) + miner.worker.start() + } + case downloader.DoneEvent: canStart = true if shouldStart { miner.SetEtherbase(miner.coinbase) miner.worker.start() } + // Stop reacting to downloader events + events.Unsubscribe() } case addr := <-miner.startCh: + miner.SetEtherbase(addr) if canStart { - miner.SetEtherbase(addr) miner.worker.start() } shouldStart = true diff --git a/miner/miner_test.go b/miner/miner_test.go index 2ed03a2397a9..127b4c7687ae 100644 --- a/miner/miner_test.go +++ b/miner/miner_test.go @@ -22,7 +22,7 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/consensus/clique" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/event" - "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/trie" ) @@ -89,12 +88,75 @@ func TestMiner(t *testing.T) { // Stop the downloader and wait for the update loop to run mux.Post(downloader.DoneEvent{}) waitForMiningState(t, miner, true) - // Start the downloader and wait for the update loop to run + + // Subsequent downloader events after a successful DoneEvent should not cause the + // miner to start or stop. This prevents a security vulnerability + // that would allow entities to present fake high blocks that would + // stop mining operations by causing a downloader sync + // until it was discovered they were invalid, whereon mining would resume. + mux.Post(downloader.StartEvent{}) + waitForMiningState(t, miner, true) + + mux.Post(downloader.FailedEvent{}) + waitForMiningState(t, miner, true) +} + +// TestMinerDownloaderFirstFails tests that mining is only +// permitted to run indefinitely once the downloader sees a DoneEvent (success). +// An initial FailedEvent should allow mining to stop on a subsequent +// downloader StartEvent. +func TestMinerDownloaderFirstFails(t *testing.T) { + miner, mux := createMiner(t) + miner.Start(common.HexToAddress("0x12345")) + waitForMiningState(t, miner, true) + // Start the downloader mux.Post(downloader.StartEvent{}) waitForMiningState(t, miner, false) + // Stop the downloader and wait for the update loop to run mux.Post(downloader.FailedEvent{}) waitForMiningState(t, miner, true) + + // Since the downloader hasn't yet emitted a successful DoneEvent, + // we expect the miner to stop on next StartEvent. + mux.Post(downloader.StartEvent{}) + waitForMiningState(t, miner, false) + + // Downloader finally succeeds. + mux.Post(downloader.DoneEvent{}) + waitForMiningState(t, miner, true) + + // Downloader starts again. + // Since it has achieved a DoneEvent once, we expect miner + // state to be unchanged. + mux.Post(downloader.StartEvent{}) + waitForMiningState(t, miner, true) + + mux.Post(downloader.FailedEvent{}) + waitForMiningState(t, miner, true) +} + +func TestMinerStartStopAfterDownloaderEvents(t *testing.T) { + miner, mux := createMiner(t) + + miner.Start(common.HexToAddress("0x12345")) + waitForMiningState(t, miner, true) + // Start the downloader + mux.Post(downloader.StartEvent{}) + waitForMiningState(t, miner, false) + + // Downloader finally succeeds. + mux.Post(downloader.DoneEvent{}) + waitForMiningState(t, miner, true) + + miner.Stop() + waitForMiningState(t, miner, false) + + miner.Start(common.HexToAddress("0x678910")) + waitForMiningState(t, miner, true) + + miner.Stop() + waitForMiningState(t, miner, false) } func TestStartWhileDownload(t *testing.T) { @@ -129,6 +191,28 @@ func TestCloseMiner(t *testing.T) { waitForMiningState(t, miner, false) } +// TestMinerSetEtherbase checks that etherbase becomes set even if mining isn't +// possible at the moment +func TestMinerSetEtherbase(t *testing.T) { + miner, mux := createMiner(t) + // Start with a 'bad' mining address + miner.Start(common.HexToAddress("0xdead")) + waitForMiningState(t, miner, true) + // Start the downloader + mux.Post(downloader.StartEvent{}) + waitForMiningState(t, miner, false) + // Now user tries to configure proper mining address + miner.Start(common.HexToAddress("0x1337")) + // Stop the downloader and wait for the update loop to run + mux.Post(downloader.DoneEvent{}) + + waitForMiningState(t, miner, true) + // The miner should now be using the good address + if got, exp := miner.coinbase, common.HexToAddress("0x1337"); got != exp { + t.Fatalf("Wrong coinbase, got %x expected %x", got, exp) + } +} + // waitForMiningState waits until either // * the desired mining state was reached // * a timeout was reached which fails the test @@ -137,10 +221,10 @@ func waitForMiningState(t *testing.T, m *Miner, mining bool) { var state bool for i := 0; i < 100; i++ { + time.Sleep(10 * time.Millisecond) if state = m.Mining(); state == mining { return } - time.Sleep(10 * time.Millisecond) } t.Fatalf("Mining() == %t, want %t", state, mining) } @@ -158,26 +242,20 @@ func createMiner(t *testing.T) (*Miner, *event.TypeMux) { if err != nil { t.Fatalf("can't create new chain config: %v", err) } - // Create event Mux - mux := new(event.TypeMux) // Create consensus engine - engine := ethash.New(ethash.Config{}, []string{}, false) - engine.SetThreads(-1) - // Create isLocalBlock - isLocalBlock := func(block *types.Block) bool { - return true - } + engine := clique.New(chainConfig.Clique, chainDB) // Create Ethereum backend - limit := uint64(1000) - bc, err := core.NewBlockChain(chainDB, new(core.CacheConfig), chainConfig, engine, vm.Config{}, isLocalBlock, &limit) + bc, err := core.NewBlockChain(chainDB, nil, chainConfig, engine, vm.Config{}, nil, nil) if err != nil { t.Fatalf("can't create new chain %v", err) } - statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + statedb, _ := state.New(common.Hash{}, state.NewDatabase(chainDB), nil) blockchain := &testBlockChain{statedb, 10000000, new(event.Feed)} - pool := core.NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) + pool := core.NewTxPool(testTxPoolConfig, chainConfig, blockchain) backend := NewMockBackend(bc, pool) + // Create event Mux + mux := new(event.TypeMux) // Create Miner - return New(backend, &config, chainConfig, mux, engine, isLocalBlock), mux + return New(backend, &config, chainConfig, mux, engine, nil), mux } diff --git a/miner/unconfirmed.go b/miner/unconfirmed.go index 3a176e8bd6f6..0489f1ea4adb 100644 --- a/miner/unconfirmed.go +++ b/miner/unconfirmed.go @@ -50,7 +50,7 @@ type unconfirmedBlocks struct { chain chainRetriever // Blockchain to verify canonical status through depth uint // Depth after which to discard previous blocks blocks *ring.Ring // Block infos to allow canonical chain cross checks - lock sync.RWMutex // Protects the fields from concurrent access + lock sync.Mutex // Protects the fields from concurrent access } // newUnconfirmedBlocks returns new data structure to track currently unconfirmed blocks. diff --git a/miner/unconfirmed_test.go b/miner/unconfirmed_test.go index 42e77f3e648c..dc83cb92652d 100644 --- a/miner/unconfirmed_test.go +++ b/miner/unconfirmed_test.go @@ -19,7 +19,6 @@ package miner import ( "testing" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) @@ -43,7 +42,7 @@ func TestUnconfirmedInsertBounds(t *testing.T) { for depth := uint64(0); depth < 2*uint64(limit); depth++ { // Insert multiple blocks for the same level just to stress it for i := 0; i < int(depth); i++ { - pool.Insert(depth, common.Hash([32]byte{byte(depth), byte(i)})) + pool.Insert(depth, [32]byte{byte(depth), byte(i)}) } // Validate that no blocks below the depth allowance are left in pool.blocks.Do(func(block interface{}) { @@ -63,7 +62,7 @@ func TestUnconfirmedShifts(t *testing.T) { pool := newUnconfirmedBlocks(new(noopChainRetriever), limit) for depth := start; depth < start+uint64(limit); depth++ { - pool.Insert(depth, common.Hash([32]byte{byte(depth)})) + pool.Insert(depth, [32]byte{byte(depth)}) } // Try to shift below the limit and ensure no blocks are dropped pool.Shift(start + uint64(limit) - 1) diff --git a/miner/worker.go b/miner/worker.go index 16f4c1c31393..5f07affdc412 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -347,7 +347,11 @@ func (w *worker) newWorkLoop(recommit time.Duration) { atomic.StoreInt32(interrupt, s) } interrupt = new(int32) - w.newWorkCh <- &newWorkReq{interrupt: interrupt, noempty: noempty, timestamp: timestamp} + select { + case w.newWorkCh <- &newWorkReq{interrupt: interrupt, noempty: noempty, timestamp: timestamp}: + case <-w.exitCh: + return + } timer.Reset(recommit) atomic.StoreInt32(&w.newTxs, 0) } @@ -793,23 +797,23 @@ func (w *worker) commitTransactions(txs *types.TransactionsByPriceAndNonce, coin w.current.state.Prepare(tx.Hash(), common.Hash{}, w.current.tcount) logs, err := w.commitTransaction(tx, coinbase) - switch err { - case core.ErrGasLimitReached: + switch { + case errors.Is(err, core.ErrGasLimitReached): // Pop the current out-of-gas transaction without shifting in the next from the account log.Trace("Gas limit exceeded for current block", "sender", from) txs.Pop() - case core.ErrNonceTooLow: + case errors.Is(err, core.ErrNonceTooLow): // New head notification data race between the transaction pool and miner, shift log.Trace("Skipping transaction with low nonce", "sender", from, "nonce", tx.Nonce()) txs.Shift() - case core.ErrNonceTooHigh: + case errors.Is(err, core.ErrNonceTooHigh): // Reorg notification data race between the transaction pool and miner, skip account = log.Trace("Skipping account with hight nonce", "sender", from, "nonce", tx.Nonce()) txs.Pop() - case nil: + case errors.Is(err, nil): // Everything ok, collect the logs and shift in the next transaction from the same account coalescedLogs = append(coalescedLogs, logs...) w.current.tcount++ diff --git a/mobile/big.go b/mobile/big.go index 86ea93245aab..c08bcf93f285 100644 --- a/mobile/big.go +++ b/mobile/big.go @@ -35,6 +35,16 @@ func NewBigInt(x int64) *BigInt { return &BigInt{big.NewInt(x)} } +// NewBigIntFromString allocates and returns a new BigInt set to x +// interpreted in the provided base. +func NewBigIntFromString(x string, base int) *BigInt { + b, success := new(big.Int).SetString(x, base) + if !success { + return nil + } + return &BigInt{b} +} + // GetBytes returns the absolute value of x as a big-endian byte slice. func (bi *BigInt) GetBytes() []byte { return bi.bigint.Bytes() diff --git a/mobile/bind.go b/mobile/bind.go index f64b37ec1580..e32d864aa58a 100644 --- a/mobile/bind.go +++ b/mobile/bind.go @@ -40,7 +40,7 @@ type MobileSigner struct { } func (s *MobileSigner) Sign(addr *Address, unsignedTx *Transaction) (signedTx *Transaction, _ error) { - sig, err := s.sign(types.EIP155Signer{}, addr.address, unsignedTx.tx) + sig, err := s.sign(addr.address, unsignedTx.tx) if err != nil { return nil, err } @@ -82,12 +82,16 @@ func NewTransactOpts() *TransactOpts { // NewKeyedTransactOpts is a utility method to easily create a transaction signer // from a single private key. -func NewKeyedTransactOpts(keyJson []byte, passphrase string) (*TransactOpts, error) { +func NewKeyedTransactOpts(keyJson []byte, passphrase string, chainID *big.Int) (*TransactOpts, error) { key, err := keystore.DecryptKey(keyJson, passphrase) if err != nil { return nil, err } - return &TransactOpts{*bind.NewKeyedTransactor(key.PrivateKey)}, nil + auth, err := bind.NewKeyedTransactorWithChainID(key.PrivateKey, chainID) + if err != nil { + return nil, err + } + return &TransactOpts{*auth}, nil } func (opts *TransactOpts) GetFrom() *Address { return &Address{opts.opts.From} } @@ -106,7 +110,7 @@ func (opts *TransactOpts) GetGasLimit() int64 { return int64(opts.opts.GasLimi func (opts *TransactOpts) SetFrom(from *Address) { opts.opts.From = from.address } func (opts *TransactOpts) SetNonce(nonce int64) { opts.opts.Nonce = big.NewInt(nonce) } func (opts *TransactOpts) SetSigner(s Signer) { - opts.opts.Signer = func(signer types.Signer, addr common.Address, tx *types.Transaction) (*types.Transaction, error) { + opts.opts.Signer = func(addr common.Address, tx *types.Transaction) (*types.Transaction, error) { sig, err := s.Sign(&Address{addr}, &Transaction{tx}) if err != nil { return nil, err @@ -171,20 +175,12 @@ func (c *BoundContract) GetDeployer() *Transaction { // Call invokes the (constant) contract method with params as input values and // sets the output to result. func (c *BoundContract) Call(opts *CallOpts, out *Interfaces, method string, args *Interfaces) error { - if len(out.objects) == 1 { - result := out.objects[0] - if err := c.contract.Call(&opts.opts, result, method, args.objects...); err != nil { - return err - } - out.objects[0] = result - } else { - results := make([]interface{}, len(out.objects)) - copy(results, out.objects) - if err := c.contract.Call(&opts.opts, &results, method, args.objects...); err != nil { - return err - } - copy(out.objects, results) + results := make([]interface{}, len(out.objects)) + copy(results, out.objects) + if err := c.contract.Call(&opts.opts, &results, method, args.objects...); err != nil { + return err } + copy(out.objects, results) return nil } diff --git a/mobile/common.go b/mobile/common.go index d7e0457261a0..124712b4b1f2 100644 --- a/mobile/common.go +++ b/mobile/common.go @@ -87,6 +87,11 @@ func (h *Hash) GetHex() string { return h.hash.Hex() } +// String implements Stringer interface for printable representation of the hash. +func (h *Hash) String() string { + return h.GetHex() +} + // Hashes represents a slice of hashes. type Hashes struct{ hashes []common.Hash } @@ -188,6 +193,11 @@ func (a *Address) GetHex() string { return a.address.Hex() } +// String returns a printable representation of the address. +func (a *Address) String() string { + return a.GetHex() +} + // Addresses represents a slice of addresses. type Addresses struct{ addresses []common.Address } diff --git a/mobile/ethereum.go b/mobile/ethereum.go index 59da85239744..97c46ddca71f 100644 --- a/mobile/ethereum.go +++ b/mobile/ethereum.go @@ -21,7 +21,7 @@ package geth import ( "errors" - ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" ) diff --git a/mobile/geth.go b/mobile/geth.go index ba58507d634c..b561e33675df 100644 --- a/mobile/geth.go +++ b/mobile/geth.go @@ -90,6 +90,22 @@ func NewNodeConfig() *NodeConfig { return &config } +// AddBootstrapNode adds an additional bootstrap node to the node config. +func (conf *NodeConfig) AddBootstrapNode(node *Enode) { + conf.BootstrapNodes.Append(node) +} + +// EncodeJSON encodes a NodeConfig into a JSON data dump. +func (conf *NodeConfig) EncodeJSON() (string, error) { + data, err := json.Marshal(conf) + return string(data), err +} + +// String returns a printable representation of the node config. +func (conf *NodeConfig) String() string { + return encodeOrError(conf) +} + // Node represents a Geth Ethereum node instance. type Node struct { node *node.Node diff --git a/mobile/types.go b/mobile/types.go index 9d7552028206..de6457e7e17c 100644 --- a/mobile/types.go +++ b/mobile/types.go @@ -28,6 +28,20 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +type jsonEncoder interface { + EncodeJSON() (string, error) +} + +// encodeOrError tries to encode the object into json. +// If the encoding fails the resulting error is returned. +func encodeOrError(encoder jsonEncoder) string { + enc, err := encoder.EncodeJSON() + if err != nil { + return err.Error() + } + return enc +} + // A Nonce is a 64-bit hash which proves (combined with the mix-hash) that // a sufficient amount of computation has been carried out on a block. type Nonce struct { @@ -44,6 +58,11 @@ func (n *Nonce) GetHex() string { return fmt.Sprintf("0x%x", n.nonce[:]) } +// String returns a printable representation of the nonce. +func (n *Nonce) String() string { + return n.GetHex() +} + // Bloom represents a 256 bit bloom filter. type Bloom struct { bloom types.Bloom @@ -59,6 +78,11 @@ func (b *Bloom) GetHex() string { return fmt.Sprintf("0x%x", b.bloom[:]) } +// String returns a printable representation of the bloom filter. +func (b *Bloom) String() string { + return b.GetHex() +} + // Header represents a block header in the Ethereum blockchain. type Header struct { header *types.Header @@ -97,6 +121,11 @@ func (h *Header) EncodeJSON() (string, error) { return string(data), err } +// String returns a printable representation of the header. +func (h *Header) String() string { + return encodeOrError(h) +} + func (h *Header) GetParentHash() *Hash { return &Hash{h.header.ParentHash} } func (h *Header) GetUncleHash() *Hash { return &Hash{h.header.UncleHash} } func (h *Header) GetCoinbase() *Address { return &Address{h.header.Coinbase} } @@ -168,6 +197,11 @@ func (b *Block) EncodeJSON() (string, error) { return string(data), err } +// String returns a printable representation of the block. +func (b *Block) String() string { + return encodeOrError(b) +} + func (b *Block) GetParentHash() *Hash { return &Hash{b.block.ParentHash()} } func (b *Block) GetUncleHash() *Hash { return &Hash{b.block.UncleHash()} } func (b *Block) GetCoinbase() *Address { return &Address{b.block.Coinbase()} } @@ -244,6 +278,11 @@ func (tx *Transaction) EncodeJSON() (string, error) { return string(data), err } +// String returns a printable representation of the transaction. +func (tx *Transaction) String() string { + return encodeOrError(tx) +} + func (tx *Transaction) GetData() []byte { return tx.tx.Data() } func (tx *Transaction) GetGas() int64 { return int64(tx.tx.Gas()) } func (tx *Transaction) GetGasPrice() *BigInt { return &BigInt{tx.tx.GasPrice()} } @@ -336,6 +375,11 @@ func (r *Receipt) EncodeJSON() (string, error) { return string(data), err } +// String returns a printable representation of the receipt. +func (r *Receipt) String() string { + return encodeOrError(r) +} + func (r *Receipt) GetStatus() int { return int(r.receipt.Status) } func (r *Receipt) GetPostState() []byte { return r.receipt.PostState } func (r *Receipt) GetCumulativeGasUsed() int64 { return int64(r.receipt.CumulativeGasUsed) } diff --git a/node/rpcstack.go b/node/rpcstack.go index caf7e5b7a83d..81e054ec995d 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -313,7 +313,7 @@ func (h *httpServer) wsAllowed() bool { // isWebsocket checks the header of an http request for a websocket upgrade request. func isWebsocket(r *http.Request) bool { return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" && - strings.ToLower(r.Header.Get("Connection")) == "upgrade" + strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") } // NewHTTPHandlerStack returns wrapped http-related handlers @@ -448,6 +448,7 @@ func (is *ipcServer) start(apis []rpc.API) error { } listener, srv, err := rpc.StartIPCEndpoint(is.endpoint, apis) if err != nil { + is.log.Warn("IPC opening failed", "url", is.endpoint, "error", err) return err } is.log.Info("IPC endpoint opened", "url", is.endpoint) diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index efab5b37dc5a..8267fb2f1ddb 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -19,6 +19,7 @@ package node import ( "bytes" "net/http" + "strings" "testing" "github.com/ethereum/go-ethereum/internal/testlog" @@ -52,25 +53,119 @@ func TestVhosts(t *testing.T) { assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } +type originTest struct { + spec string + expOk []string + expFail []string +} + +// splitAndTrim splits input separated by a comma +// and trims excessive white space from the substrings. +// Copied over from flags.go +func splitAndTrim(input string) (ret []string) { + l := strings.Split(input, ",") + for _, r := range l { + r = strings.TrimSpace(r) + if len(r) > 0 { + ret = append(ret, r) + } + } + return ret +} + // TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server. func TestWebsocketOrigins(t *testing.T) { - srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: []string{"test"}}) - defer srv.stop() - - dialer := websocket.DefaultDialer - _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{"test"}, - }) - assert.NoError(t, err) + tests := []originTest{ + { + spec: "*", // allow all + expOk: []string{"", "http://test", "https://test", "http://test:8540", "https://test:8540", + "http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"}, + }, + { + spec: "test", + expOk: []string{"http://test", "https://test", "http://test:8540", "https://test:8540"}, + expFail: []string{"http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"}, + }, + // scheme tests + { + spec: "https://test", + expOk: []string{"https://test", "https://test:9999"}, + expFail: []string{ + "test", // no scheme, required by spec + "http://test", // wrong scheme + "http://test.foo", "https://a.test.x", // subdomain variatoins + "http://testx:8540", "https://xtest:8540"}, + }, + // ip tests + { + spec: "https://12.34.56.78", + expOk: []string{"https://12.34.56.78", "https://12.34.56.78:8540"}, + expFail: []string{ + "http://12.34.56.78", // wrong scheme + "http://12.34.56.78:443", // wrong scheme + "http://1.12.34.56.78", // wrong 'domain name' + "http://12.34.56.78.a", // wrong 'domain name' + "https://87.65.43.21", "http://87.65.43.21:8540", "https://87.65.43.21:8540"}, + }, + // port tests + { + spec: "test:8540", + expOk: []string{"http://test:8540", "https://test:8540"}, + expFail: []string{ + "http://test", "https://test", // spec says port required + "http://test:8541", "https://test:8541", // wrong port + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + // scheme and port + { + spec: "https://test:8540", + expOk: []string{"https://test:8540"}, + expFail: []string{ + "https://test", // missing port + "http://test", // missing port, + wrong scheme + "http://test:8540", // wrong scheme + "http://test:8541", "https://test:8541", // wrong port + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + // several allowed origins + { + spec: "localhost,http://127.0.0.1", + expOk: []string{"localhost", "http://localhost", "https://localhost:8443", + "http://127.0.0.1", "http://127.0.0.1:8080"}, + expFail: []string{ + "https://127.0.0.1", // wrong scheme + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + } + for _, tc := range tests { + srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: splitAndTrim(tc.spec)}) + for _, origin := range tc.expOk { + if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err != nil { + t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) + } + } + for _, origin := range tc.expFail { + if err := attemptWebsocketConnectionFromOrigin(t, srv, origin); err == nil { + t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) + } + } + srv.stop() + } +} - _, _, err = dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{"bad"}, - }) - assert.Error(t, err) +// TestIsWebsocket tests if an incoming websocket upgrade request is handled properly. +func TestIsWebsocket(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) + + assert.False(t, isWebsocket(r)) + r.Header.Set("upgrade", "websocket") + assert.False(t, isWebsocket(r)) + r.Header.Set("connection", "upgrade") + assert.True(t, isWebsocket(r)) + r.Header.Set("connection", "upgrade,keep-alive") + assert.True(t, isWebsocket(r)) + r.Header.Set("connection", " UPGRADE,keep-alive") + assert.True(t, isWebsocket(r)) } func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer { @@ -88,6 +183,17 @@ func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfi return srv } +func attemptWebsocketConnectionFromOrigin(t *testing.T, srv *httpServer, browserOrigin string) error { + t.Helper() + dialer := websocket.DefaultDialer + _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ + "Content-type": []string{"application/json"}, + "Sec-WebSocket-Version": []string{"13"}, + "Origin": []string{browserOrigin}, + }) + return err +} + func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response { t.Helper() diff --git a/oss-fuzz.sh b/oss-fuzz.sh new file mode 100644 index 000000000000..e060ea88e1af --- /dev/null +++ b/oss-fuzz.sh @@ -0,0 +1,74 @@ +#/bin/bash -eu +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +# This file is for integration with Google OSS-Fuzz. +# The following ENV variables are available when executing on OSS-fuzz: +# +# /out/ $OUT Directory to store build artifacts (fuzz targets, dictionaries, options files, seed corpus archives). +# /src/ $SRC Directory to checkout source files. +# /work/ $WORK Directory to store intermediate files. +# +# $CC, $CXX, $CCC The C and C++ compiler binaries. +# $CFLAGS, $CXXFLAGS C and C++ compiler flags. +# $LIB_FUZZING_ENGINE C++ compiler argument to link fuzz target against the prebuilt engine library (e.g. libFuzzer). + +function compile_fuzzer { + path=$SRC/go-ethereum/$1 + func=$2 + fuzzer=$3 + corpusfile="${path}/testdata/${fuzzer}_seed_corpus.zip" + echo "Building $fuzzer (expecting corpus at $corpusfile)" + (cd $path && \ + go-fuzz -func $func -o $WORK/$fuzzer.a . && \ + echo "First stage built OK" && \ + $CXX $CXXFLAGS $LIB_FUZZING_ENGINE $WORK/$fuzzer.a -o $OUT/$fuzzer && \ + echo "Second stage built ok" ) + + ## Check if there exists a seed corpus file + if [ -f $corpusfile ] + then + cp $corpusfile $OUT/ + echo "Found seed corpus: $corpusfile" + fi +} + +compile_fuzzer common/bitutil Fuzz fuzzBitutilCompress +compile_fuzzer crypto/bn256 FuzzAdd fuzzBn256Add +compile_fuzzer crypto/bn256 FuzzMul fuzzBn256Mul +compile_fuzzer crypto/bn256 FuzzPair fuzzBn256Pair +compile_fuzzer core/vm/runtime Fuzz fuzzVmRuntime +compile_fuzzer crypto/blake2b Fuzz fuzzBlake2b +compile_fuzzer tests/fuzzers/keystore Fuzz fuzzKeystore +compile_fuzzer tests/fuzzers/txfetcher Fuzz fuzzTxfetcher +compile_fuzzer tests/fuzzers/rlp Fuzz fuzzRlp +compile_fuzzer tests/fuzzers/trie Fuzz fuzzTrie +compile_fuzzer tests/fuzzers/stacktrie Fuzz fuzzStackTrie +compile_fuzzer tests/fuzzers/difficulty Fuzz fuzzDifficulty + +compile_fuzzer tests/fuzzers/bls12381 FuzzG1Add fuzz_g1_add +compile_fuzzer tests/fuzzers/bls12381 FuzzG1Mul fuzz_g1_mul +compile_fuzzer tests/fuzzers/bls12381 FuzzG1MultiExp fuzz_g1_multiexp +compile_fuzzer tests/fuzzers/bls12381 FuzzG2Add fuzz_g2_add +compile_fuzzer tests/fuzzers/bls12381 FuzzG2Mul fuzz_g2_mul +compile_fuzzer tests/fuzzers/bls12381 FuzzG2MultiExp fuzz_g2_multiexp +compile_fuzzer tests/fuzzers/bls12381 FuzzPairing fuzz_pairing +compile_fuzzer tests/fuzzers/bls12381 FuzzMapG1 fuzz_map_g1 +compile_fuzzer tests/fuzzers/bls12381 FuzzMapG2 fuzz_map_g2 + +# This doesn't work very well @TODO +#compile_fuzzertests/fuzzers/abi Fuzz fuzzAbi + diff --git a/p2p/discover/node.go b/p2p/discover/node.go index e635c64ac902..9ffe101ccff8 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -46,7 +46,10 @@ func encodePubkey(key *ecdsa.PublicKey) encPubkey { return e } -func decodePubkey(curve elliptic.Curve, e encPubkey) (*ecdsa.PublicKey, error) { +func decodePubkey(curve elliptic.Curve, e []byte) (*ecdsa.PublicKey, error) { + if len(e) != len(encPubkey{}) { + return nil, errors.New("wrong size public key data") + } p := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} half := len(e) / 2 p.X.SetBytes(e[:half]) diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 44b62e751b41..47a2e7ac3caf 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -146,7 +146,6 @@ func (t *pingRecorder) updateRecord(n *enode.Node) { func (t *pingRecorder) Self() *enode.Node { return nullNode } func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } -func (t *pingRecorder) close() {} // ping simulates a ping request. func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { @@ -188,15 +187,16 @@ func hasDuplicates(slice []*node) bool { return false } +// checkNodesEqual checks whether the two given node lists contain the same nodes. func checkNodesEqual(got, want []*enode.Node) error { if len(got) == len(want) { for i := range got { if !nodeEqual(got[i], want[i]) { goto NotEqual } - return nil } } + return nil NotEqual: output := new(bytes.Buffer) @@ -227,6 +227,7 @@ func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { }) } +// hexEncPrivkey decodes h as a private key. func hexEncPrivkey(h string) *ecdsa.PrivateKey { b, err := hex.DecodeString(h) if err != nil { @@ -239,6 +240,7 @@ func hexEncPrivkey(h string) *ecdsa.PrivateKey { return key } +// hexEncPubkey decodes h as a public key. func hexEncPubkey(h string) (ret encPubkey) { b, err := hex.DecodeString(h) if err != nil { diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 20093852622a..a00de9ca18c7 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -34,7 +34,7 @@ func TestUDPv4_Lookup(t *testing.T) { test := newUDPTest(t) // Lookup on empty table returns no nodes. - targetKey, _ := decodePubkey(crypto.S256(), lookupTestnet.target) + targetKey, _ := decodePubkey(crypto.S256(), lookupTestnet.target[:]) if results := test.udp.LookupPubkey(targetKey); len(results) > 0 { t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) } @@ -279,17 +279,21 @@ func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node { return result } -func (tn *preminedTestnet) neighborsAtDistance(base *enode.Node, distance uint, elems int) []*enode.Node { - nodes := nodesByDistance{target: base.ID()} +func (tn *preminedTestnet) neighborsAtDistances(base *enode.Node, distances []uint, elems int) []*enode.Node { + var result []*enode.Node for d := range lookupTestnet.dists { for i := range lookupTestnet.dists[d] { n := lookupTestnet.node(d, i) - if uint(enode.LogDist(n.ID(), base.ID())) == distance { - nodes.push(wrapNode(n), elems) + d := enode.LogDist(base.ID(), n.ID()) + if containsUint(uint(d), distances) { + result = append(result, n) + if len(result) >= elems { + return result + } } } } - return unwrapNodes(nodes.entries) + return result } func (tn *preminedTestnet) closest(n int) (nodes []*enode.Node) { diff --git a/p2p/discover/v5_encoding.go b/p2p/discover/v5_encoding.go deleted file mode 100644 index 842234e790b9..000000000000 --- a/p2p/discover/v5_encoding.go +++ /dev/null @@ -1,659 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package discover - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/ecdsa" - "crypto/elliptic" - crand "crypto/rand" - "crypto/sha256" - "errors" - "fmt" - "hash" - "net" - "time" - - "github.com/ethereum/go-ethereum/common/math" - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/rlp" - "golang.org/x/crypto/hkdf" -) - -// TODO concurrent WHOAREYOU tie-breaker -// TODO deal with WHOAREYOU amplification factor (min packet size?) -// TODO add counter to nonce -// TODO rehandshake after X packets - -// Discovery v5 packet types. -const ( - p_pingV5 byte = iota + 1 - p_pongV5 - p_findnodeV5 - p_nodesV5 - p_requestTicketV5 - p_ticketV5 - p_regtopicV5 - p_regconfirmationV5 - p_topicqueryV5 - p_unknownV5 = byte(255) // any non-decryptable packet - p_whoareyouV5 = byte(254) // the WHOAREYOU packet -) - -// Discovery v5 packet structures. -type ( - // unknownV5 represents any packet that can't be decrypted. - unknownV5 struct { - AuthTag []byte - } - - // WHOAREYOU contains the handshake challenge. - whoareyouV5 struct { - AuthTag []byte - IDNonce [32]byte // To be signed by recipient. - RecordSeq uint64 // ENR sequence number of recipient - - node *enode.Node - sent mclock.AbsTime - } - - // PING is sent during liveness checks. - pingV5 struct { - ReqID []byte - ENRSeq uint64 - } - - // PONG is the reply to PING. - pongV5 struct { - ReqID []byte - ENRSeq uint64 - ToIP net.IP // These fields should mirror the UDP envelope address of the ping - ToPort uint16 // packet, which provides a way to discover the the external address (after NAT). - } - - // FINDNODE is a query for nodes in the given bucket. - findnodeV5 struct { - ReqID []byte - Distance uint - } - - // NODES is the reply to FINDNODE and TOPICQUERY. - nodesV5 struct { - ReqID []byte - Total uint8 - Nodes []*enr.Record - } - - // REQUESTTICKET requests a ticket for a topic queue. - requestTicketV5 struct { - ReqID []byte - Topic []byte - } - - // TICKET is the response to REQUESTTICKET. - ticketV5 struct { - ReqID []byte - Ticket []byte - } - - // REGTOPIC registers the sender in a topic queue using a ticket. - regtopicV5 struct { - ReqID []byte - Ticket []byte - ENR *enr.Record - } - - // REGCONFIRMATION is the reply to REGTOPIC. - regconfirmationV5 struct { - ReqID []byte - Registered bool - } - - // TOPICQUERY asks for nodes with the given topic. - topicqueryV5 struct { - ReqID []byte - Topic []byte - } -) - -const ( - // Encryption/authentication parameters. - authSchemeName = "gcm" - aesKeySize = 16 - gcmNonceSize = 12 - idNoncePrefix = "discovery-id-nonce" - handshakeTimeout = time.Second -) - -var ( - errTooShort = errors.New("packet too short") - errUnexpectedHandshake = errors.New("unexpected auth response, not in handshake") - errHandshakeNonceMismatch = errors.New("wrong nonce in auth response") - errInvalidAuthKey = errors.New("invalid ephemeral pubkey") - errUnknownAuthScheme = errors.New("unknown auth scheme in handshake") - errNoRecord = errors.New("expected ENR in handshake but none sent") - errInvalidNonceSig = errors.New("invalid ID nonce signature") - zeroNonce = make([]byte, gcmNonceSize) -) - -// wireCodec encodes and decodes discovery v5 packets. -type wireCodec struct { - sha256 hash.Hash - localnode *enode.LocalNode - privkey *ecdsa.PrivateKey - myChtagHash enode.ID - myWhoareyouMagic []byte - - sc *sessionCache -} - -type handshakeSecrets struct { - writeKey, readKey, authRespKey []byte -} - -type authHeader struct { - authHeaderList - isHandshake bool -} - -type authHeaderList struct { - Auth []byte // authentication info of packet - IDNonce [32]byte // IDNonce of WHOAREYOU - Scheme string // name of encryption/authentication scheme - EphemeralKey []byte // ephemeral public key - Response []byte // encrypted authResponse -} - -type authResponse struct { - Version uint - Signature []byte - Record *enr.Record `rlp:"nil"` // sender's record -} - -func (h *authHeader) DecodeRLP(r *rlp.Stream) error { - k, _, err := r.Kind() - if err != nil { - return err - } - if k == rlp.Byte || k == rlp.String { - return r.Decode(&h.Auth) - } - h.isHandshake = true - return r.Decode(&h.authHeaderList) -} - -// ephemeralKey decodes the ephemeral public key in the header. -func (h *authHeaderList) ephemeralKey(curve elliptic.Curve) *ecdsa.PublicKey { - var key encPubkey - copy(key[:], h.EphemeralKey) - pubkey, _ := decodePubkey(curve, key) - return pubkey -} - -// newWireCodec creates a wire codec. -func newWireCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *wireCodec { - c := &wireCodec{ - sha256: sha256.New(), - localnode: ln, - privkey: key, - sc: newSessionCache(1024, clock), - } - // Create magic strings for packet matching. - self := ln.ID() - c.myWhoareyouMagic = c.sha256sum(self[:], []byte("WHOAREYOU")) - copy(c.myChtagHash[:], c.sha256sum(self[:])) - return c -} - -// encode encodes a packet to a node. 'id' and 'addr' specify the destination node. The -// 'challenge' parameter should be the most recently received WHOAREYOU packet from that -// node. -func (c *wireCodec) encode(id enode.ID, addr string, packet packetV5, challenge *whoareyouV5) ([]byte, []byte, error) { - if packet.kind() == p_whoareyouV5 { - p := packet.(*whoareyouV5) - enc, err := c.encodeWhoareyou(id, p) - if err == nil { - c.sc.storeSentHandshake(id, addr, p) - } - return enc, nil, err - } - // Ensure calling code sets node if needed. - if challenge != nil && challenge.node == nil { - panic("BUG: missing challenge.node in encode") - } - writeKey := c.sc.writeKey(id, addr) - if writeKey != nil || challenge != nil { - return c.encodeEncrypted(id, addr, packet, writeKey, challenge) - } - return c.encodeRandom(id) -} - -// encodeRandom encodes a random packet. -func (c *wireCodec) encodeRandom(toID enode.ID) ([]byte, []byte, error) { - tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID()) - r := make([]byte, 44) // TODO randomize size - if _, err := crand.Read(r); err != nil { - return nil, nil, err - } - nonce := make([]byte, gcmNonceSize) - if _, err := crand.Read(nonce); err != nil { - return nil, nil, fmt.Errorf("can't get random data: %v", err) - } - b := new(bytes.Buffer) - b.Write(tag[:]) - rlp.Encode(b, nonce) - b.Write(r) - return b.Bytes(), nonce, nil -} - -// encodeWhoareyou encodes WHOAREYOU. -func (c *wireCodec) encodeWhoareyou(toID enode.ID, packet *whoareyouV5) ([]byte, error) { - // Sanity check node field to catch misbehaving callers. - if packet.RecordSeq > 0 && packet.node == nil { - panic("BUG: missing node in whoareyouV5 with non-zero seq") - } - b := new(bytes.Buffer) - b.Write(c.sha256sum(toID[:], []byte("WHOAREYOU"))) - err := rlp.Encode(b, packet) - return b.Bytes(), err -} - -// encodeEncrypted encodes an encrypted packet. -func (c *wireCodec) encodeEncrypted(toID enode.ID, toAddr string, packet packetV5, writeKey []byte, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) { - nonce := make([]byte, gcmNonceSize) - if _, err := crand.Read(nonce); err != nil { - return nil, nil, fmt.Errorf("can't get random data: %v", err) - } - - var headEnc []byte - if challenge == nil { - // Regular packet, use existing key and simply encode nonce. - headEnc, _ = rlp.EncodeToBytes(nonce) - } else { - // We're answering WHOAREYOU, generate new keys and encrypt with those. - header, sec, err := c.makeAuthHeader(nonce, challenge) - if err != nil { - return nil, nil, err - } - if headEnc, err = rlp.EncodeToBytes(header); err != nil { - return nil, nil, err - } - c.sc.storeNewSession(toID, toAddr, sec.readKey, sec.writeKey) - writeKey = sec.writeKey - } - - // Encode the packet. - body := new(bytes.Buffer) - body.WriteByte(packet.kind()) - if err := rlp.Encode(body, packet); err != nil { - return nil, nil, err - } - tag := xorTag(c.sha256sum(toID[:]), c.localnode.ID()) - headsize := len(tag) + len(headEnc) - headbuf := make([]byte, headsize) - copy(headbuf[:], tag[:]) - copy(headbuf[len(tag):], headEnc) - - // Encrypt the body. - enc, err = encryptGCM(headbuf, writeKey, nonce, body.Bytes(), tag[:]) - return enc, nonce, err -} - -// encodeAuthHeader creates the auth header on a call packet following WHOAREYOU. -func (c *wireCodec) makeAuthHeader(nonce []byte, challenge *whoareyouV5) (*authHeaderList, *handshakeSecrets, error) { - resp := &authResponse{Version: 5} - - // Add our record to response if it's newer than what remote - // side has. - ln := c.localnode.Node() - if challenge.RecordSeq < ln.Seq() { - resp.Record = ln.Record() - } - - // Create the ephemeral key. This needs to be first because the - // key is part of the ID nonce signature. - var remotePubkey = new(ecdsa.PublicKey) - if err := challenge.node.Load((*enode.Secp256k1)(remotePubkey)); err != nil { - return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient") - } - ephkey, err := crypto.GenerateKey() - if err != nil { - return nil, nil, fmt.Errorf("can't generate ephemeral key") - } - ephpubkey := encodePubkey(&ephkey.PublicKey) - - // Add ID nonce signature to response. - idsig, err := c.signIDNonce(challenge.IDNonce[:], ephpubkey[:]) - if err != nil { - return nil, nil, fmt.Errorf("can't sign: %v", err) - } - resp.Signature = idsig - - // Create session keys. - sec := c.deriveKeys(c.localnode.ID(), challenge.node.ID(), ephkey, remotePubkey, challenge) - if sec == nil { - return nil, nil, fmt.Errorf("key derivation failed") - } - - // Encrypt the authentication response and assemble the auth header. - respRLP, err := rlp.EncodeToBytes(resp) - if err != nil { - return nil, nil, fmt.Errorf("can't encode auth response: %v", err) - } - respEnc, err := encryptGCM(nil, sec.authRespKey, zeroNonce, respRLP, nil) - if err != nil { - return nil, nil, fmt.Errorf("can't encrypt auth response: %v", err) - } - head := &authHeaderList{ - Auth: nonce, - Scheme: authSchemeName, - IDNonce: challenge.IDNonce, - EphemeralKey: ephpubkey[:], - Response: respEnc, - } - return head, sec, err -} - -// deriveKeys generates session keys using elliptic-curve Diffie-Hellman key agreement. -func (c *wireCodec) deriveKeys(n1, n2 enode.ID, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, challenge *whoareyouV5) *handshakeSecrets { - eph := ecdh(priv, pub) - if eph == nil { - return nil - } - - info := []byte("discovery v5 key agreement") - info = append(info, n1[:]...) - info = append(info, n2[:]...) - kdf := hkdf.New(c.sha256reset, eph, challenge.IDNonce[:], info) - sec := handshakeSecrets{ - writeKey: make([]byte, aesKeySize), - readKey: make([]byte, aesKeySize), - authRespKey: make([]byte, aesKeySize), - } - kdf.Read(sec.writeKey) - kdf.Read(sec.readKey) - kdf.Read(sec.authRespKey) - for i := range eph { - eph[i] = 0 - } - return &sec -} - -// signIDNonce creates the ID nonce signature. -func (c *wireCodec) signIDNonce(nonce, ephkey []byte) ([]byte, error) { - idsig, err := crypto.Sign(c.idNonceHash(nonce, ephkey), c.privkey) - if err != nil { - return nil, fmt.Errorf("can't sign: %v", err) - } - return idsig[:len(idsig)-1], nil // remove recovery ID -} - -// idNonceHash computes the hash of id nonce with prefix. -func (c *wireCodec) idNonceHash(nonce, ephkey []byte) []byte { - h := c.sha256reset() - h.Write([]byte(idNoncePrefix)) - h.Write(nonce) - h.Write(ephkey) - return h.Sum(nil) -} - -// decode decodes a discovery packet. -func (c *wireCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { - // Delete timed-out handshakes. This must happen before decoding to avoid - // processing the same handshake twice. - c.sc.handshakeGC() - - if len(input) < 32 { - return enode.ID{}, nil, nil, errTooShort - } - if bytes.HasPrefix(input, c.myWhoareyouMagic) { - p, err := c.decodeWhoareyou(input) - return enode.ID{}, nil, p, err - } - sender := xorTag(input[:32], c.myChtagHash) - p, n, err := c.decodeEncrypted(sender, addr, input) - return sender, n, p, err -} - -// decodeWhoareyou decode a WHOAREYOU packet. -func (c *wireCodec) decodeWhoareyou(input []byte) (packetV5, error) { - packet := new(whoareyouV5) - err := rlp.DecodeBytes(input[32:], packet) - return packet, err -} - -// decodeEncrypted decodes an encrypted discovery packet. -func (c *wireCodec) decodeEncrypted(fromID enode.ID, fromAddr string, input []byte) (packetV5, *enode.Node, error) { - // Decode packet header. - var head authHeader - r := bytes.NewReader(input[32:]) - err := rlp.Decode(r, &head) - if err != nil { - return nil, nil, err - } - - // Decrypt and process auth response. - readKey, node, err := c.decodeAuth(fromID, fromAddr, &head) - if err != nil { - return nil, nil, err - } - - // Decrypt and decode the packet body. - headsize := len(input) - r.Len() - bodyEnc := input[headsize:] - body, err := decryptGCM(readKey, head.Auth, bodyEnc, input[:32]) - if err != nil { - if !head.isHandshake { - // Can't decrypt, start handshake. - return &unknownV5{AuthTag: head.Auth}, nil, nil - } - return nil, nil, fmt.Errorf("handshake failed: %v", err) - } - if len(body) == 0 { - return nil, nil, errTooShort - } - p, err := decodePacketBodyV5(body[0], body[1:]) - return p, node, err -} - -// decodeAuth processes an auth header. -func (c *wireCodec) decodeAuth(fromID enode.ID, fromAddr string, head *authHeader) ([]byte, *enode.Node, error) { - if !head.isHandshake { - return c.sc.readKey(fromID, fromAddr), nil, nil - } - - // Remote is attempting handshake. Verify against our last WHOAREYOU. - challenge := c.sc.getHandshake(fromID, fromAddr) - if challenge == nil { - return nil, nil, errUnexpectedHandshake - } - if head.IDNonce != challenge.IDNonce { - return nil, nil, errHandshakeNonceMismatch - } - sec, n, err := c.decodeAuthResp(fromID, fromAddr, &head.authHeaderList, challenge) - if err != nil { - return nil, n, err - } - // Swap keys to match remote. - sec.readKey, sec.writeKey = sec.writeKey, sec.readKey - c.sc.storeNewSession(fromID, fromAddr, sec.readKey, sec.writeKey) - c.sc.deleteHandshake(fromID, fromAddr) - return sec.readKey, n, err -} - -// decodeAuthResp decodes and verifies an authentication response. -func (c *wireCodec) decodeAuthResp(fromID enode.ID, fromAddr string, head *authHeaderList, challenge *whoareyouV5) (*handshakeSecrets, *enode.Node, error) { - // Decrypt / decode the response. - if head.Scheme != authSchemeName { - return nil, nil, errUnknownAuthScheme - } - ephkey := head.ephemeralKey(c.privkey.Curve) - if ephkey == nil { - return nil, nil, errInvalidAuthKey - } - sec := c.deriveKeys(fromID, c.localnode.ID(), c.privkey, ephkey, challenge) - respPT, err := decryptGCM(sec.authRespKey, zeroNonce, head.Response, nil) - if err != nil { - return nil, nil, fmt.Errorf("can't decrypt auth response header: %v", err) - } - var resp authResponse - if err := rlp.DecodeBytes(respPT, &resp); err != nil { - return nil, nil, fmt.Errorf("invalid auth response: %v", err) - } - - // Verify response node record. The remote node should include the record - // if we don't have one or if ours is older than the latest version. - node := challenge.node - if resp.Record != nil { - if node == nil || node.Seq() < resp.Record.Seq() { - n, err := enode.New(enode.ValidSchemes, resp.Record) - if err != nil { - return nil, nil, fmt.Errorf("invalid node record: %v", err) - } - if n.ID() != fromID { - return nil, nil, fmt.Errorf("record in auth respose has wrong ID: %v", n.ID()) - } - node = n - } - } - if node == nil { - return nil, nil, errNoRecord - } - - // Verify ID nonce signature. - err = c.verifyIDSignature(challenge.IDNonce[:], head.EphemeralKey, resp.Signature, node) - if err != nil { - return nil, nil, err - } - return sec, node, nil -} - -// verifyIDSignature checks that signature over idnonce was made by the node with given record. -func (c *wireCodec) verifyIDSignature(nonce, ephkey, sig []byte, n *enode.Node) error { - switch idscheme := n.Record().IdentityScheme(); idscheme { - case "v4": - var pk ecdsa.PublicKey - n.Load((*enode.Secp256k1)(&pk)) // cannot fail because record is valid - if !crypto.VerifySignature(crypto.FromECDSAPub(&pk), c.idNonceHash(nonce, ephkey), sig) { - return errInvalidNonceSig - } - return nil - default: - return fmt.Errorf("can't verify ID nonce signature against scheme %q", idscheme) - } -} - -// decodePacketBody decodes the body of an encrypted discovery packet. -func decodePacketBodyV5(ptype byte, body []byte) (packetV5, error) { - var dec packetV5 - switch ptype { - case p_pingV5: - dec = new(pingV5) - case p_pongV5: - dec = new(pongV5) - case p_findnodeV5: - dec = new(findnodeV5) - case p_nodesV5: - dec = new(nodesV5) - case p_requestTicketV5: - dec = new(requestTicketV5) - case p_ticketV5: - dec = new(ticketV5) - case p_regtopicV5: - dec = new(regtopicV5) - case p_regconfirmationV5: - dec = new(regconfirmationV5) - case p_topicqueryV5: - dec = new(topicqueryV5) - default: - return nil, fmt.Errorf("unknown packet type %d", ptype) - } - if err := rlp.DecodeBytes(body, dec); err != nil { - return nil, err - } - return dec, nil -} - -// sha256reset returns the shared hash instance. -func (c *wireCodec) sha256reset() hash.Hash { - c.sha256.Reset() - return c.sha256 -} - -// sha256sum computes sha256 on the concatenation of inputs. -func (c *wireCodec) sha256sum(inputs ...[]byte) []byte { - c.sha256.Reset() - for _, b := range inputs { - c.sha256.Write(b) - } - return c.sha256.Sum(nil) -} - -func xorTag(a []byte, b enode.ID) enode.ID { - var r enode.ID - for i := range r { - r[i] = a[i] ^ b[i] - } - return r -} - -// ecdh creates a shared secret. -func ecdh(privkey *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey) []byte { - secX, secY := pubkey.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes()) - if secX == nil { - return nil - } - sec := make([]byte, 33) - sec[0] = 0x02 | byte(secY.Bit(0)) - math.ReadBits(secX, sec[1:]) - return sec -} - -// encryptGCM encrypts pt using AES-GCM with the given key and nonce. -func encryptGCM(dest, key, nonce, pt, authData []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - panic(fmt.Errorf("can't create block cipher: %v", err)) - } - aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) - if err != nil { - panic(fmt.Errorf("can't create GCM: %v", err)) - } - return aesgcm.Seal(dest, nonce, pt, authData), nil -} - -// decryptGCM decrypts ct using AES-GCM with the given key and nonce. -func decryptGCM(key, nonce, ct, authData []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, fmt.Errorf("can't create block cipher: %v", err) - } - if len(nonce) != gcmNonceSize { - return nil, fmt.Errorf("invalid GCM nonce size: %d", len(nonce)) - } - aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) - if err != nil { - return nil, fmt.Errorf("can't create GCM: %v", err) - } - pt := make([]byte, 0, len(ct)) - return aesgcm.Open(pt, nonce, ct, authData) -} diff --git a/p2p/discover/v5_encoding_test.go b/p2p/discover/v5_encoding_test.go deleted file mode 100644 index 77e6bae6aef8..000000000000 --- a/p2p/discover/v5_encoding_test.go +++ /dev/null @@ -1,373 +0,0 @@ -// Copyright 2019 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package discover - -import ( - "bytes" - "crypto/ecdsa" - "encoding/hex" - "fmt" - "net" - "reflect" - "testing" - - "github.com/davecgh/go-spew/spew" - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/p2p/enode" -) - -var ( - testKeyA, _ = crypto.HexToECDSA("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") - testKeyB, _ = crypto.HexToECDSA("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") - testIDnonce = [32]byte{5, 6, 7, 8, 9, 10, 11, 12} -) - -func TestDeriveKeysV5(t *testing.T) { - t.Parallel() - - var ( - n1 = enode.ID{1} - n2 = enode.ID{2} - challenge = &whoareyouV5{} - db, _ = enode.OpenDB("") - ln = enode.NewLocalNode(db, testKeyA) - c = newWireCodec(ln, testKeyA, mclock.System{}) - ) - defer db.Close() - - sec1 := c.deriveKeys(n1, n2, testKeyA, &testKeyB.PublicKey, challenge) - sec2 := c.deriveKeys(n1, n2, testKeyB, &testKeyA.PublicKey, challenge) - if sec1 == nil || sec2 == nil { - t.Fatal("key agreement failed") - } - if !reflect.DeepEqual(sec1, sec2) { - t.Fatalf("keys not equal:\n %+v\n %+v", sec1, sec2) - } -} - -// This test checks the basic handshake flow where A talks to B and A has no secrets. -func TestHandshakeV5(t *testing.T) { - t.Parallel() - net := newHandshakeTest() - defer net.close() - - // A -> B RANDOM PACKET - packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) - resp := net.nodeB.expectDecode(t, p_unknownV5, packet) - - // A <- B WHOAREYOU - challenge := &whoareyouV5{ - AuthTag: resp.(*unknownV5).AuthTag, - IDNonce: testIDnonce, - RecordSeq: 0, - } - whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) - net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) - - // A -> B FINDNODE - findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) - net.nodeB.expectDecode(t, p_findnodeV5, findnode) - if len(net.nodeB.c.sc.handshakes) > 0 { - t.Fatalf("node B didn't remove handshake from challenge map") - } - - // A <- B NODES - nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) - net.nodeA.expectDecode(t, p_nodesV5, nodes) -} - -// This test checks that handshake attempts are removed within the timeout. -func TestHandshakeV5_timeout(t *testing.T) { - t.Parallel() - net := newHandshakeTest() - defer net.close() - - // A -> B RANDOM PACKET - packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) - resp := net.nodeB.expectDecode(t, p_unknownV5, packet) - - // A <- B WHOAREYOU - challenge := &whoareyouV5{ - AuthTag: resp.(*unknownV5).AuthTag, - IDNonce: testIDnonce, - RecordSeq: 0, - } - whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) - net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) - - // A -> B FINDNODE after timeout - net.clock.Run(handshakeTimeout + 1) - findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) - net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode) -} - -// This test checks handshake behavior when no record is sent in the auth response. -func TestHandshakeV5_norecord(t *testing.T) { - t.Parallel() - net := newHandshakeTest() - defer net.close() - - // A -> B RANDOM PACKET - packet, _ := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) - resp := net.nodeB.expectDecode(t, p_unknownV5, packet) - - // A <- B WHOAREYOU - nodeA := net.nodeA.n() - if nodeA.Seq() == 0 { - t.Fatal("need non-zero sequence number") - } - challenge := &whoareyouV5{ - AuthTag: resp.(*unknownV5).AuthTag, - IDNonce: testIDnonce, - RecordSeq: nodeA.Seq(), - node: nodeA, - } - whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) - net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) - - // A -> B FINDNODE - findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) - net.nodeB.expectDecode(t, p_findnodeV5, findnode) - - // A <- B NODES - nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) - net.nodeA.expectDecode(t, p_nodesV5, nodes) -} - -// In this test, A tries to send FINDNODE with existing secrets but B doesn't know -// anything about A. -func TestHandshakeV5_rekey(t *testing.T) { - t.Parallel() - net := newHandshakeTest() - defer net.close() - - initKeys := &handshakeSecrets{ - readKey: []byte("BBBBBBBBBBBBBBBB"), - writeKey: []byte("AAAAAAAAAAAAAAAA"), - } - net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeys.readKey, initKeys.writeKey) - - // A -> B FINDNODE (encrypted with zero keys) - findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{}) - net.nodeB.expectDecode(t, p_unknownV5, findnode) - - // A <- B WHOAREYOU - challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce} - whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) - net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) - - // Check that new keys haven't been stored yet. - if s := net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()); !bytes.Equal(s.writeKey, initKeys.writeKey) || !bytes.Equal(s.readKey, initKeys.readKey) { - t.Fatal("node A stored keys too early") - } - if s := net.nodeB.c.sc.session(net.nodeA.id(), net.nodeA.addr()); s != nil { - t.Fatal("node B stored keys too early") - } - - // A -> B FINDNODE encrypted with new keys - findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) - net.nodeB.expectDecode(t, p_findnodeV5, findnode) - - // A <- B NODES - nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) - net.nodeA.expectDecode(t, p_nodesV5, nodes) -} - -// In this test A and B have different keys before the handshake. -func TestHandshakeV5_rekey2(t *testing.T) { - t.Parallel() - net := newHandshakeTest() - defer net.close() - - initKeysA := &handshakeSecrets{ - readKey: []byte("BBBBBBBBBBBBBBBB"), - writeKey: []byte("AAAAAAAAAAAAAAAA"), - } - initKeysB := &handshakeSecrets{ - readKey: []byte("CCCCCCCCCCCCCCCC"), - writeKey: []byte("DDDDDDDDDDDDDDDD"), - } - net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA.readKey, initKeysA.writeKey) - net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB.readKey, initKeysA.writeKey) - - // A -> B FINDNODE encrypted with initKeysA - findnode, authTag := net.nodeA.encode(t, net.nodeB, &findnodeV5{Distance: 3}) - net.nodeB.expectDecode(t, p_unknownV5, findnode) - - // A <- B WHOAREYOU - challenge := &whoareyouV5{AuthTag: authTag, IDNonce: testIDnonce} - whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) - net.nodeA.expectDecode(t, p_whoareyouV5, whoareyou) - - // A -> B FINDNODE encrypted with new keys - findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &findnodeV5{}) - net.nodeB.expectDecode(t, p_findnodeV5, findnode) - - // A <- B NODES - nodes, _ := net.nodeB.encode(t, net.nodeA, &nodesV5{Total: 1}) - net.nodeA.expectDecode(t, p_nodesV5, nodes) -} - -// This test checks some malformed packets. -func TestDecodeErrorsV5(t *testing.T) { - t.Parallel() - net := newHandshakeTest() - defer net.close() - - net.nodeA.expectDecodeErr(t, errTooShort, []byte{}) - // TODO some more tests would be nice :) -} - -// This benchmark checks performance of authHeader decoding, verification and key derivation. -func BenchmarkV5_DecodeAuthSecp256k1(b *testing.B) { - net := newHandshakeTest() - defer net.close() - - var ( - idA = net.nodeA.id() - addrA = net.nodeA.addr() - challenge = &whoareyouV5{AuthTag: []byte("authresp"), RecordSeq: 0, node: net.nodeB.n()} - nonce = make([]byte, gcmNonceSize) - ) - header, _, _ := net.nodeA.c.makeAuthHeader(nonce, challenge) - challenge.node = nil // force ENR signature verification in decoder - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, _, err := net.nodeB.c.decodeAuthResp(idA, addrA, header, challenge) - if err != nil { - b.Fatal(err) - } - } -} - -// This benchmark checks how long it takes to decode an encrypted ping packet. -func BenchmarkV5_DecodePing(b *testing.B) { - net := newHandshakeTest() - defer net.close() - - r := []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17} - w := []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134} - net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), r, w) - net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), w, r) - addrB := net.nodeA.addr() - ping := &pingV5{ReqID: []byte("reqid"), ENRSeq: 5} - enc, _, err := net.nodeA.c.encode(net.nodeB.id(), addrB, ping, nil) - if err != nil { - b.Fatalf("can't encode: %v", err) - } - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, _, p, _ := net.nodeB.c.decode(enc, addrB) - if _, ok := p.(*pingV5); !ok { - b.Fatalf("wrong packet type %T", p) - } - } -} - -var pp = spew.NewDefaultConfig() - -type handshakeTest struct { - nodeA, nodeB handshakeTestNode - clock mclock.Simulated -} - -type handshakeTestNode struct { - ln *enode.LocalNode - c *wireCodec -} - -func newHandshakeTest() *handshakeTest { - t := new(handshakeTest) - t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) - t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) - return t -} - -func (t *handshakeTest) close() { - t.nodeA.ln.Database().Close() - t.nodeB.ln.Database().Close() -} - -func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { - db, _ := enode.OpenDB("") - n.ln = enode.NewLocalNode(db, key) - n.ln.SetStaticIP(ip) - n.c = newWireCodec(n.ln, key, clock) -} - -func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p packetV5) ([]byte, []byte) { - t.Helper() - return n.encodeWithChallenge(t, to, nil, p) -} - -func (n *handshakeTestNode) encodeWithChallenge(t testing.TB, to handshakeTestNode, c *whoareyouV5, p packetV5) ([]byte, []byte) { - t.Helper() - // Copy challenge and add destination node. This avoids sharing 'c' among the two codecs. - var challenge *whoareyouV5 - if c != nil { - challengeCopy := *c - challenge = &challengeCopy - challenge.node = to.n() - } - // Encode to destination. - enc, authTag, err := n.c.encode(to.id(), to.addr(), p, challenge) - if err != nil { - t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) - } - t.Logf("(%s) -> (%s) %s\n%s", n.ln.ID().TerminalString(), to.id().TerminalString(), p.name(), hex.Dump(enc)) - return enc, authTag -} - -func (n *handshakeTestNode) expectDecode(t *testing.T, ptype byte, p []byte) packetV5 { - t.Helper() - dec, err := n.decode(p) - if err != nil { - t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) - } - t.Logf("(%s) %#v", n.ln.ID().TerminalString(), pp.NewFormatter(dec)) - if dec.kind() != ptype { - t.Fatalf("expected packet type %d, got %d", ptype, dec.kind()) - } - return dec -} - -func (n *handshakeTestNode) expectDecodeErr(t *testing.T, wantErr error, p []byte) { - t.Helper() - if _, err := n.decode(p); !reflect.DeepEqual(err, wantErr) { - t.Fatal(fmt.Errorf("(%s) got err %q, want %q", n.ln.ID().TerminalString(), err, wantErr)) - } -} - -func (n *handshakeTestNode) decode(input []byte) (packetV5, error) { - _, _, p, err := n.c.decode(input, "127.0.0.1") - return p, err -} - -func (n *handshakeTestNode) n() *enode.Node { - return n.ln.Node() -} - -func (n *handshakeTestNode) addr() string { - return n.ln.Node().IP().String() -} - -func (n *handshakeTestNode) id() enode.ID { - return n.ln.ID() -} diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index d53375b48b63..9dd2b31733db 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -31,6 +31,7 @@ import ( "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -38,36 +39,24 @@ import ( const ( lookupRequestLimit = 3 // max requests against a single node during lookup - findnodeResultLimit = 15 // applies in FINDNODE handler + findnodeResultLimit = 16 // applies in FINDNODE handler totalNodesResponseLimit = 5 // applies in waitForNodes nodesResponseItemLimit = 3 // applies in sendNodes respTimeoutV5 = 700 * time.Millisecond ) -// codecV5 is implemented by wireCodec (and testCodec). +// codecV5 is implemented by v5wire.Codec (and testCodec). // // The UDPv5 transport is split into two objects: the codec object deals with // encoding/decoding and with the handshake; the UDPv5 object handles higher-level concerns. type codecV5 interface { - // encode encodes a packet. The 'challenge' parameter is non-nil for calls which got a - // WHOAREYOU response. - encode(fromID enode.ID, fromAddr string, p packetV5, challenge *whoareyouV5) (enc []byte, authTag []byte, err error) + // Encode encodes a packet. + Encode(enode.ID, string, v5wire.Packet, *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error) - // decode decodes a packet. It returns an *unknownV5 packet if decryption fails. - // The fromNode return value is non-nil when the input contains a handshake response. - decode(input []byte, fromAddr string) (fromID enode.ID, fromNode *enode.Node, p packetV5, err error) -} - -// packetV5 is implemented by all discv5 packet type structs. -type packetV5 interface { - // These methods provide information and set the request ID. - name() string - kind() byte - setreqid([]byte) - // handle should perform the appropriate action to handle the packet, i.e. this is the - // place to send the response. - handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) + // decode decodes a packet. It returns a *v5wire.Unknown packet if decryption fails. + // The *enode.Node return value is non-nil when the input contains a handshake response. + Decode([]byte, string) (enode.ID, *enode.Node, v5wire.Packet, error) } // UDPv5 is the implementation of protocol version 5. @@ -83,6 +72,10 @@ type UDPv5 struct { clock mclock.Clock validSchemes enr.IdentityScheme + // talkreq handler registry + trlock sync.Mutex + trhandlers map[string]func([]byte) []byte + // channels into dispatch packetInCh chan ReadPacket readNextCh chan struct{} @@ -93,7 +86,7 @@ type UDPv5 struct { // state of dispatch codec codecV5 activeCallByNode map[enode.ID]*callV5 - activeCallByAuth map[string]*callV5 + activeCallByAuth map[v5wire.Nonce]*callV5 callQueue map[enode.ID][]*callV5 // shutdown stuff @@ -106,16 +99,16 @@ type UDPv5 struct { // callV5 represents a remote procedure call against another node. type callV5 struct { node *enode.Node - packet packetV5 + packet v5wire.Packet responseType byte // expected packet type of response reqid []byte - ch chan packetV5 // responses sent here - err chan error // errors sent here + ch chan v5wire.Packet // responses sent here + err chan error // errors sent here // Valid for active calls only: - authTag []byte // authTag of request packet - handshakeCount int // # times we attempted handshake for this call - challenge *whoareyouV5 // last sent handshake challenge + nonce v5wire.Nonce // nonce of request packet + handshakeCount int // # times we attempted handshake for this call + challenge *v5wire.Whoareyou // last sent handshake challenge timeout mclock.Timer } @@ -152,6 +145,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { log: cfg.Log, validSchemes: cfg.ValidSchemes, clock: cfg.Clock, + trhandlers: make(map[string]func([]byte) []byte), // channels into dispatch packetInCh: make(chan ReadPacket, 1), readNextCh: make(chan struct{}, 1), @@ -159,9 +153,9 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { callDoneCh: make(chan *callV5), respTimeoutCh: make(chan *callTimeout), // state of dispatch - codec: newWireCodec(ln, cfg.PrivateKey, cfg.Clock), + codec: v5wire.NewCodec(ln, cfg.PrivateKey, cfg.Clock), activeCallByNode: make(map[enode.ID]*callV5), - activeCallByAuth: make(map[string]*callV5), + activeCallByAuth: make(map[v5wire.Nonce]*callV5), callQueue: make(map[enode.ID][]*callV5), // shutdown closeCtx: closeCtx, @@ -236,6 +230,29 @@ func (t *UDPv5) LocalNode() *enode.LocalNode { return t.localNode } +// RegisterTalkHandler adds a handler for 'talk requests'. The handler function is called +// whenever a request for the given protocol is received and should return the response +// data or nil. +func (t *UDPv5) RegisterTalkHandler(protocol string, handler func([]byte) []byte) { + t.trlock.Lock() + defer t.trlock.Unlock() + t.trhandlers[protocol] = handler +} + +// TalkRequest sends a talk request to n and waits for a response. +func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]byte, error) { + req := &v5wire.TalkRequest{Protocol: protocol, Message: request} + resp := t.call(n, v5wire.TalkResponseMsg, req) + defer t.callDone(resp) + select { + case respMsg := <-resp.ch: + return respMsg.(*v5wire.TalkResponse).Message, nil + case err := <-resp.err: + return nil, err + } +} + +// RandomNodes returns an iterator that finds random nodes in the DHT. func (t *UDPv5) RandomNodes() enode.Iterator { if t.tab.len() == 0 { // All nodes were dropped, refresh. The very first query will hit this @@ -283,16 +300,14 @@ func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { nodes = nodesByDistance{target: target} err error ) - for i := 0; i < lookupRequestLimit && len(nodes.entries) < findnodeResultLimit; i++ { - var r []*enode.Node - r, err = t.findnode(unwrapNode(destNode), dists[i]) - if err == errClosed { - return nil, err - } - for _, n := range r { - if n.ID() != t.Self().ID() { - nodes.push(wrapNode(n), findnodeResultLimit) - } + var r []*enode.Node + r, err = t.findnode(unwrapNode(destNode), dists) + if err == errClosed { + return nil, err + } + for _, n := range r { + if n.ID() != t.Self().ID() { + nodes.push(wrapNode(n), findnodeResultLimit) } } return nodes.entries, err @@ -301,15 +316,15 @@ func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { // lookupDistances computes the distance parameter for FINDNODE calls to dest. // It chooses distances adjacent to logdist(target, dest), e.g. for a target // with logdist(target, dest) = 255 the result is [255, 256, 254]. -func lookupDistances(target, dest enode.ID) (dists []int) { +func lookupDistances(target, dest enode.ID) (dists []uint) { td := enode.LogDist(target, dest) - dists = append(dists, td) + dists = append(dists, uint(td)) for i := 1; len(dists) < lookupRequestLimit; i++ { if td+i < 256 { - dists = append(dists, td+i) + dists = append(dists, uint(td+i)) } if td-i > 0 { - dists = append(dists, td-i) + dists = append(dists, uint(td-i)) } } return dists @@ -317,11 +332,13 @@ func lookupDistances(target, dest enode.ID) (dists []int) { // ping calls PING on a node and waits for a PONG response. func (t *UDPv5) ping(n *enode.Node) (uint64, error) { - resp := t.call(n, p_pongV5, &pingV5{ENRSeq: t.localNode.Node().Seq()}) + req := &v5wire.Ping{ENRSeq: t.localNode.Node().Seq()} + resp := t.call(n, v5wire.PongMsg, req) defer t.callDone(resp) + select { case pong := <-resp.ch: - return pong.(*pongV5).ENRSeq, nil + return pong.(*v5wire.Pong).ENRSeq, nil case err := <-resp.err: return 0, err } @@ -329,7 +346,7 @@ func (t *UDPv5) ping(n *enode.Node) (uint64, error) { // requestENR requests n's record. func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { - nodes, err := t.findnode(n, 0) + nodes, err := t.findnode(n, []uint{0}) if err != nil { return nil, err } @@ -339,26 +356,14 @@ func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { return nodes[0], nil } -// requestTicket calls REQUESTTICKET on a node and waits for a TICKET response. -func (t *UDPv5) requestTicket(n *enode.Node) ([]byte, error) { - resp := t.call(n, p_ticketV5, &pingV5{}) - defer t.callDone(resp) - select { - case response := <-resp.ch: - return response.(*ticketV5).Ticket, nil - case err := <-resp.err: - return nil, err - } -} - // findnode calls FINDNODE on a node and waits for responses. -func (t *UDPv5) findnode(n *enode.Node, distance int) ([]*enode.Node, error) { - resp := t.call(n, p_nodesV5, &findnodeV5{Distance: uint(distance)}) - return t.waitForNodes(resp, distance) +func (t *UDPv5) findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) { + resp := t.call(n, v5wire.NodesMsg, &v5wire.Findnode{Distances: distances}) + return t.waitForNodes(resp, distances) } // waitForNodes waits for NODES responses to the given call. -func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) { +func (t *UDPv5) waitForNodes(c *callV5, distances []uint) ([]*enode.Node, error) { defer t.callDone(c) var ( @@ -369,11 +374,11 @@ func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) { for { select { case responseP := <-c.ch: - response := responseP.(*nodesV5) + response := responseP.(*v5wire.Nodes) for _, record := range response.Nodes { - node, err := t.verifyResponseNode(c, record, distance, seen) + node, err := t.verifyResponseNode(c, record, distances, seen) if err != nil { - t.log.Debug("Invalid record in "+response.name(), "id", c.node.ID(), "err", err) + t.log.Debug("Invalid record in "+response.Name(), "id", c.node.ID(), "err", err) continue } nodes = append(nodes, node) @@ -391,7 +396,7 @@ func (t *UDPv5) waitForNodes(c *callV5, distance int) ([]*enode.Node, error) { } // verifyResponseNode checks validity of a record in a NODES response. -func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen map[enode.ID]struct{}) (*enode.Node, error) { +func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, seen map[enode.ID]struct{}) (*enode.Node, error) { node, err := enode.New(t.validSchemes, r) if err != nil { return nil, err @@ -402,9 +407,10 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen if c.node.UDP() <= 1024 { return nil, errLowPort } - if distance != -1 { - if d := enode.LogDist(c.node.ID(), node.ID()); d != distance { - return nil, fmt.Errorf("wrong distance %d", d) + if distances != nil { + nd := enode.LogDist(c.node.ID(), node.ID()) + if !containsUint(uint(nd), distances) { + return nil, errors.New("does not match any requested distance") } } if _, ok := seen[node.ID()]; ok { @@ -414,20 +420,29 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distance int, seen return node, nil } -// call sends the given call and sets up a handler for response packets (of type c.responseType). -// Responses are dispatched to the call's response channel. -func (t *UDPv5) call(node *enode.Node, responseType byte, packet packetV5) *callV5 { +func containsUint(x uint, xs []uint) bool { + for _, v := range xs { + if x == v { + return true + } + } + return false +} + +// call sends the given call and sets up a handler for response packets (of message type +// responseType). Responses are dispatched to the call's response channel. +func (t *UDPv5) call(node *enode.Node, responseType byte, packet v5wire.Packet) *callV5 { c := &callV5{ node: node, packet: packet, responseType: responseType, reqid: make([]byte, 8), - ch: make(chan packetV5, 1), + ch: make(chan v5wire.Packet, 1), err: make(chan error, 1), } // Assign request ID. crand.Read(c.reqid) - packet.setreqid(c.reqid) + packet.SetRequestID(c.reqid) // Send call to dispatch. select { case t.callCh <- c: @@ -439,9 +454,20 @@ func (t *UDPv5) call(node *enode.Node, responseType byte, packet packetV5) *call // callDone tells dispatch that the active call is done. func (t *UDPv5) callDone(c *callV5) { - select { - case t.callDoneCh <- c: - case <-t.closeCtx.Done(): + // This needs a loop because further responses may be incoming until the + // send to callDoneCh has completed. Such responses need to be discarded + // in order to avoid blocking the dispatch loop. + for { + select { + case <-c.ch: + // late response, discard. + case <-c.err: + // late error, discard. + case t.callDoneCh <- c: + return + case <-t.closeCtx.Done(): + return + } } } @@ -482,7 +508,7 @@ func (t *UDPv5) dispatch() { panic("BUG: callDone for inactive call") } c.timeout.Stop() - delete(t.activeCallByAuth, string(c.authTag)) + delete(t.activeCallByAuth, c.nonce) delete(t.activeCallByNode, id) t.sendNextCall(id) @@ -502,7 +528,7 @@ func (t *UDPv5) dispatch() { for id, c := range t.activeCallByNode { c.err <- errClosed delete(t.activeCallByNode, id) - delete(t.activeCallByAuth, string(c.authTag)) + delete(t.activeCallByAuth, c.nonce) } return } @@ -548,38 +574,37 @@ func (t *UDPv5) sendNextCall(id enode.ID) { // sendCall encodes and sends a request packet to the call's recipient node. // This performs a handshake if needed. func (t *UDPv5) sendCall(c *callV5) { - if len(c.authTag) > 0 { - // The call already has an authTag from a previous handshake attempt. Remove the - // entry for the authTag because we're about to generate a new authTag for this - // call. - delete(t.activeCallByAuth, string(c.authTag)) + // The call might have a nonce from a previous handshake attempt. Remove the entry for + // the old nonce because we're about to generate a new nonce for this call. + if c.nonce != (v5wire.Nonce{}) { + delete(t.activeCallByAuth, c.nonce) } addr := &net.UDPAddr{IP: c.node.IP(), Port: c.node.UDP()} - newTag, _ := t.send(c.node.ID(), addr, c.packet, c.challenge) - c.authTag = newTag - t.activeCallByAuth[string(c.authTag)] = c + newNonce, _ := t.send(c.node.ID(), addr, c.packet, c.challenge) + c.nonce = newNonce + t.activeCallByAuth[newNonce] = c t.startResponseTimeout(c) } // sendResponse sends a response packet to the given node. // This doesn't trigger a handshake even if no keys are available. -func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet packetV5) error { +func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) error { _, err := t.send(toID, toAddr, packet, nil) return err } // send sends a packet to the given node. -func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet packetV5, c *whoareyouV5) ([]byte, error) { +func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { addr := toAddr.String() - enc, authTag, err := t.codec.encode(toID, addr, packet, c) + enc, nonce, err := t.codec.Encode(toID, addr, packet, c) if err != nil { - t.log.Warn(">> "+packet.name(), "id", toID, "addr", addr, "err", err) - return authTag, err + t.log.Warn(">> "+packet.Name(), "id", toID, "addr", addr, "err", err) + return nonce, err } _, err = t.conn.WriteToUDP(enc, toAddr) - t.log.Trace(">> "+packet.name(), "id", toID, "addr", addr) - return authTag, err + t.log.Trace(">> "+packet.Name(), "id", toID, "addr", addr) + return nonce, err } // readLoop runs in its own goroutine and reads packets from the network. @@ -617,7 +642,7 @@ func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { // handlePacket decodes and processes an incoming packet from the network. func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { addr := fromAddr.String() - fromID, fromNode, packet, err := t.codec.decode(rawpacket, addr) + fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr) if err != nil { t.log.Debug("Bad discv5 packet", "id", fromID, "addr", addr, "err", err) return err @@ -626,31 +651,32 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { // Handshake succeeded, add to table. t.tab.addSeenNode(wrapNode(fromNode)) } - if packet.kind() != p_whoareyouV5 { - // WHOAREYOU logged separately to report the sender ID. - t.log.Trace("<< "+packet.name(), "id", fromID, "addr", addr) + if packet.Kind() != v5wire.WhoareyouPacket { + // WHOAREYOU logged separately to report errors. + t.log.Trace("<< "+packet.Name(), "id", fromID, "addr", addr) } - packet.handle(t, fromID, fromAddr) + t.handle(packet, fromID, fromAddr) return nil } // handleCallResponse dispatches a response packet to the call waiting for it. -func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, reqid []byte, p packetV5) { +func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5wire.Packet) bool { ac := t.activeCallByNode[fromID] - if ac == nil || !bytes.Equal(reqid, ac.reqid) { - t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.name()), "id", fromID, "addr", fromAddr) - return + if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) { + t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr) + return false } if !fromAddr.IP.Equal(ac.node.IP()) || fromAddr.Port != ac.node.UDP() { - t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.name()), "id", fromID, "addr", fromAddr) - return + t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr) + return false } - if p.kind() != ac.responseType { - t.log.Debug(fmt.Sprintf("Wrong disv5 response type %s", p.name()), "id", fromID, "addr", fromAddr) - return + if p.Kind() != ac.responseType { + t.log.Debug(fmt.Sprintf("Wrong discv5 response type %s", p.Name()), "id", fromID, "addr", fromAddr) + return false } t.startResponseTimeout(ac) ac.ch <- p + return true } // getNode looks for a node record in table and database. @@ -664,50 +690,65 @@ func (t *UDPv5) getNode(id enode.ID) *enode.Node { return nil } -// UNKNOWN - -func (p *unknownV5) name() string { return "UNKNOWN/v5" } -func (p *unknownV5) kind() byte { return p_unknownV5 } -func (p *unknownV5) setreqid(id []byte) {} +// handle processes incoming packets according to their message type. +func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) { + switch p := p.(type) { + case *v5wire.Unknown: + t.handleUnknown(p, fromID, fromAddr) + case *v5wire.Whoareyou: + t.handleWhoareyou(p, fromID, fromAddr) + case *v5wire.Ping: + t.handlePing(p, fromID, fromAddr) + case *v5wire.Pong: + if t.handleCallResponse(fromID, fromAddr, p) { + t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) + } + case *v5wire.Findnode: + t.handleFindnode(p, fromID, fromAddr) + case *v5wire.Nodes: + t.handleCallResponse(fromID, fromAddr, p) + case *v5wire.TalkRequest: + t.handleTalkRequest(p, fromID, fromAddr) + case *v5wire.TalkResponse: + t.handleCallResponse(fromID, fromAddr, p) + } +} -func (p *unknownV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - challenge := &whoareyouV5{AuthTag: p.AuthTag} +// handleUnknown initiates a handshake by responding with WHOAREYOU. +func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr *net.UDPAddr) { + challenge := &v5wire.Whoareyou{Nonce: p.Nonce} crand.Read(challenge.IDNonce[:]) if n := t.getNode(fromID); n != nil { - challenge.node = n + challenge.Node = n challenge.RecordSeq = n.Seq() } t.sendResponse(fromID, fromAddr, challenge) } -// WHOAREYOU - -func (p *whoareyouV5) name() string { return "WHOAREYOU/v5" } -func (p *whoareyouV5) kind() byte { return p_whoareyouV5 } -func (p *whoareyouV5) setreqid(id []byte) {} +var ( + errChallengeNoCall = errors.New("no matching call") + errChallengeTwice = errors.New("second handshake") +) -func (p *whoareyouV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - c, err := p.matchWithCall(t, p.AuthTag) +// handleWhoareyou resends the active call as a handshake packet. +func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr *net.UDPAddr) { + c, err := t.matchWithCall(fromID, p.Nonce) if err != nil { - t.log.Debug("Invalid WHOAREYOU/v5", "addr", fromAddr, "err", err) + t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err) return } + // Resend the call that was answered by WHOAREYOU. - t.log.Trace("<< "+p.name(), "id", c.node.ID(), "addr", fromAddr) + t.log.Trace("<< "+p.Name(), "id", c.node.ID(), "addr", fromAddr) c.handshakeCount++ c.challenge = p - p.node = c.node + p.Node = c.node t.sendCall(c) } -var ( - errChallengeNoCall = errors.New("no matching call") - errChallengeTwice = errors.New("second handshake") -) - -// matchWithCall checks whether the handshake attempt matches the active call. -func (p *whoareyouV5) matchWithCall(t *UDPv5, authTag []byte) (*callV5, error) { - c := t.activeCallByAuth[string(authTag)] +// matchWithCall checks whether a handshake attempt matches the active call. +func (t *UDPv5) matchWithCall(fromID enode.ID, nonce v5wire.Nonce) (*callV5, error) { + c := t.activeCallByAuth[nonce] if c == nil { return nil, errChallengeNoCall } @@ -717,14 +758,9 @@ func (p *whoareyouV5) matchWithCall(t *UDPv5, authTag []byte) (*callV5, error) { return c, nil } -// PING - -func (p *pingV5) name() string { return "PING/v5" } -func (p *pingV5) kind() byte { return p_pingV5 } -func (p *pingV5) setreqid(id []byte) { p.ReqID = id } - -func (p *pingV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.sendResponse(fromID, fromAddr, &pongV5{ +// handlePing sends a PONG response. +func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr *net.UDPAddr) { + t.sendResponse(fromID, fromAddr, &v5wire.Pong{ ReqID: p.ReqID, ToIP: fromAddr.IP, ToPort: uint16(fromAddr.Port), @@ -732,121 +768,81 @@ func (p *pingV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { }) } -// PONG - -func (p *pongV5) name() string { return "PONG/v5" } -func (p *pongV5) kind() byte { return p_pongV5 } -func (p *pongV5) setreqid(id []byte) { p.ReqID = id } - -func (p *pongV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) - t.handleCallResponse(fromID, fromAddr, p.ReqID, p) +// handleFindnode returns nodes to the requester. +func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr *net.UDPAddr) { + nodes := t.collectTableNodes(fromAddr.IP, p.Distances, findnodeResultLimit) + for _, resp := range packNodes(p.ReqID, nodes) { + t.sendResponse(fromID, fromAddr, resp) + } } -// FINDNODE +// collectTableNodes creates a FINDNODE result set for the given distances. +func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { + var nodes []*enode.Node + var processed = make(map[uint]struct{}) + for _, dist := range distances { + // Reject duplicate / invalid distances. + _, seen := processed[dist] + if seen || dist > 256 { + continue + } -func (p *findnodeV5) name() string { return "FINDNODE/v5" } -func (p *findnodeV5) kind() byte { return p_findnodeV5 } -func (p *findnodeV5) setreqid(id []byte) { p.ReqID = id } + // Get the nodes. + var bn []*enode.Node + if dist == 0 { + bn = []*enode.Node{t.Self()} + } else if dist <= 256 { + t.tab.mutex.Lock() + bn = unwrapNodes(t.tab.bucketAtDistance(int(dist)).entries) + t.tab.mutex.Unlock() + } + processed[dist] = struct{}{} -func (p *findnodeV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - if p.Distance == 0 { - t.sendNodes(fromID, fromAddr, p.ReqID, []*enode.Node{t.Self()}) - return - } - if p.Distance > 256 { - p.Distance = 256 - } - // Get bucket entries. - t.tab.mutex.Lock() - nodes := unwrapNodes(t.tab.bucketAtDistance(int(p.Distance)).entries) - t.tab.mutex.Unlock() - if len(nodes) > findnodeResultLimit { - nodes = nodes[:findnodeResultLimit] + // Apply some pre-checks to avoid sending invalid nodes. + for _, n := range bn { + // TODO livenessChecks > 1 + if netutil.CheckRelayIP(rip, n.IP()) != nil { + continue + } + nodes = append(nodes, n) + if len(nodes) >= limit { + return nodes + } + } } - t.sendNodes(fromID, fromAddr, p.ReqID, nodes) + return nodes } -// sendNodes sends the given records in one or more NODES packets. -func (t *UDPv5) sendNodes(toID enode.ID, toAddr *net.UDPAddr, reqid []byte, nodes []*enode.Node) { - // TODO livenessChecks > 1 - // TODO CheckRelayIP +// packNodes creates NODES response packets for the given node list. +func packNodes(reqid []byte, nodes []*enode.Node) []*v5wire.Nodes { + if len(nodes) == 0 { + return []*v5wire.Nodes{{ReqID: reqid, Total: 1}} + } + total := uint8(math.Ceil(float64(len(nodes)) / 3)) - resp := &nodesV5{ReqID: reqid, Total: total, Nodes: make([]*enr.Record, 3)} - sent := false + var resp []*v5wire.Nodes for len(nodes) > 0 { + p := &v5wire.Nodes{ReqID: reqid, Total: total} items := min(nodesResponseItemLimit, len(nodes)) - resp.Nodes = resp.Nodes[:items] for i := 0; i < items; i++ { - resp.Nodes[i] = nodes[i].Record() + p.Nodes = append(p.Nodes, nodes[i].Record()) } - t.sendResponse(toID, toAddr, resp) nodes = nodes[items:] - sent = true - } - // Ensure at least one response is sent. - if !sent { - resp.Total = 1 - resp.Nodes = nil - t.sendResponse(toID, toAddr, resp) + resp = append(resp, p) } + return resp } -// NODES - -func (p *nodesV5) name() string { return "NODES/v5" } -func (p *nodesV5) kind() byte { return p_nodesV5 } -func (p *nodesV5) setreqid(id []byte) { p.ReqID = id } - -func (p *nodesV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.handleCallResponse(fromID, fromAddr, p.ReqID, p) -} - -// REQUESTTICKET - -func (p *requestTicketV5) name() string { return "REQUESTTICKET/v5" } -func (p *requestTicketV5) kind() byte { return p_requestTicketV5 } -func (p *requestTicketV5) setreqid(id []byte) { p.ReqID = id } - -func (p *requestTicketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.sendResponse(fromID, fromAddr, &ticketV5{ReqID: p.ReqID}) -} - -// TICKET - -func (p *ticketV5) name() string { return "TICKET/v5" } -func (p *ticketV5) kind() byte { return p_ticketV5 } -func (p *ticketV5) setreqid(id []byte) { p.ReqID = id } - -func (p *ticketV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.handleCallResponse(fromID, fromAddr, p.ReqID, p) -} - -// REGTOPIC - -func (p *regtopicV5) name() string { return "REGTOPIC/v5" } -func (p *regtopicV5) kind() byte { return p_regtopicV5 } -func (p *regtopicV5) setreqid(id []byte) { p.ReqID = id } - -func (p *regtopicV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.sendResponse(fromID, fromAddr, ®confirmationV5{ReqID: p.ReqID, Registered: false}) -} - -// REGCONFIRMATION +// handleTalkRequest runs the talk request handler of the requested protocol. +func (t *UDPv5) handleTalkRequest(p *v5wire.TalkRequest, fromID enode.ID, fromAddr *net.UDPAddr) { + t.trlock.Lock() + handler := t.trhandlers[p.Protocol] + t.trlock.Unlock() -func (p *regconfirmationV5) name() string { return "REGCONFIRMATION/v5" } -func (p *regconfirmationV5) kind() byte { return p_regconfirmationV5 } -func (p *regconfirmationV5) setreqid(id []byte) { p.ReqID = id } - -func (p *regconfirmationV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { - t.handleCallResponse(fromID, fromAddr, p.ReqID, p) -} - -// TOPICQUERY - -func (p *topicqueryV5) name() string { return "TOPICQUERY/v5" } -func (p *topicqueryV5) kind() byte { return p_topicqueryV5 } -func (p *topicqueryV5) setreqid(id []byte) { p.ReqID = id } - -func (p *topicqueryV5) handle(t *UDPv5, fromID enode.ID, fromAddr *net.UDPAddr) { + var response []byte + if handler != nil { + response = handler(p.Message) + } + resp := &v5wire.TalkResponse{ReqID: p.ReqID, Message: response} + t.sendResponse(fromID, fromAddr, resp) } diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 7d3915e2dceb..d91a2097db25 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -24,22 +24,25 @@ import ( "math/rand" "net" "reflect" + "sort" "testing" "time" "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/rlp" ) // Real sockets, real crypto: this test checks end-to-end connectivity for UDPv5. -func TestEndToEndV5(t *testing.T) { +func TestUDPv5_lookupE2E(t *testing.T) { t.Parallel() + const N = 5 var nodes []*UDPv5 - for i := 0; i < 5; i++ { + for i := 0; i < N; i++ { var cfg Config if len(nodes) > 0 { bn := nodes[0].Self() @@ -49,12 +52,22 @@ func TestEndToEndV5(t *testing.T) { nodes = append(nodes, node) defer node.Close() } + last := nodes[N-1] + target := nodes[rand.Intn(N-2)].Self() - last := nodes[len(nodes)-1] - target := nodes[rand.Intn(len(nodes)-2)].Self() + // It is expected that all nodes can be found. + expectedResult := make([]*enode.Node, len(nodes)) + for i := range nodes { + expectedResult[i] = nodes[i].Self() + } + sort.Slice(expectedResult, func(i, j int) bool { + return enode.DistCmp(target.ID(), expectedResult[i].ID(), expectedResult[j].ID()) < 0 + }) + + // Do the lookup. results := last.Lookup(target.ID()) - if len(results) == 0 || results[0].ID() != target.ID() { - t.Fatalf("lookup returned wrong results: %v", results) + if err := checkNodesEqual(results, expectedResult); err != nil { + t.Fatalf("lookup returned wrong results: %v", err) } } @@ -93,8 +106,8 @@ func TestUDPv5_pingHandling(t *testing.T) { test := newUDPV5Test(t) defer test.close() - test.packetIn(&pingV5{ReqID: []byte("foo")}) - test.waitPacketOut(func(p *pongV5, addr *net.UDPAddr, authTag []byte) { + test.packetIn(&v5wire.Ping{ReqID: []byte("foo")}) + test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("foo")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -110,13 +123,13 @@ func TestUDPv5_unknownPacket(t *testing.T) { test := newUDPV5Test(t) defer test.close() - authTag := [12]byte{1, 2, 3} - check := func(p *whoareyouV5, wantSeq uint64) { + nonce := v5wire.Nonce{1, 2, 3} + check := func(p *v5wire.Whoareyou, wantSeq uint64) { t.Helper() - if !bytes.Equal(p.AuthTag, authTag[:]) { - t.Error("wrong token in WHOAREYOU:", p.AuthTag, authTag[:]) + if p.Nonce != nonce { + t.Error("wrong nonce in WHOAREYOU:", p.Nonce, nonce) } - if p.IDNonce == ([32]byte{}) { + if p.IDNonce == ([16]byte{}) { t.Error("all zero ID nonce") } if p.RecordSeq != wantSeq { @@ -125,8 +138,8 @@ func TestUDPv5_unknownPacket(t *testing.T) { } // Unknown packet from unknown node. - test.packetIn(&unknownV5{AuthTag: authTag[:]}) - test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { + test.packetIn(&v5wire.Unknown{Nonce: nonce}) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { check(p, 0) }) @@ -134,8 +147,8 @@ func TestUDPv5_unknownPacket(t *testing.T) { n := test.getNode(test.remotekey, test.remoteaddr).Node() test.table.addSeenNode(wrapNode(n)) - test.packetIn(&unknownV5{AuthTag: authTag[:]}) - test.waitPacketOut(func(p *whoareyouV5, addr *net.UDPAddr, _ []byte) { + test.packetIn(&v5wire.Unknown{Nonce: nonce}) + test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { check(p, n.Seq()) }) } @@ -147,24 +160,40 @@ func TestUDPv5_findnodeHandling(t *testing.T) { defer test.close() // Create test nodes and insert them into the table. - nodes := nodesAtDistance(test.table.self().ID(), 253, 10) - fillTable(test.table, wrapNodes(nodes)) + nodes253 := nodesAtDistance(test.table.self().ID(), 253, 10) + nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4) + nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10) + fillTable(test.table, wrapNodes(nodes253)) + fillTable(test.table, wrapNodes(nodes249)) + fillTable(test.table, wrapNodes(nodes248)) // Requesting with distance zero should return the node's own record. - test.packetIn(&findnodeV5{ReqID: []byte{0}, Distance: 0}) + test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}}) test.expectNodes([]byte{0}, 1, []*enode.Node{test.udp.Self()}) - // Requesting with distance > 256 caps it at 256. - test.packetIn(&findnodeV5{ReqID: []byte{1}, Distance: 4234098}) + // Requesting with distance > 256 shouldn't crash. + test.packetIn(&v5wire.Findnode{ReqID: []byte{1}, Distances: []uint{4234098}}) test.expectNodes([]byte{1}, 1, nil) - // This request gets no nodes because the corresponding bucket is empty. - test.packetIn(&findnodeV5{ReqID: []byte{2}, Distance: 254}) + // Requesting with empty distance list shouldn't crash either. + test.packetIn(&v5wire.Findnode{ReqID: []byte{2}, Distances: []uint{}}) test.expectNodes([]byte{2}, 1, nil) - // This request gets all test nodes. - test.packetIn(&findnodeV5{ReqID: []byte{3}, Distance: 253}) - test.expectNodes([]byte{3}, 4, nodes) + // This request gets no nodes because the corresponding bucket is empty. + test.packetIn(&v5wire.Findnode{ReqID: []byte{3}, Distances: []uint{254}}) + test.expectNodes([]byte{3}, 1, nil) + + // This request gets all the distance-253 nodes. + test.packetIn(&v5wire.Findnode{ReqID: []byte{4}, Distances: []uint{253}}) + test.expectNodes([]byte{4}, 4, nodes253) + + // This request gets all the distance-249 nodes and some more at 248 because + // the bucket at 249 is not full. + test.packetIn(&v5wire.Findnode{ReqID: []byte{5}, Distances: []uint{249, 248}}) + var nodes []*enode.Node + nodes = append(nodes, nodes249...) + nodes = append(nodes, nodes248[:10]...) + test.expectNodes([]byte{5}, 5, nodes) } func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes []*enode.Node) { @@ -172,16 +201,17 @@ func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes for _, n := range wantNodes { nodeSet[n.ID()] = n.Record() } + for { - test.waitPacketOut(func(p *nodesV5, addr *net.UDPAddr, authTag []byte) { + test.waitPacketOut(func(p *v5wire.Nodes, addr *net.UDPAddr, _ v5wire.Nonce) { + if !bytes.Equal(p.ReqID, wantReqID) { + test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID) + } if len(p.Nodes) > 3 { test.t.Fatalf("too many nodes in response") } if p.Total != wantTotal { - test.t.Fatalf("wrong total response count %d", p.Total) - } - if !bytes.Equal(p.ReqID, wantReqID) { - test.t.Fatalf("wrong request ID in response: %v", p.ReqID) + test.t.Fatalf("wrong total response count %d, want %d", p.Total, wantTotal) } for _, record := range p.Nodes { n, _ := enode.New(enode.ValidSchemesForTesting, record) @@ -215,7 +245,7 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) {}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {}) if err := <-done; err != errTimeout { t.Fatalf("want errTimeout, got %q", err) } @@ -225,8 +255,8 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { - test.packetInFrom(test.remotekey, test.remoteaddr, &pongV5{ReqID: p.ReqID}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != nil { t.Fatal(err) @@ -237,9 +267,9 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} - test.packetInFrom(test.remotekey, wrongAddr, &pongV5{ReqID: p.ReqID}) + test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != errTimeout { t.Fatalf("want errTimeout for reply from wrong IP, got %q", err) @@ -255,29 +285,29 @@ func TestUDPv5_findnodeCall(t *testing.T) { // Launch the request: var ( - distance = 230 - remote = test.getNode(test.remotekey, test.remoteaddr).Node() - nodes = nodesAtDistance(remote.ID(), distance, 8) - done = make(chan error, 1) - response []*enode.Node + distances = []uint{230} + remote = test.getNode(test.remotekey, test.remoteaddr).Node() + nodes = nodesAtDistance(remote.ID(), int(distances[0]), 8) + done = make(chan error, 1) + response []*enode.Node ) go func() { var err error - response, err = test.udp.findnode(remote, distance) + response, err = test.udp.findnode(remote, distances) done <- err }() // Serve the responses: - test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { - if p.Distance != uint(distance) { - t.Fatalf("wrong bucket: %d", p.Distance) + test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { + if !reflect.DeepEqual(p.Distances, distances) { + t.Fatalf("wrong distances in request: %v", p.Distances) } - test.packetIn(&nodesV5{ + test.packetIn(&v5wire.Nodes{ ReqID: p.ReqID, Total: 2, Nodes: nodesToRecords(nodes[:4]), }) - test.packetIn(&nodesV5{ + test.packetIn(&v5wire.Nodes{ ReqID: p.ReqID, Total: 2, Nodes: nodesToRecords(nodes[4:]), @@ -314,16 +344,16 @@ func TestUDPv5_callResend(t *testing.T) { }() // Ping answered by WHOAREYOU. - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { - test.packetIn(&whoareyouV5{AuthTag: authTag}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) // Ping should be re-sent. - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { - test.packetIn(&pongV5{ReqID: p.ReqID}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) }) // Answer the other ping. - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { - test.packetIn(&pongV5{ReqID: p.ReqID}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != nil { t.Fatalf("unexpected ping error: %v", err) @@ -347,12 +377,12 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) { }() // Ping answered by WHOAREYOU. - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { - test.packetIn(&whoareyouV5{AuthTag: authTag}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) // Ping answered by WHOAREYOU again. - test.waitPacketOut(func(p *pingV5, addr *net.UDPAddr, authTag []byte) { - test.packetIn(&whoareyouV5{AuthTag: authTag}) + test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) if err := <-done; err != errTimeout { t.Fatalf("unexpected ping error: %q", err) @@ -367,27 +397,27 @@ func TestUDPv5_callTimeoutReset(t *testing.T) { // Launch the request: var ( - distance = 230 + distance = uint(230) remote = test.getNode(test.remotekey, test.remoteaddr).Node() - nodes = nodesAtDistance(remote.ID(), distance, 8) + nodes = nodesAtDistance(remote.ID(), int(distance), 8) done = make(chan error, 1) ) go func() { - _, err := test.udp.findnode(remote, distance) + _, err := test.udp.findnode(remote, []uint{distance}) done <- err }() // Serve two responses, slowly. - test.waitPacketOut(func(p *findnodeV5, addr *net.UDPAddr, authTag []byte) { + test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { time.Sleep(respTimeout - 50*time.Millisecond) - test.packetIn(&nodesV5{ + test.packetIn(&v5wire.Nodes{ ReqID: p.ReqID, Total: 2, Nodes: nodesToRecords(nodes[:4]), }) time.Sleep(respTimeout - 50*time.Millisecond) - test.packetIn(&nodesV5{ + test.packetIn(&v5wire.Nodes{ ReqID: p.ReqID, Total: 2, Nodes: nodesToRecords(nodes[4:]), @@ -398,6 +428,97 @@ func TestUDPv5_callTimeoutReset(t *testing.T) { } } +// This test checks that TALKREQ calls the registered handler function. +func TestUDPv5_talkHandling(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + var recvMessage []byte + test.udp.RegisterTalkHandler("test", func(message []byte) []byte { + recvMessage = message + return []byte("test response") + }) + + // Successful case: + test.packetIn(&v5wire.TalkRequest{ + ReqID: []byte("foo"), + Protocol: "test", + Message: []byte("test request"), + }) + test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { + if !bytes.Equal(p.ReqID, []byte("foo")) { + t.Error("wrong request ID in response:", p.ReqID) + } + if string(p.Message) != "test response" { + t.Errorf("wrong talk response message: %q", p.Message) + } + if string(recvMessage) != "test request" { + t.Errorf("wrong message received in handler: %q", recvMessage) + } + }) + + // Check that empty response is returned for unregistered protocols. + recvMessage = nil + test.packetIn(&v5wire.TalkRequest{ + ReqID: []byte("2"), + Protocol: "wrong", + Message: []byte("test request"), + }) + test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { + if !bytes.Equal(p.ReqID, []byte("2")) { + t.Error("wrong request ID in response:", p.ReqID) + } + if string(p.Message) != "" { + t.Errorf("wrong talk response message: %q", p.Message) + } + if recvMessage != nil { + t.Errorf("handler was called for wrong protocol: %q", recvMessage) + } + }) +} + +// This test checks that outgoing TALKREQ calls work. +func TestUDPv5_talkRequest(t *testing.T) { + t.Parallel() + test := newUDPV5Test(t) + defer test.close() + + remote := test.getNode(test.remotekey, test.remoteaddr).Node() + done := make(chan error, 1) + + // This request times out. + go func() { + _, err := test.udp.TalkRequest(remote, "test", []byte("test request")) + done <- err + }() + test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {}) + if err := <-done; err != errTimeout { + t.Fatalf("want errTimeout, got %q", err) + } + + // This request works. + go func() { + _, err := test.udp.TalkRequest(remote, "test", []byte("test request")) + done <- err + }() + test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { + if p.Protocol != "test" { + t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) + } + if string(p.Message) != "test request" { + t.Errorf("wrong message talk request: %q", p.Message) + } + test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.TalkResponse{ + ReqID: p.ReqID, + Message: []byte("test response"), + }) + }) + if err := <-done; err != nil { + t.Fatal(err) + } +} + // This test checks that lookup works. func TestUDPv5_lookup(t *testing.T) { t.Parallel() @@ -417,7 +538,8 @@ func TestUDPv5_lookup(t *testing.T) { } // Seed table with initial node. - fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}) + initialNode := lookupTestnet.node(256, 0) + fillTable(test.table, []*node{wrapNode(initialNode)}) // Start the lookup. resultC := make(chan []*enode.Node, 1) @@ -427,22 +549,30 @@ func TestUDPv5_lookup(t *testing.T) { }() // Answer lookup packets. + asked := make(map[enode.ID]bool) for done := false; !done; { - done = test.waitPacketOut(func(p packetV5, to *net.UDPAddr, authTag []byte) { + done = test.waitPacketOut(func(p v5wire.Packet, to *net.UDPAddr, _ v5wire.Nonce) { recipient, key := lookupTestnet.nodeByAddr(to) switch p := p.(type) { - case *pingV5: - test.packetInFrom(key, to, &pongV5{ReqID: p.ReqID}) - case *findnodeV5: - nodes := lookupTestnet.neighborsAtDistance(recipient, p.Distance, 3) - response := &nodesV5{ReqID: p.ReqID, Total: 1, Nodes: nodesToRecords(nodes)} - test.packetInFrom(key, to, response) + case *v5wire.Ping: + test.packetInFrom(key, to, &v5wire.Pong{ReqID: p.ReqID}) + case *v5wire.Findnode: + if asked[recipient.ID()] { + t.Error("Asked node", recipient.ID(), "twice") + } + asked[recipient.ID()] = true + nodes := lookupTestnet.neighborsAtDistances(recipient, p.Distances, 16) + t.Logf("Got FINDNODE for %v, returning %d nodes", p.Distances, len(nodes)) + for _, resp := range packNodes(p.ReqID, nodes) { + test.packetInFrom(key, to, resp) + } } }) } // Verify result nodes. - checkLookupResults(t, lookupTestnet, <-resultC) + results := <-resultC + checkLookupResults(t, lookupTestnet, results) } // This test checks the local node can be utilised to set key-values. @@ -481,6 +611,7 @@ type udpV5Test struct { nodesByIP map[string]*enode.LocalNode } +// testCodec is the packet encoding used by protocol tests. This codec does not perform encryption. type testCodec struct { test *udpV5Test id enode.ID @@ -489,46 +620,44 @@ type testCodec struct { type testCodecFrame struct { NodeID enode.ID - AuthTag []byte + AuthTag v5wire.Nonce Ptype byte Packet rlp.RawValue } -func (c *testCodec) encode(toID enode.ID, addr string, p packetV5, _ *whoareyouV5) ([]byte, []byte, error) { +func (c *testCodec) Encode(toID enode.ID, addr string, p v5wire.Packet, _ *v5wire.Whoareyou) ([]byte, v5wire.Nonce, error) { c.ctr++ - authTag := make([]byte, 8) - binary.BigEndian.PutUint64(authTag, c.ctr) + var authTag v5wire.Nonce + binary.BigEndian.PutUint64(authTag[:], c.ctr) + penc, _ := rlp.EncodeToBytes(p) - frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.kind(), penc}) + frame, err := rlp.EncodeToBytes(testCodecFrame{c.id, authTag, p.Kind(), penc}) return frame, authTag, err } -func (c *testCodec) decode(input []byte, addr string) (enode.ID, *enode.Node, packetV5, error) { +func (c *testCodec) Decode(input []byte, addr string) (enode.ID, *enode.Node, v5wire.Packet, error) { frame, p, err := c.decodeFrame(input) if err != nil { return enode.ID{}, nil, nil, err } - if p.kind() == p_whoareyouV5 { - frame.NodeID = enode.ID{} // match wireCodec behavior - } return frame.NodeID, nil, p, nil } -func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p packetV5, err error) { +func (c *testCodec) decodeFrame(input []byte) (frame testCodecFrame, p v5wire.Packet, err error) { if err = rlp.DecodeBytes(input, &frame); err != nil { return frame, nil, fmt.Errorf("invalid frame: %v", err) } switch frame.Ptype { - case p_unknownV5: - dec := new(unknownV5) + case v5wire.UnknownPacket: + dec := new(v5wire.Unknown) err = rlp.DecodeBytes(frame.Packet, &dec) p = dec - case p_whoareyouV5: - dec := new(whoareyouV5) + case v5wire.WhoareyouPacket: + dec := new(v5wire.Whoareyou) err = rlp.DecodeBytes(frame.Packet, &dec) p = dec default: - p, err = decodePacketBodyV5(frame.Ptype, frame.Packet) + p, err = v5wire.DecodeMessage(frame.Ptype, frame.Packet) } return frame, p, err } @@ -561,20 +690,20 @@ func newUDPV5Test(t *testing.T) *udpV5Test { } // handles a packet as if it had been sent to the transport. -func (test *udpV5Test) packetIn(packet packetV5) { +func (test *udpV5Test) packetIn(packet v5wire.Packet) { test.t.Helper() test.packetInFrom(test.remotekey, test.remoteaddr, packet) } // handles a packet as if it had been sent to the transport by the key/endpoint. -func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet packetV5) { +func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet v5wire.Packet) { test.t.Helper() ln := test.getNode(key, addr) codec := &testCodec{test: test, id: ln.ID()} - enc, _, err := codec.encode(test.udp.Self().ID(), addr.String(), packet, nil) + enc, _, err := codec.Encode(test.udp.Self().ID(), addr.String(), packet, nil) if err != nil { - test.t.Errorf("%s encode error: %v", packet.name(), err) + test.t.Errorf("%s encode error: %v", packet.Name(), err) } if test.udp.dispatchReadPacket(addr, enc) { <-test.udp.readNextCh // unblock UDPv5.dispatch @@ -596,8 +725,12 @@ func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode. return ln } +// waitPacketOut waits for the next output packet and handles it using the given 'validate' +// function. The function must be of type func (X, *net.UDPAddr, v5wire.Nonce) where X is +// assignable to packetV5. func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() + fn := reflect.ValueOf(validate) exptype := fn.Type().In(0) diff --git a/p2p/discover/v5wire/crypto.go b/p2p/discover/v5wire/crypto.go new file mode 100644 index 000000000000..fc0a0edef594 --- /dev/null +++ b/p2p/discover/v5wire/crypto.go @@ -0,0 +1,180 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package v5wire + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/elliptic" + "errors" + "fmt" + "hash" + + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" + "golang.org/x/crypto/hkdf" +) + +const ( + // Encryption/authentication parameters. + aesKeySize = 16 + gcmNonceSize = 12 +) + +// Nonce represents a nonce used for AES/GCM. +type Nonce [gcmNonceSize]byte + +// EncodePubkey encodes a public key. +func EncodePubkey(key *ecdsa.PublicKey) []byte { + switch key.Curve { + case crypto.S256(): + return crypto.CompressPubkey(key) + default: + panic("unsupported curve " + key.Curve.Params().Name + " in EncodePubkey") + } +} + +// DecodePubkey decodes a public key in compressed format. +func DecodePubkey(curve elliptic.Curve, e []byte) (*ecdsa.PublicKey, error) { + switch curve { + case crypto.S256(): + if len(e) != 33 { + return nil, errors.New("wrong size public key data") + } + return crypto.DecompressPubkey(e) + default: + return nil, fmt.Errorf("unsupported curve %s in DecodePubkey", curve.Params().Name) + } +} + +// idNonceHash computes the ID signature hash used in the handshake. +func idNonceHash(h hash.Hash, challenge, ephkey []byte, destID enode.ID) []byte { + h.Reset() + h.Write([]byte("discovery v5 identity proof")) + h.Write(challenge) + h.Write(ephkey) + h.Write(destID[:]) + return h.Sum(nil) +} + +// makeIDSignature creates the ID nonce signature. +func makeIDSignature(hash hash.Hash, key *ecdsa.PrivateKey, challenge, ephkey []byte, destID enode.ID) ([]byte, error) { + input := idNonceHash(hash, challenge, ephkey, destID) + switch key.Curve { + case crypto.S256(): + idsig, err := crypto.Sign(input, key) + if err != nil { + return nil, err + } + return idsig[:len(idsig)-1], nil // remove recovery ID + default: + return nil, fmt.Errorf("unsupported curve %s", key.Curve.Params().Name) + } +} + +// s256raw is an unparsed secp256k1 public key ENR entry. +type s256raw []byte + +func (s256raw) ENRKey() string { return "secp256k1" } + +// verifyIDSignature checks that signature over idnonce was made by the given node. +func verifyIDSignature(hash hash.Hash, sig []byte, n *enode.Node, challenge, ephkey []byte, destID enode.ID) error { + switch idscheme := n.Record().IdentityScheme(); idscheme { + case "v4": + var pubkey s256raw + if n.Load(&pubkey) != nil { + return errors.New("no secp256k1 public key in record") + } + input := idNonceHash(hash, challenge, ephkey, destID) + if !crypto.VerifySignature(pubkey, input, sig) { + return errInvalidNonceSig + } + return nil + default: + return fmt.Errorf("can't verify ID nonce signature against scheme %q", idscheme) + } +} + +type hashFn func() hash.Hash + +// deriveKeys creates the session keys. +func deriveKeys(hash hashFn, priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, n1, n2 enode.ID, challenge []byte) *session { + const text = "discovery v5 key agreement" + var info = make([]byte, 0, len(text)+len(n1)+len(n2)) + info = append(info, text...) + info = append(info, n1[:]...) + info = append(info, n2[:]...) + + eph := ecdh(priv, pub) + if eph == nil { + return nil + } + kdf := hkdf.New(hash, eph, challenge, info) + sec := session{writeKey: make([]byte, aesKeySize), readKey: make([]byte, aesKeySize)} + kdf.Read(sec.writeKey) + kdf.Read(sec.readKey) + for i := range eph { + eph[i] = 0 + } + return &sec +} + +// ecdh creates a shared secret. +func ecdh(privkey *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey) []byte { + secX, secY := pubkey.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes()) + if secX == nil { + return nil + } + sec := make([]byte, 33) + sec[0] = 0x02 | byte(secY.Bit(0)) + math.ReadBits(secX, sec[1:]) + return sec +} + +// encryptGCM encrypts pt using AES-GCM with the given key and nonce. The ciphertext is +// appended to dest, which must not overlap with plaintext. The resulting ciphertext is 16 +// bytes longer than plaintext because it contains an authentication tag. +func encryptGCM(dest, key, nonce, plaintext, authData []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + panic(fmt.Errorf("can't create block cipher: %v", err)) + } + aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) + if err != nil { + panic(fmt.Errorf("can't create GCM: %v", err)) + } + return aesgcm.Seal(dest, nonce, plaintext, authData), nil +} + +// decryptGCM decrypts ct using AES-GCM with the given key and nonce. +func decryptGCM(key, nonce, ct, authData []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("can't create block cipher: %v", err) + } + if len(nonce) != gcmNonceSize { + return nil, fmt.Errorf("invalid GCM nonce size: %d", len(nonce)) + } + aesgcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize) + if err != nil { + return nil, fmt.Errorf("can't create GCM: %v", err) + } + pt := make([]byte, 0, len(ct)) + return aesgcm.Open(pt, nonce, ct, authData) +} diff --git a/p2p/discover/v5wire/crypto_test.go b/p2p/discover/v5wire/crypto_test.go new file mode 100644 index 000000000000..72169b43141a --- /dev/null +++ b/p2p/discover/v5wire/crypto_test.go @@ -0,0 +1,124 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package v5wire + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/sha256" + "reflect" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +func TestVector_ECDH(t *testing.T) { + var ( + staticKey = hexPrivkey("0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736") + publicKey = hexPubkey(crypto.S256(), "0x039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231") + want = hexutil.MustDecode("0x033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e") + ) + result := ecdh(staticKey, publicKey) + check(t, "shared-secret", result, want) +} + +func TestVector_KDF(t *testing.T) { + var ( + ephKey = hexPrivkey("0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736") + cdata = hexutil.MustDecode("0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000") + net = newHandshakeTest() + ) + defer net.close() + + destKey := &testKeyB.PublicKey + s := deriveKeys(sha256.New, ephKey, destKey, net.nodeA.id(), net.nodeB.id(), cdata) + t.Logf("ephemeral-key = %#x", ephKey.D) + t.Logf("dest-pubkey = %#x", EncodePubkey(destKey)) + t.Logf("node-id-a = %#x", net.nodeA.id().Bytes()) + t.Logf("node-id-b = %#x", net.nodeB.id().Bytes()) + t.Logf("challenge-data = %#x", cdata) + check(t, "initiator-key", s.writeKey, hexutil.MustDecode("0xdccc82d81bd610f4f76d3ebe97a40571")) + check(t, "recipient-key", s.readKey, hexutil.MustDecode("0xac74bb8773749920b0d3a8881c173ec5")) +} + +func TestVector_IDSignature(t *testing.T) { + var ( + key = hexPrivkey("0xfb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736") + destID = enode.HexID("0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") + ephkey = hexutil.MustDecode("0x039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231") + cdata = hexutil.MustDecode("0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000") + ) + + sig, err := makeIDSignature(sha256.New(), key, cdata, ephkey, destID) + if err != nil { + t.Fatal(err) + } + t.Logf("static-key = %#x", key.D) + t.Logf("challenge-data = %#x", cdata) + t.Logf("ephemeral-pubkey = %#x", ephkey) + t.Logf("node-id-B = %#x", destID.Bytes()) + expected := "0x94852a1e2318c4e5e9d422c98eaf19d1d90d876b29cd06ca7cb7546d0fff7b484fe86c09a064fe72bdbef73ba8e9c34df0cd2b53e9d65528c2c7f336d5dfc6e6" + check(t, "id-signature", sig, hexutil.MustDecode(expected)) +} + +func TestDeriveKeys(t *testing.T) { + t.Parallel() + + var ( + n1 = enode.ID{1} + n2 = enode.ID{2} + cdata = []byte{1, 2, 3, 4} + ) + sec1 := deriveKeys(sha256.New, testKeyA, &testKeyB.PublicKey, n1, n2, cdata) + sec2 := deriveKeys(sha256.New, testKeyB, &testKeyA.PublicKey, n1, n2, cdata) + if sec1 == nil || sec2 == nil { + t.Fatal("key agreement failed") + } + if !reflect.DeepEqual(sec1, sec2) { + t.Fatalf("keys not equal:\n %+v\n %+v", sec1, sec2) + } +} + +func check(t *testing.T, what string, x, y []byte) { + t.Helper() + + if !bytes.Equal(x, y) { + t.Errorf("wrong %s: %#x != %#x", what, x, y) + } else { + t.Logf("%s = %#x", what, x) + } +} + +func hexPrivkey(input string) *ecdsa.PrivateKey { + key, err := crypto.HexToECDSA(strings.TrimPrefix(input, "0x")) + if err != nil { + panic(err) + } + return key +} + +func hexPubkey(curve elliptic.Curve, input string) *ecdsa.PublicKey { + key, err := DecodePubkey(curve, hexutil.MustDecode(input)) + if err != nil { + panic(err) + } + return key +} diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go new file mode 100644 index 000000000000..f502339e1e56 --- /dev/null +++ b/p2p/discover/v5wire/encoding.go @@ -0,0 +1,648 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package v5wire + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + crand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "hash" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +// TODO concurrent WHOAREYOU tie-breaker +// TODO rehandshake after X packets + +// Header represents a packet header. +type Header struct { + IV [sizeofMaskingIV]byte + StaticHeader + AuthData []byte + + src enode.ID // used by decoder +} + +// StaticHeader contains the static fields of a packet header. +type StaticHeader struct { + ProtocolID [6]byte + Version uint16 + Flag byte + Nonce Nonce + AuthSize uint16 +} + +// Authdata layouts. +type ( + whoareyouAuthData struct { + IDNonce [16]byte // ID proof data + RecordSeq uint64 // highest known ENR sequence of requester + } + + handshakeAuthData struct { + h struct { + SrcID enode.ID + SigSize byte // ignature data + PubkeySize byte // offset of + } + // Trailing variable-size data. + signature, pubkey, record []byte + } + + messageAuthData struct { + SrcID enode.ID + } +) + +// Packet header flag values. +const ( + flagMessage = iota + flagWhoareyou + flagHandshake +) + +// Protocol constants. +const ( + version = 1 + minVersion = 1 + sizeofMaskingIV = 16 + + minMessageSize = 48 // this refers to data after static headers + randomPacketMsgSize = 20 +) + +var protocolID = [6]byte{'d', 'i', 's', 'c', 'v', '5'} + +// Errors. +var ( + errTooShort = errors.New("packet too short") + errInvalidHeader = errors.New("invalid packet header") + errInvalidFlag = errors.New("invalid flag value in header") + errMinVersion = errors.New("version of packet header below minimum") + errMsgTooShort = errors.New("message/handshake packet below minimum size") + errAuthSize = errors.New("declared auth size is beyond packet length") + errUnexpectedHandshake = errors.New("unexpected auth response, not in handshake") + errInvalidAuthKey = errors.New("invalid ephemeral pubkey") + errNoRecord = errors.New("expected ENR in handshake but none sent") + errInvalidNonceSig = errors.New("invalid ID nonce signature") + errMessageTooShort = errors.New("message contains no data") + errMessageDecrypt = errors.New("cannot decrypt message") +) + +// Public errors. +var ( + ErrInvalidReqID = errors.New("request ID larger than 8 bytes") +) + +// Packet sizes. +var ( + sizeofStaticHeader = binary.Size(StaticHeader{}) + sizeofWhoareyouAuthData = binary.Size(whoareyouAuthData{}) + sizeofHandshakeAuthData = binary.Size(handshakeAuthData{}.h) + sizeofMessageAuthData = binary.Size(messageAuthData{}) + sizeofStaticPacketData = sizeofMaskingIV + sizeofStaticHeader +) + +// Codec encodes and decodes Discovery v5 packets. +// This type is not safe for concurrent use. +type Codec struct { + sha256 hash.Hash + localnode *enode.LocalNode + privkey *ecdsa.PrivateKey + sc *SessionCache + + // encoder buffers + buf bytes.Buffer // whole packet + headbuf bytes.Buffer // packet header + msgbuf bytes.Buffer // message RLP plaintext + msgctbuf []byte // message data ciphertext + + // decoder buffer + reader bytes.Reader +} + +// NewCodec creates a wire codec. +func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock) *Codec { + c := &Codec{ + sha256: sha256.New(), + localnode: ln, + privkey: key, + sc: NewSessionCache(1024, clock), + } + return c +} + +// Encode encodes a packet to a node. 'id' and 'addr' specify the destination node. The +// 'challenge' parameter should be the most recently received WHOAREYOU packet from that +// node. +func (c *Codec) Encode(id enode.ID, addr string, packet Packet, challenge *Whoareyou) ([]byte, Nonce, error) { + // Create the packet header. + var ( + head Header + session *session + msgData []byte + err error + ) + switch { + case packet.Kind() == WhoareyouPacket: + head, err = c.encodeWhoareyou(id, packet.(*Whoareyou)) + case challenge != nil: + // We have an unanswered challenge, send handshake. + head, session, err = c.encodeHandshakeHeader(id, addr, challenge) + default: + session = c.sc.session(id, addr) + if session != nil { + // There is a session, use it. + head, err = c.encodeMessageHeader(id, session) + } else { + // No keys, send random data to kick off the handshake. + head, msgData, err = c.encodeRandom(id) + } + } + if err != nil { + return nil, Nonce{}, err + } + + // Generate masking IV. + if err := c.sc.maskingIVGen(head.IV[:]); err != nil { + return nil, Nonce{}, fmt.Errorf("can't generate masking IV: %v", err) + } + + // Encode header data. + c.writeHeaders(&head) + + // Store sent WHOAREYOU challenges. + if challenge, ok := packet.(*Whoareyou); ok { + challenge.ChallengeData = bytesCopy(&c.buf) + c.sc.storeSentHandshake(id, addr, challenge) + } else if msgData == nil { + headerData := c.buf.Bytes() + msgData, err = c.encryptMessage(session, packet, &head, headerData) + if err != nil { + return nil, Nonce{}, err + } + } + + enc, err := c.EncodeRaw(id, head, msgData) + return enc, head.Nonce, err +} + +// EncodeRaw encodes a packet with the given header. +func (c *Codec) EncodeRaw(id enode.ID, head Header, msgdata []byte) ([]byte, error) { + c.writeHeaders(&head) + + // Apply masking. + masked := c.buf.Bytes()[sizeofMaskingIV:] + mask := head.mask(id) + mask.XORKeyStream(masked[:], masked[:]) + + // Write message data. + c.buf.Write(msgdata) + return c.buf.Bytes(), nil +} + +func (c *Codec) writeHeaders(head *Header) { + c.buf.Reset() + c.buf.Write(head.IV[:]) + binary.Write(&c.buf, binary.BigEndian, &head.StaticHeader) + c.buf.Write(head.AuthData) +} + +// makeHeader creates a packet header. +func (c *Codec) makeHeader(toID enode.ID, flag byte, authsizeExtra int) Header { + var authsize int + switch flag { + case flagMessage: + authsize = sizeofMessageAuthData + case flagWhoareyou: + authsize = sizeofWhoareyouAuthData + case flagHandshake: + authsize = sizeofHandshakeAuthData + default: + panic(fmt.Errorf("BUG: invalid packet header flag %x", flag)) + } + authsize += authsizeExtra + if authsize > int(^uint16(0)) { + panic(fmt.Errorf("BUG: auth size %d overflows uint16", authsize)) + } + return Header{ + StaticHeader: StaticHeader{ + ProtocolID: protocolID, + Version: version, + Flag: flag, + AuthSize: uint16(authsize), + }, + } +} + +// encodeRandom encodes a packet with random content. +func (c *Codec) encodeRandom(toID enode.ID) (Header, []byte, error) { + head := c.makeHeader(toID, flagMessage, 0) + + // Encode auth data. + auth := messageAuthData{SrcID: c.localnode.ID()} + if _, err := crand.Read(head.Nonce[:]); err != nil { + return head, nil, fmt.Errorf("can't get random data: %v", err) + } + c.headbuf.Reset() + binary.Write(&c.headbuf, binary.BigEndian, auth) + head.AuthData = c.headbuf.Bytes() + + // Fill message ciphertext buffer with random bytes. + c.msgctbuf = append(c.msgctbuf[:0], make([]byte, randomPacketMsgSize)...) + crand.Read(c.msgctbuf) + return head, c.msgctbuf, nil +} + +// encodeWhoareyou encodes a WHOAREYOU packet. +func (c *Codec) encodeWhoareyou(toID enode.ID, packet *Whoareyou) (Header, error) { + // Sanity check node field to catch misbehaving callers. + if packet.RecordSeq > 0 && packet.Node == nil { + panic("BUG: missing node in whoareyou with non-zero seq") + } + + // Create header. + head := c.makeHeader(toID, flagWhoareyou, 0) + head.AuthData = bytesCopy(&c.buf) + head.Nonce = packet.Nonce + + // Encode auth data. + auth := &whoareyouAuthData{ + IDNonce: packet.IDNonce, + RecordSeq: packet.RecordSeq, + } + c.headbuf.Reset() + binary.Write(&c.headbuf, binary.BigEndian, auth) + head.AuthData = c.headbuf.Bytes() + return head, nil +} + +// encodeHandshakeMessage encodes the handshake message packet header. +func (c *Codec) encodeHandshakeHeader(toID enode.ID, addr string, challenge *Whoareyou) (Header, *session, error) { + // Ensure calling code sets challenge.node. + if challenge.Node == nil { + panic("BUG: missing challenge.Node in encode") + } + + // Generate new secrets. + auth, session, err := c.makeHandshakeAuth(toID, addr, challenge) + if err != nil { + return Header{}, nil, err + } + + // Generate nonce for message. + nonce, err := c.sc.nextNonce(session) + if err != nil { + return Header{}, nil, fmt.Errorf("can't generate nonce: %v", err) + } + + // TODO: this should happen when the first authenticated message is received + c.sc.storeNewSession(toID, addr, session) + + // Encode the auth header. + var ( + authsizeExtra = len(auth.pubkey) + len(auth.signature) + len(auth.record) + head = c.makeHeader(toID, flagHandshake, authsizeExtra) + ) + c.headbuf.Reset() + binary.Write(&c.headbuf, binary.BigEndian, &auth.h) + c.headbuf.Write(auth.signature) + c.headbuf.Write(auth.pubkey) + c.headbuf.Write(auth.record) + head.AuthData = c.headbuf.Bytes() + head.Nonce = nonce + return head, session, err +} + +// encodeAuthHeader creates the auth header on a request packet following WHOAREYOU. +func (c *Codec) makeHandshakeAuth(toID enode.ID, addr string, challenge *Whoareyou) (*handshakeAuthData, *session, error) { + auth := new(handshakeAuthData) + auth.h.SrcID = c.localnode.ID() + + // Create the ephemeral key. This needs to be first because the + // key is part of the ID nonce signature. + var remotePubkey = new(ecdsa.PublicKey) + if err := challenge.Node.Load((*enode.Secp256k1)(remotePubkey)); err != nil { + return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient") + } + ephkey, err := c.sc.ephemeralKeyGen() + if err != nil { + return nil, nil, fmt.Errorf("can't generate ephemeral key") + } + ephpubkey := EncodePubkey(&ephkey.PublicKey) + auth.pubkey = ephpubkey[:] + auth.h.PubkeySize = byte(len(auth.pubkey)) + + // Add ID nonce signature to response. + cdata := challenge.ChallengeData + idsig, err := makeIDSignature(c.sha256, c.privkey, cdata, ephpubkey[:], toID) + if err != nil { + return nil, nil, fmt.Errorf("can't sign: %v", err) + } + auth.signature = idsig + auth.h.SigSize = byte(len(auth.signature)) + + // Add our record to response if it's newer than what remote side has. + ln := c.localnode.Node() + if challenge.RecordSeq < ln.Seq() { + auth.record, _ = rlp.EncodeToBytes(ln.Record()) + } + + // Create session keys. + sec := deriveKeys(sha256.New, ephkey, remotePubkey, c.localnode.ID(), challenge.Node.ID(), cdata) + if sec == nil { + return nil, nil, fmt.Errorf("key derivation failed") + } + return auth, sec, err +} + +// encodeMessage encodes an encrypted message packet. +func (c *Codec) encodeMessageHeader(toID enode.ID, s *session) (Header, error) { + head := c.makeHeader(toID, flagMessage, 0) + + // Create the header. + nonce, err := c.sc.nextNonce(s) + if err != nil { + return Header{}, fmt.Errorf("can't generate nonce: %v", err) + } + auth := messageAuthData{SrcID: c.localnode.ID()} + c.buf.Reset() + binary.Write(&c.buf, binary.BigEndian, &auth) + head.AuthData = bytesCopy(&c.buf) + head.Nonce = nonce + return head, err +} + +func (c *Codec) encryptMessage(s *session, p Packet, head *Header, headerData []byte) ([]byte, error) { + // Encode message plaintext. + c.msgbuf.Reset() + c.msgbuf.WriteByte(p.Kind()) + if err := rlp.Encode(&c.msgbuf, p); err != nil { + return nil, err + } + messagePT := c.msgbuf.Bytes() + + // Encrypt into message ciphertext buffer. + messageCT, err := encryptGCM(c.msgctbuf[:0], s.writeKey, head.Nonce[:], messagePT, headerData) + if err == nil { + c.msgctbuf = messageCT + } + return messageCT, err +} + +// Decode decodes a discovery packet. +func (c *Codec) Decode(input []byte, addr string) (src enode.ID, n *enode.Node, p Packet, err error) { + // Unmask the static header. + if len(input) < sizeofStaticPacketData { + return enode.ID{}, nil, nil, errTooShort + } + var head Header + copy(head.IV[:], input[:sizeofMaskingIV]) + mask := head.mask(c.localnode.ID()) + staticHeader := input[sizeofMaskingIV:sizeofStaticPacketData] + mask.XORKeyStream(staticHeader, staticHeader) + + // Decode and verify the static header. + c.reader.Reset(staticHeader) + binary.Read(&c.reader, binary.BigEndian, &head.StaticHeader) + remainingInput := len(input) - sizeofStaticPacketData + if err := head.checkValid(remainingInput); err != nil { + return enode.ID{}, nil, nil, err + } + + // Unmask auth data. + authDataEnd := sizeofStaticPacketData + int(head.AuthSize) + authData := input[sizeofStaticPacketData:authDataEnd] + mask.XORKeyStream(authData, authData) + head.AuthData = authData + + // Delete timed-out handshakes. This must happen before decoding to avoid + // processing the same handshake twice. + c.sc.handshakeGC() + + // Decode auth part and message. + headerData := input[:authDataEnd] + msgData := input[authDataEnd:] + switch head.Flag { + case flagWhoareyou: + p, err = c.decodeWhoareyou(&head, headerData) + case flagHandshake: + n, p, err = c.decodeHandshakeMessage(addr, &head, headerData, msgData) + case flagMessage: + p, err = c.decodeMessage(addr, &head, headerData, msgData) + default: + err = errInvalidFlag + } + return head.src, n, p, err +} + +// decodeWhoareyou reads packet data after the header as a WHOAREYOU packet. +func (c *Codec) decodeWhoareyou(head *Header, headerData []byte) (Packet, error) { + if len(head.AuthData) != sizeofWhoareyouAuthData { + return nil, fmt.Errorf("invalid auth size %d for WHOAREYOU", len(head.AuthData)) + } + var auth whoareyouAuthData + c.reader.Reset(head.AuthData) + binary.Read(&c.reader, binary.BigEndian, &auth) + p := &Whoareyou{ + Nonce: head.Nonce, + IDNonce: auth.IDNonce, + RecordSeq: auth.RecordSeq, + ChallengeData: make([]byte, len(headerData)), + } + copy(p.ChallengeData, headerData) + return p, nil +} + +func (c *Codec) decodeHandshakeMessage(fromAddr string, head *Header, headerData, msgData []byte) (n *enode.Node, p Packet, err error) { + node, auth, session, err := c.decodeHandshake(fromAddr, head) + if err != nil { + c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + return nil, nil, err + } + + // Decrypt the message using the new session keys. + msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, session.readKey) + if err != nil { + c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + return node, msg, err + } + + // Handshake OK, drop the challenge and store the new session keys. + c.sc.storeNewSession(auth.h.SrcID, fromAddr, session) + c.sc.deleteHandshake(auth.h.SrcID, fromAddr) + return node, msg, nil +} + +func (c *Codec) decodeHandshake(fromAddr string, head *Header) (n *enode.Node, auth handshakeAuthData, s *session, err error) { + if auth, err = c.decodeHandshakeAuthData(head); err != nil { + return nil, auth, nil, err + } + + // Verify against our last WHOAREYOU. + challenge := c.sc.getHandshake(auth.h.SrcID, fromAddr) + if challenge == nil { + return nil, auth, nil, errUnexpectedHandshake + } + // Get node record. + n, err = c.decodeHandshakeRecord(challenge.Node, auth.h.SrcID, auth.record) + if err != nil { + return nil, auth, nil, err + } + // Verify ID nonce signature. + sig := auth.signature + cdata := challenge.ChallengeData + err = verifyIDSignature(c.sha256, sig, n, cdata, auth.pubkey, c.localnode.ID()) + if err != nil { + return nil, auth, nil, err + } + // Verify ephemeral key is on curve. + ephkey, err := DecodePubkey(c.privkey.Curve, auth.pubkey) + if err != nil { + return nil, auth, nil, errInvalidAuthKey + } + // Derive sesssion keys. + session := deriveKeys(sha256.New, c.privkey, ephkey, auth.h.SrcID, c.localnode.ID(), cdata) + session = session.keysFlipped() + return n, auth, session, nil +} + +// decodeHandshakeAuthData reads the authdata section of a handshake packet. +func (c *Codec) decodeHandshakeAuthData(head *Header) (auth handshakeAuthData, err error) { + // Decode fixed size part. + if len(head.AuthData) < sizeofHandshakeAuthData { + return auth, fmt.Errorf("header authsize %d too low for handshake", head.AuthSize) + } + c.reader.Reset(head.AuthData) + binary.Read(&c.reader, binary.BigEndian, &auth.h) + head.src = auth.h.SrcID + + // Decode variable-size part. + var ( + vardata = head.AuthData[sizeofHandshakeAuthData:] + sigAndKeySize = int(auth.h.SigSize) + int(auth.h.PubkeySize) + keyOffset = int(auth.h.SigSize) + recOffset = keyOffset + int(auth.h.PubkeySize) + ) + if len(vardata) < sigAndKeySize { + return auth, errTooShort + } + auth.signature = vardata[:keyOffset] + auth.pubkey = vardata[keyOffset:recOffset] + auth.record = vardata[recOffset:] + return auth, nil +} + +// decodeHandshakeRecord verifies the node record contained in a handshake packet. The +// remote node should include the record if we don't have one or if ours is older than the +// latest sequence number. +func (c *Codec) decodeHandshakeRecord(local *enode.Node, wantID enode.ID, remote []byte) (*enode.Node, error) { + node := local + if len(remote) > 0 { + var record enr.Record + if err := rlp.DecodeBytes(remote, &record); err != nil { + return nil, err + } + if local == nil || local.Seq() < record.Seq() { + n, err := enode.New(enode.ValidSchemes, &record) + if err != nil { + return nil, fmt.Errorf("invalid node record: %v", err) + } + if n.ID() != wantID { + return nil, fmt.Errorf("record in handshake has wrong ID: %v", n.ID()) + } + node = n + } + } + if node == nil { + return nil, errNoRecord + } + return node, nil +} + +// decodeMessage reads packet data following the header as an ordinary message packet. +func (c *Codec) decodeMessage(fromAddr string, head *Header, headerData, msgData []byte) (Packet, error) { + if len(head.AuthData) != sizeofMessageAuthData { + return nil, fmt.Errorf("invalid auth size %d for message packet", len(head.AuthData)) + } + var auth messageAuthData + c.reader.Reset(head.AuthData) + binary.Read(&c.reader, binary.BigEndian, &auth) + head.src = auth.SrcID + + // Try decrypting the message. + key := c.sc.readKey(auth.SrcID, fromAddr) + msg, err := c.decryptMessage(msgData, head.Nonce[:], headerData, key) + if err == errMessageDecrypt { + // It didn't work. Start the handshake since this is an ordinary message packet. + return &Unknown{Nonce: head.Nonce}, nil + } + return msg, err +} + +func (c *Codec) decryptMessage(input, nonce, headerData, readKey []byte) (Packet, error) { + msgdata, err := decryptGCM(readKey, nonce, input, headerData) + if err != nil { + return nil, errMessageDecrypt + } + if len(msgdata) == 0 { + return nil, errMessageTooShort + } + return DecodeMessage(msgdata[0], msgdata[1:]) +} + +// checkValid performs some basic validity checks on the header. +// The packetLen here is the length remaining after the static header. +func (h *StaticHeader) checkValid(packetLen int) error { + if h.ProtocolID != protocolID { + return errInvalidHeader + } + if h.Version < minVersion { + return errMinVersion + } + if h.Flag != flagWhoareyou && packetLen < minMessageSize { + return errMsgTooShort + } + if int(h.AuthSize) > packetLen { + return errAuthSize + } + return nil +} + +// headerMask returns a cipher for 'masking' / 'unmasking' packet headers. +func (h *Header) mask(destID enode.ID) cipher.Stream { + block, err := aes.NewCipher(destID[:16]) + if err != nil { + panic("can't create cipher") + } + return cipher.NewCTR(block, h.IV[:]) +} + +func bytesCopy(r *bytes.Buffer) []byte { + b := make([]byte, r.Len()) + copy(b, r.Bytes()) + return b +} diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go new file mode 100644 index 000000000000..d9c807e0a8db --- /dev/null +++ b/p2p/discover/v5wire/encoding_test.go @@ -0,0 +1,636 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package v5wire + +import ( + "bytes" + "crypto/ecdsa" + "encoding/hex" + "flag" + "fmt" + "io/ioutil" + "net" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +// To regenerate discv5 test vectors, run +// +// go test -run TestVectors -write-test-vectors +// +var writeTestVectorsFlag = flag.Bool("write-test-vectors", false, "Overwrite discv5 test vectors in testdata/") + +var ( + testKeyA, _ = crypto.HexToECDSA("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") + testKeyB, _ = crypto.HexToECDSA("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") + testEphKey, _ = crypto.HexToECDSA("0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6") + testIDnonce = [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +) + +// This test checks that the minPacketSize and randomPacketMsgSize constants are well-defined. +func TestMinSizes(t *testing.T) { + var ( + gcmTagSize = 16 + emptyMsg = sizeofMessageAuthData + gcmTagSize + ) + t.Log("static header size", sizeofStaticPacketData) + t.Log("whoareyou size", sizeofStaticPacketData+sizeofWhoareyouAuthData) + t.Log("empty msg size", sizeofStaticPacketData+emptyMsg) + if want := emptyMsg; minMessageSize != want { + t.Fatalf("wrong minMessageSize %d, want %d", minMessageSize, want) + } + if sizeofMessageAuthData+randomPacketMsgSize < minMessageSize { + t.Fatalf("randomPacketMsgSize %d too small", randomPacketMsgSize) + } +} + +// This test checks the basic handshake flow where A talks to B and A has no secrets. +func TestHandshake(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{}) + resp := net.nodeB.expectDecode(t, UnknownPacket, packet) + + // A <- B WHOAREYOU + challenge := &Whoareyou{ + Nonce: resp.(*Unknown).Nonce, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou) + + // A -> B FINDNODE (handshake packet) + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{}) + net.nodeB.expectDecode(t, FindnodeMsg, findnode) + if len(net.nodeB.c.sc.handshakes) > 0 { + t.Fatalf("node B didn't remove handshake from challenge map") + } + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1}) + net.nodeA.expectDecode(t, NodesMsg, nodes) +} + +// This test checks that handshake attempts are removed within the timeout. +func TestHandshake_timeout(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{}) + resp := net.nodeB.expectDecode(t, UnknownPacket, packet) + + // A <- B WHOAREYOU + challenge := &Whoareyou{ + Nonce: resp.(*Unknown).Nonce, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou) + + // A -> B FINDNODE (handshake packet) after timeout + net.clock.Run(handshakeTimeout + 1) + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{}) + net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode) +} + +// This test checks handshake behavior when no record is sent in the auth response. +func TestHandshake_norecord(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{}) + resp := net.nodeB.expectDecode(t, UnknownPacket, packet) + + // A <- B WHOAREYOU + nodeA := net.nodeA.n() + if nodeA.Seq() == 0 { + t.Fatal("need non-zero sequence number") + } + challenge := &Whoareyou{ + Nonce: resp.(*Unknown).Nonce, + IDNonce: testIDnonce, + RecordSeq: nodeA.Seq(), + Node: nodeA, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou) + + // A -> B FINDNODE + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{}) + net.nodeB.expectDecode(t, FindnodeMsg, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1}) + net.nodeA.expectDecode(t, NodesMsg, nodes) +} + +// In this test, A tries to send FINDNODE with existing secrets but B doesn't know +// anything about A. +func TestHandshake_rekey(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + session := &session{ + readKey: []byte("BBBBBBBBBBBBBBBB"), + writeKey: []byte("AAAAAAAAAAAAAAAA"), + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session) + + // A -> B FINDNODE (encrypted with zero keys) + findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{}) + net.nodeB.expectDecode(t, UnknownPacket, findnode) + + // A <- B WHOAREYOU + challenge := &Whoareyou{Nonce: authTag, IDNonce: testIDnonce} + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou) + + // Check that new keys haven't been stored yet. + sa := net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()) + if !bytes.Equal(sa.writeKey, session.writeKey) || !bytes.Equal(sa.readKey, session.readKey) { + t.Fatal("node A stored keys too early") + } + if s := net.nodeB.c.sc.session(net.nodeA.id(), net.nodeA.addr()); s != nil { + t.Fatal("node B stored keys too early") + } + + // A -> B FINDNODE encrypted with new keys + findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{}) + net.nodeB.expectDecode(t, FindnodeMsg, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1}) + net.nodeA.expectDecode(t, NodesMsg, nodes) +} + +// In this test A and B have different keys before the handshake. +func TestHandshake_rekey2(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + initKeysA := &session{ + readKey: []byte("BBBBBBBBBBBBBBBB"), + writeKey: []byte("AAAAAAAAAAAAAAAA"), + } + initKeysB := &session{ + readKey: []byte("CCCCCCCCCCCCCCCC"), + writeKey: []byte("DDDDDDDDDDDDDDDD"), + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), initKeysA) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), initKeysB) + + // A -> B FINDNODE encrypted with initKeysA + findnode, authTag := net.nodeA.encode(t, net.nodeB, &Findnode{Distances: []uint{3}}) + net.nodeB.expectDecode(t, UnknownPacket, findnode) + + // A <- B WHOAREYOU + challenge := &Whoareyou{Nonce: authTag, IDNonce: testIDnonce} + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou) + + // A -> B FINDNODE (handshake packet) + findnode, _ = net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{}) + net.nodeB.expectDecode(t, FindnodeMsg, findnode) + + // A <- B NODES + nodes, _ := net.nodeB.encode(t, net.nodeA, &Nodes{Total: 1}) + net.nodeA.expectDecode(t, NodesMsg, nodes) +} + +func TestHandshake_BadHandshakeAttack(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + // A -> B RANDOM PACKET + packet, _ := net.nodeA.encode(t, net.nodeB, &Findnode{}) + resp := net.nodeB.expectDecode(t, UnknownPacket, packet) + + // A <- B WHOAREYOU + challenge := &Whoareyou{ + Nonce: resp.(*Unknown).Nonce, + IDNonce: testIDnonce, + RecordSeq: 0, + } + whoareyou, _ := net.nodeB.encode(t, net.nodeA, challenge) + net.nodeA.expectDecode(t, WhoareyouPacket, whoareyou) + + // A -> B FINDNODE + incorrect_challenge := &Whoareyou{ + IDNonce: [16]byte{5, 6, 7, 8, 9, 6, 11, 12}, + RecordSeq: challenge.RecordSeq, + Node: challenge.Node, + sent: challenge.sent, + } + incorrect_findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, incorrect_challenge, &Findnode{}) + incorrect_findnode2 := make([]byte, len(incorrect_findnode)) + copy(incorrect_findnode2, incorrect_findnode) + + net.nodeB.expectDecodeErr(t, errInvalidNonceSig, incorrect_findnode) + + // Reject new findnode as previous handshake is now deleted. + net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, incorrect_findnode2) + + // The findnode packet is again rejected even with a valid challenge this time. + findnode, _ := net.nodeA.encodeWithChallenge(t, net.nodeB, challenge, &Findnode{}) + net.nodeB.expectDecodeErr(t, errUnexpectedHandshake, findnode) +} + +// This test checks some malformed packets. +func TestDecodeErrorsV5(t *testing.T) { + t.Parallel() + net := newHandshakeTest() + defer net.close() + + net.nodeA.expectDecodeErr(t, errTooShort, []byte{}) + // TODO some more tests would be nice :) + // - check invalid authdata sizes + // - check invalid handshake data sizes +} + +// This test checks that all test vectors can be decoded. +func TestTestVectorsV5(t *testing.T) { + var ( + idA = enode.PubkeyToIDV4(&testKeyA.PublicKey) + idB = enode.PubkeyToIDV4(&testKeyB.PublicKey) + addr = "127.0.0.1" + session = &session{ + writeKey: hexutil.MustDecode("0x00000000000000000000000000000000"), + readKey: hexutil.MustDecode("0x01010101010101010101010101010101"), + } + challenge0A, challenge1A, challenge0B Whoareyou + ) + + // Create challenge packets. + c := Whoareyou{ + Nonce: Nonce{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + IDNonce: testIDnonce, + } + challenge0A, challenge1A, challenge0B = c, c, c + challenge1A.RecordSeq = 1 + net := newHandshakeTest() + challenge0A.Node = net.nodeA.n() + challenge0B.Node = net.nodeB.n() + challenge1A.Node = net.nodeA.n() + net.close() + + type testVectorTest struct { + name string // test vector name + packet Packet // the packet to be encoded + challenge *Whoareyou // handshake challenge passed to encoder + prep func(*handshakeTest) // called before encode/decode + } + tests := []testVectorTest{ + { + name: "v5.1-whoareyou", + packet: &challenge0B, + }, + { + name: "v5.1-ping-message", + packet: &Ping{ + ReqID: []byte{0, 0, 0, 1}, + ENRSeq: 2, + }, + prep: func(net *handshakeTest) { + net.nodeA.c.sc.storeNewSession(idB, addr, session) + net.nodeB.c.sc.storeNewSession(idA, addr, session.keysFlipped()) + }, + }, + { + name: "v5.1-ping-handshake-enr", + packet: &Ping{ + ReqID: []byte{0, 0, 0, 1}, + ENRSeq: 1, + }, + challenge: &challenge0A, + prep: func(net *handshakeTest) { + // Update challenge.Header.AuthData. + net.nodeA.c.Encode(idB, "", &challenge0A, nil) + net.nodeB.c.sc.storeSentHandshake(idA, addr, &challenge0A) + }, + }, + { + name: "v5.1-ping-handshake", + packet: &Ping{ + ReqID: []byte{0, 0, 0, 1}, + ENRSeq: 1, + }, + challenge: &challenge1A, + prep: func(net *handshakeTest) { + // Update challenge data. + net.nodeA.c.Encode(idB, "", &challenge1A, nil) + net.nodeB.c.sc.storeSentHandshake(idA, addr, &challenge1A) + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + net := newHandshakeTest() + defer net.close() + + // Override all random inputs. + net.nodeA.c.sc.nonceGen = func(counter uint32) (Nonce, error) { + return Nonce{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, nil + } + net.nodeA.c.sc.maskingIVGen = func(buf []byte) error { + return nil // all zero + } + net.nodeA.c.sc.ephemeralKeyGen = func() (*ecdsa.PrivateKey, error) { + return testEphKey, nil + } + + // Prime the codec for encoding/decoding. + if test.prep != nil { + test.prep(net) + } + + file := filepath.Join("testdata", test.name+".txt") + if *writeTestVectorsFlag { + // Encode the packet. + d, nonce := net.nodeA.encodeWithChallenge(t, net.nodeB, test.challenge, test.packet) + comment := testVectorComment(net, test.packet, test.challenge, nonce) + writeTestVector(file, comment, d) + } + enc := hexFile(file) + net.nodeB.expectDecode(t, test.packet.Kind(), enc) + }) + } +} + +// testVectorComment creates the commentary for discv5 test vector files. +func testVectorComment(net *handshakeTest, p Packet, challenge *Whoareyou, nonce Nonce) string { + o := new(strings.Builder) + printWhoareyou := func(p *Whoareyou) { + fmt.Fprintf(o, "whoareyou.challenge-data = %#x\n", p.ChallengeData) + fmt.Fprintf(o, "whoareyou.request-nonce = %#x\n", p.Nonce[:]) + fmt.Fprintf(o, "whoareyou.id-nonce = %#x\n", p.IDNonce[:]) + fmt.Fprintf(o, "whoareyou.enr-seq = %d\n", p.RecordSeq) + } + + fmt.Fprintf(o, "src-node-id = %#x\n", net.nodeA.id().Bytes()) + fmt.Fprintf(o, "dest-node-id = %#x\n", net.nodeB.id().Bytes()) + switch p := p.(type) { + case *Whoareyou: + // WHOAREYOU packet. + printWhoareyou(p) + case *Ping: + fmt.Fprintf(o, "nonce = %#x\n", nonce[:]) + fmt.Fprintf(o, "read-key = %#x\n", net.nodeA.c.sc.session(net.nodeB.id(), net.nodeB.addr()).writeKey) + fmt.Fprintf(o, "ping.req-id = %#x\n", p.ReqID) + fmt.Fprintf(o, "ping.enr-seq = %d\n", p.ENRSeq) + if challenge != nil { + // Handshake message packet. + fmt.Fprint(o, "\nhandshake inputs:\n\n") + printWhoareyou(challenge) + fmt.Fprintf(o, "ephemeral-key = %#x\n", testEphKey.D.Bytes()) + fmt.Fprintf(o, "ephemeral-pubkey = %#x\n", crypto.CompressPubkey(&testEphKey.PublicKey)) + } + default: + panic(fmt.Errorf("unhandled packet type %T", p)) + } + return o.String() +} + +// This benchmark checks performance of handshake packet decoding. +func BenchmarkV5_DecodeHandshakePingSecp256k1(b *testing.B) { + net := newHandshakeTest() + defer net.close() + + var ( + idA = net.nodeA.id() + challenge = &Whoareyou{Node: net.nodeB.n()} + message = &Ping{ReqID: []byte("reqid")} + ) + enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), "", message, challenge) + if err != nil { + b.Fatal("can't encode handshake packet") + } + challenge.Node = nil // force ENR signature verification in decoder + b.ResetTimer() + + input := make([]byte, len(enc)) + for i := 0; i < b.N; i++ { + copy(input, enc) + net.nodeB.c.sc.storeSentHandshake(idA, "", challenge) + _, _, _, err := net.nodeB.c.Decode(input, "") + if err != nil { + b.Fatal(err) + } + } +} + +// This benchmark checks how long it takes to decode an encrypted ping packet. +func BenchmarkV5_DecodePing(b *testing.B) { + net := newHandshakeTest() + defer net.close() + + session := &session{ + readKey: []byte{233, 203, 93, 195, 86, 47, 177, 186, 227, 43, 2, 141, 244, 230, 120, 17}, + writeKey: []byte{79, 145, 252, 171, 167, 216, 252, 161, 208, 190, 176, 106, 214, 39, 178, 134}, + } + net.nodeA.c.sc.storeNewSession(net.nodeB.id(), net.nodeB.addr(), session) + net.nodeB.c.sc.storeNewSession(net.nodeA.id(), net.nodeA.addr(), session.keysFlipped()) + addrB := net.nodeA.addr() + ping := &Ping{ReqID: []byte("reqid"), ENRSeq: 5} + enc, _, err := net.nodeA.c.Encode(net.nodeB.id(), addrB, ping, nil) + if err != nil { + b.Fatalf("can't encode: %v", err) + } + b.ResetTimer() + + input := make([]byte, len(enc)) + for i := 0; i < b.N; i++ { + copy(input, enc) + _, _, packet, _ := net.nodeB.c.Decode(input, addrB) + if _, ok := packet.(*Ping); !ok { + b.Fatalf("wrong packet type %T", packet) + } + } +} + +var pp = spew.NewDefaultConfig() + +type handshakeTest struct { + nodeA, nodeB handshakeTestNode + clock mclock.Simulated +} + +type handshakeTestNode struct { + ln *enode.LocalNode + c *Codec +} + +func newHandshakeTest() *handshakeTest { + t := new(handshakeTest) + t.nodeA.init(testKeyA, net.IP{127, 0, 0, 1}, &t.clock) + t.nodeB.init(testKeyB, net.IP{127, 0, 0, 1}, &t.clock) + return t +} + +func (t *handshakeTest) close() { + t.nodeA.ln.Database().Close() + t.nodeB.ln.Database().Close() +} + +func (n *handshakeTestNode) init(key *ecdsa.PrivateKey, ip net.IP, clock mclock.Clock) { + db, _ := enode.OpenDB("") + n.ln = enode.NewLocalNode(db, key) + n.ln.SetStaticIP(ip) + if n.ln.Node().Seq() != 1 { + panic(fmt.Errorf("unexpected seq %d", n.ln.Node().Seq())) + } + n.c = NewCodec(n.ln, key, clock) +} + +func (n *handshakeTestNode) encode(t testing.TB, to handshakeTestNode, p Packet) ([]byte, Nonce) { + t.Helper() + return n.encodeWithChallenge(t, to, nil, p) +} + +func (n *handshakeTestNode) encodeWithChallenge(t testing.TB, to handshakeTestNode, c *Whoareyou, p Packet) ([]byte, Nonce) { + t.Helper() + + // Copy challenge and add destination node. This avoids sharing 'c' among the two codecs. + var challenge *Whoareyou + if c != nil { + challengeCopy := *c + challenge = &challengeCopy + challenge.Node = to.n() + } + // Encode to destination. + enc, nonce, err := n.c.Encode(to.id(), to.addr(), p, challenge) + if err != nil { + t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) + } + t.Logf("(%s) -> (%s) %s\n%s", n.ln.ID().TerminalString(), to.id().TerminalString(), p.Name(), hex.Dump(enc)) + return enc, nonce +} + +func (n *handshakeTestNode) expectDecode(t *testing.T, ptype byte, p []byte) Packet { + t.Helper() + + dec, err := n.decode(p) + if err != nil { + t.Fatal(fmt.Errorf("(%s) %v", n.ln.ID().TerminalString(), err)) + } + t.Logf("(%s) %#v", n.ln.ID().TerminalString(), pp.NewFormatter(dec)) + if dec.Kind() != ptype { + t.Fatalf("expected packet type %d, got %d", ptype, dec.Kind()) + } + return dec +} + +func (n *handshakeTestNode) expectDecodeErr(t *testing.T, wantErr error, p []byte) { + t.Helper() + if _, err := n.decode(p); !reflect.DeepEqual(err, wantErr) { + t.Fatal(fmt.Errorf("(%s) got err %q, want %q", n.ln.ID().TerminalString(), err, wantErr)) + } +} + +func (n *handshakeTestNode) decode(input []byte) (Packet, error) { + _, _, p, err := n.c.Decode(input, "127.0.0.1") + return p, err +} + +func (n *handshakeTestNode) n() *enode.Node { + return n.ln.Node() +} + +func (n *handshakeTestNode) addr() string { + return n.ln.Node().IP().String() +} + +func (n *handshakeTestNode) id() enode.ID { + return n.ln.ID() +} + +// hexFile reads the given file and decodes the hex data contained in it. +// Whitespace and any lines beginning with the # character are ignored. +func hexFile(file string) []byte { + fileContent, err := ioutil.ReadFile(file) + if err != nil { + panic(err) + } + + // Gather hex data, ignore comments. + var text []byte + for _, line := range bytes.Split(fileContent, []byte("\n")) { + line = bytes.TrimSpace(line) + if len(line) > 0 && line[0] == '#' { + continue + } + text = append(text, line...) + } + + // Parse the hex. + if bytes.HasPrefix(text, []byte("0x")) { + text = text[2:] + } + data := make([]byte, hex.DecodedLen(len(text))) + if _, err := hex.Decode(data, text); err != nil { + panic("invalid hex in " + file) + } + return data +} + +// writeTestVector writes a test vector file with the given commentary and binary data. +func writeTestVector(file, comment string, data []byte) { + fd, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + panic(err) + } + defer fd.Close() + + if len(comment) > 0 { + for _, line := range strings.Split(strings.TrimSpace(comment), "\n") { + fmt.Fprintf(fd, "# %s\n", line) + } + fmt.Fprintln(fd) + } + for len(data) > 0 { + var chunk []byte + if len(data) < 32 { + chunk = data + } else { + chunk = data[:32] + } + data = data[len(chunk):] + fmt.Fprintf(fd, "%x\n", chunk) + } +} diff --git a/p2p/discover/v5wire/msg.go b/p2p/discover/v5wire/msg.go new file mode 100644 index 000000000000..7c3686111b7f --- /dev/null +++ b/p2p/discover/v5wire/msg.go @@ -0,0 +1,249 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package v5wire + +import ( + "fmt" + "net" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/rlp" +) + +// Packet is implemented by all message types. +type Packet interface { + Name() string // Name returns a string corresponding to the message type. + Kind() byte // Kind returns the message type. + RequestID() []byte // Returns the request ID. + SetRequestID([]byte) // Sets the request ID. +} + +// Message types. +const ( + PingMsg byte = iota + 1 + PongMsg + FindnodeMsg + NodesMsg + TalkRequestMsg + TalkResponseMsg + RequestTicketMsg + TicketMsg + RegtopicMsg + RegconfirmationMsg + TopicQueryMsg + + UnknownPacket = byte(255) // any non-decryptable packet + WhoareyouPacket = byte(254) // the WHOAREYOU packet +) + +// Protocol messages. +type ( + // Unknown represents any packet that can't be decrypted. + Unknown struct { + Nonce Nonce + } + + // WHOAREYOU contains the handshake challenge. + Whoareyou struct { + ChallengeData []byte // Encoded challenge + Nonce Nonce // Nonce of request packet + IDNonce [16]byte // Identity proof data + RecordSeq uint64 // ENR sequence number of recipient + + // Node is the locally known node record of recipient. + // This must be set by the caller of Encode. + Node *enode.Node + + sent mclock.AbsTime // for handshake GC. + } + + // PING is sent during liveness checks. + Ping struct { + ReqID []byte + ENRSeq uint64 + } + + // PONG is the reply to PING. + Pong struct { + ReqID []byte + ENRSeq uint64 + ToIP net.IP // These fields should mirror the UDP envelope address of the ping + ToPort uint16 // packet, which provides a way to discover the the external address (after NAT). + } + + // FINDNODE is a query for nodes in the given bucket. + Findnode struct { + ReqID []byte + Distances []uint + } + + // NODES is the reply to FINDNODE and TOPICQUERY. + Nodes struct { + ReqID []byte + Total uint8 + Nodes []*enr.Record + } + + // TALKREQ is an application-level request. + TalkRequest struct { + ReqID []byte + Protocol string + Message []byte + } + + // TALKRESP is the reply to TALKREQ. + TalkResponse struct { + ReqID []byte + Message []byte + } + + // REQUESTTICKET requests a ticket for a topic queue. + RequestTicket struct { + ReqID []byte + Topic []byte + } + + // TICKET is the response to REQUESTTICKET. + Ticket struct { + ReqID []byte + Ticket []byte + } + + // REGTOPIC registers the sender in a topic queue using a ticket. + Regtopic struct { + ReqID []byte + Ticket []byte + ENR *enr.Record + } + + // REGCONFIRMATION is the reply to REGTOPIC. + Regconfirmation struct { + ReqID []byte + Registered bool + } + + // TOPICQUERY asks for nodes with the given topic. + TopicQuery struct { + ReqID []byte + Topic []byte + } +) + +// DecodeMessage decodes the message body of a packet. +func DecodeMessage(ptype byte, body []byte) (Packet, error) { + var dec Packet + switch ptype { + case PingMsg: + dec = new(Ping) + case PongMsg: + dec = new(Pong) + case FindnodeMsg: + dec = new(Findnode) + case NodesMsg: + dec = new(Nodes) + case TalkRequestMsg: + dec = new(TalkRequest) + case TalkResponseMsg: + dec = new(TalkResponse) + case RequestTicketMsg: + dec = new(RequestTicket) + case TicketMsg: + dec = new(Ticket) + case RegtopicMsg: + dec = new(Regtopic) + case RegconfirmationMsg: + dec = new(Regconfirmation) + case TopicQueryMsg: + dec = new(TopicQuery) + default: + return nil, fmt.Errorf("unknown packet type %d", ptype) + } + if err := rlp.DecodeBytes(body, dec); err != nil { + return nil, err + } + if dec.RequestID() != nil && len(dec.RequestID()) > 8 { + return nil, ErrInvalidReqID + } + return dec, nil +} + +func (*Whoareyou) Name() string { return "WHOAREYOU/v5" } +func (*Whoareyou) Kind() byte { return WhoareyouPacket } +func (*Whoareyou) RequestID() []byte { return nil } +func (*Whoareyou) SetRequestID([]byte) {} + +func (*Unknown) Name() string { return "UNKNOWN/v5" } +func (*Unknown) Kind() byte { return UnknownPacket } +func (*Unknown) RequestID() []byte { return nil } +func (*Unknown) SetRequestID([]byte) {} + +func (*Ping) Name() string { return "PING/v5" } +func (*Ping) Kind() byte { return PingMsg } +func (p *Ping) RequestID() []byte { return p.ReqID } +func (p *Ping) SetRequestID(id []byte) { p.ReqID = id } + +func (*Pong) Name() string { return "PONG/v5" } +func (*Pong) Kind() byte { return PongMsg } +func (p *Pong) RequestID() []byte { return p.ReqID } +func (p *Pong) SetRequestID(id []byte) { p.ReqID = id } + +func (*Findnode) Name() string { return "FINDNODE/v5" } +func (*Findnode) Kind() byte { return FindnodeMsg } +func (p *Findnode) RequestID() []byte { return p.ReqID } +func (p *Findnode) SetRequestID(id []byte) { p.ReqID = id } + +func (*Nodes) Name() string { return "NODES/v5" } +func (*Nodes) Kind() byte { return NodesMsg } +func (p *Nodes) RequestID() []byte { return p.ReqID } +func (p *Nodes) SetRequestID(id []byte) { p.ReqID = id } + +func (*TalkRequest) Name() string { return "TALKREQ/v5" } +func (*TalkRequest) Kind() byte { return TalkRequestMsg } +func (p *TalkRequest) RequestID() []byte { return p.ReqID } +func (p *TalkRequest) SetRequestID(id []byte) { p.ReqID = id } + +func (*TalkResponse) Name() string { return "TALKRESP/v5" } +func (*TalkResponse) Kind() byte { return TalkResponseMsg } +func (p *TalkResponse) RequestID() []byte { return p.ReqID } +func (p *TalkResponse) SetRequestID(id []byte) { p.ReqID = id } + +func (*RequestTicket) Name() string { return "REQTICKET/v5" } +func (*RequestTicket) Kind() byte { return RequestTicketMsg } +func (p *RequestTicket) RequestID() []byte { return p.ReqID } +func (p *RequestTicket) SetRequestID(id []byte) { p.ReqID = id } + +func (*Regtopic) Name() string { return "REGTOPIC/v5" } +func (*Regtopic) Kind() byte { return RegtopicMsg } +func (p *Regtopic) RequestID() []byte { return p.ReqID } +func (p *Regtopic) SetRequestID(id []byte) { p.ReqID = id } + +func (*Ticket) Name() string { return "TICKET/v5" } +func (*Ticket) Kind() byte { return TicketMsg } +func (p *Ticket) RequestID() []byte { return p.ReqID } +func (p *Ticket) SetRequestID(id []byte) { p.ReqID = id } + +func (*Regconfirmation) Name() string { return "REGCONFIRMATION/v5" } +func (*Regconfirmation) Kind() byte { return RegconfirmationMsg } +func (p *Regconfirmation) RequestID() []byte { return p.ReqID } +func (p *Regconfirmation) SetRequestID(id []byte) { p.ReqID = id } + +func (*TopicQuery) Name() string { return "TOPICQUERY/v5" } +func (*TopicQuery) Kind() byte { return TopicQueryMsg } +func (p *TopicQuery) RequestID() []byte { return p.ReqID } +func (p *TopicQuery) SetRequestID(id []byte) { p.ReqID = id } diff --git a/p2p/discover/v5_session.go b/p2p/discover/v5wire/session.go similarity index 58% rename from p2p/discover/v5_session.go rename to p2p/discover/v5wire/session.go index 8a0eeb6977c9..d52b5c118101 100644 --- a/p2p/discover/v5_session.go +++ b/p2p/discover/v5wire/session.go @@ -14,22 +14,33 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package discover +package v5wire import ( + "crypto/ecdsa" crand "crypto/rand" + "encoding/binary" + "time" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/hashicorp/golang-lru/simplelru" ) -// The sessionCache keeps negotiated encryption keys and +const handshakeTimeout = time.Second + +// The SessionCache keeps negotiated encryption keys and // state for in-progress handshakes in the Discovery v5 wire protocol. -type sessionCache struct { +type SessionCache struct { sessions *simplelru.LRU - handshakes map[sessionID]*whoareyouV5 + handshakes map[sessionID]*Whoareyou clock mclock.Clock + + // hooks for overriding randomness. + nonceGen func(uint32) (Nonce, error) + maskingIVGen func([]byte) error + ephemeralKeyGen func() (*ecdsa.PrivateKey, error) } // sessionID identifies a session or handshake. @@ -45,27 +56,45 @@ type session struct { nonceCounter uint32 } -func newSessionCache(maxItems int, clock mclock.Clock) *sessionCache { +// keysFlipped returns a copy of s with the read and write keys flipped. +func (s *session) keysFlipped() *session { + return &session{s.readKey, s.writeKey, s.nonceCounter} +} + +func NewSessionCache(maxItems int, clock mclock.Clock) *SessionCache { cache, err := simplelru.NewLRU(maxItems, nil) if err != nil { panic("can't create session cache") } - return &sessionCache{ - sessions: cache, - handshakes: make(map[sessionID]*whoareyouV5), - clock: clock, + return &SessionCache{ + sessions: cache, + handshakes: make(map[sessionID]*Whoareyou), + clock: clock, + nonceGen: generateNonce, + maskingIVGen: generateMaskingIV, + ephemeralKeyGen: crypto.GenerateKey, } } +func generateNonce(counter uint32) (n Nonce, err error) { + binary.BigEndian.PutUint32(n[:4], counter) + _, err = crand.Read(n[4:]) + return n, err +} + +func generateMaskingIV(buf []byte) error { + _, err := crand.Read(buf) + return err +} + // nextNonce creates a nonce for encrypting a message to the given session. -func (sc *sessionCache) nextNonce(id enode.ID, addr string) []byte { - n := make([]byte, gcmNonceSize) - crand.Read(n) - return n +func (sc *SessionCache) nextNonce(s *session) (Nonce, error) { + s.nonceCounter++ + return sc.nonceGen(s.nonceCounter) } // session returns the current session for the given node, if any. -func (sc *sessionCache) session(id enode.ID, addr string) *session { +func (sc *SessionCache) session(id enode.ID, addr string) *session { item, ok := sc.sessions.Get(sessionID{id, addr}) if !ok { return nil @@ -74,46 +103,36 @@ func (sc *sessionCache) session(id enode.ID, addr string) *session { } // readKey returns the current read key for the given node. -func (sc *sessionCache) readKey(id enode.ID, addr string) []byte { +func (sc *SessionCache) readKey(id enode.ID, addr string) []byte { if s := sc.session(id, addr); s != nil { return s.readKey } return nil } -// writeKey returns the current read key for the given node. -func (sc *sessionCache) writeKey(id enode.ID, addr string) []byte { - if s := sc.session(id, addr); s != nil { - return s.writeKey - } - return nil -} - // storeNewSession stores new encryption keys in the cache. -func (sc *sessionCache) storeNewSession(id enode.ID, addr string, r, w []byte) { - sc.sessions.Add(sessionID{id, addr}, &session{ - readKey: r, writeKey: w, - }) +func (sc *SessionCache) storeNewSession(id enode.ID, addr string, s *session) { + sc.sessions.Add(sessionID{id, addr}, s) } // getHandshake gets the handshake challenge we previously sent to the given remote node. -func (sc *sessionCache) getHandshake(id enode.ID, addr string) *whoareyouV5 { +func (sc *SessionCache) getHandshake(id enode.ID, addr string) *Whoareyou { return sc.handshakes[sessionID{id, addr}] } // storeSentHandshake stores the handshake challenge sent to the given remote node. -func (sc *sessionCache) storeSentHandshake(id enode.ID, addr string, challenge *whoareyouV5) { +func (sc *SessionCache) storeSentHandshake(id enode.ID, addr string, challenge *Whoareyou) { challenge.sent = sc.clock.Now() sc.handshakes[sessionID{id, addr}] = challenge } // deleteHandshake deletes handshake data for the given node. -func (sc *sessionCache) deleteHandshake(id enode.ID, addr string) { +func (sc *SessionCache) deleteHandshake(id enode.ID, addr string) { delete(sc.handshakes, sessionID{id, addr}) } // handshakeGC deletes timed-out handshakes. -func (sc *sessionCache) handshakeGC() { +func (sc *SessionCache) handshakeGC() { deadline := sc.clock.Now().Add(-handshakeTimeout) for key, challenge := range sc.handshakes { if challenge.sent < deadline { diff --git a/p2p/discover/v5wire/testdata/v5.1-ping-handshake-enr.txt b/p2p/discover/v5wire/testdata/v5.1-ping-handshake-enr.txt new file mode 100644 index 000000000000..477f9e15a826 --- /dev/null +++ b/p2p/discover/v5wire/testdata/v5.1-ping-handshake-enr.txt @@ -0,0 +1,27 @@ +# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb +# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9 +# nonce = 0xffffffffffffffffffffffff +# read-key = 0x53b1c075f41876423154e157470c2f48 +# ping.req-id = 0x00000001 +# ping.enr-seq = 1 +# +# handshake inputs: +# +# whoareyou.challenge-data = 0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000 +# whoareyou.request-nonce = 0x0102030405060708090a0b0c +# whoareyou.id-nonce = 0x0102030405060708090a0b0c0d0e0f10 +# whoareyou.enr-seq = 0 +# ephemeral-key = 0x0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6 +# ephemeral-pubkey = 0x039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5 + +00000000000000000000000000000000088b3d4342774649305f313964a39e55 +ea96c005ad539c8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3 +4c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be9856 +2fafa0a1fa86d8e7a3b95ae78cc2b988ded6a5b59eb83ad58097252188b902b2 +1481e30e5e285f19735796706adff216ab862a9186875f9494150c4ae06fa4d1 +f0396c93f215fa4ef524e0ed04c3c21e39b1868e1ca8105e585ec17315e755e6 +cfc4dd6cb7fd8e1a1f55e49b4b5eb024221482105346f3c82b15fdaae36a3bb1 +2a494683b4a3c7f2ae41306252fed84785e2bbff3b022812d0882f06978df84a +80d443972213342d04b9048fc3b1d5fcb1df0f822152eced6da4d3f6df27e70e +4539717307a0208cd208d65093ccab5aa596a34d7511401987662d8cf62b1394 +71 diff --git a/p2p/discover/v5wire/testdata/v5.1-ping-handshake.txt b/p2p/discover/v5wire/testdata/v5.1-ping-handshake.txt new file mode 100644 index 000000000000..b3f304766cc5 --- /dev/null +++ b/p2p/discover/v5wire/testdata/v5.1-ping-handshake.txt @@ -0,0 +1,23 @@ +# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb +# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9 +# nonce = 0xffffffffffffffffffffffff +# read-key = 0x4f9fac6de7567d1e3b1241dffe90f662 +# ping.req-id = 0x00000001 +# ping.enr-seq = 1 +# +# handshake inputs: +# +# whoareyou.challenge-data = 0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000001 +# whoareyou.request-nonce = 0x0102030405060708090a0b0c +# whoareyou.id-nonce = 0x0102030405060708090a0b0c0d0e0f10 +# whoareyou.enr-seq = 1 +# ephemeral-key = 0x0288ef00023598499cb6c940146d050d2b1fb914198c327f76aad590bead68b6 +# ephemeral-pubkey = 0x039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5 + +00000000000000000000000000000000088b3d4342774649305f313964a39e55 +ea96c005ad521d8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3 +4c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef +268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfb +a776b21301f8b65efd5796706adff216ab862a9186875f9494150c4ae06fa4d1 +f0396c93f215fa4ef524f1eadf5f0f4126b79336671cbcf7a885b1f8bd2a5d83 +9cf8 diff --git a/p2p/discover/v5wire/testdata/v5.1-ping-message.txt b/p2p/discover/v5wire/testdata/v5.1-ping-message.txt new file mode 100644 index 000000000000..f82b99c3bc75 --- /dev/null +++ b/p2p/discover/v5wire/testdata/v5.1-ping-message.txt @@ -0,0 +1,10 @@ +# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb +# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9 +# nonce = 0xffffffffffffffffffffffff +# read-key = 0x00000000000000000000000000000000 +# ping.req-id = 0x00000001 +# ping.enr-seq = 2 + +00000000000000000000000000000000088b3d4342774649325f313964a39e55 +ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3 +4c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc diff --git a/p2p/discover/v5wire/testdata/v5.1-whoareyou.txt b/p2p/discover/v5wire/testdata/v5.1-whoareyou.txt new file mode 100644 index 000000000000..1a75f525ee96 --- /dev/null +++ b/p2p/discover/v5wire/testdata/v5.1-whoareyou.txt @@ -0,0 +1,9 @@ +# src-node-id = 0xaaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb +# dest-node-id = 0xbbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9 +# whoareyou.challenge-data = 0x000000000000000000000000000000006469736376350001010102030405060708090a0b0c00180102030405060708090a0b0c0d0e0f100000000000000000 +# whoareyou.request-nonce = 0x0102030405060708090a0b0c +# whoareyou.id-nonce = 0x0102030405060708090a0b0c0d0e0f10 +# whoareyou.enr-seq = 0 + +00000000000000000000000000000000088b3d434277464933a1ccc59f5967ad +1d6035f15e528627dde75cd68292f9e6c27d6b66c8100a873fcbaed4e16b8d diff --git a/p2p/enode/node.go b/p2p/enode/node.go index 3f6cda6d4a21..c2429e0e8eb4 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -23,7 +23,6 @@ import ( "errors" "fmt" "math/bits" - "math/rand" "net" "strings" @@ -278,23 +277,3 @@ func LogDist(a, b ID) int { } return len(a)*8 - lz } - -// RandomID returns a random ID b such that logdist(a, b) == n. -func RandomID(a ID, n int) (b ID) { - if n == 0 { - return a - } - // flip bit at position n, fill the rest with random bits - b = a - pos := len(a) - n/8 - 1 - bit := byte(0x01) << (byte(n%8) - 1) - if bit == 0 { - pos++ - bit = 0x80 - } - b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits - for i := pos + 1; i < len(a); i++ { - b[i] = byte(rand.Intn(255)) - } - return b -} diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index bd066ce85738..d62f383f0bc9 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -61,6 +61,10 @@ const ( dbVersion = 9 ) +var ( + errInvalidIP = errors.New("invalid IP") +) + var zeroIP = make(net.IP, 16) // DB is the node database, storing previously seen nodes and any collected metadata about @@ -163,7 +167,7 @@ func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { } key = key[len(dbDiscoverRoot)+1:] // Split out the IP. - ip = net.IP(key[:16]) + ip = key[:16] if ip4 := ip.To4(); ip4 != nil { ip = ip4 } @@ -359,16 +363,25 @@ func (db *DB) expireNodes() { // LastPingReceived retrieves the time of the last ping packet received from // a remote node. func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time { + if ip = ip.To16(); ip == nil { + return time.Time{} + } return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0) } // UpdateLastPingReceived updates the last time we tried contacting a remote node. func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix()) } // LastPongReceived retrieves the time of the last successful pong from remote node. func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { + if ip = ip.To16(); ip == nil { + return time.Time{} + } // Launch expirer db.ensureExpirer() return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePong)), 0) @@ -376,26 +389,41 @@ func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { // UpdateLastPongReceived updates the last pong time of a node. func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix()) } // FindFails retrieves the number of findnode failures since bonding. func (db *DB) FindFails(id ID, ip net.IP) int { + if ip = ip.To16(); ip == nil { + return 0 + } return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails))) } // UpdateFindFails updates the number of findnode failures since bonding. func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) } // FindFailsV5 retrieves the discv5 findnode failure counter. func (db *DB) FindFailsV5(id ID, ip net.IP) int { + if ip = ip.To16(); ip == nil { + return 0 + } return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails))) } // UpdateFindFailsV5 stores the discv5 findnode failure counter. func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { + if ip = ip.To16(); ip == nil { + return errInvalidIP + } return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) } diff --git a/p2p/message_test.go b/p2p/message_test.go index a01f75556162..e575c5d96e66 100644 --- a/p2p/message_test.go +++ b/p2p/message_test.go @@ -18,11 +18,9 @@ package p2p import ( "bytes" - "encoding/hex" "fmt" "io" "runtime" - "strings" "testing" "time" ) @@ -141,12 +139,3 @@ func TestEOFSignal(t *testing.T) { default: } } - -func unhex(str string) []byte { - r := strings.NewReplacer("\t", "", " ", "", "\n", "") - b, err := hex.DecodeString(r.Replace(str)) - if err != nil { - panic(fmt.Sprintf("invalid hex string: %q", str)) - } - return b -} diff --git a/p2p/nat/nat.go b/p2p/nat/nat.go index f47534784c61..9d5519b9c4d5 100644 --- a/p2p/nat/nat.go +++ b/p2p/nat/nat.go @@ -219,9 +219,8 @@ func (n *autodisc) String() string { defer n.mu.Unlock() if n.found == nil { return n.what - } else { - return n.found.String() } + return n.found.String() } // wait blocks until auto-discovery has been performed. diff --git a/p2p/netutil/error.go b/p2p/netutil/error.go index cb21b9cd4ce4..5d3d9bfd653a 100644 --- a/p2p/netutil/error.go +++ b/p2p/netutil/error.go @@ -23,3 +23,11 @@ func IsTemporaryError(err error) bool { }) return ok && tempErr.Temporary() || isPacketTooBig(err) } + +// IsTimeout checks whether the given error is a timeout. +func IsTimeout(err error) bool { + timeoutErr, ok := err.(interface { + Timeout() bool + }) + return ok && timeoutErr.Timeout() +} diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go index ab28b47a159f..def93bac43fa 100644 --- a/p2p/nodestate/nodestate.go +++ b/p2p/nodestate/nodestate.go @@ -725,7 +725,7 @@ func (ns *NodeStateMachine) opFinish() { } ns.opPending = nil ns.opFlag = false - ns.opWait.Signal() + ns.opWait.Broadcast() } // Operation calls the given function as an operation callback. This allows the caller diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go index 5f99a3da74cb..d06ad755e22e 100644 --- a/p2p/nodestate/nodestate_test.go +++ b/p2p/nodestate/nodestate_test.go @@ -240,9 +240,8 @@ func uint64FieldEnc(field interface{}) ([]byte, error) { if u, ok := field.(uint64); ok { enc, err := rlp.EncodeToBytes(&u) return enc, err - } else { - return nil, errors.New("invalid field type") } + return nil, errors.New("invalid field type") } func uint64FieldDec(enc []byte) (interface{}, error) { @@ -254,9 +253,8 @@ func uint64FieldDec(enc []byte) (interface{}, error) { func stringFieldEnc(field interface{}) ([]byte, error) { if s, ok := field.(string); ok { return []byte(s), nil - } else { - return nil, errors.New("invalid field type") } + return nil, errors.New("invalid field type") } func stringFieldDec(enc []byte) (interface{}, error) { diff --git a/p2p/peer.go b/p2p/peer.go index 54fb653e247d..a9c3cf01dac2 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -138,8 +138,17 @@ func (p *Peer) Node() *enode.Node { return p.rw.node } -// Name returns the node name that the remote node advertised. +// Name returns an abbreviated form of the name func (p *Peer) Name() string { + s := p.rw.name + if len(s) > 20 { + return s[:20] + "..." + } + return s +} + +// Fullname returns the node name that the remote node advertised. +func (p *Peer) Fullname() string { return p.rw.name } @@ -463,7 +472,7 @@ func (p *Peer) Info() *PeerInfo { info := &PeerInfo{ Enode: p.Node().URLv4(), ID: p.ID().String(), - Name: p.Name(), + Name: p.Fullname(), Caps: caps, Protocols: make(map[string]interface{}), } diff --git a/p2p/peer_test.go b/p2p/peer_test.go index e40deb98f018..4308bbd2eb4b 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -86,9 +86,15 @@ func newNode(id enode.ID, addr string) *enode.Node { } func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan error) { - fd1, fd2 := net.Pipe() - c1 := &conn{fd: fd1, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd1)} - c2 := &conn{fd: fd2, node: newNode(randomID(), ""), transport: newTestTransport(&newkey().PublicKey, fd2)} + var ( + fd1, fd2 = net.Pipe() + key1, key2 = newkey(), newkey() + t1 = newTestTransport(&key2.PublicKey, fd1, nil) + t2 = newTestTransport(&key1.PublicKey, fd2, &key1.PublicKey) + ) + + c1 := &conn{fd: fd1, node: newNode(uintID(1), ""), transport: t1} + c2 := &conn{fd: fd2, node: newNode(uintID(2), ""), transport: t2} for _, p := range protos { c1.caps = append(c1.caps, p.cap()) c2.caps = append(c2.caps, p.cap()) @@ -173,9 +179,12 @@ func TestPeerPing(t *testing.T) { } } +// This test checks that a disconnect message sent by a peer is returned +// as the error from Peer.run. func TestPeerDisconnect(t *testing.T) { closer, rw, _, disc := testPeer(nil) defer closer() + if err := SendItems(rw, discMsg, DiscQuitting); err != nil { t.Fatal(err) } diff --git a/p2p/rlpx.go b/p2p/rlpx/rlpx.go similarity index 63% rename from p2p/rlpx.go rename to p2p/rlpx/rlpx.go index 4d903a08a0f6..2021bf08be11 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx/rlpx.go @@ -14,7 +14,8 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package p2p +// Package rlpx implements the RLPx transport protocol. +package rlpx import ( "bytes" @@ -29,169 +30,312 @@ import ( "fmt" "hash" "io" - "io/ioutil" mrand "math/rand" "net" - "sync" "time" - "github.com/ethereum/go-ethereum/common/bitutil" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" "github.com/golang/snappy" "golang.org/x/crypto/sha3" ) -const ( - maxUint24 = ^uint32(0) >> 8 - - sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 - sigLen = crypto.SignatureLength // elliptic S256 - pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte - shaLen = 32 // hash length (for nonce etc) - - authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 - authRespLen = pubLen + shaLen + 1 - - eciesOverhead = 65 /* pubkey */ + 16 /* IV */ + 32 /* MAC */ - - encAuthMsgLen = authMsgLen + eciesOverhead // size of encrypted pre-EIP-8 initiator handshake - encAuthRespLen = authRespLen + eciesOverhead // size of encrypted pre-EIP-8 handshake reply +// Conn is an RLPx network connection. It wraps a low-level network connection. The +// underlying connection should not be used for other activity when it is wrapped by Conn. +// +// Before sending messages, a handshake must be performed by calling the Handshake method. +// This type is not generally safe for concurrent use, but reading and writing of messages +// may happen concurrently after the handshake. +type Conn struct { + dialDest *ecdsa.PublicKey + conn net.Conn + handshake *handshakeState + snappy bool +} - // total timeout for encryption handshake and protocol - // handshake in both directions. - handshakeTimeout = 5 * time.Second +type handshakeState struct { + enc cipher.Stream + dec cipher.Stream - // This is the timeout for sending the disconnect reason. - // This is shorter than the usual timeout because we don't want - // to wait if the connection is known to be bad anyway. - discWriteTimeout = 1 * time.Second -) - -// errPlainMessageTooLarge is returned if a decompressed message length exceeds -// the allowed 24 bits (i.e. length >= 16MB). -var errPlainMessageTooLarge = errors.New("message length >= 16MB") + macCipher cipher.Block + egressMAC hash.Hash + ingressMAC hash.Hash +} -// rlpx is the transport protocol used by actual (non-test) connections. -// It wraps the frame encoder with locks and read/write deadlines. -type rlpx struct { - fd net.Conn +// NewConn wraps the given network connection. If dialDest is non-nil, the connection +// behaves as the initiator during the handshake. +func NewConn(conn net.Conn, dialDest *ecdsa.PublicKey) *Conn { + return &Conn{ + dialDest: dialDest, + conn: conn, + } +} - rmu, wmu sync.Mutex - rw *rlpxFrameRW +// SetSnappy enables or disables snappy compression of messages. This is usually called +// after the devp2p Hello message exchange when the negotiated version indicates that +// compression is available on both ends of the connection. +func (c *Conn) SetSnappy(snappy bool) { + c.snappy = snappy } -func newRLPX(fd net.Conn) transport { - fd.SetDeadline(time.Now().Add(handshakeTimeout)) - return &rlpx{fd: fd} +// SetReadDeadline sets the deadline for all future read operations. +func (c *Conn) SetReadDeadline(time time.Time) error { + return c.conn.SetReadDeadline(time) } -func (t *rlpx) ReadMsg() (Msg, error) { - t.rmu.Lock() - defer t.rmu.Unlock() - t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout)) - return t.rw.ReadMsg() +// SetWriteDeadline sets the deadline for all future write operations. +func (c *Conn) SetWriteDeadline(time time.Time) error { + return c.conn.SetWriteDeadline(time) } -func (t *rlpx) WriteMsg(msg Msg) error { - t.wmu.Lock() - defer t.wmu.Unlock() - t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout)) - return t.rw.WriteMsg(msg) +// SetDeadline sets the deadline for all future read and write operations. +func (c *Conn) SetDeadline(time time.Time) error { + return c.conn.SetDeadline(time) } -func (t *rlpx) close(err error) { - t.wmu.Lock() - defer t.wmu.Unlock() - // Tell the remote end why we're disconnecting if possible. - if t.rw != nil { - if r, ok := err.(DiscReason); ok && r != DiscNetworkError { - // rlpx tries to send DiscReason to disconnected peer - // if the connection is net.Pipe (in-memory simulation) - // it hangs forever, since net.Pipe does not implement - // a write deadline. Because of this only try to send - // the disconnect reason message if there is no error. - if err := t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout)); err == nil { - SendItems(t.rw, discMsg, r) - } +// Read reads a message from the connection. +func (c *Conn) Read() (code uint64, data []byte, wireSize int, err error) { + if c.handshake == nil { + panic("can't ReadMsg before handshake") + } + + frame, err := c.handshake.readFrame(c.conn) + if err != nil { + return 0, nil, 0, err + } + code, data, err = rlp.SplitUint64(frame) + if err != nil { + return 0, nil, 0, fmt.Errorf("invalid message code: %v", err) + } + wireSize = len(data) + + // If snappy is enabled, verify and decompress message. + if c.snappy { + var actualSize int + actualSize, err = snappy.DecodedLen(data) + if err != nil { + return code, nil, 0, err + } + if actualSize > maxUint24 { + return code, nil, 0, errPlainMessageTooLarge } + data, err = snappy.Decode(nil, data) } - t.fd.Close() + return code, data, wireSize, err } -func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { - // Writing our handshake happens concurrently, we prefer - // returning the handshake read error. If the remote side - // disconnects us early with a valid reason, we should return it - // as the error so it can be tracked elsewhere. - werr := make(chan error, 1) - go func() { werr <- Send(t.rw, handshakeMsg, our) }() - if their, err = readProtocolHandshake(t.rw); err != nil { - <-werr // make sure the write terminates too +func (h *handshakeState) readFrame(conn io.Reader) ([]byte, error) { + // read the header + headbuf := make([]byte, 32) + if _, err := io.ReadFull(conn, headbuf); err != nil { return nil, err } - if err := <-werr; err != nil { - return nil, fmt.Errorf("write error: %v", err) + + // verify header mac + shouldMAC := updateMAC(h.ingressMAC, h.macCipher, headbuf[:16]) + if !hmac.Equal(shouldMAC, headbuf[16:]) { + return nil, errors.New("bad header MAC") } - // If the protocol version supports Snappy encoding, upgrade immediately - t.rw.snappy = their.Version >= snappyProtocolVersion + h.dec.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now decrypted + fsize := readInt24(headbuf) + // ignore protocol type for now - return their, nil -} + // read the frame content + var rsize = fsize // frame size rounded up to 16 byte boundary + if padding := fsize % 16; padding > 0 { + rsize += 16 - padding + } + framebuf := make([]byte, rsize) + if _, err := io.ReadFull(conn, framebuf); err != nil { + return nil, err + } -func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) { - msg, err := rw.ReadMsg() - if err != nil { + // read and validate frame MAC. we can re-use headbuf for that. + h.ingressMAC.Write(framebuf) + fmacseed := h.ingressMAC.Sum(nil) + if _, err := io.ReadFull(conn, headbuf[:16]); err != nil { return nil, err } - if msg.Size > baseProtocolMaxMsgSize { - return nil, fmt.Errorf("message too big") + shouldMAC = updateMAC(h.ingressMAC, h.macCipher, fmacseed) + if !hmac.Equal(shouldMAC, headbuf[:16]) { + return nil, errors.New("bad frame MAC") + } + + // decrypt frame content + h.dec.XORKeyStream(framebuf, framebuf) + return framebuf[:fsize], nil +} + +// Write writes a message to the connection. +// +// Write returns the written size of the message data. This may be less than or equal to +// len(data) depending on whether snappy compression is enabled. +func (c *Conn) Write(code uint64, data []byte) (uint32, error) { + if c.handshake == nil { + panic("can't WriteMsg before handshake") } - if msg.Code == discMsg { - // Disconnect before protocol handshake is valid according to the - // spec and we send it ourself if the post-handshake checks fail. - // We can't return the reason directly, though, because it is echoed - // back otherwise. Wrap it in a string instead. - var reason [1]DiscReason - rlp.Decode(msg.Payload, &reason) - return nil, reason[0] + if len(data) > maxUint24 { + return 0, errPlainMessageTooLarge } - if msg.Code != handshakeMsg { - return nil, fmt.Errorf("expected handshake, got %x", msg.Code) + if c.snappy { + data = snappy.Encode(nil, data) } - var hs protoHandshake - if err := msg.Decode(&hs); err != nil { - return nil, err + + wireSize := uint32(len(data)) + err := c.handshake.writeFrame(c.conn, code, data) + return wireSize, err +} + +func (h *handshakeState) writeFrame(conn io.Writer, code uint64, data []byte) error { + ptype, _ := rlp.EncodeToBytes(code) + + // write header + headbuf := make([]byte, 32) + fsize := len(ptype) + len(data) + if fsize > maxUint24 { + return errPlainMessageTooLarge + } + putInt24(uint32(fsize), headbuf) + copy(headbuf[3:], zeroHeader) + h.enc.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now encrypted + + // write header MAC + copy(headbuf[16:], updateMAC(h.egressMAC, h.macCipher, headbuf[:16])) + if _, err := conn.Write(headbuf); err != nil { + return err + } + + // write encrypted frame, updating the egress MAC hash with + // the data written to conn. + tee := cipher.StreamWriter{S: h.enc, W: io.MultiWriter(conn, h.egressMAC)} + if _, err := tee.Write(ptype); err != nil { + return err + } + if _, err := tee.Write(data); err != nil { + return err } - if len(hs.ID) != 64 || !bitutil.TestBytes(hs.ID) { - return nil, DiscInvalidIdentity + if padding := fsize % 16; padding > 0 { + if _, err := tee.Write(zero16[:16-padding]); err != nil { + return err + } } - return &hs, nil + + // write frame MAC. egress MAC hash is up to date because + // frame content was written to it as well. + fmacseed := h.egressMAC.Sum(nil) + mac := updateMAC(h.egressMAC, h.macCipher, fmacseed) + _, err := conn.Write(mac) + return err +} + +func readInt24(b []byte) uint32 { + return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16 } -// doEncHandshake runs the protocol handshake using authenticated -// messages. the protocol handshake is the first authenticated message -// and also verifies whether the encryption handshake 'worked' and the -// remote side actually provided the right public key. -func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { +func putInt24(v uint32, b []byte) { + b[0] = byte(v >> 16) + b[1] = byte(v >> 8) + b[2] = byte(v) +} + +// updateMAC reseeds the given hash with encrypted seed. +// it returns the first 16 bytes of the hash sum after seeding. +func updateMAC(mac hash.Hash, block cipher.Block, seed []byte) []byte { + aesbuf := make([]byte, aes.BlockSize) + block.Encrypt(aesbuf, mac.Sum(nil)) + for i := range aesbuf { + aesbuf[i] ^= seed[i] + } + mac.Write(aesbuf) + return mac.Sum(nil)[:16] +} + +// Handshake performs the handshake. This must be called before any data is written +// or read from the connection. +func (c *Conn) Handshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) { var ( - sec secrets + sec Secrets err error ) - if dial == nil { - sec, err = receiverEncHandshake(t.fd, prv) + if c.dialDest != nil { + sec, err = initiatorEncHandshake(c.conn, prv, c.dialDest) } else { - sec, err = initiatorEncHandshake(t.fd, prv, dial) + sec, err = receiverEncHandshake(c.conn, prv) } if err != nil { return nil, err } - t.wmu.Lock() - t.rw = newRLPXFrameRW(t.fd, sec) - t.wmu.Unlock() - return sec.Remote.ExportECDSA(), nil + c.InitWithSecrets(sec) + return sec.remote, err +} + +// InitWithSecrets injects connection secrets as if a handshake had +// been performed. This cannot be called after the handshake. +func (c *Conn) InitWithSecrets(sec Secrets) { + if c.handshake != nil { + panic("can't handshake twice") + } + macc, err := aes.NewCipher(sec.MAC) + if err != nil { + panic("invalid MAC secret: " + err.Error()) + } + encc, err := aes.NewCipher(sec.AES) + if err != nil { + panic("invalid AES secret: " + err.Error()) + } + // we use an all-zeroes IV for AES because the key used + // for encryption is ephemeral. + iv := make([]byte, encc.BlockSize()) + c.handshake = &handshakeState{ + enc: cipher.NewCTR(encc, iv), + dec: cipher.NewCTR(encc, iv), + macCipher: macc, + egressMAC: sec.EgressMAC, + ingressMAC: sec.IngressMAC, + } +} + +// Close closes the underlying network connection. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// Constants for the handshake. +const ( + maxUint24 = int(^uint32(0) >> 8) + + sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 + sigLen = crypto.SignatureLength // elliptic S256 + pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte + shaLen = 32 // hash length (for nonce etc) + + authMsgLen = sigLen + shaLen + pubLen + shaLen + 1 + authRespLen = pubLen + shaLen + 1 + + eciesOverhead = 65 /* pubkey */ + 16 /* IV */ + 32 /* MAC */ + + encAuthMsgLen = authMsgLen + eciesOverhead // size of encrypted pre-EIP-8 initiator handshake + encAuthRespLen = authRespLen + eciesOverhead // size of encrypted pre-EIP-8 handshake reply +) + +var ( + // this is used in place of actual frame header data. + // TODO: replace this when Msg contains the protocol type code. + zeroHeader = []byte{0xC2, 0x80, 0x80} + // sixteen zero bytes + zero16 = make([]byte, 16) + + // errPlainMessageTooLarge is returned if a decompressed message length exceeds + // the allowed 24 bits (i.e. length >= 16MB). + errPlainMessageTooLarge = errors.New("message length >= 16MB") +) + +// Secrets represents the connection secrets which are negotiated during the handshake. +type Secrets struct { + AES, MAC []byte + EgressMAC, IngressMAC hash.Hash + remote *ecdsa.PublicKey } // encHandshake contains the state of the encryption handshake. @@ -203,15 +347,6 @@ type encHandshake struct { remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk } -// secrets represents the connection secrets -// which are negotiated during the encryption handshake. -type secrets struct { - Remote *ecies.PublicKey - AES, MAC []byte - EgressMAC, IngressMAC hash.Hash - Token []byte -} - // RLPx v4 handshake auth (defined in EIP-8). type authMsgV4 struct { gotPlain bool // whether read packet had plain format. @@ -235,19 +370,85 @@ type authRespV4 struct { Rest []rlp.RawValue `rlp:"tail"` } +// receiverEncHandshake negotiates a session token on conn. +// it should be called on the listening side of the connection. +// +// prv is the local client's private key. +func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (s Secrets, err error) { + authMsg := new(authMsgV4) + authPacket, err := readHandshakeMsg(authMsg, encAuthMsgLen, prv, conn) + if err != nil { + return s, err + } + h := new(encHandshake) + if err := h.handleAuthMsg(authMsg, prv); err != nil { + return s, err + } + + authRespMsg, err := h.makeAuthResp() + if err != nil { + return s, err + } + var authRespPacket []byte + if authMsg.gotPlain { + authRespPacket, err = authRespMsg.sealPlain(h) + } else { + authRespPacket, err = sealEIP8(authRespMsg, h) + } + if err != nil { + return s, err + } + if _, err = conn.Write(authRespPacket); err != nil { + return s, err + } + return h.secrets(authPacket, authRespPacket) +} + +func (h *encHandshake) handleAuthMsg(msg *authMsgV4, prv *ecdsa.PrivateKey) error { + // Import the remote identity. + rpub, err := importPublicKey(msg.InitiatorPubkey[:]) + if err != nil { + return err + } + h.initNonce = msg.Nonce[:] + h.remote = rpub + + // Generate random keypair for ECDH. + // If a private key is already set, use it instead of generating one (for testing). + if h.randomPrivKey == nil { + h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) + if err != nil { + return err + } + } + + // Check the signature. + token, err := h.staticSharedSecret(prv) + if err != nil { + return err + } + signedMsg := xor(token, h.initNonce) + remoteRandomPub, err := crypto.Ecrecover(signedMsg, msg.Signature[:]) + if err != nil { + return err + } + h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) + return nil +} + // secrets is called after the handshake is completed. // It extracts the connection secrets from the handshake values. -func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) { +func (h *encHandshake) secrets(auth, authResp []byte) (Secrets, error) { ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen) if err != nil { - return secrets{}, err + return Secrets{}, err } // derive base secrets from ephemeral key agreement sharedSecret := crypto.Keccak256(ecdheSecret, crypto.Keccak256(h.respNonce, h.initNonce)) aesSecret := crypto.Keccak256(ecdheSecret, sharedSecret) - s := secrets{ - Remote: h.remote, + s := Secrets{ + remote: h.remote.ExportECDSA(), AES: aesSecret, MAC: crypto.Keccak256(ecdheSecret, aesSecret), } @@ -278,7 +479,7 @@ func (h *encHandshake) staticSharedSecret(prv *ecdsa.PrivateKey) ([]byte, error) // it should be called on the dialing side of the connection. // // prv is the local client's private key. -func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remote *ecdsa.PublicKey) (s secrets, err error) { +func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remote *ecdsa.PublicKey) (s Secrets, err error) { h := &encHandshake{initiator: true, remote: ecies.ImportECDSAPublic(remote)} authMsg, err := h.makeAuthMsg(prv) if err != nil { @@ -288,6 +489,7 @@ func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remote *ec if err != nil { return s, err } + if _, err = conn.Write(authPacket); err != nil { return s, err } @@ -342,72 +544,6 @@ func (h *encHandshake) handleAuthResp(msg *authRespV4) (err error) { return err } -// receiverEncHandshake negotiates a session token on conn. -// it should be called on the listening side of the connection. -// -// prv is the local client's private key. -func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey) (s secrets, err error) { - authMsg := new(authMsgV4) - authPacket, err := readHandshakeMsg(authMsg, encAuthMsgLen, prv, conn) - if err != nil { - return s, err - } - h := new(encHandshake) - if err := h.handleAuthMsg(authMsg, prv); err != nil { - return s, err - } - - authRespMsg, err := h.makeAuthResp() - if err != nil { - return s, err - } - var authRespPacket []byte - if authMsg.gotPlain { - authRespPacket, err = authRespMsg.sealPlain(h) - } else { - authRespPacket, err = sealEIP8(authRespMsg, h) - } - if err != nil { - return s, err - } - if _, err = conn.Write(authRespPacket); err != nil { - return s, err - } - return h.secrets(authPacket, authRespPacket) -} - -func (h *encHandshake) handleAuthMsg(msg *authMsgV4, prv *ecdsa.PrivateKey) error { - // Import the remote identity. - rpub, err := importPublicKey(msg.InitiatorPubkey[:]) - if err != nil { - return err - } - h.initNonce = msg.Nonce[:] - h.remote = rpub - - // Generate random keypair for ECDH. - // If a private key is already set, use it instead of generating one (for testing). - if h.randomPrivKey == nil { - h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil) - if err != nil { - return err - } - } - - // Check the signature. - token, err := h.staticSharedSecret(prv) - if err != nil { - return err - } - signedMsg := xor(token, h.initNonce) - remoteRandomPub, err := crypto.Ecrecover(signedMsg, msg.Signature[:]) - if err != nil { - return err - } - h.remoteRandomPub, _ = importPublicKey(remoteRandomPub) - return nil -} - func (h *encHandshake) makeAuthResp() (msg *authRespV4, err error) { // Generate random nonce. h.respNonce = make([]byte, shaLen) @@ -531,201 +667,3 @@ func xor(one, other []byte) (xor []byte) { } return xor } - -var ( - // this is used in place of actual frame header data. - // TODO: replace this when Msg contains the protocol type code. - zeroHeader = []byte{0xC2, 0x80, 0x80} - // sixteen zero bytes - zero16 = make([]byte, 16) -) - -// rlpxFrameRW implements a simplified version of RLPx framing. -// chunked messages are not supported and all headers are equal to -// zeroHeader. -// -// rlpxFrameRW is not safe for concurrent use from multiple goroutines. -type rlpxFrameRW struct { - conn io.ReadWriter - enc cipher.Stream - dec cipher.Stream - - macCipher cipher.Block - egressMAC hash.Hash - ingressMAC hash.Hash - - snappy bool -} - -func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW { - macc, err := aes.NewCipher(s.MAC) - if err != nil { - panic("invalid MAC secret: " + err.Error()) - } - encc, err := aes.NewCipher(s.AES) - if err != nil { - panic("invalid AES secret: " + err.Error()) - } - // we use an all-zeroes IV for AES because the key used - // for encryption is ephemeral. - iv := make([]byte, encc.BlockSize()) - return &rlpxFrameRW{ - conn: conn, - enc: cipher.NewCTR(encc, iv), - dec: cipher.NewCTR(encc, iv), - macCipher: macc, - egressMAC: s.EgressMAC, - ingressMAC: s.IngressMAC, - } -} - -func (rw *rlpxFrameRW) WriteMsg(msg Msg) error { - ptype, _ := rlp.EncodeToBytes(msg.Code) - - // if snappy is enabled, compress message now - if rw.snappy { - if msg.Size > maxUint24 { - return errPlainMessageTooLarge - } - payload, _ := ioutil.ReadAll(msg.Payload) - payload = snappy.Encode(nil, payload) - - msg.Payload = bytes.NewReader(payload) - msg.Size = uint32(len(payload)) - } - msg.meterSize = msg.Size - if metrics.Enabled && msg.meterCap.Name != "" { // don't meter non-subprotocol messages - m := fmt.Sprintf("%s/%s/%d/%#02x", egressMeterName, msg.meterCap.Name, msg.meterCap.Version, msg.meterCode) - metrics.GetOrRegisterMeter(m, nil).Mark(int64(msg.meterSize)) - metrics.GetOrRegisterMeter(m+"/packets", nil).Mark(1) - } - // write header - headbuf := make([]byte, 32) - fsize := uint32(len(ptype)) + msg.Size - if fsize > maxUint24 { - return errors.New("message size overflows uint24") - } - putInt24(fsize, headbuf) // TODO: check overflow - copy(headbuf[3:], zeroHeader) - rw.enc.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now encrypted - - // write header MAC - copy(headbuf[16:], updateMAC(rw.egressMAC, rw.macCipher, headbuf[:16])) - if _, err := rw.conn.Write(headbuf); err != nil { - return err - } - - // write encrypted frame, updating the egress MAC hash with - // the data written to conn. - tee := cipher.StreamWriter{S: rw.enc, W: io.MultiWriter(rw.conn, rw.egressMAC)} - if _, err := tee.Write(ptype); err != nil { - return err - } - if _, err := io.Copy(tee, msg.Payload); err != nil { - return err - } - if padding := fsize % 16; padding > 0 { - if _, err := tee.Write(zero16[:16-padding]); err != nil { - return err - } - } - - // write frame MAC. egress MAC hash is up to date because - // frame content was written to it as well. - fmacseed := rw.egressMAC.Sum(nil) - mac := updateMAC(rw.egressMAC, rw.macCipher, fmacseed) - _, err := rw.conn.Write(mac) - return err -} - -func (rw *rlpxFrameRW) ReadMsg() (msg Msg, err error) { - // read the header - headbuf := make([]byte, 32) - if _, err := io.ReadFull(rw.conn, headbuf); err != nil { - return msg, err - } - // verify header mac - shouldMAC := updateMAC(rw.ingressMAC, rw.macCipher, headbuf[:16]) - if !hmac.Equal(shouldMAC, headbuf[16:]) { - return msg, errors.New("bad header MAC") - } - rw.dec.XORKeyStream(headbuf[:16], headbuf[:16]) // first half is now decrypted - fsize := readInt24(headbuf) - // ignore protocol type for now - - // read the frame content - var rsize = fsize // frame size rounded up to 16 byte boundary - if padding := fsize % 16; padding > 0 { - rsize += 16 - padding - } - framebuf := make([]byte, rsize) - if _, err := io.ReadFull(rw.conn, framebuf); err != nil { - return msg, err - } - - // read and validate frame MAC. we can re-use headbuf for that. - rw.ingressMAC.Write(framebuf) - fmacseed := rw.ingressMAC.Sum(nil) - if _, err := io.ReadFull(rw.conn, headbuf[:16]); err != nil { - return msg, err - } - shouldMAC = updateMAC(rw.ingressMAC, rw.macCipher, fmacseed) - if !hmac.Equal(shouldMAC, headbuf[:16]) { - return msg, errors.New("bad frame MAC") - } - - // decrypt frame content - rw.dec.XORKeyStream(framebuf, framebuf) - - // decode message code - content := bytes.NewReader(framebuf[:fsize]) - if err := rlp.Decode(content, &msg.Code); err != nil { - return msg, err - } - msg.Size = uint32(content.Len()) - msg.meterSize = msg.Size - msg.Payload = content - - // if snappy is enabled, verify and decompress message - if rw.snappy { - payload, err := ioutil.ReadAll(msg.Payload) - if err != nil { - return msg, err - } - size, err := snappy.DecodedLen(payload) - if err != nil { - return msg, err - } - if size > int(maxUint24) { - return msg, errPlainMessageTooLarge - } - payload, err = snappy.Decode(nil, payload) - if err != nil { - return msg, err - } - msg.Size, msg.Payload = uint32(size), bytes.NewReader(payload) - } - return msg, nil -} - -// updateMAC reseeds the given hash with encrypted seed. -// it returns the first 16 bytes of the hash sum after seeding. -func updateMAC(mac hash.Hash, block cipher.Block, seed []byte) []byte { - aesbuf := make([]byte, aes.BlockSize) - block.Encrypt(aesbuf, mac.Sum(nil)) - for i := range aesbuf { - aesbuf[i] ^= seed[i] - } - mac.Write(aesbuf) - return mac.Sum(nil)[:16] -} - -func readInt24(b []byte) uint32 { - return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16 -} - -func putInt24(v uint32, b []byte) { - b[0] = byte(v >> 16) - b[1] = byte(v >> 8) - b[2] = byte(v) -} diff --git a/p2p/rlpx_test.go b/p2p/rlpx/rlpx_test.go similarity index 58% rename from p2p/rlpx_test.go rename to p2p/rlpx/rlpx_test.go index 3f686fe09f46..127a01816454 100644 --- a/p2p/rlpx_test.go +++ b/p2p/rlpx/rlpx_test.go @@ -1,4 +1,4 @@ -// Copyright 2015 The go-ethereum Authors +// Copyright 2020 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify @@ -14,298 +14,145 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package p2p +package rlpx import ( "bytes" "crypto/ecdsa" - "crypto/rand" - "errors" + "encoding/hex" "fmt" "io" - "io/ioutil" "net" "reflect" "strings" - "sync" "testing" - "time" "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/ecies" - "github.com/ethereum/go-ethereum/p2p/simulations/pipes" "github.com/ethereum/go-ethereum/rlp" - "golang.org/x/crypto/sha3" + "github.com/stretchr/testify/assert" ) -func TestSharedSecret(t *testing.T) { - prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) - pub0 := &prv0.PublicKey - prv1, _ := crypto.GenerateKey() - pub1 := &prv1.PublicKey - - ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen) - if err != nil { - return - } - ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen) - if err != nil { - return - } - t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1) - if !bytes.Equal(ss0, ss1) { - t.Errorf("don't match :(") - } +type message struct { + code uint64 + data []byte + err error } -func TestEncHandshake(t *testing.T) { - for i := 0; i < 10; i++ { - start := time.Now() - if err := testEncHandshake(nil); err != nil { - t.Fatalf("i=%d %v", i, err) - } - t.Logf("(without token) %d %v\n", i+1, time.Since(start)) - } - for i := 0; i < 10; i++ { - tok := make([]byte, shaLen) - rand.Reader.Read(tok) - start := time.Now() - if err := testEncHandshake(tok); err != nil { - t.Fatalf("i=%d %v", i, err) - } - t.Logf("(with token) %d %v\n", i+1, time.Since(start)) - } +func TestHandshake(t *testing.T) { + p1, p2 := createPeers(t) + p1.Close() + p2.Close() } -func testEncHandshake(token []byte) error { - type result struct { - side string - pubkey *ecdsa.PublicKey - err error - } - var ( - prv0, _ = crypto.GenerateKey() - prv1, _ = crypto.GenerateKey() - fd0, fd1 = net.Pipe() - c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx) - output = make(chan result) - ) - - go func() { - r := result{side: "initiator"} - defer func() { output <- r }() - defer fd0.Close() +// This test checks that messages can be sent and received through WriteMsg/ReadMsg. +func TestReadWriteMsg(t *testing.T) { + peer1, peer2 := createPeers(t) + defer peer1.Close() + defer peer2.Close() - r.pubkey, r.err = c0.doEncHandshake(prv0, &prv1.PublicKey) - if r.err != nil { - return - } - if !reflect.DeepEqual(r.pubkey, &prv1.PublicKey) { - r.err = fmt.Errorf("remote pubkey mismatch: got %v, want: %v", r.pubkey, &prv1.PublicKey) - } - }() - go func() { - r := result{side: "receiver"} - defer func() { output <- r }() - defer fd1.Close() - - r.pubkey, r.err = c1.doEncHandshake(prv1, nil) - if r.err != nil { - return - } - if !reflect.DeepEqual(r.pubkey, &prv0.PublicKey) { - r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.pubkey, &prv0.PublicKey) - } - }() + testCode := uint64(23) + testData := []byte("test") + checkMsgReadWrite(t, peer1, peer2, testCode, testData) - // wait for results from both sides - r1, r2 := <-output, <-output - if r1.err != nil { - return fmt.Errorf("%s side error: %v", r1.side, r1.err) - } - if r2.err != nil { - return fmt.Errorf("%s side error: %v", r2.side, r2.err) - } - - // compare derived secrets - if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) { - return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC) - } - if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) { - return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC) - } - if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) { - return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc) - } - if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) { - return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec) - } - return nil + t.Log("enabling snappy") + peer1.SetSnappy(true) + peer2.SetSnappy(true) + checkMsgReadWrite(t, peer1, peer2, testCode, testData) } -func TestProtocolHandshake(t *testing.T) { - var ( - prv0, _ = crypto.GenerateKey() - pub0 = crypto.FromECDSAPub(&prv0.PublicKey)[1:] - hs0 = &protoHandshake{Version: 3, ID: pub0, Caps: []Cap{{"a", 0}, {"b", 2}}} - - prv1, _ = crypto.GenerateKey() - pub1 = crypto.FromECDSAPub(&prv1.PublicKey)[1:] - hs1 = &protoHandshake{Version: 3, ID: pub1, Caps: []Cap{{"c", 1}, {"d", 3}}} - - wg sync.WaitGroup - ) +func checkMsgReadWrite(t *testing.T, p1, p2 *Conn, msgCode uint64, msgData []byte) { + // Set up the reader. + ch := make(chan message, 1) + go func() { + var msg message + msg.code, msg.data, _, msg.err = p1.Read() + ch <- msg + }() - fd0, fd1, err := pipes.TCPPipe() + // Write the message. + _, err := p2.Write(msgCode, msgData) if err != nil { t.Fatal(err) } - wg.Add(2) - go func() { - defer wg.Done() - defer fd0.Close() - rlpx := newRLPX(fd0) - rpubkey, err := rlpx.doEncHandshake(prv0, &prv1.PublicKey) - if err != nil { - t.Errorf("dial side enc handshake failed: %v", err) - return - } - if !reflect.DeepEqual(rpubkey, &prv1.PublicKey) { - t.Errorf("dial side remote pubkey mismatch: got %v, want %v", rpubkey, &prv1.PublicKey) - return - } + // Check it was received correctly. + msg := <-ch + assert.Equal(t, msgCode, msg.code, "wrong message code returned from ReadMsg") + assert.Equal(t, msgData, msg.data, "wrong message data returned from ReadMsg") +} - phs, err := rlpx.doProtoHandshake(hs0) - if err != nil { - t.Errorf("dial side proto handshake error: %v", err) - return - } - phs.Rest = nil - if !reflect.DeepEqual(phs, hs1) { - t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) - return - } - rlpx.close(DiscQuitting) - }() - go func() { - defer wg.Done() - defer fd1.Close() - rlpx := newRLPX(fd1) - rpubkey, err := rlpx.doEncHandshake(prv1, nil) - if err != nil { - t.Errorf("listen side enc handshake failed: %v", err) - return - } - if !reflect.DeepEqual(rpubkey, &prv0.PublicKey) { - t.Errorf("listen side remote pubkey mismatch: got %v, want %v", rpubkey, &prv0.PublicKey) - return - } +func createPeers(t *testing.T) (peer1, peer2 *Conn) { + conn1, conn2 := net.Pipe() + key1, key2 := newkey(), newkey() + peer1 = NewConn(conn1, &key2.PublicKey) // dialer + peer2 = NewConn(conn2, nil) // listener + doHandshake(t, peer1, peer2, key1, key2) + return peer1, peer2 +} - phs, err := rlpx.doProtoHandshake(hs1) +func doHandshake(t *testing.T, peer1, peer2 *Conn, key1, key2 *ecdsa.PrivateKey) { + keyChan := make(chan *ecdsa.PublicKey, 1) + go func() { + pubKey, err := peer2.Handshake(key2) if err != nil { - t.Errorf("listen side proto handshake error: %v", err) - return - } - phs.Rest = nil - if !reflect.DeepEqual(phs, hs0) { - t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) - return - } - - if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { - t.Errorf("error receiving disconnect: %v", err) + t.Errorf("peer2 could not do handshake: %v", err) } + keyChan <- pubKey }() - wg.Wait() -} -func TestProtocolHandshakeErrors(t *testing.T) { - tests := []struct { - code uint64 - msg interface{} - err error - }{ - { - code: discMsg, - msg: []DiscReason{DiscQuitting}, - err: DiscQuitting, - }, - { - code: 0x989898, - msg: []byte{1}, - err: errors.New("expected handshake, got 989898"), - }, - { - code: handshakeMsg, - msg: make([]byte, baseProtocolMaxMsgSize+2), - err: errors.New("message too big"), - }, - { - code: handshakeMsg, - msg: []byte{1, 2, 3}, - err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), - }, - { - code: handshakeMsg, - msg: &protoHandshake{Version: 3}, - err: DiscInvalidIdentity, - }, + pubKey2, err := peer1.Handshake(key1) + if err != nil { + t.Errorf("peer1 could not do handshake: %v", err) } + pubKey1 := <-keyChan - for i, test := range tests { - p1, p2 := MsgPipe() - go Send(p1, test.code, test.msg) - _, err := readProtocolHandshake(p2) - if !reflect.DeepEqual(err, test.err) { - t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) - } + // Confirm the handshake was successful. + if !reflect.DeepEqual(pubKey1, &key1.PublicKey) || !reflect.DeepEqual(pubKey2, &key2.PublicKey) { + t.Fatal("unsuccessful handshake") } } -func TestRLPXFrameFake(t *testing.T) { - buf := new(bytes.Buffer) +// This test checks the frame data of written messages. +func TestFrameReadWrite(t *testing.T) { + conn := NewConn(nil, nil) hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}) - rw := newRLPXFrameRW(buf, secrets{ + conn.InitWithSecrets(Secrets{ AES: crypto.Keccak256(), MAC: crypto.Keccak256(), IngressMAC: hash, EgressMAC: hash, }) + h := conn.handshake golden := unhex(` -00828ddae471818bb0bfa6b551d1cb42 -01010101010101010101010101010101 -ba628a4ba590cb43f7848f41c4382885 -01010101010101010101010101010101 -`) - - // Check WriteMsg. This puts a message into the buffer. - if err := Send(rw, 8, []uint{1, 2, 3, 4}); err != nil { + 00828ddae471818bb0bfa6b551d1cb42 + 01010101010101010101010101010101 + ba628a4ba590cb43f7848f41c4382885 + 01010101010101010101010101010101 + `) + msgCode := uint64(8) + msg := []uint{1, 2, 3, 4} + msgEnc, _ := rlp.EncodeToBytes(msg) + + // Check writeFrame. The frame that's written should be equal to the test vector. + buf := new(bytes.Buffer) + if err := h.writeFrame(buf, msgCode, msgEnc); err != nil { t.Fatalf("WriteMsg error: %v", err) } - written := buf.Bytes() - if !bytes.Equal(written, golden) { - t.Fatalf("output mismatch:\n got: %x\n want: %x", written, golden) + if !bytes.Equal(buf.Bytes(), golden) { + t.Fatalf("output mismatch:\n got: %x\n want: %x", buf.Bytes(), golden) } - // Check ReadMsg. It reads the message encoded by WriteMsg, which - // is equivalent to the golden message above. - msg, err := rw.ReadMsg() + // Check readFrame on the test vector. + content, err := h.readFrame(bytes.NewReader(golden)) if err != nil { t.Fatalf("ReadMsg error: %v", err) } - if msg.Size != 5 { - t.Errorf("msg size mismatch: got %d, want %d", msg.Size, 5) - } - if msg.Code != 8 { - t.Errorf("msg code mismatch: got %d, want %d", msg.Code, 8) - } - payload, _ := ioutil.ReadAll(msg.Payload) - wantPayload := unhex("C401020304") - if !bytes.Equal(payload, wantPayload) { - t.Errorf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload) + wantContent := unhex("08C401020304") + if !bytes.Equal(content, wantContent) { + t.Errorf("frame content mismatch:\ngot %x\nwant %x", content, wantContent) } } @@ -314,66 +161,8 @@ type fakeHash []byte func (fakeHash) Write(p []byte) (int, error) { return len(p), nil } func (fakeHash) Reset() {} func (fakeHash) BlockSize() int { return 0 } - -func (h fakeHash) Size() int { return len(h) } -func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } - -func TestRLPXFrameRW(t *testing.T) { - var ( - aesSecret = make([]byte, 16) - macSecret = make([]byte, 16) - egressMACinit = make([]byte, 32) - ingressMACinit = make([]byte, 32) - ) - for _, s := range [][]byte{aesSecret, macSecret, egressMACinit, ingressMACinit} { - rand.Read(s) - } - conn := new(bytes.Buffer) - - s1 := secrets{ - AES: aesSecret, - MAC: macSecret, - EgressMAC: sha3.NewLegacyKeccak256(), - IngressMAC: sha3.NewLegacyKeccak256(), - } - s1.EgressMAC.Write(egressMACinit) - s1.IngressMAC.Write(ingressMACinit) - rw1 := newRLPXFrameRW(conn, s1) - - s2 := secrets{ - AES: aesSecret, - MAC: macSecret, - EgressMAC: sha3.NewLegacyKeccak256(), - IngressMAC: sha3.NewLegacyKeccak256(), - } - s2.EgressMAC.Write(ingressMACinit) - s2.IngressMAC.Write(egressMACinit) - rw2 := newRLPXFrameRW(conn, s2) - - // send some messages - for i := 0; i < 10; i++ { - // write message into conn buffer - wmsg := []interface{}{"foo", "bar", strings.Repeat("test", i)} - err := Send(rw1, uint64(i), wmsg) - if err != nil { - t.Fatalf("WriteMsg error (i=%d): %v", i, err) - } - - // read message that rw1 just wrote - msg, err := rw2.ReadMsg() - if err != nil { - t.Fatalf("ReadMsg error (i=%d): %v", i, err) - } - if msg.Code != uint64(i) { - t.Fatalf("msg code mismatch: got %d, want %d", msg.Code, i) - } - payload, _ := ioutil.ReadAll(msg.Payload) - wantPayload, _ := rlp.EncodeToBytes(wmsg) - if !bytes.Equal(payload, wantPayload) { - t.Fatalf("msg payload mismatch:\ngot %x\nwant %x", payload, wantPayload) - } - } -} +func (h fakeHash) Size() int { return len(h) } +func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) } type handshakeAuthTest struct { input string @@ -598,3 +387,20 @@ func TestHandshakeForwardCompatibility(t *testing.T) { t.Errorf("ingress-mac('foo') mismatch:\ngot %x\nwant %x", fooIngressHash, wantFooIngressHash) } } + +func unhex(str string) []byte { + r := strings.NewReplacer("\t", "", " ", "", "\n", "") + b, err := hex.DecodeString(r.Replace(str)) + if err != nil { + panic(fmt.Sprintf("invalid hex string: %q", str)) + } + return b +} + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} diff --git a/p2p/server.go b/p2p/server.go index 1fe5f3978923..275cb5ea5c34 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -166,7 +166,7 @@ type Server struct { // Hooks for testing. These are useful because we can inhibit // the whole protocol stack. - newTransport func(net.Conn) transport + newTransport func(net.Conn, *ecdsa.PublicKey) transport newPeerHook func(*Peer) listenFunc func(network, addr string) (net.Listener, error) @@ -231,7 +231,7 @@ type conn struct { type transport interface { // The two handshakes. - doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) + doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) // The MsgReadWriter can only be used after the encryption // handshake has completed. The code uses conn.id to track this @@ -757,7 +757,7 @@ running: // The handshakes are done and it passed all checks. p := srv.launchPeer(c) peers[c.node.ID()] = p - srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", truncateName(c.name)) + srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name()) srv.dialsched.peerAdded(c) if p.Inbound() { inboundCount++ @@ -854,13 +854,18 @@ func (srv *Server) listenLoop() { <-slots var ( - fd net.Conn - err error + fd net.Conn + err error + lastLog time.Time ) for { fd, err = srv.listener.Accept() if netutil.IsTemporaryError(err) { - srv.log.Debug("Temporary read error", "err", err) + if time.Since(lastLog) > 1*time.Second { + srv.log.Debug("Temporary read error", "err", err) + lastLog = time.Now() + } + time.Sleep(time.Millisecond * 200) continue } else if err != nil { srv.log.Debug("Read error", "err", err) @@ -914,7 +919,13 @@ func (srv *Server) checkInboundConn(fd net.Conn, remoteIP net.IP) error { // as a peer. It returns when the connection has been added as a peer // or the handshakes have failed. func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error { - c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)} + c := &conn{fd: fd, flags: flags, cont: make(chan error)} + if dialDest == nil { + c.transport = srv.newTransport(fd, nil) + } else { + c.transport = srv.newTransport(fd, dialDest.Pubkey()) + } + err := srv.setupConn(c, flags, dialDest) if err != nil { c.close(err) @@ -943,16 +954,12 @@ func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) erro } // Run the RLPx handshake. - remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey) + remotePubkey, err := c.doEncHandshake(srv.PrivateKey) if err != nil { srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err) return err } if dialDest != nil { - // For dialed connections, check that the remote public key matches. - if dialPubkey.X.Cmp(remotePubkey.X) != 0 || dialPubkey.Y.Cmp(remotePubkey.Y) != 0 { - return DiscUnexpectedIdentity - } c.node = dialDest } else { c.node = nodeFromConn(remotePubkey, c.fd) @@ -994,13 +1001,6 @@ func nodeFromConn(pubkey *ecdsa.PublicKey, conn net.Conn) *enode.Node { return enode.NewV4(pubkey, ip, port, port) } -func truncateName(s string) string { - if len(s) > 20 { - return s[:20] + "..." - } - return s -} - // checkpoint sends the conn to run, which performs the // post-handshake checks for the stage (posthandshake, addpeer). func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error { diff --git a/p2p/server_test.go b/p2p/server_test.go index 7dc344a67df5..a5b3190aede2 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -18,6 +18,7 @@ package p2p import ( "crypto/ecdsa" + "crypto/sha256" "errors" "io" "math/rand" @@ -31,28 +32,27 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" - "golang.org/x/crypto/sha3" + "github.com/ethereum/go-ethereum/p2p/rlpx" ) type testTransport struct { - rpub *ecdsa.PublicKey - *rlpx - + *rlpxTransport + rpub *ecdsa.PublicKey closeErr error } -func newTestTransport(rpub *ecdsa.PublicKey, fd net.Conn) transport { - wrapped := newRLPX(fd).(*rlpx) - wrapped.rw = newRLPXFrameRW(fd, secrets{ - MAC: zero16, - AES: zero16, - IngressMAC: sha3.NewLegacyKeccak256(), - EgressMAC: sha3.NewLegacyKeccak256(), +func newTestTransport(rpub *ecdsa.PublicKey, fd net.Conn, dialDest *ecdsa.PublicKey) transport { + wrapped := newRLPX(fd, dialDest).(*rlpxTransport) + wrapped.conn.InitWithSecrets(rlpx.Secrets{ + AES: make([]byte, 16), + MAC: make([]byte, 16), + EgressMAC: sha256.New(), + IngressMAC: sha256.New(), }) - return &testTransport{rpub: rpub, rlpx: wrapped} + return &testTransport{rpub: rpub, rlpxTransport: wrapped} } -func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { +func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) { return c.rpub, nil } @@ -62,7 +62,7 @@ func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, } func (c *testTransport) close(err error) { - c.rlpx.fd.Close() + c.conn.Close() c.closeErr = err } @@ -76,9 +76,11 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) * Logger: testlog.Logger(t, log.LvlTrace), } server := &Server{ - Config: config, - newPeerHook: pf, - newTransport: func(fd net.Conn) transport { return newTestTransport(remoteKey, fd) }, + Config: config, + newPeerHook: pf, + newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { + return newTestTransport(remoteKey, fd, dialDest) + }, } if err := server.Start(); err != nil { t.Fatalf("Could not start server: %v", err) @@ -253,7 +255,7 @@ func TestServerAtCap(t *testing.T) { newconn := func(id enode.ID) *conn { fd, _ := net.Pipe() - tx := newTestTransport(&trustedNode.PublicKey, fd) + tx := newTestTransport(&trustedNode.PublicKey, fd, nil) node := enode.SignNull(new(enr.Record), id) return &conn{fd: fd, transport: tx, flags: inboundConn, node: node, cont: make(chan error)} } @@ -321,7 +323,7 @@ func TestServerPeerLimits(t *testing.T) { Protocols: []Protocol{discard}, Logger: testlog.Logger(t, log.LvlTrace), }, - newTransport: func(fd net.Conn) transport { return tp }, + newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return tp }, } if err := srv.Start(); err != nil { t.Fatalf("couldn't start server: %v", err) @@ -390,13 +392,6 @@ func TestServerSetupConn(t *testing.T) { wantCalls: "doEncHandshake,close,", wantCloseErr: errors.New("read error"), }, - { - tt: &setupTransport{pubkey: clientpub}, - dialDest: enode.NewV4(&newkey().PublicKey, nil, 0, 0), - flags: dynDialedConn, - wantCalls: "doEncHandshake,close,", - wantCloseErr: DiscUnexpectedIdentity, - }, { tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: randomID().Bytes()}}, dialDest: enode.NewV4(clientpub, nil, 0, 0), @@ -437,7 +432,7 @@ func TestServerSetupConn(t *testing.T) { } srv := &Server{ Config: cfg, - newTransport: func(fd net.Conn) transport { return test.tt }, + newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return test.tt }, log: cfg.Logger, } if !test.dontstart { @@ -468,7 +463,7 @@ type setupTransport struct { closeErr error } -func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { +func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) { c.calls += "doEncHandshake," return c.pubkey, c.encHandshakeErr } @@ -522,9 +517,9 @@ func TestServerInboundThrottle(t *testing.T) { Protocols: []Protocol{discard}, Logger: testlog.Logger(t, log.LvlTrace), }, - newTransport: func(fd net.Conn) transport { + newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { newTransportCalled <- struct{}{} - return newRLPX(fd) + return newRLPX(fd, dialDest) }, listenFunc: func(network, laddr string) (net.Listener, error) { fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444} diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go index 7ef908814d79..0ed3deab38a6 100644 --- a/p2p/simulations/adapters/exec.go +++ b/p2p/simulations/adapters/exec.go @@ -184,7 +184,19 @@ func (n *ExecNode) Start(snapshots map[string][]byte) (err error) { if err != nil { return fmt.Errorf("error generating node config: %s", err) } - + // expose the admin namespace via websocket if it's not enabled + exposed := confCopy.Stack.WSExposeAll + if !exposed { + for _, api := range confCopy.Stack.WSModules { + if api == "admin" { + exposed = true + break + } + } + } + if !exposed { + confCopy.Stack.WSModules = append(confCopy.Stack.WSModules, "admin") + } // start the one-shot server that waits for startup information ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -362,13 +374,44 @@ type execNodeConfig struct { PeerAddrs map[string]string `json:"peer_addrs,omitempty"` } +func initLogging() { + // Initialize the logging by default first. + glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.LogfmtFormat())) + glogger.Verbosity(log.LvlInfo) + log.Root().SetHandler(glogger) + + confEnv := os.Getenv(envNodeConfig) + if confEnv == "" { + return + } + var conf execNodeConfig + if err := json.Unmarshal([]byte(confEnv), &conf); err != nil { + return + } + var writer = os.Stderr + if conf.Node.LogFile != "" { + logWriter, err := os.Create(conf.Node.LogFile) + if err != nil { + return + } + writer = logWriter + } + var verbosity = log.LvlInfo + if conf.Node.LogVerbosity <= log.LvlTrace && conf.Node.LogVerbosity >= log.LvlCrit { + verbosity = conf.Node.LogVerbosity + } + // Reinitialize the logger + glogger = log.NewGlogHandler(log.StreamHandler(writer, log.TerminalFormat(true))) + glogger.Verbosity(verbosity) + log.Root().SetHandler(glogger) +} + // execP2PNode starts a simulation node when the current binary is executed with // argv[0] being "p2p-node", reading the service / ID from argv[1] / argv[2] // and the node config from an environment variable. func execP2PNode() { - glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.LogfmtFormat())) - glogger.Verbosity(log.LvlInfo) - log.Root().SetHandler(glogger) + initLogging() + statusURL := os.Getenv(envStatusURL) if statusURL == "" { log.Crit("missing " + envStatusURL) @@ -380,7 +423,7 @@ func execP2PNode() { if stackErr != nil { status.Err = stackErr.Error() } else { - status.WSEndpoint = "ws://" + stack.WSEndpoint() + status.WSEndpoint = stack.WSEndpoint() status.NodeInfo = stack.Server().NodeInfo() } @@ -454,7 +497,6 @@ func startExecNodeStack() (*node.Node, error) { return nil, err } services[name] = service - stack.RegisterLifecycle(service) } // Add the snapshot API. diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index fd10da431929..4fc7abc06a43 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -99,8 +99,9 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { Dialer: s, EnableMsgEvents: config.EnableMsgEvents, }, - NoUSB: true, - Logger: log.New("node.id", id.String()), + ExternalSigner: config.ExternalSigner, + NoUSB: true, + Logger: log.New("node.id", id.String()), }) if err != nil { return nil, err @@ -263,7 +264,6 @@ func (sn *SimNode) Start(snapshots map[string][]byte) error { continue } sn.running[name] = service - sn.node.RegisterLifecycle(service) } }) if regErr != nil { diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go index 716cde6a6c7d..1da464a10d41 100644 --- a/p2p/simulations/adapters/types.go +++ b/p2p/simulations/adapters/types.go @@ -107,6 +107,9 @@ type NodeConfig struct { // These values need to be checked and acted upon by node Services Properties []string + // ExternalSigner specifies an external URI for a clef-type signer + ExternalSigner string + // Enode node *enode.Node @@ -117,6 +120,17 @@ type NodeConfig struct { Reachable func(id enode.ID) bool Port uint16 + + // LogFile is the log file name of the p2p node at runtime. + // + // The default value is empty so that the default log writer + // is the system standard output. + LogFile string + + // LogVerbosity is the log verbosity of the p2p node at runtime. + // + // The default verbosity is INFO. + LogVerbosity log.Lvl } // nodeConfigJSON is used to encode and decode NodeConfig as JSON by encoding @@ -125,10 +139,12 @@ type nodeConfigJSON struct { ID string `json:"id"` PrivateKey string `json:"private_key"` Name string `json:"name"` - Services []string `json:"services"` + Lifecycles []string `json:"lifecycles"` Properties []string `json:"properties"` EnableMsgEvents bool `json:"enable_msg_events"` Port uint16 `json:"port"` + LogFile string `json:"logfile"` + LogVerbosity int `json:"log_verbosity"` } // MarshalJSON implements the json.Marshaler interface by encoding the config @@ -137,10 +153,12 @@ func (n *NodeConfig) MarshalJSON() ([]byte, error) { confJSON := nodeConfigJSON{ ID: n.ID.String(), Name: n.Name, - Services: n.Lifecycles, + Lifecycles: n.Lifecycles, Properties: n.Properties, Port: n.Port, EnableMsgEvents: n.EnableMsgEvents, + LogFile: n.LogFile, + LogVerbosity: int(n.LogVerbosity), } if n.PrivateKey != nil { confJSON.PrivateKey = hex.EncodeToString(crypto.FromECDSA(n.PrivateKey)) @@ -175,10 +193,12 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { } n.Name = confJSON.Name - n.Lifecycles = confJSON.Services + n.Lifecycles = confJSON.Lifecycles n.Properties = confJSON.Properties n.Port = confJSON.Port n.EnableMsgEvents = confJSON.EnableMsgEvents + n.LogFile = confJSON.LogFile + n.LogVerbosity = log.Lvl(confJSON.LogVerbosity) return nil } @@ -208,6 +228,7 @@ func RandomNodeConfig() *NodeConfig { Name: fmt.Sprintf("node_%s", enodId.String()), Port: port, EnableMsgEvents: true, + LogVerbosity: log.LvlInfo, } } diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go index a54db4ea68d3..9b5e2c37f54c 100644 --- a/p2p/simulations/network.go +++ b/p2p/simulations/network.go @@ -454,9 +454,8 @@ func (net *Network) getNodeIDs(excludeIDs []enode.ID) []enode.ID { if len(excludeIDs) > 0 { // Return the difference of nodeIDs and excludeIDs return filterIDs(nodeIDs, excludeIDs) - } else { - return nodeIDs } + return nodeIDs } // GetNodes returns the existing nodes. @@ -472,9 +471,8 @@ func (net *Network) getNodes(excludeIDs []enode.ID) []*Node { if len(excludeIDs) > 0 { nodeIDs := net.getNodeIDs(excludeIDs) return net.getNodesByID(nodeIDs) - } else { - return net.Nodes } + return net.Nodes } // GetNodesByID returns existing nodes with the given enode.IDs. @@ -1098,7 +1096,6 @@ func (net *Network) executeNodeEvent(e *Event) error { func (net *Network) executeConnEvent(e *Event) error { if e.Conn.Up { return net.Connect(e.Conn.One, e.Conn.Other) - } else { - return net.Disconnect(e.Conn.One, e.Conn.Other) } + return net.Disconnect(e.Conn.One, e.Conn.Other) } diff --git a/p2p/transport.go b/p2p/transport.go new file mode 100644 index 000000000000..3f1cd7d64f20 --- /dev/null +++ b/p2p/transport.go @@ -0,0 +1,177 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package p2p + +import ( + "bytes" + "crypto/ecdsa" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common/bitutil" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p/rlpx" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + // total timeout for encryption handshake and protocol + // handshake in both directions. + handshakeTimeout = 5 * time.Second + + // This is the timeout for sending the disconnect reason. + // This is shorter than the usual timeout because we don't want + // to wait if the connection is known to be bad anyway. + discWriteTimeout = 1 * time.Second +) + +// rlpxTransport is the transport used by actual (non-test) connections. +// It wraps an RLPx connection with locks and read/write deadlines. +type rlpxTransport struct { + rmu, wmu sync.Mutex + wbuf bytes.Buffer + conn *rlpx.Conn +} + +func newRLPX(conn net.Conn, dialDest *ecdsa.PublicKey) transport { + return &rlpxTransport{conn: rlpx.NewConn(conn, dialDest)} +} + +func (t *rlpxTransport) ReadMsg() (Msg, error) { + t.rmu.Lock() + defer t.rmu.Unlock() + + var msg Msg + t.conn.SetReadDeadline(time.Now().Add(frameReadTimeout)) + code, data, wireSize, err := t.conn.Read() + if err == nil { + msg = Msg{ + ReceivedAt: time.Now(), + Code: code, + Size: uint32(len(data)), + meterSize: uint32(wireSize), + Payload: bytes.NewReader(data), + } + } + return msg, err +} + +func (t *rlpxTransport) WriteMsg(msg Msg) error { + t.wmu.Lock() + defer t.wmu.Unlock() + + // Copy message data to write buffer. + t.wbuf.Reset() + if _, err := io.CopyN(&t.wbuf, msg.Payload, int64(msg.Size)); err != nil { + return err + } + + // Write the message. + t.conn.SetWriteDeadline(time.Now().Add(frameWriteTimeout)) + size, err := t.conn.Write(msg.Code, t.wbuf.Bytes()) + if err != nil { + return err + } + + // Set metrics. + msg.meterSize = size + if metrics.Enabled && msg.meterCap.Name != "" { // don't meter non-subprotocol messages + m := fmt.Sprintf("%s/%s/%d/%#02x", egressMeterName, msg.meterCap.Name, msg.meterCap.Version, msg.meterCode) + metrics.GetOrRegisterMeter(m, nil).Mark(int64(msg.meterSize)) + metrics.GetOrRegisterMeter(m+"/packets", nil).Mark(1) + } + return nil +} + +func (t *rlpxTransport) close(err error) { + t.wmu.Lock() + defer t.wmu.Unlock() + + // Tell the remote end why we're disconnecting if possible. + // We only bother doing this if the underlying connection supports + // setting a timeout tough. + if t.conn != nil { + if r, ok := err.(DiscReason); ok && r != DiscNetworkError { + deadline := time.Now().Add(discWriteTimeout) + if err := t.conn.SetWriteDeadline(deadline); err == nil { + // Connection supports write deadline. + t.wbuf.Reset() + rlp.Encode(&t.wbuf, []DiscReason{r}) + t.conn.Write(discMsg, t.wbuf.Bytes()) + } + } + } + t.conn.Close() +} + +func (t *rlpxTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) { + t.conn.SetDeadline(time.Now().Add(handshakeTimeout)) + return t.conn.Handshake(prv) +} + +func (t *rlpxTransport) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) { + // Writing our handshake happens concurrently, we prefer + // returning the handshake read error. If the remote side + // disconnects us early with a valid reason, we should return it + // as the error so it can be tracked elsewhere. + werr := make(chan error, 1) + go func() { werr <- Send(t, handshakeMsg, our) }() + if their, err = readProtocolHandshake(t); err != nil { + <-werr // make sure the write terminates too + return nil, err + } + if err := <-werr; err != nil { + return nil, fmt.Errorf("write error: %v", err) + } + // If the protocol version supports Snappy encoding, upgrade immediately + t.conn.SetSnappy(their.Version >= snappyProtocolVersion) + + return their, nil +} + +func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) { + msg, err := rw.ReadMsg() + if err != nil { + return nil, err + } + if msg.Size > baseProtocolMaxMsgSize { + return nil, fmt.Errorf("message too big") + } + if msg.Code == discMsg { + // Disconnect before protocol handshake is valid according to the + // spec and we send it ourself if the post-handshake checks fail. + // We can't return the reason directly, though, because it is echoed + // back otherwise. Wrap it in a string instead. + var reason [1]DiscReason + rlp.Decode(msg.Payload, &reason) + return nil, reason[0] + } + if msg.Code != handshakeMsg { + return nil, fmt.Errorf("expected handshake, got %x", msg.Code) + } + var hs protoHandshake + if err := msg.Decode(&hs); err != nil { + return nil, err + } + if len(hs.ID) != 64 || !bitutil.TestBytes(hs.ID) { + return nil, DiscInvalidIdentity + } + return &hs, nil +} diff --git a/p2p/transport_test.go b/p2p/transport_test.go new file mode 100644 index 000000000000..753ea30bf196 --- /dev/null +++ b/p2p/transport_test.go @@ -0,0 +1,148 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package p2p + +import ( + "errors" + "reflect" + "sync" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/simulations/pipes" +) + +func TestProtocolHandshake(t *testing.T) { + var ( + prv0, _ = crypto.GenerateKey() + pub0 = crypto.FromECDSAPub(&prv0.PublicKey)[1:] + hs0 = &protoHandshake{Version: 3, ID: pub0, Caps: []Cap{{"a", 0}, {"b", 2}}} + + prv1, _ = crypto.GenerateKey() + pub1 = crypto.FromECDSAPub(&prv1.PublicKey)[1:] + hs1 = &protoHandshake{Version: 3, ID: pub1, Caps: []Cap{{"c", 1}, {"d", 3}}} + + wg sync.WaitGroup + ) + + fd0, fd1, err := pipes.TCPPipe() + if err != nil { + t.Fatal(err) + } + + wg.Add(2) + go func() { + defer wg.Done() + defer fd0.Close() + frame := newRLPX(fd0, &prv1.PublicKey) + rpubkey, err := frame.doEncHandshake(prv0) + if err != nil { + t.Errorf("dial side enc handshake failed: %v", err) + return + } + if !reflect.DeepEqual(rpubkey, &prv1.PublicKey) { + t.Errorf("dial side remote pubkey mismatch: got %v, want %v", rpubkey, &prv1.PublicKey) + return + } + + phs, err := frame.doProtoHandshake(hs0) + if err != nil { + t.Errorf("dial side proto handshake error: %v", err) + return + } + phs.Rest = nil + if !reflect.DeepEqual(phs, hs1) { + t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1)) + return + } + frame.close(DiscQuitting) + }() + go func() { + defer wg.Done() + defer fd1.Close() + rlpx := newRLPX(fd1, nil) + rpubkey, err := rlpx.doEncHandshake(prv1) + if err != nil { + t.Errorf("listen side enc handshake failed: %v", err) + return + } + if !reflect.DeepEqual(rpubkey, &prv0.PublicKey) { + t.Errorf("listen side remote pubkey mismatch: got %v, want %v", rpubkey, &prv0.PublicKey) + return + } + + phs, err := rlpx.doProtoHandshake(hs1) + if err != nil { + t.Errorf("listen side proto handshake error: %v", err) + return + } + phs.Rest = nil + if !reflect.DeepEqual(phs, hs0) { + t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0)) + return + } + + if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil { + t.Errorf("error receiving disconnect: %v", err) + } + }() + wg.Wait() +} + +func TestProtocolHandshakeErrors(t *testing.T) { + tests := []struct { + code uint64 + msg interface{} + err error + }{ + { + code: discMsg, + msg: []DiscReason{DiscQuitting}, + err: DiscQuitting, + }, + { + code: 0x989898, + msg: []byte{1}, + err: errors.New("expected handshake, got 989898"), + }, + { + code: handshakeMsg, + msg: make([]byte, baseProtocolMaxMsgSize+2), + err: errors.New("message too big"), + }, + { + code: handshakeMsg, + msg: []byte{1, 2, 3}, + err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"), + }, + { + code: handshakeMsg, + msg: &protoHandshake{Version: 3}, + err: DiscInvalidIdentity, + }, + } + + for i, test := range tests { + p1, p2 := MsgPipe() + go Send(p1, test.code, test.msg) + _, err := readProtocolHandshake(p2) + if !reflect.DeepEqual(err, test.err) { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } + } +} diff --git a/params/bootnodes.go b/params/bootnodes.go index c8736b8ae87e..d4512bf789a5 100644 --- a/params/bootnodes.go +++ b/params/bootnodes.go @@ -56,17 +56,21 @@ var GoerliBootnodes = []string{ "enode://011f758e6552d105183b1761c5e2dea0111bc20fd5f6422bc7f91e0fabbec9a6595caf6239b37feb773dddd3f87240d99d859431891e4a642cf2a0a9e6cbb98a@51.141.78.53:30303", "enode://176b9417f511d05b6b2cf3e34b756cf0a7096b3094572a8f6ef4cdcb9d1f9d00683bf0f83347eebdf3b81c3521c2332086d9592802230bf528eaf606a1d9677b@13.93.54.137:30303", "enode://46add44b9f13965f7b9875ac6b85f016f341012d84f975377573800a863526f4da19ae2c620ec73d11591fa9510e992ecc03ad0751f53cc02f7c7ed6d55c7291@94.237.54.114:30313", - "enode://c1f8b7c2ac4453271fa07d8e9ecf9a2e8285aa0bd0c07df0131f47153306b0736fd3db8924e7a9bf0bed6b1d8d4f87362a71b033dc7c64547728d953e43e59b2@52.64.155.147:30303", - "enode://f4a9c6ee28586009fb5a96c8af13a58ed6d8315a9eee4772212c1d4d9cebe5a8b8a78ea4434f318726317d04a3f531a1ef0420cf9752605a562cfe858c46e263@213.186.16.82:30303", + "enode://b5948a2d3e9d486c4d75bf32713221c2bd6cf86463302339299bd227dc2e276cd5a1c7ca4f43a0e9122fe9af884efed563bd2a1fd28661f3b5f5ad7bf1de5949@18.218.250.66:30303", // Ethereum Foundation bootnode "enode://a61215641fb8714a373c80edbfa0ea8878243193f57c96eeb44d0bc019ef295abd4e044fd619bfc4c59731a73fb79afe84e9ab6da0c743ceb479cbb6d263fa91@3.11.147.67:30303", + + // Goerli Initiative bootnodes + "enode://a869b02cec167211fb4815a82941db2e7ed2936fd90e78619c53eb17753fcf0207463e3419c264e2a1dd8786de0df7e68cf99571ab8aeb7c4e51367ef186b1dd@51.15.116.226:30303", + "enode://807b37ee4816ecf407e9112224494b74dd5933625f655962d892f2f0f02d7fbbb3e2a94cf87a96609526f30c998fd71e93e2f53015c558ffc8b03eceaf30ee33@51.15.119.157:30303", + "enode://a59e33ccd2b3e52d578f1fbd70c6f9babda2650f0760d6ff3b37742fdcdfdb3defba5d56d315b40c46b70198c7621e63ffa3f987389c7118634b0fefbbdfa7fd@51.15.119.157:40303", } -// YoloV1Bootnodes are the enode URLs of the P2P bootstrap nodes running on the -// YOLOv1 ephemeral test network. -var YoloV1Bootnodes = []string{ - "enode://9e1096aa59862a6f164994cb5cb16f5124d6c992cdbf4535ff7dea43ea1512afe5448dca9df1b7ab0726129603f1a3336b631e4d7a1a44c94daddd03241587f9@35.178.210.161:30303", +// YoloV2Bootnodes are the enode URLs of the P2P bootstrap nodes running on the +// YOLOv2 ephemeral test network. +var YoloV2Bootnodes = []string{ + "enode://9e1096aa59862a6f164994cb5cb16f5124d6c992cdbf4535ff7dea43ea1512afe5448dca9df1b7ab0726129603f1a3336b631e4d7a1a44c94daddd03241587f9@3.9.20.133:30303", } const dnsPrefix = "enrtree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@" diff --git a/params/config.go b/params/config.go index 2cd66acc2f92..588d848f1c2f 100644 --- a/params/config.go +++ b/params/config.go @@ -31,7 +31,8 @@ var ( RopstenGenesisHash = common.HexToHash("0x41941023680923e0fe4d74a34bdac8141f2540e3ae90623718e47d66d1ca4a2d") RinkebyGenesisHash = common.HexToHash("0x6341fd3daf94b748c72ced5a5b26028f2474f5f00d824504e4fa37a75767e177") GoerliGenesisHash = common.HexToHash("0xbf7e331f7f7c1dd2e05159666b3bf8bc7a8a3a9eb1d518969eab529dd9b88c1a") - YoloV1GenesisHash = common.HexToHash("0xc3fd235071f24f93865b0850bd2a2119b30f7224d18a0e34c7bbf549ad7e3d36") + // TODO: update with yolov2 values + YoloV2GenesisHash = common.HexToHash("0x498a7239036dd2cd09e2bb8a80922b78632017958c332b42044c250d603a8a3e") ) // TrustedCheckpoints associates each known checkpoint with the genesis hash of @@ -73,10 +74,10 @@ var ( // MainnetTrustedCheckpoint contains the light client trusted checkpoint for the main network. MainnetTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 329, - SectionHead: common.HexToHash("0x96bb6d286ded20a18480dd98d537ab503bd81110c6b9c3f8ad1f9338f3b9852d"), - CHTRoot: common.HexToHash("0x10627ff648077adeaab9dbd4e5bbed8671c86005b2aef5f5d4857acca19a49d8"), - BloomRoot: common.HexToHash("0xf499b0cfaf426a490b7b5ddca58d3031b008f0c15338f8f25c20f3df050bf785"), + SectionIndex: 345, + SectionHead: common.HexToHash("0x5453bab878704adebc934b41fd214a07ea7a72b8572ff088dca7f7956cd0ef28"), + CHTRoot: common.HexToHash("0x7693d432595846c094f47cb37f5c868b0b7b1968fc6b0fc411ded1345fdaffab"), + BloomRoot: common.HexToHash("0x8b0e7895bc39840d8dac857e26bdf3d0a07684b0b962b252546659e0337a9f70"), } // MainnetCheckpointOracle contains a set of configs for the main network oracle. @@ -112,10 +113,10 @@ var ( // RopstenTrustedCheckpoint contains the light client trusted checkpoint for the Ropsten test network. RopstenTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 262, - SectionHead: common.HexToHash("0x12b068f285789b966a983b632266484f1bc93803df6c78773538a5777f57a236"), - CHTRoot: common.HexToHash("0x14000a1407e866f174f3a20fe9f271acd704bcf929b5205d83b70a1bba8c82c2"), - BloomRoot: common.HexToHash("0x2f4f4a34a55e35d0691c79a79e39b6f661259345080fb880da5195c11c2413be"), + SectionIndex: 279, + SectionHead: common.HexToHash("0x4a4912848d4c06090097073357c10015d11c6f4544a0f93cbdd584701c3b7d58"), + CHTRoot: common.HexToHash("0x9053b7867ae921e80a4e2f5a4b15212e4af3d691ca712fb33dc150e9c6ea221c"), + BloomRoot: common.HexToHash("0x3dc04cb1be7ddc271f3f83469b47b76184a79d7209ef51d85b1539ea6d25a645"), } // RopstenCheckpointOracle contains a set of configs for the Ropsten test network oracle. @@ -154,10 +155,10 @@ var ( // RinkebyTrustedCheckpoint contains the light client trusted checkpoint for the Rinkeby test network. RinkebyTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 217, - SectionHead: common.HexToHash("0x9afa4900a60cb44b102eb2eb5e5ef1d7f4cc1911c1c0588518995fb778ffe894"), - CHTRoot: common.HexToHash("0xcc963e5085622c7cb6b3bf747fbfdfe71887e0d5bc9e4b3fb0474d44fc97942a"), - BloomRoot: common.HexToHash("0x1064ca3a36b6f129783cff51bb18fb038bade47d2b776d1cccb9c74925106703"), + SectionIndex: 232, + SectionHead: common.HexToHash("0x8170fca4039b11a008c11f9996ff112151cbb17411437bb2f86288e11158b2f0"), + CHTRoot: common.HexToHash("0x4526560d92ae1b3a6d3ee780c3ad289ba2bbf1b5da58d9ea107f2f26412b631f"), + BloomRoot: common.HexToHash("0x82a889098a35d6a21ea8894d35a1db69b94bad61b988bbe5ae4601437320e331"), } // RinkebyCheckpointOracle contains a set of configs for the Rinkeby test network oracle. @@ -194,10 +195,10 @@ var ( // GoerliTrustedCheckpoint contains the light client trusted checkpoint for the Görli test network. GoerliTrustedCheckpoint = &TrustedCheckpoint{ - SectionIndex: 101, - SectionHead: common.HexToHash("0x396f5dd8e526edfb550873bcfe0e93dc00d70be4b881ab256980833b97a18c3e"), - CHTRoot: common.HexToHash("0x0d145657a6595508ef878c9bbf8eca045631986f664bfab0d898fc64804a4e64"), - BloomRoot: common.HexToHash("0x12df34d07cf1268abe22d40ee6deb199b8918e3d57d52f9e70f9b2883f57d74f"), + SectionIndex: 116, + SectionHead: common.HexToHash("0xf2d200f636f213c9c7bb4e747ff564813da7708253037103aef3d8be5203c5e1"), + CHTRoot: common.HexToHash("0xb0ac83e2ccf6c2776945e099c4e3df50fe6200499c8b2045c34cafdf57d15087"), + BloomRoot: common.HexToHash("0xfb580ad1c611230a4bfc56534f58bcb156d028bc6ce70e35403dc019c7c02d90"), } // GoerliCheckpointOracle contains a set of configs for the Goerli test network oracle. @@ -213,9 +214,9 @@ var ( Threshold: 2, } - // YoloV1ChainConfig contains the chain parameters to run a node on the YOLOv1 test network. - YoloV1ChainConfig = &ChainConfig{ - ChainID: big.NewInt(133519467574833), + // YoloV2ChainConfig contains the chain parameters to run a node on the YOLOv2 test network. + YoloV2ChainConfig = &ChainConfig{ + ChainID: big.NewInt(133519467574834), HomesteadBlock: big.NewInt(0), DAOForkBlock: nil, DAOForkSupport: true, @@ -227,7 +228,7 @@ var ( PetersburgBlock: big.NewInt(0), IstanbulBlock: big.NewInt(0), MuirGlacierBlock: nil, - YoloV1Block: big.NewInt(0), + YoloV2Block: big.NewInt(0), Clique: &CliqueConfig{ Period: 15, Epoch: 30000, @@ -320,7 +321,7 @@ type ChainConfig struct { IstanbulBlock *big.Int `json:"istanbulBlock,omitempty"` // Istanbul switch block (nil = no fork, 0 = already on istanbul) MuirGlacierBlock *big.Int `json:"muirGlacierBlock,omitempty"` // Eip-2384 (bomb delay) switch block (nil = no fork, 0 = already activated) - YoloV1Block *big.Int `json:"yoloV1Block,omitempty"` // YOLO v1: https://github.com/ethereum/EIPs/pull/2657 (Ephemeral testnet) + YoloV2Block *big.Int `json:"yoloV2Block,omitempty"` // YOLO v2: Gas repricings TODO @holiman add EIP references EWASMBlock *big.Int `json:"ewasmBlock,omitempty"` // EWASM switch block (nil = no fork, 0 = already activated) // Various consensus engines @@ -358,7 +359,7 @@ func (c *ChainConfig) String() string { default: engine = "unknown" } - return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Petersburg: %v Istanbul: %v, Muir Glacier: %v, YOLO v1: %v, Engine: %v}", + return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Petersburg: %v Istanbul: %v, Muir Glacier: %v, YOLO v2: %v, Engine: %v}", c.ChainID, c.HomesteadBlock, c.DAOForkBlock, @@ -371,7 +372,7 @@ func (c *ChainConfig) String() string { c.PetersburgBlock, c.IstanbulBlock, c.MuirGlacierBlock, - c.YoloV1Block, + c.YoloV2Block, engine, ) } @@ -428,9 +429,9 @@ func (c *ChainConfig) IsIstanbul(num *big.Int) bool { return isForked(c.IstanbulBlock, num) } -// IsYoloV1 returns whether num is either equal to the YoloV1 fork block or greater. -func (c *ChainConfig) IsYoloV1(num *big.Int) bool { - return isForked(c.YoloV1Block, num) +// IsYoloV2 returns whether num is either equal to the YoloV1 fork block or greater. +func (c *ChainConfig) IsYoloV2(num *big.Int) bool { + return isForked(c.YoloV2Block, num) } // IsEWASM returns whether num represents a block number after the EWASM fork @@ -476,7 +477,7 @@ func (c *ChainConfig) CheckConfigForkOrder() error { {name: "petersburgBlock", block: c.PetersburgBlock}, {name: "istanbulBlock", block: c.IstanbulBlock}, {name: "muirGlacierBlock", block: c.MuirGlacierBlock, optional: true}, - {name: "yoloV1Block", block: c.YoloV1Block}, + {name: "yoloV2Block", block: c.YoloV2Block}, } { if lastFork.name != "" { // Next one must be higher number @@ -540,8 +541,8 @@ func (c *ChainConfig) checkCompatible(newcfg *ChainConfig, head *big.Int) *Confi if isForkIncompatible(c.MuirGlacierBlock, newcfg.MuirGlacierBlock, head) { return newCompatError("Muir Glacier fork block", c.MuirGlacierBlock, newcfg.MuirGlacierBlock) } - if isForkIncompatible(c.YoloV1Block, newcfg.YoloV1Block, head) { - return newCompatError("YOLOv1 fork block", c.YoloV1Block, newcfg.YoloV1Block) + if isForkIncompatible(c.YoloV2Block, newcfg.YoloV2Block, head) { + return newCompatError("YOLOv2 fork block", c.YoloV2Block, newcfg.YoloV2Block) } if isForkIncompatible(c.EWASMBlock, newcfg.EWASMBlock, head) { return newCompatError("ewasm fork block", c.EWASMBlock, newcfg.EWASMBlock) @@ -613,7 +614,7 @@ type Rules struct { ChainID *big.Int IsHomestead, IsEIP150, IsEIP155, IsEIP158 bool IsByzantium, IsConstantinople, IsPetersburg, IsIstanbul bool - IsYoloV1 bool + IsYoloV2 bool } // Rules ensures c's ChainID is not nil. @@ -632,6 +633,6 @@ func (c *ChainConfig) Rules(num *big.Int) Rules { IsConstantinople: c.IsConstantinople(num), IsPetersburg: c.IsPetersburg(num), IsIstanbul: c.IsIstanbul(num), - IsYoloV1: c.IsYoloV1(num), + IsYoloV2: c.IsYoloV2(num), } } diff --git a/params/protocol_params.go b/params/protocol_params.go index eae935743cb0..ffb901ec3d3b 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -52,14 +52,10 @@ const ( NetSstoreResetRefund uint64 = 4800 // Once per SSTORE operation for resetting to the original non-zero value NetSstoreResetClearRefund uint64 = 19800 // Once per SSTORE operation for resetting to the original zero value - SstoreSentryGasEIP2200 uint64 = 2300 // Minimum gas required to be present for an SSTORE call, not consumed - SstoreNoopGasEIP2200 uint64 = 800 // Once per SSTORE operation if the value doesn't change. - SstoreDirtyGasEIP2200 uint64 = 800 // Once per SSTORE operation if a dirty value is changed. - SstoreInitGasEIP2200 uint64 = 20000 // Once per SSTORE operation from clean zero to non-zero - SstoreInitRefundEIP2200 uint64 = 19200 // Once per SSTORE operation for resetting to the original zero value - SstoreCleanGasEIP2200 uint64 = 5000 // Once per SSTORE operation from clean non-zero to something else - SstoreCleanRefundEIP2200 uint64 = 4200 // Once per SSTORE operation for resetting to the original non-zero value - SstoreClearRefundEIP2200 uint64 = 15000 // Once per SSTORE operation for clearing an originally existing storage slot + SstoreSentryGasEIP2200 uint64 = 2300 // Minimum gas required to be present for an SSTORE call, not consumed + SstoreSetGasEIP2200 uint64 = 20000 // Once per SSTORE operation from clean zero to non-zero + SstoreResetGasEIP2200 uint64 = 5000 // Once per SSTORE operation from clean non-zero to something else + SstoreClearsScheduleRefundEIP2200 uint64 = 15000 // Once per SSTORE operation for clearing an originally existing storage slot JumpdestGas uint64 = 1 // Once per JUMPDEST operation. EpochDuration uint64 = 30000 // Duration between proof-of-work epochs. @@ -120,7 +116,6 @@ const ( Ripemd160PerWordGas uint64 = 120 // Per-word price for a RIPEMD160 operation IdentityBaseGas uint64 = 15 // Base price for a data copy operation IdentityPerWordGas uint64 = 3 // Per-work price for a data copy operation - ModExpQuadCoeffDiv uint64 = 20 // Divisor for the quadratic particle of the big int modular exponentiation Bn256AddGasByzantium uint64 = 500 // Byzantium gas needed for an elliptic curve addition Bn256AddGasIstanbul uint64 = 150 // Gas needed for an elliptic curve addition diff --git a/params/version.go b/params/version.go index 521efe15f5a7..b6d6bd3f1d0f 100644 --- a/params/version.go +++ b/params/version.go @@ -23,7 +23,7 @@ import ( const ( VersionMajor = 1 // Major version component of the current release VersionMinor = 9 // Minor version component of the current release - VersionPatch = 22 // Patch version component of the current release + VersionPatch = 26 // Patch version component of the current release VersionMeta = "unstable" // Version metadata to append to the version string ) diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 68037451971b..418ee10a350e 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -39,9 +39,8 @@ func (e *testEncoder) EncodeRLP(w io.Writer) error { } if e.err != nil { return e.err - } else { - w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) } + w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) return nil } diff --git a/rlp/raw.go b/rlp/raw.go index c2a8517f62f1..3071e99cab1b 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -180,3 +180,74 @@ func readSize(b []byte, slen byte) (uint64, error) { } return s, nil } + +// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice. +func AppendUint64(b []byte, i uint64) []byte { + if i == 0 { + return append(b, 0x80) + } else if i < 128 { + return append(b, byte(i)) + } + switch { + case i < (1 << 8): + return append(b, 0x81, byte(i)) + case i < (1 << 16): + return append(b, 0x82, + byte(i>>8), + byte(i), + ) + case i < (1 << 24): + return append(b, 0x83, + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 32): + return append(b, 0x84, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 40): + return append(b, 0x85, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + case i < (1 << 48): + return append(b, 0x86, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 56): + return append(b, 0x87, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + default: + return append(b, 0x88, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + } +} diff --git a/rlp/raw_test.go b/rlp/raw_test.go index cdae4ff08859..c976c4f73429 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -21,6 +21,7 @@ import ( "io" "reflect" "testing" + "testing/quick" ) func TestCountValues(t *testing.T) { @@ -239,3 +240,40 @@ func TestReadSize(t *testing.T) { } } } + +func TestAppendUint64(t *testing.T) { + tests := []struct { + input uint64 + slice []byte + output string + }{ + {0, nil, "80"}, + {1, nil, "01"}, + {2, nil, "02"}, + {127, nil, "7F"}, + {128, nil, "8180"}, + {129, nil, "8181"}, + {0xFFFFFF, nil, "83FFFFFF"}, + {127, []byte{1, 2, 3}, "0102037F"}, + {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"}, + } + + for _, test := range tests { + x := AppendUint64(test.slice, test.input) + if !bytes.Equal(x, unhex(test.output)) { + t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output) + } + } +} + +func TestAppendUint64Random(t *testing.T) { + fn := func(i uint64) bool { + enc, _ := EncodeToBytes(i) + encAppend := AppendUint64(nil, i) + return bytes.Equal(enc, encAppend) + } + config := quick.Config{MaxCountScale: 50} + if err := quick.Check(fn, &config); err != nil { + t.Fatal(err) + } +} diff --git a/rpc/client.go b/rpc/client.go index 91e68e73e692..9393adb77903 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -404,9 +404,8 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) if c.isHTTP { return c.sendHTTP(ctx, op, msg) - } else { - return c.send(ctx, op, msg) } + return c.send(ctx, op, msg) } // EthSubscribe registers a subscripion under the "eth" namespace. diff --git a/rpc/endpoints.go b/rpc/endpoints.go index 9fc07051723f..d78ebe2858bc 100644 --- a/rpc/endpoints.go +++ b/rpc/endpoints.go @@ -18,6 +18,7 @@ package rpc import ( "net" + "strings" "github.com/ethereum/go-ethereum/log" ) @@ -25,13 +26,22 @@ import ( // StartIPCEndpoint starts an IPC endpoint. func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) { // Register all the APIs exposed by the services. - handler := NewServer() + var ( + handler = NewServer() + regMap = make(map[string]struct{}) + registered []string + ) for _, api := range apis { if err := handler.RegisterName(api.Namespace, api.Service); err != nil { + log.Info("IPC registration failed", "namespace", api.Namespace, "error", err) return nil, nil, err } - log.Debug("IPC registered", "namespace", api.Namespace) + if _, ok := regMap[api.Namespace]; !ok { + registered = append(registered, api.Namespace) + regMap[api.Namespace] = struct{}{} + } } + log.Debug("IPCs registered", "namespaces", strings.Join(registered, ",")) // All APIs registered, start the IPC listener. listener, err := ipcListen(ipcEndpoint) if err != nil { diff --git a/rpc/websocket.go b/rpc/websocket.go index a716383be91a..cd60eeb613cd 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -75,14 +75,14 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { allowAllOrigins = true } if origin != "" { - origins.Add(strings.ToLower(origin)) + origins.Add(origin) } } // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") if hostname, err := os.Hostname(); err == nil { - origins.Add("http://" + strings.ToLower(hostname)) + origins.Add("http://" + hostname) } } log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) @@ -97,7 +97,7 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { } // Verify origin against whitelist. origin := strings.ToLower(req.Header.Get("Origin")) - if allowAllOrigins || origins.Contains(origin) { + if allowAllOrigins || originIsAllowed(origins, origin) { return true } log.Warn("Rejected WebSocket connection", "origin", origin) @@ -120,6 +120,65 @@ func (e wsHandshakeError) Error() string { return s } +func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { + it := allowedOrigins.Iterator() + for origin := range it.C { + if ruleAllowsOrigin(origin.(string), browserOrigin) { + return true + } + } + return false +} + +func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { + var ( + allowedScheme, allowedHostname, allowedPort string + browserScheme, browserHostname, browserPort string + err error + ) + allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) + if err != nil { + log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) + return false + } + browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) + if err != nil { + log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) + return false + } + if allowedScheme != "" && allowedScheme != browserScheme { + return false + } + if allowedHostname != "" && allowedHostname != browserHostname { + return false + } + if allowedPort != "" && allowedPort != browserPort { + return false + } + return true +} + +func parseOriginURL(origin string) (string, string, string, error) { + parsedURL, err := url.Parse(strings.ToLower(origin)) + if err != nil { + return "", "", "", err + } + var scheme, hostname, port string + if strings.Contains(origin, "://") { + scheme = parsedURL.Scheme + hostname = parsedURL.Hostname() + port = parsedURL.Port() + } else { + scheme = "" + hostname = parsedURL.Scheme + port = parsedURL.Opaque + if hostname == "" { + hostname = origin + } + } + return scheme, hostname, port, nil +} + // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { diff --git a/signer/core/api.go b/signer/core/api.go index 3817345c8f1a..7595d2d484b6 100644 --- a/signer/core/api.go +++ b/signer/core/api.go @@ -41,7 +41,7 @@ const ( // numberOfAccountsToDerive For hardware wallets, the number of accounts to derive numberOfAccountsToDerive = 10 // ExternalAPIVersion -- see extapi_changelog.md - ExternalAPIVersion = "6.0.0" + ExternalAPIVersion = "6.1.0" // InternalAPIVersion -- see intapi_changelog.md InternalAPIVersion = "7.0.1" ) @@ -62,6 +62,8 @@ type ExternalAPI interface { EcRecover(ctx context.Context, data hexutil.Bytes, sig hexutil.Bytes) (common.Address, error) // Version info about the APIs Version(ctx context.Context) (string, error) + // SignGnosisSafeTransaction signs/confirms a gnosis-safe multisig transaction + SignGnosisSafeTx(ctx context.Context, signerAddress common.MixedcaseAddress, gnosisTx GnosisSafeTx, methodSelector *string) (*GnosisSafeTx, error) } // UIClientAPI specifies what method a UI needs to implement to be able to be used as a @@ -234,6 +236,7 @@ type ( Address common.MixedcaseAddress `json:"address"` Rawdata []byte `json:"raw_data"` Messages []*NameValueType `json:"messages"` + Callinfo []ValidationInfo `json:"call_info"` Hash hexutil.Bytes `json:"hash"` Meta Metadata `json:"meta"` } @@ -319,68 +322,73 @@ func (api *SignerAPI) openTrezor(url accounts.URL) { // startUSBListener starts a listener for USB events, for hardware wallet interaction func (api *SignerAPI) startUSBListener() { - events := make(chan accounts.WalletEvent, 16) + eventCh := make(chan accounts.WalletEvent, 16) am := api.am - am.Subscribe(events) - go func() { + am.Subscribe(eventCh) + // Open any wallets already attached + for _, wallet := range am.Wallets() { + if err := wallet.Open(""); err != nil { + log.Warn("Failed to open wallet", "url", wallet.URL(), "err", err) + if err == usbwallet.ErrTrezorPINNeeded { + go api.openTrezor(wallet.URL()) + } + } + } + go api.derivationLoop(eventCh) +} - // Open any wallets already attached - for _, wallet := range am.Wallets() { - if err := wallet.Open(""); err != nil { - log.Warn("Failed to open wallet", "url", wallet.URL(), "err", err) +// derivationLoop listens for wallet events +func (api *SignerAPI) derivationLoop(events chan accounts.WalletEvent) { + // Listen for wallet event till termination + for event := range events { + switch event.Kind { + case accounts.WalletArrived: + if err := event.Wallet.Open(""); err != nil { + log.Warn("New wallet appeared, failed to open", "url", event.Wallet.URL(), "err", err) if err == usbwallet.ErrTrezorPINNeeded { - go api.openTrezor(wallet.URL()) + go api.openTrezor(event.Wallet.URL()) } } - } - // Listen for wallet event till termination - for event := range events { - switch event.Kind { - case accounts.WalletArrived: - if err := event.Wallet.Open(""); err != nil { - log.Warn("New wallet appeared, failed to open", "url", event.Wallet.URL(), "err", err) - if err == usbwallet.ErrTrezorPINNeeded { - go api.openTrezor(event.Wallet.URL()) - } - } - case accounts.WalletOpened: - status, _ := event.Wallet.Status() - log.Info("New wallet appeared", "url", event.Wallet.URL(), "status", status) - var derive = func(numToDerive int, base accounts.DerivationPath) { - // Derive first N accounts, hardcoded for now - var nextPath = make(accounts.DerivationPath, len(base)) - copy(nextPath[:], base[:]) - - for i := 0; i < numToDerive; i++ { - acc, err := event.Wallet.Derive(nextPath, true) - if err != nil { - log.Warn("Account derivation failed", "error", err) - } else { - log.Info("Derived account", "address", acc.Address, "path", nextPath) - } - nextPath[len(nextPath)-1]++ + case accounts.WalletOpened: + status, _ := event.Wallet.Status() + log.Info("New wallet appeared", "url", event.Wallet.URL(), "status", status) + var derive = func(limit int, next func() accounts.DerivationPath) { + // Derive first N accounts, hardcoded for now + for i := 0; i < limit; i++ { + path := next() + if acc, err := event.Wallet.Derive(path, true); err != nil { + log.Warn("Account derivation failed", "error", err) + } else { + log.Info("Derived account", "address", acc.Address, "path", path) } } - if event.Wallet.URL().Scheme == "ledger" { - log.Info("Deriving ledger default paths") - derive(numberOfAccountsToDerive/2, accounts.DefaultBaseDerivationPath) - log.Info("Deriving ledger legacy paths") - derive(numberOfAccountsToDerive/2, accounts.LegacyLedgerBaseDerivationPath) - } else { - derive(numberOfAccountsToDerive, accounts.DefaultBaseDerivationPath) - } - case accounts.WalletDropped: - log.Info("Old wallet dropped", "url", event.Wallet.URL()) - event.Wallet.Close() } + log.Info("Deriving default paths") + derive(numberOfAccountsToDerive, accounts.DefaultIterator(accounts.DefaultBaseDerivationPath)) + if event.Wallet.URL().Scheme == "ledger" { + log.Info("Deriving ledger legacy paths") + derive(numberOfAccountsToDerive, accounts.DefaultIterator(accounts.LegacyLedgerBaseDerivationPath)) + log.Info("Deriving ledger live paths") + // For ledger live, since it's based off the same (DefaultBaseDerivationPath) + // as one we've already used, we need to step it forward one step to avoid + // hitting the same path again + nextFn := accounts.LedgerLiveIterator(accounts.DefaultBaseDerivationPath) + nextFn() + derive(numberOfAccountsToDerive, nextFn) + } + case accounts.WalletDropped: + log.Info("Old wallet dropped", "url", event.Wallet.URL()) + event.Wallet.Close() } - }() + } } // List returns the set of wallet this signer manages. Each wallet can contain // multiple accounts. func (api *SignerAPI) List(ctx context.Context) ([]common.Address, error) { - var accs []accounts.Account + var accs = make([]accounts.Account, 0) + // accs is initialized as empty list, not nil. We use 'nil' to signal + // rejection, as opposed to an empty list. for _, wallet := range api.am.Wallets() { accs = append(accs, wallet.Accounts()...) } @@ -390,13 +398,11 @@ func (api *SignerAPI) List(ctx context.Context) ([]common.Address, error) { } if result.Accounts == nil { return nil, ErrRequestDenied - } addresses := make([]common.Address, 0) for _, acc := range result.Accounts { addresses = append(addresses, acc.Address) } - return addresses, nil } @@ -581,6 +587,33 @@ func (api *SignerAPI) SignTransaction(ctx context.Context, args SendTxArgs, meth } +func (api *SignerAPI) SignGnosisSafeTx(ctx context.Context, signerAddress common.MixedcaseAddress, gnosisTx GnosisSafeTx, methodSelector *string) (*GnosisSafeTx, error) { + // Do the usual validations, but on the last-stage transaction + args := gnosisTx.ArgsForValidation() + msgs, err := api.validator.ValidateTransaction(methodSelector, args) + if err != nil { + return nil, err + } + // If we are in 'rejectMode', then reject rather than show the user warnings + if api.rejectMode { + if err := msgs.getWarnings(); err != nil { + return nil, err + } + } + typedData := gnosisTx.ToTypedData() + signature, preimage, err := api.signTypedData(ctx, signerAddress, typedData, msgs) + if err != nil { + return nil, err + } + checkSummedSender, _ := common.NewMixedcaseAddressFromString(signerAddress.Address().Hex()) + + gnosisTx.Signature = signature + gnosisTx.SafeTxHash = common.BytesToHash(preimage) + gnosisTx.Sender = *checkSummedSender // Must be checksumed to be accepted by relay + + return &gnosisTx, nil +} + // Returns the external api version. This method does not require user acceptance. Available methods are // available via enumeration anyway, and this info does not contain user-specific data func (api *SignerAPI) Version(ctx context.Context) (string, error) { diff --git a/signer/core/auditlog.go b/signer/core/auditlog.go index 1092e7a92340..bda88a8b2e55 100644 --- a/signer/core/auditlog.go +++ b/signer/core/auditlog.go @@ -18,6 +18,7 @@ package core import ( "context" + "encoding/json" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -61,13 +62,32 @@ func (l *AuditLogger) SignTransaction(ctx context.Context, args SendTxArgs, meth } func (l *AuditLogger) SignData(ctx context.Context, contentType string, addr common.MixedcaseAddress, data interface{}) (hexutil.Bytes, error) { + marshalledData, _ := json.Marshal(data) // can ignore error, marshalling what we just unmarshalled l.log.Info("SignData", "type", "request", "metadata", MetadataFromContext(ctx).String(), - "addr", addr.String(), "data", data, "content-type", contentType) + "addr", addr.String(), "data", marshalledData, "content-type", contentType) b, e := l.api.SignData(ctx, contentType, addr, data) l.log.Info("SignData", "type", "response", "data", common.Bytes2Hex(b), "error", e) return b, e } +func (l *AuditLogger) SignGnosisSafeTx(ctx context.Context, addr common.MixedcaseAddress, gnosisTx GnosisSafeTx, methodSelector *string) (*GnosisSafeTx, error) { + sel := "" + if methodSelector != nil { + sel = *methodSelector + } + data, _ := json.Marshal(gnosisTx) // can ignore error, marshalling what we just unmarshalled + l.log.Info("SignGnosisSafeTx", "type", "request", "metadata", MetadataFromContext(ctx).String(), + "addr", addr.String(), "data", string(data), "selector", sel) + res, e := l.api.SignGnosisSafeTx(ctx, addr, gnosisTx, methodSelector) + if res != nil { + data, _ := json.Marshal(res) // can ignore error, marshalling what we just unmarshalled + l.log.Info("SignGnosisSafeTx", "type", "response", "data", string(data), "error", e) + } else { + l.log.Info("SignGnosisSafeTx", "type", "response", "data", res, "error", e) + } + return res, e +} + func (l *AuditLogger) SignTypedData(ctx context.Context, addr common.MixedcaseAddress, data TypedData) (hexutil.Bytes, error) { l.log.Info("SignTypedData", "type", "request", "metadata", MetadataFromContext(ctx).String(), "addr", addr.String(), "data", data) diff --git a/signer/core/cliui.go b/signer/core/cliui.go index 27a2f71aaaf7..cbfb56c9df9e 100644 --- a/signer/core/cliui.go +++ b/signer/core/cliui.go @@ -148,6 +148,13 @@ func (ui *CommandlineUI) ApproveSignData(request *SignDataRequest) (SignDataResp fmt.Printf("-------- Sign data request--------------\n") fmt.Printf("Account: %s\n", request.Address.String()) + if len(request.Callinfo) != 0 { + fmt.Printf("\nValidation messages:\n") + for _, m := range request.Callinfo { + fmt.Printf(" * %s : %s\n", m.Typ, m.Message) + } + fmt.Println() + } fmt.Printf("messages:\n") for _, nvt := range request.Messages { fmt.Printf("\u00a0\u00a0%v\n", strings.TrimSpace(nvt.Pprint(1))) diff --git a/signer/core/gnosis_safe.go b/signer/core/gnosis_safe.go new file mode 100644 index 000000000000..e4385f9dc37a --- /dev/null +++ b/signer/core/gnosis_safe.go @@ -0,0 +1,91 @@ +package core + +import ( + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/math" +) + +// GnosisSafeTx is a type to parse the safe-tx returned by the relayer, +// it also conforms to the API required by the Gnosis Safe tx relay service. +// See 'SafeMultisigTransaction' on https://safe-transaction.mainnet.gnosis.io/ +type GnosisSafeTx struct { + // These fields are only used on output + Signature hexutil.Bytes `json:"signature"` + SafeTxHash common.Hash `json:"contractTransactionHash"` + Sender common.MixedcaseAddress `json:"sender"` + // These fields are used both on input and output + Safe common.MixedcaseAddress `json:"safe"` + To common.MixedcaseAddress `json:"to"` + Value math.Decimal256 `json:"value"` + GasPrice math.Decimal256 `json:"gasPrice"` + Data *hexutil.Bytes `json:"data"` + Operation uint8 `json:"operation"` + GasToken common.Address `json:"gasToken"` + RefundReceiver common.Address `json:"refundReceiver"` + BaseGas big.Int `json:"baseGas"` + SafeTxGas big.Int `json:"safeTxGas"` + Nonce big.Int `json:"nonce"` + InputExpHash common.Hash `json:"safeTxHash"` +} + +// ToTypedData converts the tx to a EIP-712 Typed Data structure for signing +func (tx *GnosisSafeTx) ToTypedData() TypedData { + var data hexutil.Bytes + if tx.Data != nil { + data = *tx.Data + } + gnosisTypedData := TypedData{ + Types: Types{ + "EIP712Domain": []Type{{Name: "verifyingContract", Type: "address"}}, + "SafeTx": []Type{ + {Name: "to", Type: "address"}, + {Name: "value", Type: "uint256"}, + {Name: "data", Type: "bytes"}, + {Name: "operation", Type: "uint8"}, + {Name: "safeTxGas", Type: "uint256"}, + {Name: "baseGas", Type: "uint256"}, + {Name: "gasPrice", Type: "uint256"}, + {Name: "gasToken", Type: "address"}, + {Name: "refundReceiver", Type: "address"}, + {Name: "nonce", Type: "uint256"}, + }, + }, + Domain: TypedDataDomain{ + VerifyingContract: tx.Safe.Address().Hex(), + }, + PrimaryType: "SafeTx", + Message: TypedDataMessage{ + "to": tx.To.Address().Hex(), + "value": tx.Value.String(), + "data": data, + "operation": fmt.Sprintf("%d", tx.Operation), + "safeTxGas": fmt.Sprintf("%#d", &tx.SafeTxGas), + "baseGas": fmt.Sprintf("%#d", &tx.BaseGas), + "gasPrice": tx.GasPrice.String(), + "gasToken": tx.GasToken.Hex(), + "refundReceiver": tx.RefundReceiver.Hex(), + "nonce": fmt.Sprintf("%d", tx.Nonce.Uint64()), + }, + } + return gnosisTypedData +} + +// ArgsForValidation returns a SendTxArgs struct, which can be used for the +// common validations, e.g. look up 4byte destinations +func (tx *GnosisSafeTx) ArgsForValidation() *SendTxArgs { + args := &SendTxArgs{ + From: tx.Safe, + To: &tx.To, + Gas: hexutil.Uint64(tx.SafeTxGas.Uint64()), + GasPrice: hexutil.Big(tx.GasPrice), + Value: hexutil.Big(tx.Value), + Nonce: hexutil.Uint64(tx.Nonce.Uint64()), + Data: tx.Data, + Input: nil, + } + return args +} diff --git a/signer/core/signed_data.go b/signer/core/signed_data.go index 7fc66b4b740a..3bff1e1f200d 100644 --- a/signer/core/signed_data.go +++ b/signer/core/signed_data.go @@ -125,7 +125,7 @@ var typedDataReferenceTypeRegexp = regexp.MustCompile(`^[A-Z](\w*)(\[\])?$`) // // Note, the produced signature conforms to the secp256k1 curve R, S and V values, // where the V value will be 27 or 28 for legacy reasons, if legacyV==true. -func (api *SignerAPI) sign(addr common.MixedcaseAddress, req *SignDataRequest, legacyV bool) (hexutil.Bytes, error) { +func (api *SignerAPI) sign(req *SignDataRequest, legacyV bool) (hexutil.Bytes, error) { // We make the request prior to looking up if we actually have the account, to prevent // account-enumeration via the API res, err := api.UI.ApproveSignData(req) @@ -136,7 +136,7 @@ func (api *SignerAPI) sign(addr common.MixedcaseAddress, req *SignDataRequest, l return nil, ErrRequestDenied } // Look up the wallet containing the requested signer - account := accounts.Account{Address: addr.Address()} + account := accounts.Account{Address: req.Address.Address()} wallet, err := api.am.Find(account) if err != nil { return nil, err @@ -167,7 +167,7 @@ func (api *SignerAPI) SignData(ctx context.Context, contentType string, addr com if err != nil { return nil, err } - signature, err := api.sign(addr, req, transformV) + signature, err := api.sign(req, transformV) if err != nil { api.UI.ShowError(err.Error()) return nil, err @@ -312,28 +312,47 @@ func cliqueHeaderHashAndRlp(header *types.Header) (hash, rlp []byte, err error) // SignTypedData signs EIP-712 conformant typed data // hash = keccak256("\x19${byteVersion}${domainSeparator}${hashStruct(message)}") +// It returns +// - the signature, +// - and/or any error func (api *SignerAPI) SignTypedData(ctx context.Context, addr common.MixedcaseAddress, typedData TypedData) (hexutil.Bytes, error) { + signature, _, err := api.signTypedData(ctx, addr, typedData, nil) + return signature, err +} + +// signTypedData is identical to the capitalized version, except that it also returns the hash (preimage) +// - the signature preimage (hash) +func (api *SignerAPI) signTypedData(ctx context.Context, addr common.MixedcaseAddress, + typedData TypedData, validationMessages *ValidationMessages) (hexutil.Bytes, hexutil.Bytes, error) { domainSeparator, err := typedData.HashStruct("EIP712Domain", typedData.Domain.Map()) if err != nil { - return nil, err + return nil, nil, err } typedDataHash, err := typedData.HashStruct(typedData.PrimaryType, typedData.Message) if err != nil { - return nil, err + return nil, nil, err } rawData := []byte(fmt.Sprintf("\x19\x01%s%s", string(domainSeparator), string(typedDataHash))) sighash := crypto.Keccak256(rawData) messages, err := typedData.Format() if err != nil { - return nil, err - } - req := &SignDataRequest{ContentType: DataTyped.Mime, Rawdata: rawData, Messages: messages, Hash: sighash} - signature, err := api.sign(addr, req, true) + return nil, nil, err + } + req := &SignDataRequest{ + ContentType: DataTyped.Mime, + Rawdata: rawData, + Messages: messages, + Hash: sighash, + Address: addr} + if validationMessages != nil { + req.Callinfo = validationMessages.Messages + } + signature, err := api.sign(req, true) if err != nil { api.UI.ShowError(err.Error()) - return nil, err + return nil, nil, err } - return signature, nil + return signature, sighash, nil } // HashStruct generates a keccak256 hash of the encoding of the provided data @@ -420,8 +439,8 @@ func (typedData *TypedData) EncodeData(primaryType string, data map[string]inter buffer := bytes.Buffer{} // Verify extra data - if len(typedData.Types[primaryType]) < len(data) { - return nil, errors.New("there is extra data provided in the message") + if exp, got := len(typedData.Types[primaryType]), len(data); exp < got { + return nil, fmt.Errorf("there is extra data provided in the message (%d < %d)", exp, got) } // Add typehash @@ -487,7 +506,7 @@ func parseBytes(encType interface{}) ([]byte, bool) { case []byte: return v, true case hexutil.Bytes: - return []byte(v), true + return v, true case string: bytes, err := hexutil.Decode(v) if err != nil { @@ -834,7 +853,11 @@ func (nvt *NameValueType) Pprint(depth int) string { output.WriteString(sublevel) } } else { - output.WriteString(fmt.Sprintf("%q\n", nvt.Value)) + if nvt.Value != nil { + output.WriteString(fmt.Sprintf("%q\n", nvt.Value)) + } else { + output.WriteString("\n") + } } return output.String() } diff --git a/signer/core/signed_data_test.go b/signer/core/signed_data_test.go index ab5f2cc9627d..23b7b9897bca 100644 --- a/signer/core/signed_data_test.go +++ b/signer/core/signed_data_test.go @@ -17,6 +17,7 @@ package core_test import ( + "bytes" "context" "encoding/json" "fmt" @@ -414,3 +415,119 @@ func TestFuzzerFiles(t *testing.T) { typedData.Format() } } + +var gnosisTypedData = ` +{ + "types": { + "EIP712Domain": [ + { "type": "address", "name": "verifyingContract" } + ], + "SafeTx": [ + { "type": "address", "name": "to" }, + { "type": "uint256", "name": "value" }, + { "type": "bytes", "name": "data" }, + { "type": "uint8", "name": "operation" }, + { "type": "uint256", "name": "safeTxGas" }, + { "type": "uint256", "name": "baseGas" }, + { "type": "uint256", "name": "gasPrice" }, + { "type": "address", "name": "gasToken" }, + { "type": "address", "name": "refundReceiver" }, + { "type": "uint256", "name": "nonce" } + ] + }, + "domain": { + "verifyingContract": "0x25a6c4BBd32B2424A9c99aEB0584Ad12045382B3" + }, + "primaryType": "SafeTx", + "message": { + "to": "0x9eE457023bB3De16D51A003a247BaEaD7fce313D", + "value": "20000000000000000", + "data": "0x", + "operation": 0, + "safeTxGas": 27845, + "baseGas": 0, + "gasPrice": "0", + "gasToken": "0x0000000000000000000000000000000000000000", + "refundReceiver": "0x0000000000000000000000000000000000000000", + "nonce": 3 + } +}` + +var gnosisTx = ` +{ + "safe": "0x25a6c4BBd32B2424A9c99aEB0584Ad12045382B3", + "to": "0x9eE457023bB3De16D51A003a247BaEaD7fce313D", + "value": "20000000000000000", + "data": null, + "operation": 0, + "gasToken": "0x0000000000000000000000000000000000000000", + "safeTxGas": 27845, + "baseGas": 0, + "gasPrice": "0", + "refundReceiver": "0x0000000000000000000000000000000000000000", + "nonce": 3, + "executionDate": null, + "submissionDate": "2020-09-15T21:59:23.815748Z", + "modified": "2020-09-15T21:59:23.815748Z", + "blockNumber": null, + "transactionHash": null, + "safeTxHash": "0x28bae2bd58d894a1d9b69e5e9fde3570c4b98a6fc5499aefb54fb830137e831f", + "executor": null, + "isExecuted": false, + "isSuccessful": null, + "ethGasPrice": null, + "gasUsed": null, + "fee": null, + "origin": null, + "dataDecoded": null, + "confirmationsRequired": null, + "confirmations": [ + { + "owner": "0xAd2e180019FCa9e55CADe76E4487F126Fd08DA34", + "submissionDate": "2020-09-15T21:59:28.281243Z", + "transactionHash": null, + "confirmationType": "CONFIRMATION", + "signature": "0x5e562065a0cb15d766dac0cd49eb6d196a41183af302c4ecad45f1a81958d7797753f04424a9b0aa1cb0448e4ec8e189540fbcdda7530ef9b9d95dfc2d36cb521b", + "signatureType": "EOA" + } + ], + "signatures": null + } +` + +// TestGnosisTypedData tests the scenario where a user submits a full EIP-712 +// struct without using the gnosis-specific endpoint +func TestGnosisTypedData(t *testing.T) { + var td core.TypedData + err := json.Unmarshal([]byte(gnosisTypedData), &td) + if err != nil { + t.Fatalf("unmarshalling failed '%v'", err) + } + _, sighash, err := sign(td) + if err != nil { + t.Fatal(err) + } + expSigHash := common.FromHex("0x28bae2bd58d894a1d9b69e5e9fde3570c4b98a6fc5499aefb54fb830137e831f") + if !bytes.Equal(expSigHash, sighash) { + t.Fatalf("Error, got %x, wanted %x", sighash, expSigHash) + } +} + +// TestGnosisCustomData tests the scenario where a user submits only the gnosis-safe +// specific data, and we fill the TypedData struct on our side +func TestGnosisCustomData(t *testing.T) { + var tx core.GnosisSafeTx + err := json.Unmarshal([]byte(gnosisTx), &tx) + if err != nil { + t.Fatal(err) + } + var td = tx.ToTypedData() + _, sighash, err := sign(td) + if err != nil { + t.Fatal(err) + } + expSigHash := common.FromHex("0x28bae2bd58d894a1d9b69e5e9fde3570c4b98a6fc5499aefb54fb830137e831f") + if !bytes.Equal(expSigHash, sighash) { + t.Fatalf("Error, got %x, wanted %x", sighash, expSigHash) + } +} diff --git a/tests/block_test_util.go b/tests/block_test_util.go index be9cdb70cd6e..c043f0b3eb80 100644 --- a/tests/block_test_util.go +++ b/tests/block_test_util.go @@ -147,7 +147,7 @@ func (t *BlockTest) Run(snapshotter bool) error { } // Cross-check the snapshot-to-hash against the trie hash if snapshotter { - if err := snapshot.VerifyState(chain.Snapshot(), chain.CurrentBlock().Root()); err != nil { + if err := snapshot.VerifyState(chain.Snapshots(), chain.CurrentBlock().Root()); err != nil { return err } } diff --git a/tests/fuzzers/abi/abifuzzer.go b/tests/fuzzers/abi/abifuzzer.go index ed5c7c05868b..76d3c800f742 100644 --- a/tests/fuzzers/abi/abifuzzer.go +++ b/tests/fuzzers/abi/abifuzzer.go @@ -30,7 +30,7 @@ import ( func unpackPack(abi abi.ABI, method string, inputType []interface{}, input []byte) bool { outptr := reflect.New(reflect.TypeOf(inputType)) - if err := abi.Unpack(outptr.Interface(), method, input); err == nil { + if err := abi.UnpackIntoInterface(outptr.Interface(), method, input); err == nil { output, err := abi.Pack(method, input) if err != nil { // We have some false positives as we can unpack these type successfully, but not pack them @@ -51,7 +51,7 @@ func unpackPack(abi abi.ABI, method string, inputType []interface{}, input []byt func packUnpack(abi abi.ABI, method string, input []interface{}) bool { if packed, err := abi.Pack(method, input); err == nil { outptr := reflect.New(reflect.TypeOf(input)) - err := abi.Unpack(outptr.Interface(), method, packed) + err := abi.UnpackIntoInterface(outptr.Interface(), method, packed) if err != nil { panic(err) } diff --git a/tests/fuzzers/bls12381/bls_fuzzer.go b/tests/fuzzers/bls12381/bls_fuzzer.go new file mode 100644 index 000000000000..7e3f94c2aac4 --- /dev/null +++ b/tests/fuzzers/bls12381/bls_fuzzer.go @@ -0,0 +1,101 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package bls + +import ( + "bytes" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/vm" +) + +const ( + blsG1Add = byte(10) + blsG1Mul = byte(11) + blsG1MultiExp = byte(12) + blsG2Add = byte(13) + blsG2Mul = byte(14) + blsG2MultiExp = byte(15) + blsPairing = byte(16) + blsMapG1 = byte(17) + blsMapG2 = byte(18) +) + +func FuzzG1Add(data []byte) int { return fuzz(blsG1Add, data) } +func FuzzG1Mul(data []byte) int { return fuzz(blsG1Mul, data) } +func FuzzG1MultiExp(data []byte) int { return fuzz(blsG1MultiExp, data) } +func FuzzG2Add(data []byte) int { return fuzz(blsG2Add, data) } +func FuzzG2Mul(data []byte) int { return fuzz(blsG2Mul, data) } +func FuzzG2MultiExp(data []byte) int { return fuzz(blsG2MultiExp, data) } +func FuzzPairing(data []byte) int { return fuzz(blsPairing, data) } +func FuzzMapG1(data []byte) int { return fuzz(blsMapG1, data) } +func FuzzMapG2(data []byte) int { return fuzz(blsMapG2, data) } + +func checkInput(id byte, inputLen int) bool { + switch id { + case blsG1Add: + return inputLen == 256 + case blsG1Mul: + return inputLen == 160 + case blsG1MultiExp: + return inputLen%160 == 0 + case blsG2Add: + return inputLen == 512 + case blsG2Mul: + return inputLen == 288 + case blsG2MultiExp: + return inputLen%288 == 0 + case blsPairing: + return inputLen%384 == 0 + case blsMapG1: + return inputLen == 64 + case blsMapG2: + return inputLen == 128 + } + panic("programmer error") +} + +// The fuzzer functions must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise +// other values are reserved for future use. +func fuzz(id byte, data []byte) int { + // Even on bad input, it should not crash, so we still test the gas calc + precompile := vm.PrecompiledContractsYoloV2[common.BytesToAddress([]byte{id})] + gas := precompile.RequiredGas(data) + if !checkInput(id, len(data)) { + return 0 + } + // If the gas cost is too large (25M), bail out + if gas > 25*1000*1000 { + return 0 + } + cpy := make([]byte, len(data)) + copy(cpy, data) + _, err := precompile.Run(cpy) + if !bytes.Equal(cpy, data) { + panic(fmt.Sprintf("input data modified, precompile %d: %x %x", id, data, cpy)) + } + if err != nil { + return 0 + } + return 1 +} diff --git a/tests/fuzzers/bls12381/testdata/fuzz_g1_add_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_g1_add_seed_corpus.zip new file mode 100644 index 000000000000..16498c1cba89 Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_g1_add_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_g1_mul_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_g1_mul_seed_corpus.zip new file mode 100644 index 000000000000..57f9d6696d8c Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_g1_mul_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_g1_multiexp_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_g1_multiexp_seed_corpus.zip new file mode 100644 index 000000000000..7271f040f3b8 Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_g1_multiexp_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_g2_add_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_g2_add_seed_corpus.zip new file mode 100644 index 000000000000..cd5206ca0bcb Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_g2_add_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_g2_mul_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_g2_mul_seed_corpus.zip new file mode 100644 index 000000000000..f784a5a3d7a9 Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_g2_mul_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_g2_multiexp_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_g2_multiexp_seed_corpus.zip new file mode 100644 index 000000000000..c205117a4680 Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_g2_multiexp_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_map_g1_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_map_g1_seed_corpus.zip new file mode 100644 index 000000000000..70382fbe53db Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_map_g1_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_map_g2_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_map_g2_seed_corpus.zip new file mode 100644 index 000000000000..67adc5b5e8d6 Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_map_g2_seed_corpus.zip differ diff --git a/tests/fuzzers/bls12381/testdata/fuzz_pairing_seed_corpus.zip b/tests/fuzzers/bls12381/testdata/fuzz_pairing_seed_corpus.zip new file mode 100644 index 000000000000..e24d2b0a52fe Binary files /dev/null and b/tests/fuzzers/bls12381/testdata/fuzz_pairing_seed_corpus.zip differ diff --git a/tests/fuzzers/difficulty/debug/main.go b/tests/fuzzers/difficulty/debug/main.go new file mode 100644 index 000000000000..23516b3a0dfa --- /dev/null +++ b/tests/fuzzers/difficulty/debug/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/ethereum/go-ethereum/tests/fuzzers/difficulty" +) + +func main() { + if len(os.Args) != 2 { + fmt.Fprintf(os.Stderr, "Usage: debug ") + os.Exit(1) + } + crasher := os.Args[1] + data, err := ioutil.ReadFile(crasher) + if err != nil { + fmt.Fprintf(os.Stderr, "error loading crasher %v: %v", crasher, err) + os.Exit(1) + } + difficulty.Fuzz(data) +} diff --git a/tests/fuzzers/difficulty/difficulty-fuzz.go b/tests/fuzzers/difficulty/difficulty-fuzz.go new file mode 100644 index 000000000000..e4c5dcf57cbc --- /dev/null +++ b/tests/fuzzers/difficulty/difficulty-fuzz.go @@ -0,0 +1,145 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package difficulty + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/big" + + "github.com/ethereum/go-ethereum/consensus/ethash" + "github.com/ethereum/go-ethereum/core/types" +) + +type fuzzer struct { + input io.Reader + exhausted bool + debugging bool +} + +func (f *fuzzer) read(size int) []byte { + out := make([]byte, size) + if _, err := f.input.Read(out); err != nil { + f.exhausted = true + } + return out +} + +func (f *fuzzer) readSlice(min, max int) []byte { + var a uint16 + binary.Read(f.input, binary.LittleEndian, &a) + size := min + int(a)%(max-min) + out := make([]byte, size) + if _, err := f.input.Read(out); err != nil { + f.exhausted = true + } + return out +} + +func (f *fuzzer) readUint64(min, max uint64) uint64 { + if min == max { + return min + } + var a uint64 + if err := binary.Read(f.input, binary.LittleEndian, &a); err != nil { + f.exhausted = true + } + a = min + a%(max-min) + return a +} +func (f *fuzzer) readBool() bool { + return f.read(1)[0]&0x1 == 0 +} + +// The function must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise +// other values are reserved for future use. +func Fuzz(data []byte) int { + f := fuzzer{ + input: bytes.NewReader(data), + exhausted: false, + } + return f.fuzz() +} + +var minDifficulty = big.NewInt(0x2000) + +type calculator func(time uint64, parent *types.Header) *big.Int + +func (f *fuzzer) fuzz() int { + // A parent header + header := &types.Header{} + if f.readBool() { + header.UncleHash = types.EmptyUncleHash + } + // Difficulty can range between 0x2000 (2 bytes) and up to 32 bytes + { + diff := new(big.Int).SetBytes(f.readSlice(2, 32)) + if diff.Cmp(minDifficulty) < 0 { + diff.Set(minDifficulty) + } + header.Difficulty = diff + } + // Number can range between 0 and up to 32 bytes (but not so that the child exceeds it) + { + // However, if we use astronomic numbers, then the bomb exp karatsuba calculation + // in the legacy methods) + // times out, so we limit it to fit within reasonable bounds + number := new(big.Int).SetBytes(f.readSlice(0, 4)) // 4 bytes: 32 bits: block num max 4 billion + header.Number = number + } + // Both parent and child time must fit within uint64 + var time uint64 + { + childTime := f.readUint64(1, 0xFFFFFFFFFFFFFFFF) + //fmt.Printf("childTime: %x\n",childTime) + delta := f.readUint64(1, childTime) + //fmt.Printf("delta: %v\n", delta) + pTime := childTime - delta + header.Time = pTime + time = childTime + } + // Bomb delay will never exceed uint64 + bombDelay := new(big.Int).SetUint64(f.readUint64(1, 0xFFFFFFFFFFFFFFFe)) + + if f.exhausted { + return 0 + } + + for i, pair := range []struct { + bigFn calculator + u256Fn calculator + }{ + {ethash.FrontierDifficultyCalulator, ethash.CalcDifficultyFrontierU256}, + {ethash.HomesteadDifficultyCalulator, ethash.CalcDifficultyHomesteadU256}, + {ethash.DynamicDifficultyCalculator(bombDelay), ethash.MakeDifficultyCalculatorU256(bombDelay)}, + } { + want := pair.bigFn(time, header) + have := pair.u256Fn(time, header) + if want.Cmp(have) != 0 { + panic(fmt.Sprintf("pair %d: want %x have %x\nparent.Number: %x\np.Time: %x\nc.Time: %x\nBombdelay: %v\n", i, want, have, + header.Number, header.Time, time, bombDelay)) + } + } + return 1 +} diff --git a/tests/fuzzers/keystore/keystore-fuzzer.go b/tests/fuzzers/keystore/keystore-fuzzer.go index 704f29dc482e..e3bcae92e159 100644 --- a/tests/fuzzers/keystore/keystore-fuzzer.go +++ b/tests/fuzzers/keystore/keystore-fuzzer.go @@ -33,5 +33,5 @@ func Fuzz(input []byte) int { panic(err) } os.Remove(a.URL.Path) - return 0 + return 1 } diff --git a/tests/fuzzers/rangeproof/corpus/1c14030f26872e57bf1481084f151d71eed8161c-1 b/tests/fuzzers/rangeproof/corpus/1c14030f26872e57bf1481084f151d71eed8161c-1 new file mode 100644 index 000000000000..31c08bafaff6 Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/1c14030f26872e57bf1481084f151d71eed8161c-1 differ diff --git a/tests/fuzzers/rangeproof/corpus/27e54254422543060a13ea8a4bc913d768e4adb6-2 b/tests/fuzzers/rangeproof/corpus/27e54254422543060a13ea8a4bc913d768e4adb6-2 new file mode 100644 index 000000000000..7bce13ef80e5 Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/27e54254422543060a13ea8a4bc913d768e4adb6-2 differ diff --git a/tests/fuzzers/rangeproof/corpus/6bfc2cbe2d7a43361e240118439785445a0fdfb7-5 b/tests/fuzzers/rangeproof/corpus/6bfc2cbe2d7a43361e240118439785445a0fdfb7-5 new file mode 100644 index 000000000000..613e76a020f5 Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/6bfc2cbe2d7a43361e240118439785445a0fdfb7-5 differ diff --git a/tests/fuzzers/rangeproof/corpus/a67e63bc0c0004bd009944a6061297cb7d4ac238-1 b/tests/fuzzers/rangeproof/corpus/a67e63bc0c0004bd009944a6061297cb7d4ac238-1 new file mode 100644 index 000000000000..805ad8df77be Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/a67e63bc0c0004bd009944a6061297cb7d4ac238-1 differ diff --git a/tests/fuzzers/rangeproof/corpus/ae892bbae0a843950bc8316496e595b1a194c009-4 b/tests/fuzzers/rangeproof/corpus/ae892bbae0a843950bc8316496e595b1a194c009-4 new file mode 100644 index 000000000000..605acf81c1a9 Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/ae892bbae0a843950bc8316496e595b1a194c009-4 differ diff --git a/tests/fuzzers/rangeproof/corpus/ee05d0d813f6261b3dba16506f9ea03d9c5e993d-2 b/tests/fuzzers/rangeproof/corpus/ee05d0d813f6261b3dba16506f9ea03d9c5e993d-2 new file mode 100644 index 000000000000..8f32dd775ada Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/ee05d0d813f6261b3dba16506f9ea03d9c5e993d-2 differ diff --git a/tests/fuzzers/rangeproof/corpus/f50a6d57a46d30184aa294af5b252ab9701af7c9-2 b/tests/fuzzers/rangeproof/corpus/f50a6d57a46d30184aa294af5b252ab9701af7c9-2 new file mode 100644 index 000000000000..af96210f204e Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/f50a6d57a46d30184aa294af5b252ab9701af7c9-2 differ diff --git a/tests/fuzzers/rangeproof/corpus/random.dat b/tests/fuzzers/rangeproof/corpus/random.dat new file mode 100644 index 000000000000..2c998ad81297 Binary files /dev/null and b/tests/fuzzers/rangeproof/corpus/random.dat differ diff --git a/tests/fuzzers/rangeproof/debug/main.go b/tests/fuzzers/rangeproof/debug/main.go new file mode 100644 index 000000000000..a81c69fea532 --- /dev/null +++ b/tests/fuzzers/rangeproof/debug/main.go @@ -0,0 +1,41 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/ethereum/go-ethereum/tests/fuzzers/rangeproof" +) + +func main() { + if len(os.Args) != 2 { + fmt.Fprintf(os.Stderr, "Usage: debug \n") + fmt.Fprintf(os.Stderr, "Example\n") + fmt.Fprintf(os.Stderr, " $ debug ../crashers/4bbef6857c733a87ecf6fd8b9e7238f65eb9862a\n") + os.Exit(1) + } + crasher := os.Args[1] + data, err := ioutil.ReadFile(crasher) + if err != nil { + fmt.Fprintf(os.Stderr, "error loading crasher %v: %v", crasher, err) + os.Exit(1) + } + rangeproof.Fuzz(data) +} diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go new file mode 100644 index 000000000000..b82a38072351 --- /dev/null +++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go @@ -0,0 +1,218 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rangeproof + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "sort" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/ethereum/go-ethereum/trie" +) + +type kv struct { + k, v []byte + t bool +} + +type entrySlice []*kv + +func (p entrySlice) Len() int { return len(p) } +func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } +func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +type fuzzer struct { + input io.Reader + exhausted bool +} + +func (f *fuzzer) randBytes(n int) []byte { + r := make([]byte, n) + if _, err := f.input.Read(r); err != nil { + f.exhausted = true + } + return r +} + +func (f *fuzzer) readInt() uint64 { + var x uint64 + if err := binary.Read(f.input, binary.LittleEndian, &x); err != nil { + f.exhausted = true + } + return x +} + +func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) { + + trie := new(trie.Trie) + vals := make(map[string]*kv) + size := f.readInt() + // Fill it with some fluff + for i := byte(0); i < byte(size); i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + if f.exhausted { + return nil, nil + } + // And now fill with some random + for i := 0; i < n; i++ { + k := f.randBytes(32) + v := f.randBytes(20) + value := &kv{k, v, false} + trie.Update(k, v) + vals[string(k)] = value + if f.exhausted { + return nil, nil + } + } + return trie, vals +} + +func (f *fuzzer) fuzz() int { + maxSize := 200 + tr, vals := f.randomTrie(1 + int(f.readInt())%maxSize) + if f.exhausted { + return 0 // input too short + } + var entries entrySlice + for _, kv := range vals { + entries = append(entries, kv) + } + if len(entries) <= 1 { + return 0 + } + sort.Sort(entries) + + var ok = 0 + for { + start := int(f.readInt() % uint64(len(entries))) + end := 1 + int(f.readInt()%uint64(len(entries)-1)) + testcase := int(f.readInt() % uint64(6)) + index := int(f.readInt() & 0xFFFFFFFF) + index2 := int(f.readInt() & 0xFFFFFFFF) + if f.exhausted { + break + } + proof := memorydb.New() + if err := tr.Prove(entries[start].k, 0, proof); err != nil { + panic(fmt.Sprintf("Failed to prove the first node %v", err)) + } + if err := tr.Prove(entries[end-1].k, 0, proof); err != nil { + panic(fmt.Sprintf("Failed to prove the last node %v", err)) + } + var keys [][]byte + var vals [][]byte + for i := start; i < end; i++ { + keys = append(keys, entries[i].k) + vals = append(vals, entries[i].v) + } + if len(keys) == 0 { + return 0 + } + var first, last = keys[0], keys[len(keys)-1] + testcase %= 6 + switch testcase { + case 0: + // Modified key + keys[index%len(keys)] = f.randBytes(32) // In theory it can't be same + case 1: + // Modified val + vals[index%len(vals)] = f.randBytes(20) // In theory it can't be same + case 2: + // Gapped entry slice + index = index % len(keys) + keys = append(keys[:index], keys[index+1:]...) + vals = append(vals[:index], vals[index+1:]...) + case 3: + // Out of order + index1 := index % len(keys) + index2 := index2 % len(keys) + keys[index1], keys[index2] = keys[index2], keys[index1] + vals[index1], vals[index2] = vals[index2], vals[index1] + case 4: + // Set random key to nil, do nothing + keys[index%len(keys)] = nil + case 5: + // Set random value to nil, deletion + vals[index%len(vals)] = nil + + // Other cases: + // Modify something in the proof db + // add stuff to proof db + // drop stuff from proof db + + } + if f.exhausted { + break + } + ok = 1 + //nodes, subtrie + nodes, subtrie, notary, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof) + if err != nil { + if nodes != nil { + panic("err != nil && nodes != nil") + } + if subtrie != nil { + panic("err != nil && subtrie != nil") + } + if notary != nil { + panic("err != nil && notary != nil") + } + if hasMore { + panic("err != nil && hasMore == true") + } + } else { + if nodes == nil { + panic("err == nil && nodes == nil") + } + if subtrie == nil { + panic("err == nil && subtrie == nil") + } + if notary == nil { + panic("err == nil && subtrie == nil") + } + } + } + return ok +} + +// The function must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise; other values are reserved for future use. +func Fuzz(input []byte) int { + if len(input) < 100 { + return 0 + } + r := bytes.NewReader(input) + f := fuzzer{ + input: r, + exhausted: false, + } + return f.fuzz() +} diff --git a/tests/fuzzers/rlp/rlp_fuzzer.go b/tests/fuzzers/rlp/rlp_fuzzer.go index 534540476cb4..18b36287b53c 100644 --- a/tests/fuzzers/rlp/rlp_fuzzer.go +++ b/tests/fuzzers/rlp/rlp_fuzzer.go @@ -37,17 +37,17 @@ func decodeEncode(input []byte, val interface{}, i int) { } func Fuzz(input []byte) int { + if len(input) == 0 { + return 0 + } + var i int { - if len(input) > 0 { - rlp.Split(input) - } + rlp.Split(input) } { - if len(input) > 0 { - if elems, _, err := rlp.SplitList(input); err == nil { - rlp.CountValues(elems) - } + if elems, _, err := rlp.SplitList(input); err == nil { + rlp.CountValues(elems) } } @@ -123,5 +123,5 @@ func Fuzz(input []byte) int { var rs types.Receipts decodeEncode(input, &rs, i) } - return 0 + return 1 } diff --git a/tests/fuzzers/stacktrie/debug/main.go b/tests/fuzzers/stacktrie/debug/main.go new file mode 100644 index 000000000000..1ec28a8ef155 --- /dev/null +++ b/tests/fuzzers/stacktrie/debug/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + + "github.com/ethereum/go-ethereum/tests/fuzzers/stacktrie" +) + +func main() { + if len(os.Args) != 2 { + fmt.Fprintf(os.Stderr, "Usage: debug ") + os.Exit(1) + } + crasher := os.Args[1] + data, err := ioutil.ReadFile(crasher) + if err != nil { + fmt.Fprintf(os.Stderr, "error loading crasher %v: %v", crasher, err) + os.Exit(1) + } + stacktrie.Debug(data) +} diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go new file mode 100644 index 000000000000..5cea7769c284 --- /dev/null +++ b/tests/fuzzers/stacktrie/trie_fuzzer.go @@ -0,0 +1,204 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package stacktrie + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "sort" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/trie" + "golang.org/x/crypto/sha3" +) + +type fuzzer struct { + input io.Reader + exhausted bool + debugging bool +} + +func (f *fuzzer) read(size int) []byte { + out := make([]byte, size) + if _, err := f.input.Read(out); err != nil { + f.exhausted = true + } + return out +} + +func (f *fuzzer) readSlice(min, max int) []byte { + var a uint16 + binary.Read(f.input, binary.LittleEndian, &a) + size := min + int(a)%(max-min) + out := make([]byte, size) + if _, err := f.input.Read(out); err != nil { + f.exhausted = true + } + return out +} + +// spongeDb is a dummy db backend which accumulates writes in a sponge +type spongeDb struct { + sponge hash.Hash + debug bool +} + +func (s *spongeDb) Has(key []byte) (bool, error) { panic("implement me") } +func (s *spongeDb) Get(key []byte) ([]byte, error) { return nil, errors.New("no such elem") } +func (s *spongeDb) Delete(key []byte) error { panic("implement me") } +func (s *spongeDb) NewBatch() ethdb.Batch { return &spongeBatch{s} } +func (s *spongeDb) Stat(property string) (string, error) { panic("implement me") } +func (s *spongeDb) Compact(start []byte, limit []byte) error { panic("implement me") } +func (s *spongeDb) Close() error { return nil } + +func (s *spongeDb) Put(key []byte, value []byte) error { + if s.debug { + fmt.Printf("db.Put %x : %x\n", key, value) + } + s.sponge.Write(key) + s.sponge.Write(value) + return nil +} +func (s *spongeDb) NewIterator(prefix []byte, start []byte) ethdb.Iterator { panic("implement me") } + +// spongeBatch is a dummy batch which immediately writes to the underlying spongedb +type spongeBatch struct { + db *spongeDb +} + +func (b *spongeBatch) Put(key, value []byte) error { + b.db.Put(key, value) + return nil +} +func (b *spongeBatch) Delete(key []byte) error { panic("implement me") } +func (b *spongeBatch) ValueSize() int { return 100 } +func (b *spongeBatch) Write() error { return nil } +func (b *spongeBatch) Reset() {} +func (b *spongeBatch) Replay(w ethdb.KeyValueWriter) error { return nil } + +type kv struct { + k, v []byte +} +type kvs []kv + +func (k kvs) Len() int { + return len(k) +} + +func (k kvs) Less(i, j int) bool { + return bytes.Compare(k[i].k, k[j].k) < 0 +} + +func (k kvs) Swap(i, j int) { + k[j], k[i] = k[i], k[j] +} + +// The function must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise +// other values are reserved for future use. +func Fuzz(data []byte) int { + f := fuzzer{ + input: bytes.NewReader(data), + exhausted: false, + } + return f.fuzz() +} + +func Debug(data []byte) int { + f := fuzzer{ + input: bytes.NewReader(data), + exhausted: false, + debugging: true, + } + return f.fuzz() +} + +func (f *fuzzer) fuzz() int { + + // This spongeDb is used to check the sequence of disk-db-writes + var ( + spongeA = &spongeDb{sponge: sha3.NewLegacyKeccak256()} + dbA = trie.NewDatabase(spongeA) + trieA, _ = trie.New(common.Hash{}, dbA) + spongeB = &spongeDb{sponge: sha3.NewLegacyKeccak256()} + trieB = trie.NewStackTrie(spongeB) + vals kvs + useful bool + maxElements = 10000 + // operate on unique keys only + keys = make(map[string]struct{}) + ) + // Fill the trie with elements + for i := 0; !f.exhausted && i < maxElements; i++ { + k := f.read(32) + v := f.readSlice(1, 500) + if f.exhausted { + // If it was exhausted while reading, the value may be all zeroes, + // thus 'deletion' which is not supported on stacktrie + break + } + if _, present := keys[string(k)]; present { + // This key is a duplicate, ignore it + continue + } + keys[string(k)] = struct{}{} + vals = append(vals, kv{k: k, v: v}) + trieA.Update(k, v) + useful = true + } + if !useful { + return 0 + } + // Flush trie -> database + rootA, err := trieA.Commit(nil) + if err != nil { + panic(err) + } + // Flush memdb -> disk (sponge) + dbA.Commit(rootA, false, nil) + + // Stacktrie requires sorted insertion + sort.Sort(vals) + for _, kv := range vals { + if f.debugging { + fmt.Printf("{\"0x%x\" , \"0x%x\"} // stacktrie.Update\n", kv.k, kv.v) + } + trieB.Update(kv.k, kv.v) + } + rootB := trieB.Hash() + if _, err := trieB.Commit(); err != nil { + panic(err) + } + if rootA != rootB { + panic(fmt.Sprintf("roots differ: (trie) %x != %x (stacktrie)", rootA, rootB)) + } + sumA := spongeA.sponge.Sum(nil) + sumB := spongeB.sponge.Sum(nil) + if !bytes.Equal(sumA, sumB) { + panic(fmt.Sprintf("sequence differ: (trie) %x != %x (stacktrie)", sumA, sumB)) + } + return 1 +} diff --git a/tests/fuzzers/trie/trie-fuzzer.go b/tests/fuzzers/trie/trie-fuzzer.go index 98188380535f..762ab5f3470a 100644 --- a/tests/fuzzers/trie/trie-fuzzer.go +++ b/tests/fuzzers/trie/trie-fuzzer.go @@ -122,15 +122,22 @@ func Generate(input []byte) randTest { return steps } +// The function must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise +// other values are reserved for future use. func Fuzz(input []byte) int { program := Generate(input) if len(program) == 0 { - return -1 + return 0 } if err := runRandTest(program); err != nil { panic(err) } - return 0 + return 1 } func runRandTest(rt randTest) error { diff --git a/tests/fuzzers/txfetcher/txfetcher_fuzzer.go b/tests/fuzzers/txfetcher/txfetcher_fuzzer.go index 10c7eb942496..d1d6fdc66592 100644 --- a/tests/fuzzers/txfetcher/txfetcher_fuzzer.go +++ b/tests/fuzzers/txfetcher/txfetcher_fuzzer.go @@ -51,8 +51,9 @@ func init() { func Fuzz(input []byte) int { // Don't generate insanely large test cases, not much value in them if len(input) > 16*1024 { - return -1 + return 0 } + verbose := false r := bytes.NewReader(input) // Reduce the problem space for certain fuzz runs. Small tx space is better @@ -124,7 +125,9 @@ func Fuzz(input []byte) int { announceIdxs[i] = (int(annBuf[0])*256 + int(annBuf[1])) % len(txs) announces[i] = txs[announceIdxs[i]].Hash() } - fmt.Println("Notify", peer, announceIdxs) + if verbose { + fmt.Println("Notify", peer, announceIdxs) + } if err := f.Notify(peer, announces); err != nil { panic(err) } @@ -163,8 +166,9 @@ func Fuzz(input []byte) int { return 0 } direct := (directFlag % 2) == 0 - - fmt.Println("Enqueue", peer, deliverIdxs, direct) + if verbose { + fmt.Println("Enqueue", peer, deliverIdxs, direct) + } if err := f.Enqueue(peer, deliveries, direct); err != nil { panic(err) } @@ -177,8 +181,9 @@ func Fuzz(input []byte) int { return 0 } peer := peers[int(peerIdx)%len(peers)] - - fmt.Println("Drop", peer) + if verbose { + fmt.Println("Drop", peer) + } if err := f.Drop(peer); err != nil { panic(err) } @@ -191,8 +196,9 @@ func Fuzz(input []byte) int { return 0 } tick := time.Duration(tickCnt) * 100 * time.Millisecond - - fmt.Println("Sleep", tick) + if verbose { + fmt.Println("Sleep", tick) + } clock.Run(tick) } } diff --git a/tests/init.go b/tests/init.go index d920c70e2e62..607c69ddb3b5 100644 --- a/tests/init.go +++ b/tests/init.go @@ -141,7 +141,7 @@ var Forks = map[string]*params.ChainConfig{ PetersburgBlock: big.NewInt(0), IstanbulBlock: big.NewInt(5), }, - "YOLOv1": { + "YOLOv2": { ChainID: big.NewInt(1), HomesteadBlock: big.NewInt(0), EIP150Block: big.NewInt(0), @@ -151,9 +151,9 @@ var Forks = map[string]*params.ChainConfig{ ConstantinopleBlock: big.NewInt(0), PetersburgBlock: big.NewInt(0), IstanbulBlock: big.NewInt(0), - YoloV1Block: big.NewInt(0), + YoloV2Block: big.NewInt(0), }, - // This specification is subject to change, but is for now identical to YOLOv1 + // This specification is subject to change, but is for now identical to YOLOv2 // for cross-client testing purposes "Berlin": { ChainID: big.NewInt(1), @@ -165,7 +165,7 @@ var Forks = map[string]*params.ChainConfig{ ConstantinopleBlock: big.NewInt(0), PetersburgBlock: big.NewInt(0), IstanbulBlock: big.NewInt(0), - YoloV1Block: big.NewInt(0), + YoloV2Block: big.NewInt(0), }, } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index a999cba4715c..96c731696940 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -182,10 +182,21 @@ func (t *StateTest) RunNoVerify(subtest StateSubtest, vmconfig vm.Config, snapsh if err != nil { return nil, nil, common.Hash{}, err } - context := core.NewEVMContext(msg, block.Header(), nil, &t.json.Env.Coinbase) + txContext := core.NewEVMTxContext(msg) + context := core.NewEVMBlockContext(block.Header(), nil, &t.json.Env.Coinbase) context.GetHash = vmTestBlockHash - evm := vm.NewEVM(context, statedb, config, vmconfig) + evm := vm.NewEVM(context, txContext, statedb, config, vmconfig) + if config.IsYoloV2(context.BlockNumber) { + statedb.AddAddressToAccessList(msg.From()) + if dst := msg.To(); dst != nil { + statedb.AddAddressToAccessList(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range evm.ActivePrecompiles() { + statedb.AddAddressToAccessList(addr) + } + } gaspool := new(core.GasPool) gaspool.AddGas(block.GasLimit()) snapshot := statedb.Snapshot() @@ -225,7 +236,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc, snapshotter boo var snaps *snapshot.Tree if snapshotter { - snaps = snapshot.New(db, sdb.TrieDB(), 1, root, false) + snaps = snapshot.New(db, sdb.TrieDB(), 1, root, false, false) } statedb, _ = state.New(root, sdb, snaps) return snaps, statedb diff --git a/tests/vm_test_util.go b/tests/vm_test_util.go index ad124b7b2534..418cc6716864 100644 --- a/tests/vm_test_util.go +++ b/tests/vm_test_util.go @@ -138,20 +138,22 @@ func (t *VMTest) newEVM(statedb *state.StateDB, vmconfig vm.Config) *vm.EVM { return core.CanTransfer(db, address, amount) } transfer := func(db vm.StateDB, sender, recipient common.Address, amount *big.Int) {} - context := vm.Context{ + txContext := vm.TxContext{ + Origin: t.json.Exec.Origin, + GasPrice: t.json.Exec.GasPrice, + } + context := vm.BlockContext{ CanTransfer: canTransfer, Transfer: transfer, GetHash: vmTestBlockHash, - Origin: t.json.Exec.Origin, Coinbase: t.json.Env.Coinbase, BlockNumber: new(big.Int).SetUint64(t.json.Env.Number), Time: new(big.Int).SetUint64(t.json.Env.Timestamp), GasLimit: t.json.Env.GasLimit, Difficulty: t.json.Env.Difficulty, - GasPrice: t.json.Exec.GasPrice, } vmconfig.NoRecursion = true - return vm.NewEVM(context, statedb, params.MainnetChainConfig, vmconfig) + return vm.NewEVM(context, txContext, statedb, params.MainnetChainConfig, vmconfig) } func vmTestBlockHash(n uint64) common.Hash { diff --git a/trie/committer.go b/trie/committer.go index fc8b7ceda52c..33fd9e982339 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" ) @@ -33,10 +32,9 @@ const leafChanSize = 200 // leaf represents a trie leaf value type leaf struct { - size int // size of the rlp data (estimate) - hash common.Hash // hash of rlp data - node node // the node to commit - vnodes bool // set to true if the node (possibly) contains a valueNode + size int // size of the rlp data (estimate) + hash common.Hash // hash of rlp data + node node // the node to commit } // committer is a type used for the trie Commit operation. A committer has some @@ -74,18 +72,12 @@ func returnCommitterToPool(h *committer) { committerPool.Put(h) } -// commitNeeded returns 'false' if the given node is already in sync with db -func (c *committer) commitNeeded(n node) bool { - hash, dirty := n.cache() - return hash == nil || dirty -} - // commit collapses a node down into a hash node and inserts it into the database func (c *committer) Commit(n node, db *Database) (hashNode, error) { if db == nil { return nil, errors.New("no db provided") } - h, err := c.commit(n, db, true) + h, err := c.commit(n, db) if err != nil { return nil, err } @@ -93,7 +85,7 @@ func (c *committer) Commit(n node, db *Database) (hashNode, error) { } // commit collapses a node down into a hash node and inserts it into the database -func (c *committer) commit(n node, db *Database, force bool) (node, error) { +func (c *committer) commit(n node, db *Database) (node, error) { // if this path is clean, use available cached data hash, dirty := n.cache() if hash != nil && !dirty { @@ -104,8 +96,11 @@ func (c *committer) commit(n node, db *Database, force bool) (node, error) { case *shortNode: // Commit child collapsed := cn.copy() - if _, ok := cn.Val.(valueNode); !ok { - childV, err := c.commit(cn.Val, db, false) + + // If the child is fullnode, recursively commit. + // Otherwise it can only be hashNode or valueNode. + if _, ok := cn.Val.(*fullNode); ok { + childV, err := c.commit(cn.Val, db) if err != nil { return nil, err } @@ -113,78 +108,78 @@ func (c *committer) commit(n node, db *Database, force bool) (node, error) { } // The key needs to be copied, since we're delivering it to database collapsed.Key = hexToCompact(cn.Key) - hashedNode := c.store(collapsed, db, force, true) + hashedNode := c.store(collapsed, db) if hn, ok := hashedNode.(hashNode); ok { return hn, nil } return collapsed, nil case *fullNode: - hashedKids, hasVnodes, err := c.commitChildren(cn, db, force) + hashedKids, err := c.commitChildren(cn, db) if err != nil { return nil, err } collapsed := cn.copy() collapsed.Children = hashedKids - hashedNode := c.store(collapsed, db, force, hasVnodes) + hashedNode := c.store(collapsed, db) if hn, ok := hashedNode.(hashNode); ok { return hn, nil } return collapsed, nil - case valueNode: - return c.store(cn, db, force, false), nil - // hashnodes aren't stored case hashNode: return cn, nil + default: + // nil, valuenode shouldn't be committed + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) } - return hash, nil } // commitChildren commits the children of the given fullnode -func (c *committer) commitChildren(n *fullNode, db *Database, force bool) ([17]node, bool, error) { +func (c *committer) commitChildren(n *fullNode, db *Database) ([17]node, error) { var children [17]node - var hasValueNodeChildren = false - for i, child := range n.Children { + for i := 0; i < 16; i++ { + child := n.Children[i] if child == nil { continue } - hnode, err := c.commit(child, db, false) - if err != nil { - return children, false, err + // If it's the hashed child, save the hash value directly. + // Note: it's impossible that the child in range [0, 15] + // is a valuenode. + if hn, ok := child.(hashNode); ok { + children[i] = hn + continue } - children[i] = hnode - if _, ok := hnode.(valueNode); ok { - hasValueNodeChildren = true + // Commit the child recursively and store the "hashed" value. + // Note the returned node can be some embedded nodes, so it's + // possible the type is not hashnode. + hashed, err := c.commit(child, db) + if err != nil { + return children, err } + children[i] = hashed + } + // For the 17th child, it's possible the type is valuenode. + if n.Children[16] != nil { + children[16] = n.Children[16] } - return children, hasValueNodeChildren, nil + return children, nil } // store hashes the node n and if we have a storage layer specified, it writes // the key/value pair to it and tracks any node->child references as well as any // node->external trie references. -func (c *committer) store(n node, db *Database, force bool, hasVnodeChildren bool) node { +func (c *committer) store(n node, db *Database) node { // Larger nodes are replaced by their hash and stored in the database. var ( hash, _ = n.cache() size int ) if hash == nil { - if vn, ok := n.(valueNode); ok { - c.tmp.Reset() - if err := rlp.Encode(&c.tmp, vn); err != nil { - panic("encode error: " + err.Error()) - } - size = len(c.tmp) - if size < 32 && !force { - return n // Nodes smaller than 32 bytes are stored inside their parent - } - hash = c.makeHashNode(c.tmp) - } else { - // This was not generated - must be a small node stored in the parent - // No need to do anything here - return n - } + // This was not generated - must be a small node stored in the parent. + // In theory we should apply the leafCall here if it's not nil(embedded + // node usually contains value). But small value(less than 32bytes) is + // not our target. + return n } else { // We have the hash already, estimate the RLP encoding-size of the node. // The size is used for mem tracking, does not need to be exact @@ -194,10 +189,9 @@ func (c *committer) store(n node, db *Database, force bool, hasVnodeChildren boo // The leaf channel will be active only when there an active leaf-callback if c.leafCh != nil { c.leafCh <- &leaf{ - size: size, - hash: common.BytesToHash(hash), - node: n, - vnodes: hasVnodeChildren, + size: size, + hash: common.BytesToHash(hash), + node: n, } } else if db != nil { // No leaf-callback used, but there's still a database. Do serial @@ -209,30 +203,30 @@ func (c *committer) store(n node, db *Database, force bool, hasVnodeChildren boo return hash } -// commitLoop does the actual insert + leaf callback for nodes +// commitLoop does the actual insert + leaf callback for nodes. func (c *committer) commitLoop(db *Database) { for item := range c.leafCh { var ( - hash = item.hash - size = item.size - n = item.node - hasVnodes = item.vnodes + hash = item.hash + size = item.size + n = item.node ) // We are pooling the trie nodes into an intermediate memory cache db.lock.Lock() db.insert(hash, size, n) db.lock.Unlock() - if c.onleaf != nil && hasVnodes { + + if c.onleaf != nil { switch n := n.(type) { case *shortNode: if child, ok := n.Val.(valueNode); ok { c.onleaf(nil, child, hash) } case *fullNode: - for i := 0; i < 16; i++ { - if child, ok := n.Children[i].(valueNode); ok { - c.onleaf(nil, child, hash) - } + // For children in range [0, 15], it's impossible + // to contain valuenode. Only check the 17th child. + if n.Children[16] != nil { + c.onleaf(nil, n.Children[16].(valueNode), hash) } } } @@ -273,6 +267,5 @@ func estimateSize(n node) int { return 1 + len(n) default: panic(fmt.Sprintf("node type %T", n)) - } } diff --git a/trie/database.go b/trie/database.go index fa8906b7a3d0..7c2f2070979c 100644 --- a/trie/database.go +++ b/trie/database.go @@ -99,6 +99,11 @@ type rawNode []byte func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawNode) EncodeRLP(w io.Writer) error { + _, err := w.Write(n) + return err +} + // rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. @@ -199,7 +204,7 @@ func forGatherChildren(n node, onChild func(hash common.Hash)) { } case hashNode: onChild(common.BytesToHash(n)) - case valueNode, nil: + case valueNode, nil, rawNode: default: panic(fmt.Sprintf("unknown node type: %T", n)) } @@ -267,33 +272,43 @@ func expandNode(hash hashNode, n node) node { } } +// Config defines all necessary options for database. +type Config struct { + Cache int // Memory allowance (MB) to use for caching trie nodes in memory + Journal string // Journal of clean cache to survive node restarts + Preimages bool // Flag whether the preimage of trie key is recorded +} + // NewDatabase creates a new trie database to store ephemeral trie content before // its written out to disk or garbage collected. No read cache is created, so all // data retrievals will hit the underlying disk database. func NewDatabase(diskdb ethdb.KeyValueStore) *Database { - return NewDatabaseWithCache(diskdb, 0, "") + return NewDatabaseWithConfig(diskdb, nil) } -// NewDatabaseWithCache creates a new trie database to store ephemeral trie content +// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // before its written out to disk or garbage collected. It also acts as a read cache // for nodes loaded from disk. -func NewDatabaseWithCache(diskdb ethdb.KeyValueStore, cache int, journal string) *Database { +func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { var cleans *fastcache.Cache - if cache > 0 { - if journal == "" { - cleans = fastcache.New(cache * 1024 * 1024) + if config != nil && config.Cache > 0 { + if config.Journal == "" { + cleans = fastcache.New(config.Cache * 1024 * 1024) } else { - cleans = fastcache.LoadFromFileOrNew(journal, cache*1024*1024) + cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024) } } - return &Database{ + db := &Database{ diskdb: diskdb, cleans: cleans, dirties: map[common.Hash]*cachedNode{{}: { children: make(map[common.Hash]uint16), }}, - preimages: make(map[common.Hash][]byte), } + if config == nil || config.Preimages { // TODO(karalabe): Flip to default off in the future + db.preimages = make(map[common.Hash][]byte) + } + return db } // DiskDB retrieves the persistent storage backing the trie database. @@ -340,6 +355,11 @@ func (db *Database) insert(hash common.Hash, size int, node node) { // // Note, this method assumes that the database's lock is held! func (db *Database) insertPreimage(hash common.Hash, preimage []byte) { + // Short circuit if preimage collection is disabled + if db.preimages == nil { + return + } + // Track the preimage if a yet unknown one if _, ok := db.preimages[hash]; ok { return } @@ -426,6 +446,10 @@ func (db *Database) Node(hash common.Hash) ([]byte, error) { // preimage retrieves a cached trie node pre-image from memory. If it cannot be // found cached, the method queries the persistent database for the content. func (db *Database) preimage(hash common.Hash) []byte { + // Short circuit if preimage collection is disabled + if db.preimages == nil { + return nil + } // Retrieve the node from cache if available db.lock.RLock() preimage := db.preimages[hash] @@ -583,12 +607,16 @@ func (db *Database) Cap(limit common.StorageSize) error { // leave for later to deduplicate writes. flushPreimages := db.preimagesSize > 4*1024*1024 if flushPreimages { - rawdb.WritePreimages(batch, db.preimages) - if batch.ValueSize() > ethdb.IdealBatchSize { - if err := batch.Write(); err != nil { - return err + if db.preimages == nil { + log.Error("Attempted to write preimages whilst disabled") + } else { + rawdb.WritePreimages(batch, db.preimages) + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() } - batch.Reset() } } // Keep committing nodes from the flush-list until we're below allowance @@ -625,7 +653,11 @@ func (db *Database) Cap(limit common.StorageSize) error { defer db.lock.Unlock() if flushPreimages { - db.preimages, db.preimagesSize = make(map[common.Hash][]byte), 0 + if db.preimages == nil { + log.Error("Attempted to reset preimage cache whilst disabled") + } else { + db.preimages, db.preimagesSize = make(map[common.Hash][]byte), 0 + } } for db.oldest != oldest { node := db.dirties[db.oldest] @@ -669,20 +701,21 @@ func (db *Database) Commit(node common.Hash, report bool, callback func(common.H batch := db.diskdb.NewBatch() // Move all of the accumulated preimages into a write batch - rawdb.WritePreimages(batch, db.preimages) - if batch.ValueSize() > ethdb.IdealBatchSize { + if db.preimages != nil { + rawdb.WritePreimages(batch, db.preimages) + if batch.ValueSize() > ethdb.IdealBatchSize { + if err := batch.Write(); err != nil { + return err + } + batch.Reset() + } + // Since we're going to replay trie node writes into the clean cache, flush out + // any batched pre-images before continuing. if err := batch.Write(); err != nil { return err } batch.Reset() } - // Since we're going to replay trie node writes into the clean cache, flush out - // any batched pre-images before continuing. - if err := batch.Write(); err != nil { - return err - } - batch.Reset() - // Move the trie itself into the batch, flushing if enough data is accumulated nodes, storage := len(db.dirties), db.dirtiesSize @@ -704,8 +737,9 @@ func (db *Database) Commit(node common.Hash, report bool, callback func(common.H batch.Reset() // Reset the storage counters and bumpd metrics - db.preimages, db.preimagesSize = make(map[common.Hash][]byte), 0 - + if db.preimages != nil { + db.preimages, db.preimagesSize = make(map[common.Hash][]byte), 0 + } memcacheCommitTimeTimer.Update(time.Since(start)) memcacheCommitSizeMeter.Mark(int64(storage - db.dirtiesSize)) memcacheCommitNodesMeter.Mark(int64(nodes - len(db.dirties))) diff --git a/trie/encoding.go b/trie/encoding.go index 1955a3e664f5..8ee0022ef3a0 100644 --- a/trie/encoding.go +++ b/trie/encoding.go @@ -51,6 +51,35 @@ func hexToCompact(hex []byte) []byte { return buf } +// hexToCompactInPlace places the compact key in input buffer, returning the length +// needed for the representation +func hexToCompactInPlace(hex []byte) int { + var ( + hexLen = len(hex) // length of the hex input + firstByte = byte(0) + ) + // Check if we have a terminator there + if hexLen > 0 && hex[hexLen-1] == 16 { + firstByte = 1 << 5 + hexLen-- // last part was the terminator, ignore that + } + var ( + binLen = hexLen/2 + 1 + ni = 0 // index in hex + bi = 1 // index in bin (compact) + ) + if hexLen&1 == 1 { + firstByte |= 1 << 4 // odd flag + firstByte |= hex[0] // first nibble is contained in the first byte + ni++ + } + for ; ni < hexLen; bi, ni = bi+1, ni+2 { + hex[bi] = hex[ni]<<4 | hex[ni+1] + } + hex[0] = firstByte + return binLen +} + func compactToHex(compact []byte) []byte { if len(compact) == 0 { return compact diff --git a/trie/encoding_test.go b/trie/encoding_test.go index 97d8da136134..16393313f743 100644 --- a/trie/encoding_test.go +++ b/trie/encoding_test.go @@ -18,6 +18,8 @@ package trie import ( "bytes" + "encoding/hex" + "math/rand" "testing" ) @@ -75,6 +77,40 @@ func TestHexKeybytes(t *testing.T) { } } +func TestHexToCompactInPlace(t *testing.T) { + for i, keyS := range []string{ + "00", + "060a040c0f000a090b040803010801010900080d090a0a0d0903000b10", + "10", + } { + hexBytes, _ := hex.DecodeString(keyS) + exp := hexToCompact(hexBytes) + sz := hexToCompactInPlace(hexBytes) + got := hexBytes[:sz] + if !bytes.Equal(exp, got) { + t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, keyS, got, exp) + } + } +} + +func TestHexToCompactInPlaceRandom(t *testing.T) { + for i := 0; i < 10000; i++ { + l := rand.Intn(128) + key := make([]byte, l) + rand.Read(key) + hexBytes := keybytesToHex(key) + hexOrig := []byte(string(hexBytes)) + exp := hexToCompact(hexBytes) + sz := hexToCompactInPlace(hexBytes) + got := hexBytes[:sz] + + if !bytes.Equal(exp, got) { + t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n", + key, hexOrig, got, exp) + } + } +} + func BenchmarkHexToCompact(b *testing.B) { testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} for i := 0; i < b.N; i++ { diff --git a/trie/hasher.go b/trie/hasher.go index 57cd3e1f3676..3a62a2f1199c 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -66,11 +66,11 @@ func returnHasherToPool(h *hasher) { // hash collapses a node down into a hash node, also returning a copy of the // original node initialized with the computed hash to replace the original one. func (h *hasher) hash(n node, force bool) (hashed node, cached node) { - // We're not storing the node, just hashing, use available cached data + // Return the cached hash if it's available if hash, _ := n.cache(); hash != nil { return hash, n } - // Trie not processed yet or needs storage, walk the children + // Trie not processed yet, walk the children switch n := n.(type) { case *shortNode: collapsed, cached := h.hashShortNodeChildren(n) diff --git a/trie/iterator.go b/trie/iterator.go index bb4025d8f35b..76d437c40388 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -173,7 +173,7 @@ func (it *nodeIterator) LeafKey() []byte { func (it *nodeIterator) LeafBlob() []byte { if len(it.stack) > 0 { if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { - return []byte(node) + return node } } panic("not at leaf") diff --git a/trie/notary.go b/trie/notary.go new file mode 100644 index 000000000000..5a64727aa72d --- /dev/null +++ b/trie/notary.go @@ -0,0 +1,57 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/ethdb/memorydb" +) + +// KeyValueNotary tracks which keys have been accessed through a key-value reader +// with te scope of verifying if certain proof datasets are maliciously bloated. +type KeyValueNotary struct { + ethdb.KeyValueReader + reads map[string]struct{} +} + +// NewKeyValueNotary wraps a key-value database with an access notary to track +// which items have bene accessed. +func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary { + return &KeyValueNotary{ + KeyValueReader: db, + reads: make(map[string]struct{}), + } +} + +// Get retrieves an item from the underlying database, but also tracks it as an +// accessed slot for bloat checks. +func (k *KeyValueNotary) Get(key []byte) ([]byte, error) { + k.reads[string(key)] = struct{}{} + return k.KeyValueReader.Get(key) +} + +// Accessed returns s snapshot of the original key-value store containing only the +// data accessed through the notary. +func (k *KeyValueNotary) Accessed() ethdb.KeyValueStore { + db := memorydb.New() + for keystr := range k.reads { + key := []byte(keystr) + val, _ := k.KeyValueReader.Get(key) + db.Put(key, val) + } + return db +} diff --git a/trie/proof.go b/trie/proof.go index 6fc21fc1f7ee..e7102f12b757 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -129,10 +129,11 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } -// proofToPath converts a merkle proof to trie node path. -// The main purpose of this function is recovering a node -// path from the merkle proof stream. All necessary nodes -// will be resolved and leave the remaining as hashnode. +// proofToPath converts a merkle proof to trie node path. The main purpose of +// this function is recovering a node path from the merkle proof stream. All +// necessary nodes will be resolved and leave the remaining as hashnode. +// +// The given edge proof is allowed to be an existent or non-existent proof. func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader, allowNonExistent bool) (node, []byte, error) { // resolveNode retrieves and resolves trie node from merkle proof stream resolveNode := func(hash common.Hash) (node, error) { @@ -205,54 +206,61 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV } // unsetInternal removes all internal node references(hashnode, embedded node). -// It should be called after a trie is constructed with two edge proofs. Also -// the given boundary keys must be the one used to construct the edge proofs. +// It should be called after a trie is constructed with two edge paths. Also +// the given boundary keys must be the one used to construct the edge paths. // // It's the key step for range proof. All visited nodes should be marked dirty // since the node content might be modified. Besides it can happen that some // fullnodes only have one child which is disallowed. But if the proof is valid, // the missing children will be filled, otherwise it will be thrown anyway. +// +// Note we have the assumption here the given boundary keys are different +// and right is larger than left. func unsetInternal(n node, left []byte, right []byte) error { left, right = keybytesToHex(left), keybytesToHex(right) - // todo(rjl493456442) different length edge keys should be supported - if len(left) != len(right) { - return errors.New("inconsistent edge path") - } // Step down to the fork point. There are two scenarios can happen: - // - the fork point is a shortnode: the left proof MUST point to a - // non-existent key and the key doesn't match with the shortnode - // - the fork point is a fullnode: the left proof can point to an - // existent key or not. + // - the fork point is a shortnode: either the key of left proof or + // right proof doesn't match with shortnode's key. + // - the fork point is a fullnode: both two edge proofs are allowed + // to point to a non-existent key. var ( pos = 0 parent node + + // fork indicator, 0 means no fork, -1 means proof is less, 1 means proof is greater + shortForkLeft, shortForkRight int ) findFork: for { switch rn := (n).(type) { case *shortNode: - // The right proof must point to an existent key. - if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) { - return errors.New("invalid edge path") - } rn.flags = nodeFlag{dirty: true} - // Special case, the non-existent proof points to the same path - // as the existent proof, but the path of existent proof is longer. - // In this case, the fork point is this shortnode. - if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) { + + // If either the key of left proof or right proof doesn't match with + // shortnode, stop here and the forkpoint is the shortnode. + if len(left)-pos < len(rn.Key) { + shortForkLeft = bytes.Compare(left[pos:], rn.Key) + } else { + shortForkLeft = bytes.Compare(left[pos:pos+len(rn.Key)], rn.Key) + } + if len(right)-pos < len(rn.Key) { + shortForkRight = bytes.Compare(right[pos:], rn.Key) + } else { + shortForkRight = bytes.Compare(right[pos:pos+len(rn.Key)], rn.Key) + } + if shortForkLeft != 0 || shortForkRight != 0 { break findFork } parent = n n, pos = rn.Val, pos+len(rn.Key) case *fullNode: - leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]] - // The right proof must point to an existent key. - if rightnode == nil { - return errors.New("invalid edge path") - } rn.flags = nodeFlag{dirty: true} - if leftnode != rightnode { + + // If either the node pointed by left proof or right proof is nil, + // stop here and the forkpoint is the fullnode. + leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]] + if leftnode == nil || rightnode == nil || leftnode != rightnode { break findFork } parent = n @@ -263,12 +271,42 @@ findFork: } switch rn := n.(type) { case *shortNode: - if _, ok := rn.Val.(valueNode); ok { - parent.(*fullNode).Children[right[pos-1]] = nil + // There can have these five scenarios: + // - both proofs are less than the trie path => no valid range + // - both proofs are greater than the trie path => no valid range + // - left proof is less and right proof is greater => valid range, unset the shortnode entirely + // - left proof points to the shortnode, but right proof is greater + // - right proof points to the shortnode, but left proof is less + if shortForkLeft == -1 && shortForkRight == -1 { + return errors.New("empty range") + } + if shortForkLeft == 1 && shortForkRight == 1 { + return errors.New("empty range") + } + if shortForkLeft != 0 && shortForkRight != 0 { + parent.(*fullNode).Children[left[pos-1]] = nil return nil } - return unset(rn, rn.Val, right[pos:], len(rn.Key), true) + // Only one proof points to non-existent key. + if shortForkRight != 0 { + // Unset left proof's path + if _, ok := rn.Val.(valueNode); ok { + parent.(*fullNode).Children[left[pos-1]] = nil + return nil + } + return unset(rn, rn.Val, left[pos:], len(rn.Key), false) + } + if shortForkLeft != 0 { + // Unset right proof's path. + if _, ok := rn.Val.(valueNode); ok { + parent.(*fullNode).Children[right[pos-1]] = nil + return nil + } + return unset(rn, rn.Val, right[pos:], len(rn.Key), true) + } + return nil case *fullNode: + // unset all internal nodes in the forkpoint for i := left[pos] + 1; i < right[pos]; i++ { rn.Children[i] = nil } @@ -285,19 +323,17 @@ findFork: } // unset removes all internal node references either the left most or right most. -// If we try to unset all right most references, it can meet these scenarios: +// It can meet these scenarios: // -// - The given path is existent in the trie, unset the associated shortnode +// - The given path is existent in the trie, unset the associated nodes with the +// specific direction // - The given path is non-existent in the trie // - the fork point is a fullnode, the corresponding child pointed by path // is nil, return -// - the fork point is a shortnode, the key of shortnode is less than path, +// - the fork point is a shortnode, the shortnode is included in the range, // keep the entire branch and return. -// - the fork point is a shortnode, the key of shortnode is greater than path, +// - the fork point is a shortnode, the shortnode is excluded in the range, // unset the entire branch. -// -// If we try to unset all left most references, then the given path should -// be existent. func unset(parent node, child node, key []byte, pos int, removeLeft bool) error { switch cld := child.(type) { case *fullNode: @@ -317,18 +353,29 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) { // Find the fork point, it's an non-existent branch. if removeLeft { - return errors.New("invalid right edge proof") - } - if bytes.Compare(cld.Key, key[pos:]) > 0 { - // The key of fork shortnode is greater than the - // path(it belongs to the range), unset the entrie - // branch. The parent must be a fullnode. - fn := parent.(*fullNode) - fn.Children[key[pos-1]] = nil + if bytes.Compare(cld.Key, key[pos:]) < 0 { + // The key of fork shortnode is less than the path + // (it belongs to the range), unset the entrie + // branch. The parent must be a fullnode. + fn := parent.(*fullNode) + fn.Children[key[pos-1]] = nil + } else { + // The key of fork shortnode is greater than the + // path(it doesn't belong to the range), keep + // it with the cached hash available. + } } else { - // The key of fork shortnode is less than the - // path(it doesn't belong to the range), keep - // it with the cached hash available. + if bytes.Compare(cld.Key, key[pos:]) > 0 { + // The key of fork shortnode is greater than the + // path(it belongs to the range), unset the entrie + // branch. The parent must be a fullnode. + fn := parent.(*fullNode) + fn.Children[key[pos-1]] = nil + } else { + // The key of fork shortnode is less than the + // path(it doesn't belong to the range), keep + // it with the cached hash available. + } } return nil } @@ -340,11 +387,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error cld.flags = nodeFlag{dirty: true} return unset(cld, cld.Val, key, pos+len(cld.Key), removeLeft) case nil: - // If the node is nil, it's a child of the fork point - // fullnode(it's an non-existent branch). - if removeLeft { - return errors.New("invalid right edge proof") - } + // If the node is nil, then it's a child of the fork point + // fullnode(it's a non-existent branch). return nil default: panic("it shouldn't happen") // hashNode, valueNode @@ -380,98 +424,166 @@ func hasRightElement(node node, key []byte) bool { return false } -// VerifyRangeProof checks whether the given leaf nodes and edge proofs -// can prove the given trie leaves range is matched with given root hash -// and the range is consecutive(no gap inside) and monotonic increasing. +// VerifyRangeProof checks whether the given leaf nodes and edge proof +// can prove the given trie leaves range is matched with the specific root. +// Besides, the range should be consecutive (no gap inside) and monotonic +// increasing. // -// Note the given first edge proof can be non-existing proof. For example -// the first proof is for an non-existent values 0x03. The given batch -// leaves are [0x04, 0x05, .. 0x09]. It's still feasible to prove. But the -// last edge proof should always be an existent proof. +// Note the given proof actually contains two edge proofs. Both of them can +// be non-existent proofs. For example the first proof is for a non-existent +// key 0x03, the last proof is for a non-existent key 0x10. The given batch +// leaves are [0x04, 0x05, .. 0x09]. It's still feasible to prove the given +// batch is valid. // // The firstKey is paired with firstProof, not necessarily the same as keys[0] -// (unless firstProof is an existent proof). +// (unless firstProof is an existent proof). Similarly, lastKey and lastProof +// are paired. // // Expect the normal case, this function can also be used to verify the following -// range proofs(note this function doesn't accept zero element proof): +// range proofs: // -// - All elements proof. In this case the left and right proof can be nil, but the -// range should be all the leaves in the trie. +// - All elements proof. In this case the proof can be nil, but the range should +// be all the leaves in the trie. // -// - One element proof. In this case no matter the left edge proof is a non-existent +// - One element proof. In this case no matter the edge proof is a non-existent // proof or not, we can always verify the correctness of the proof. // +// - Zero element proof. In this case a single non-existent proof is enough to prove. +// Besides, if there are still some other leaves available on the right side, then +// an error will be returned. +// // Except returning the error to indicate the proof is valid or not, the function will // also return a flag to indicate whether there exists more accounts/slots in the trie. -func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, values [][]byte, firstProof ethdb.KeyValueReader, lastProof ethdb.KeyValueReader) (error, bool) { +func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, *Trie, *KeyValueNotary, bool, error) { if len(keys) != len(values) { - return fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)), false - } - if len(keys) == 0 { - return errors.New("empty proof"), false + return nil, nil, nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) } // Ensure the received batch is monotonic increasing. for i := 0; i < len(keys)-1; i++ { if bytes.Compare(keys[i], keys[i+1]) >= 0 { - return errors.New("range is not monotonically increasing"), false + return nil, nil, nil, false, errors.New("range is not monotonically increasing") } } + // Create a key-value notary to track which items from the given proof the + // range prover actually needed to verify the data + notary := NewKeyValueNotary(proof) + // Special case, there is no edge proof at all. The given range is expected // to be the whole leaf-set in the trie. - if firstProof == nil && lastProof == nil { - emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New())) + if proof == nil { + var ( + diskdb = memorydb.New() + triedb = NewDatabase(diskdb) + ) + tr, err := New(common.Hash{}, triedb) if err != nil { - return err, false + return nil, nil, nil, false, err } for index, key := range keys { - emptytrie.TryUpdate(key, values[index]) + tr.TryUpdate(key, values[index]) + } + if tr.Hash() != rootHash { + return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) + } + // Proof seems valid, serialize all the nodes into the database + if _, err := tr.Commit(nil); err != nil { + return nil, nil, nil, false, err + } + if err := triedb.Commit(rootHash, false, nil); err != nil { + return nil, nil, nil, false, err + } + return diskdb, tr, notary, false, nil // No more elements + } + // Special case, there is a provided edge proof but zero key/value + // pairs, ensure there are no more accounts / slots in the trie. + if len(keys) == 0 { + root, val, err := proofToPath(rootHash, nil, firstKey, notary, true) + if err != nil { + return nil, nil, nil, false, err + } + if val != nil || hasRightElement(root, firstKey) { + return nil, nil, nil, false, errors.New("more entries available") } - if emptytrie.Hash() != rootHash { - return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash()), false + // Since the entire proof is a single path, we can construct a trie and a + // node database directly out of the inputs, no need to generate them + diskdb := notary.Accessed() + tr := &Trie{ + db: NewDatabase(diskdb), + root: root, } - return nil, false // no more element. + return diskdb, tr, notary, hasRightElement(root, firstKey), nil } - // Special case, there is only one element and left edge - // proof is an existent one. - if len(keys) == 1 && bytes.Equal(keys[0], firstKey) { - root, val, err := proofToPath(rootHash, nil, firstKey, firstProof, false) + // Special case, there is only one element and two edge keys are same. + // In this case, we can't construct two edge paths. So handle it here. + if len(keys) == 1 && bytes.Equal(firstKey, lastKey) { + root, val, err := proofToPath(rootHash, nil, firstKey, notary, false) if err != nil { - return err, false + return nil, nil, nil, false, err + } + if !bytes.Equal(firstKey, keys[0]) { + return nil, nil, nil, false, errors.New("correct proof but invalid key") } if !bytes.Equal(val, values[0]) { - return fmt.Errorf("correct proof but invalid data"), false + return nil, nil, nil, false, errors.New("correct proof but invalid data") + } + // Since the entire proof is a single path, we can construct a trie and a + // node database directly out of the inputs, no need to generate them + diskdb := notary.Accessed() + tr := &Trie{ + db: NewDatabase(diskdb), + root: root, } - return nil, hasRightElement(root, keys[0]) + return diskdb, tr, notary, hasRightElement(root, firstKey), nil + } + // Ok, in all other cases, we require two edge paths available. + // First check the validity of edge keys. + if bytes.Compare(firstKey, lastKey) >= 0 { + return nil, nil, nil, false, errors.New("invalid edge keys") + } + // todo(rjl493456442) different length edge keys should be supported + if len(firstKey) != len(lastKey) { + return nil, nil, nil, false, errors.New("inconsistent edge keys") } // Convert the edge proofs to edge trie paths. Then we can // have the same tree architecture with the original one. // For the first edge proof, non-existent proof is allowed. - root, _, err := proofToPath(rootHash, nil, firstKey, firstProof, true) + root, _, err := proofToPath(rootHash, nil, firstKey, notary, true) if err != nil { - return err, false + return nil, nil, nil, false, err } // Pass the root node here, the second path will be merged // with the first one. For the last edge proof, non-existent - // proof is not allowed. - root, _, err = proofToPath(rootHash, root, keys[len(keys)-1], lastProof, false) + // proof is also allowed. + root, _, err = proofToPath(rootHash, root, lastKey, notary, true) if err != nil { - return err, false + return nil, nil, nil, false, err } // Remove all internal references. All the removed parts should // be re-filled(or re-constructed) by the given leaves range. - if err := unsetInternal(root, firstKey, keys[len(keys)-1]); err != nil { - return err, false + if err := unsetInternal(root, firstKey, lastKey); err != nil { + return nil, nil, nil, false, err } - // Rebuild the trie with the leave stream, the shape of trie + // Rebuild the trie with the leaf stream, the shape of trie // should be same with the original one. - newtrie := &Trie{root: root, db: NewDatabase(memorydb.New())} + var ( + diskdb = memorydb.New() + triedb = NewDatabase(diskdb) + ) + tr := &Trie{root: root, db: triedb} for index, key := range keys { - newtrie.TryUpdate(key, values[index]) + tr.TryUpdate(key, values[index]) + } + if tr.Hash() != rootHash { + return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) + } + // Proof seems valid, serialize all the nodes into the database + if _, err := tr.Commit(nil); err != nil { + return nil, nil, nil, false, err } - if newtrie.Hash() != rootHash { - return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash()), false + if err := triedb.Commit(rootHash, false, nil); err != nil { + return nil, nil, nil, false, err } - return nil, hasRightElement(root, keys[len(keys)-1]) + return diskdb, tr, notary, hasRightElement(root, keys[len(keys)-1]), nil } // get returns the child of the given node. Return nil if the diff --git a/trie/proof_test.go b/trie/proof_test.go index 55585e4daf46..3ecd318886f7 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -19,6 +19,7 @@ package trie import ( "bytes" crand "crypto/rand" + "encoding/binary" mrand "math/rand" "sort" "testing" @@ -166,15 +167,13 @@ func TestRangeProof(t *testing.T) { sort.Sort(entries) for i := 0; i < 500; i++ { start := mrand.Intn(len(entries)) - end := mrand.Intn(len(entries)-start) + start - if start == end { - continue - } - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + end := mrand.Intn(len(entries)-start) + start + 1 + + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } var keys [][]byte @@ -183,15 +182,15 @@ func TestRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) if err != nil { t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } } } -// TestRangeProof tests normal range proof with the first edge proof -// as the non-existent proof. The test cases are generated randomly. +// TestRangeProof tests normal range proof with two non-existent proofs. +// The test cases are generated randomly. func TestRangeProofWithNonExistentProof(t *testing.T) { trie, vals := randomTrie(4096) var entries entrySlice @@ -201,20 +200,31 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { sort.Sort(entries) for i := 0; i < 500; i++ { start := mrand.Intn(len(entries)) - end := mrand.Intn(len(entries)-start) + start - if start == end { - continue - } - firstProof, lastProof := memorydb.New(), memorydb.New() + end := mrand.Intn(len(entries)-start) + start + 1 + proof := memorydb.New() + // Short circuit if the decreased key is same with the previous key first := decreseKey(common.CopyBytes(entries[start].k)) if start != 0 && bytes.Equal(first, entries[start-1].k) { continue } - if err := trie.Prove(first, 0, firstProof); err != nil { + // Short circuit if the decreased key is underflow + if bytes.Compare(first, entries[start].k) > 0 { + continue + } + // Short circuit if the increased key is same with the next key + last := increseKey(common.CopyBytes(entries[end-1].k)) + if end != len(entries) && bytes.Equal(last, entries[end].k) { + continue + } + // Short circuit if the increased key is overflow + if bytes.Compare(last, entries[end-1].k) < 0 { + continue + } + if err := trie.Prove(first, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } var keys [][]byte @@ -223,16 +233,36 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - err, _ := VerifyRangeProof(trie.Hash(), first, keys, vals, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) if err != nil { t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } } + // Special case, two edge proofs for two edge key. + proof := memorydb.New() + first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + var k [][]byte + var v [][]byte + for i := 0; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + if err != nil { + t.Fatal("Failed to verify whole rang with non-existent edges") + } } // TestRangeProofWithInvalidNonExistentProof tests such scenarios: -// - The last edge proof is an non-existent proof // - There exists a gap between the first element and the left edge proof +// - There exists a gap between the last element and the right edge proof func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { trie, vals := randomTrie(4096) var entries entrySlice @@ -243,44 +273,45 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { // Case 1 start, end := 100, 200 - first, last := decreseKey(common.CopyBytes(entries[start].k)), increseKey(common.CopyBytes(entries[end].k)) - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(first, 0, firstProof); err != nil { + first := decreseKey(common.CopyBytes(entries[start].k)) + + proof := memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(last, 0, lastProof); err != nil { + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - var k [][]byte - var v [][]byte + start = 105 // Gap created + k := make([][]byte, 0) + v := make([][]byte, 0) for i := start; i < end; i++ { k = append(k, entries[i].k) v = append(v, entries[i].v) } - err, _ := VerifyRangeProof(trie.Hash(), first, k, v, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) if err == nil { t.Fatalf("Expected to detect the error, got nil") } // Case 2 start, end = 100, 200 - first = decreseKey(common.CopyBytes(entries[start].k)) - - firstProof, lastProof = memorydb.New(), memorydb.New() - if err := trie.Prove(first, 0, firstProof); err != nil { + last := increseKey(common.CopyBytes(entries[end-1].k)) + proof = memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - start = 105 // Gap created + end = 195 // Capped slice k = make([][]byte, 0) v = make([][]byte, 0) for i := start; i < end; i++ { k = append(k, entries[i].k) v = append(v, entries[i].v) } - err, _ = VerifyRangeProof(trie.Hash(), first, k, v, firstProof, lastProof) + _, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) if err == nil { t.Fatalf("Expected to detect the error, got nil") } @@ -297,31 +328,59 @@ func TestOneElementRangeProof(t *testing.T) { } sort.Sort(entries) - // One element with existent edge proof + // One element with existent edge proof, both edge proofs + // point to the SAME key. start := 1000 - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[start].k, 0, lastProof); err != nil { + _, _, _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // One element with left non-existent edge proof + start = 1000 + first := decreseKey(common.CopyBytes(entries[start].k)) + proof = memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - err, _ := VerifyRangeProof(trie.Hash(), entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, firstProof, lastProof) + _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } - // One element with non-existent edge proof + // One element with right non-existent edge proof start = 1000 - first := decreseKey(common.CopyBytes(entries[start].k)) - firstProof, lastProof = memorydb.New(), memorydb.New() - if err := trie.Prove(first, 0, firstProof); err != nil { + last := increseKey(common.CopyBytes(entries[start].k)) + proof = memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, _, _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // One element with two non-existent edge proofs + start = 1000 + first, last = decreseKey(common.CopyBytes(entries[start].k)), increseKey(common.CopyBytes(entries[start].k)) + proof = memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[start].k, 0, lastProof); err != nil { + if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - err, _ = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[start].k}, [][]byte{entries[start].v}, firstProof, lastProof) + _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -343,20 +402,35 @@ func TestAllElementsProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - err, _ := VerifyRangeProof(trie.Hash(), k[0], k, v, nil, nil) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) if err != nil { t.Fatalf("Expected no error, got %v", err) } - // Even with edge proofs, it should still work. - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(entries[0].k, 0, firstProof); err != nil { + // With edge proofs, it should still work. + proof := memorydb.New() + if err := trie.Prove(entries[0].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[len(entries)-1].k, 0, lastProof); err != nil { + if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - err, _ = VerifyRangeProof(trie.Hash(), k[0], k, v, firstProof, lastProof) + _, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Even with non-existent edge proofs, it should still work. + proof = memorydb.New() + first := common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -376,11 +450,11 @@ func TestSingleSideRangeProof(t *testing.T) { var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} for _, pos := range cases { - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil { + proof := memorydb.New() + if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil { + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } k := make([][]byte, 0) @@ -389,7 +463,43 @@ func TestSingleSideRangeProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - err, _ := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + } + } +} + +// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. +func TestReverseSingleSideRangeProof(t *testing.T) { + for i := 0; i < 64; i++ { + trie := new(Trie) + var entries entrySlice + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + sort.Sort(entries) + + var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} + for _, pos := range cases { + proof := memorydb.New() + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if err := trie.Prove(last.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := pos; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, _, _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -409,15 +519,12 @@ func TestBadRangeProof(t *testing.T) { for i := 0; i < 500; i++ { start := mrand.Intn(len(entries)) - end := mrand.Intn(len(entries)-start) + start - if start == end { - continue - } - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + end := mrand.Intn(len(entries)-start) + start + 1 + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } var keys [][]byte @@ -426,6 +533,7 @@ func TestBadRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } + var first, last = keys[0], keys[len(keys)-1] testcase := mrand.Intn(6) var index int switch testcase { @@ -439,38 +547,31 @@ func TestBadRangeProof(t *testing.T) { vals[index] = randBytes(20) // In theory it can't be same case 2: // Gapped entry slice - - // There are only two elements, skip it. Dropped any element - // will lead to single edge proof which is always correct. - if end-start <= 2 { - continue - } - // If the dropped element is the first or last one and it's a - // batch of small size elements. In this special case, it can - // happen that the proof for the edge element is exactly same - // with the first/last second element(since small values are - // embedded in the parent). Avoid this case. index = mrand.Intn(end - start) - if (index == end-start-1 || index == 0) && end <= 100 { + if (index == 0 && start < 100) || (index == end-start-1 && end <= 100) { continue } keys = append(keys[:index], keys[index+1:]...) vals = append(vals[:index], vals[index+1:]...) case 3: - // Switched entry slice, same effect with gapped - index = mrand.Intn(end - start) - keys[index] = entries[len(entries)-1].k - vals[index] = entries[len(entries)-1].v + // Out of order + index1 := mrand.Intn(end - start) + index2 := mrand.Intn(end - start) + if index1 == index2 { + continue + } + keys[index1], keys[index2] = keys[index2], keys[index1] + vals[index1], vals[index2] = vals[index2], vals[index1] case 4: - // Set random key to nil + // Set random key to nil, do nothing index = mrand.Intn(end - start) keys[index] = nil case 5: - // Set random value to nil + // Set random value to nil, deletion index = mrand.Intn(end - start) vals[index] = nil } - err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) if err == nil { t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) } @@ -488,11 +589,11 @@ func TestGappedRangeProof(t *testing.T) { entries = append(entries, value) } first, last := 2, 8 - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(entries[first].k, 0, firstProof); err != nil { + proof := memorydb.New() + if err := trie.Prove(entries[first].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[last-1].k, 0, lastProof); err != nil { + if err := trie.Prove(entries[last-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } var keys [][]byte @@ -504,12 +605,55 @@ func TestGappedRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) if err == nil { t.Fatal("expect error, got nil") } } +// TestSameSideProofs tests the element is not in the range covered by proofs +func TestSameSideProofs(t *testing.T) { + trie, vals := randomTrie(4096) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, kv) + } + sort.Sort(entries) + + pos := 1000 + first := decreseKey(common.CopyBytes(entries[pos].k)) + first = decreseKey(first) + last := decreseKey(common.CopyBytes(entries[pos].k)) + + proof := memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + if err == nil { + t.Fatalf("Expected error, got nil") + } + + first = increseKey(common.CopyBytes(entries[pos].k)) + last = increseKey(common.CopyBytes(entries[pos].k)) + last = increseKey(last) + + proof = memorydb.New() + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + if err == nil { + t.Fatalf("Expected error, got nil") + } +} + func TestHasRightElement(t *testing.T) { trie := new(Trie) var entries entrySlice @@ -530,38 +674,49 @@ func TestHasRightElement(t *testing.T) { {0, 10, true}, {50, 100, true}, {50, len(entries), false}, // No more element expected - {len(entries) - 1, len(entries), false}, // Single last element + {len(entries) - 1, len(entries), false}, // Single last element with two existent proofs(point to same key) + {len(entries) - 1, -1, false}, // Single last element with non-existent right proof {0, len(entries), false}, // The whole set with existent left proof {-1, len(entries), false}, // The whole set with non-existent left proof + {-1, -1, false}, // The whole set with non-existent left/right proof } for _, c := range cases { var ( - firstKey []byte - start = c.start - firstProof = memorydb.New() - lastProof = memorydb.New() + firstKey []byte + lastKey []byte + start = c.start + end = c.end + proof = memorydb.New() ) if c.start == -1 { firstKey, start = common.Hash{}.Bytes(), 0 - if err := trie.Prove(firstKey, 0, firstProof); err != nil { + if err := trie.Prove(firstKey, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } } else { firstKey = entries[c.start].k - if err := trie.Prove(entries[c.start].k, 0, firstProof); err != nil { + if err := trie.Prove(entries[c.start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } } - if err := trie.Prove(entries[c.end-1].k, 0, lastProof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) + if c.end == -1 { + lastKey, end = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes(), len(entries) + if err := trie.Prove(lastKey, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + } else { + lastKey = entries[c.end-1].k + if err := trie.Prove(entries[c.end-1].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } } k := make([][]byte, 0) v := make([][]byte, 0) - for i := start; i < c.end; i++ { + for i := start; i < end; i++ { k = append(k, entries[i].k) v = append(v, entries[i].v) } - err, hasMore := VerifyRangeProof(trie.Hash(), firstKey, k, v, firstProof, lastProof) + _, _, _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -571,6 +726,83 @@ func TestHasRightElement(t *testing.T) { } } +// TestEmptyRangeProof tests the range proof with "no" element. +// The first edge proof must be a non-existent proof. +func TestEmptyRangeProof(t *testing.T) { + trie, vals := randomTrie(4096) + var entries entrySlice + for _, kv := range vals { + entries = append(entries, kv) + } + sort.Sort(entries) + + var cases = []struct { + pos int + err bool + }{ + {len(entries) - 1, false}, + {500, true}, + } + for _, c := range cases { + proof := memorydb.New() + first := increseKey(common.CopyBytes(entries[c.pos].k)) + if err := trie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + db, tr, not, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) + if c.err && err == nil { + t.Fatalf("Expected error, got nil") + } + if !c.err && err != nil { + t.Fatalf("Expected no error, got %v", err) + } + // If no error was returned, ensure the returned trie and database contains + // the entire proof, since there's no value + if !c.err { + if err := tr.Prove(first, 0, memorydb.New()); err != nil { + t.Errorf("returned trie doesn't contain original proof: %v", err) + } + if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() { + t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len()) + } + if not == nil { + t.Errorf("missing notary") + } + } + } +} + +// TestBloatedProof tests a malicious proof, where the proof is more or less the +// whole trie. +func TestBloatedProof(t *testing.T) { + // Use a small trie + trie, kvs := nonRandomTrie(100) + var entries entrySlice + for _, kv := range kvs { + entries = append(entries, kv) + } + sort.Sort(entries) + var keys [][]byte + var vals [][]byte + + proof := memorydb.New() + for i, entry := range entries { + trie.Prove(entry.k, 0, proof) + if i == 50 { + keys = append(keys, entry.k) + vals = append(vals, entry.v) + } + } + want := memorydb.New() + trie.Prove(keys[0], 0, want) + trie.Prove(keys[len(keys)-1], 0, want) + + _, _, notary, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + if used := notary.Accessed().(*memorydb.Database); used.Len() != want.Len() { + t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len()) + } +} + // mutateByte changes one byte in b. func mutateByte(b []byte) { for r := mrand.Intn(len(b)); ; { @@ -655,11 +887,11 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) { start := 2 end := start + size - firstProof, lastProof := memorydb.New(), memorydb.New() - if err := trie.Prove(entries[start].k, 0, firstProof); err != nil { + proof := memorydb.New() + if err := trie.Prove(entries[start].k, 0, proof); err != nil { b.Fatalf("Failed to prove the first node %v", err) } - if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil { + if err := trie.Prove(entries[end-1].k, 0, proof); err != nil { b.Fatalf("Failed to prove the last node %v", err) } var keys [][]byte @@ -671,7 +903,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) { b.ResetTimer() for i := 0; i < b.N; i++ { - err, _ := VerifyRangeProof(trie.Hash(), keys[0], keys, values, firstProof, lastProof) + _, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) if err != nil { b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -702,3 +934,20 @@ func randBytes(n int) []byte { crand.Read(r) return r } + +func nonRandomTrie(n int) (*Trie, map[string]*kv) { + trie := new(Trie) + vals := make(map[string]*kv) + max := uint64(0xffffffffffffffff) + for i := uint64(0); i < uint64(n); i++ { + value := make([]byte, 32) + key := make([]byte, 32) + binary.LittleEndian.PutUint64(key, i) + binary.LittleEndian.PutUint64(value, i-max) + //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + elem := &kv{key, value, false} + trie.Update(elem.k, elem.v) + vals[string(elem.k)] = elem + } + return trie, vals +} diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 87b364fb1bba..e38471c1b76f 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -147,12 +147,13 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) { // Write all the pre-images to the actual disk database if len(t.getSecKeyCache()) > 0 { - t.trie.db.lock.Lock() - for hk, key := range t.secKeyCache { - t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key) + if t.trie.db.preimages != nil { // Ugly direct check but avoids the below write lock + t.trie.db.lock.Lock() + for hk, key := range t.secKeyCache { + t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key) + } + t.trie.db.lock.Unlock() } - t.trie.db.lock.Unlock() - t.secKeyCache = make(map[string][]byte) } // Commit the trie to its intermediate node database diff --git a/trie/stacktrie.go b/trie/stacktrie.go new file mode 100644 index 000000000000..575a04022f3a --- /dev/null +++ b/trie/stacktrie.go @@ -0,0 +1,426 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "errors" + "fmt" + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" +) + +var ErrCommitDisabled = errors.New("no database for committing") + +var stPool = sync.Pool{ + New: func() interface{} { + return NewStackTrie(nil) + }, +} + +func stackTrieFromPool(db ethdb.KeyValueStore) *StackTrie { + st := stPool.Get().(*StackTrie) + st.db = db + return st +} + +func returnToPool(st *StackTrie) { + st.Reset() + stPool.Put(st) +} + +// StackTrie is a trie implementation that expects keys to be inserted +// in order. Once it determines that a subtree will no longer be inserted +// into, it will hash it and free up the memory it uses. +type StackTrie struct { + nodeType uint8 // node type (as in branch, ext, leaf) + val []byte // value contained by this node if it's a leaf + key []byte // key chunk covered by this (full|ext) node + keyOffset int // offset of the key chunk inside a full key + children [16]*StackTrie // list of children (for fullnodes and exts) + + db ethdb.KeyValueStore // Pointer to the commit db, can be nil +} + +// NewStackTrie allocates and initializes an empty trie. +func NewStackTrie(db ethdb.KeyValueStore) *StackTrie { + return &StackTrie{ + nodeType: emptyNode, + db: db, + } +} + +func newLeaf(ko int, key, val []byte, db ethdb.KeyValueStore) *StackTrie { + st := stackTrieFromPool(db) + st.nodeType = leafNode + st.keyOffset = ko + st.key = append(st.key, key[ko:]...) + st.val = val + return st +} + +func newExt(ko int, key []byte, child *StackTrie, db ethdb.KeyValueStore) *StackTrie { + st := stackTrieFromPool(db) + st.nodeType = extNode + st.keyOffset = ko + st.key = append(st.key, key[ko:]...) + st.children[0] = child + return st +} + +// List all values that StackTrie#nodeType can hold +const ( + emptyNode = iota + branchNode + extNode + leafNode + hashedNode +) + +// TryUpdate inserts a (key, value) pair into the stack trie +func (st *StackTrie) TryUpdate(key, value []byte) error { + k := keybytesToHex(key) + if len(value) == 0 { + panic("deletion not supported") + } + st.insert(k[:len(k)-1], value) + return nil +} + +func (st *StackTrie) Update(key, value []byte) { + if err := st.TryUpdate(key, value); err != nil { + log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + } +} + +func (st *StackTrie) Reset() { + st.db = nil + st.key = st.key[:0] + st.val = nil + for i := range st.children { + st.children[i] = nil + } + st.nodeType = emptyNode + st.keyOffset = 0 +} + +// Helper function that, given a full key, determines the index +// at which the chunk pointed by st.keyOffset is different from +// the same chunk in the full key. +func (st *StackTrie) getDiffIndex(key []byte) int { + diffindex := 0 + for ; diffindex < len(st.key) && st.key[diffindex] == key[st.keyOffset+diffindex]; diffindex++ { + } + return diffindex +} + +// Helper function to that inserts a (key, value) pair into +// the trie. +func (st *StackTrie) insert(key, value []byte) { + switch st.nodeType { + case branchNode: /* Branch */ + idx := int(key[st.keyOffset]) + // Unresolve elder siblings + for i := idx - 1; i >= 0; i-- { + if st.children[i] != nil { + if st.children[i].nodeType != hashedNode { + st.children[i].hash() + } + break + } + } + // Add new child + if st.children[idx] == nil { + st.children[idx] = stackTrieFromPool(st.db) + st.children[idx].keyOffset = st.keyOffset + 1 + } + st.children[idx].insert(key, value) + case extNode: /* Ext */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Check if chunks are identical. If so, recurse into + // the child node. Otherwise, the key has to be split + // into 1) an optional common prefix, 2) the fullnode + // representing the two differing path, and 3) a leaf + // for each of the differentiated subtrees. + if diffidx == len(st.key) { + // Ext key and key segment are identical, recurse into + // the child node. + st.children[0].insert(key, value) + return + } + // Save the original part. Depending if the break is + // at the extension's last byte or not, create an + // intermediate extension or use the extension's child + // node directly. + var n *StackTrie + if diffidx < len(st.key)-1 { + n = newExt(diffidx+1, st.key, st.children[0], st.db) + } else { + // Break on the last byte, no need to insert + // an extension node: reuse the current node + n = st.children[0] + } + // Convert to hash + n.hash() + var p *StackTrie + if diffidx == 0 { + // the break is on the first byte, so + // the current node is converted into + // a branch node. + st.children[0] = nil + p = st + st.nodeType = branchNode + } else { + // the common prefix is at least one byte + // long, insert a new intermediate branch + // node. + st.children[0] = stackTrieFromPool(st.db) + st.children[0].nodeType = branchNode + st.children[0].keyOffset = st.keyOffset + diffidx + p = st.children[0] + } + // Create a leaf for the inserted part + o := newLeaf(st.keyOffset+diffidx+1, key, value, st.db) + + // Insert both child leaves where they belong: + origIdx := st.key[diffidx] + newIdx := key[diffidx+st.keyOffset] + p.children[origIdx] = n + p.children[newIdx] = o + st.key = st.key[:diffidx] + + case leafNode: /* Leaf */ + // Compare both key chunks and see where they differ + diffidx := st.getDiffIndex(key) + + // Overwriting a key isn't supported, which means that + // the current leaf is expected to be split into 1) an + // optional extension for the common prefix of these 2 + // keys, 2) a fullnode selecting the path on which the + // keys differ, and 3) one leaf for the differentiated + // component of each key. + if diffidx >= len(st.key) { + panic("Trying to insert into existing key") + } + + // Check if the split occurs at the first nibble of the + // chunk. In that case, no prefix extnode is necessary. + // Otherwise, create that + var p *StackTrie + if diffidx == 0 { + // Convert current leaf into a branch + st.nodeType = branchNode + p = st + st.children[0] = nil + } else { + // Convert current node into an ext, + // and insert a child branch node. + st.nodeType = extNode + st.children[0] = NewStackTrie(st.db) + st.children[0].nodeType = branchNode + st.children[0].keyOffset = st.keyOffset + diffidx + p = st.children[0] + } + + // Create the two child leaves: the one containing the + // original value and the one containing the new value + // The child leave will be hashed directly in order to + // free up some memory. + origIdx := st.key[diffidx] + p.children[origIdx] = newLeaf(diffidx+1, st.key, st.val, st.db) + p.children[origIdx].hash() + + newIdx := key[diffidx+st.keyOffset] + p.children[newIdx] = newLeaf(p.keyOffset+1, key, value, st.db) + + // Finally, cut off the key part that has been passed + // over to the children. + st.key = st.key[:diffidx] + st.val = nil + case emptyNode: /* Empty */ + st.nodeType = leafNode + st.key = key[st.keyOffset:] + st.val = value + case hashedNode: + panic("trying to insert into hash") + default: + panic("invalid type") + } +} + +// hash() hashes the node 'st' and converts it into 'hashedNode', if possible. +// Possible outcomes: +// 1. The rlp-encoded value was >= 32 bytes: +// - Then the 32-byte `hash` will be accessible in `st.val`. +// - And the 'st.type' will be 'hashedNode' +// 2. The rlp-encoded value was < 32 bytes +// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'. +// - And the 'st.type' will be 'hashedNode' AGAIN +// +// This method will also: +// set 'st.type' to hashedNode +// clear 'st.key' +func (st *StackTrie) hash() { + /* Shortcut if node is already hashed */ + if st.nodeType == hashedNode { + return + } + // The 'hasher' is taken from a pool, but we don't actually + // claim an instance until all children are done with their hashing, + // and we actually need one + var h *hasher + + switch st.nodeType { + case branchNode: + var nodes [17]node + for i, child := range st.children { + if child == nil { + nodes[i] = nilValueNode + continue + } + child.hash() + if len(child.val) < 32 { + nodes[i] = rawNode(child.val) + } else { + nodes[i] = hashNode(child.val) + } + st.children[i] = nil // Reclaim mem from subtree + returnToPool(child) + } + nodes[16] = nilValueNode + h = newHasher(false) + defer returnHasherToPool(h) + h.tmp.Reset() + if err := rlp.Encode(&h.tmp, nodes); err != nil { + panic(err) + } + case extNode: + st.children[0].hash() + h = newHasher(false) + defer returnHasherToPool(h) + h.tmp.Reset() + var valuenode node + if len(st.children[0].val) < 32 { + valuenode = rawNode(st.children[0].val) + } else { + valuenode = hashNode(st.children[0].val) + } + n := struct { + Key []byte + Val node + }{ + Key: hexToCompact(st.key), + Val: valuenode, + } + if err := rlp.Encode(&h.tmp, n); err != nil { + panic(err) + } + returnToPool(st.children[0]) + st.children[0] = nil // Reclaim mem from subtree + case leafNode: + h = newHasher(false) + defer returnHasherToPool(h) + h.tmp.Reset() + st.key = append(st.key, byte(16)) + sz := hexToCompactInPlace(st.key) + n := [][]byte{st.key[:sz], st.val} + if err := rlp.Encode(&h.tmp, n); err != nil { + panic(err) + } + case emptyNode: + st.val = st.val[:0] + st.val = append(st.val, emptyRoot[:]...) + st.key = st.key[:0] + st.nodeType = hashedNode + return + default: + panic("Invalid node type") + } + st.key = st.key[:0] + st.nodeType = hashedNode + if len(h.tmp) < 32 { + st.val = st.val[:0] + st.val = append(st.val, h.tmp...) + return + } + // Going to write the hash to the 'val'. Need to ensure it's properly sized first + // Typically, 'branchNode's will have no 'val', and require this allocation + if required := 32 - len(st.val); required > 0 { + buf := make([]byte, required) + st.val = append(st.val, buf...) + } + st.val = st.val[:32] + h.sha.Reset() + h.sha.Write(h.tmp) + h.sha.Read(st.val) + if st.db != nil { + // TODO! Is it safe to Put the slice here? + // Do all db implementations copy the value provided? + st.db.Put(st.val, h.tmp) + } +} + +// Hash returns the hash of the current node +func (st *StackTrie) Hash() (h common.Hash) { + st.hash() + if len(st.val) != 32 { + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed, and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing. + ret := make([]byte, 32) + h := newHasher(false) + defer returnHasherToPool(h) + h.sha.Reset() + h.sha.Write(st.val) + h.sha.Read(ret) + return common.BytesToHash(ret) + } + return common.BytesToHash(st.val) +} + +// Commit will firstly hash the entrie trie if it's still not hashed +// and then commit all nodes to the associated database. Actually most +// of the trie nodes MAY have been committed already. The main purpose +// here is to commit the root node. +// +// The associated database is expected, otherwise the whole commit +// functionality should be disabled. +func (st *StackTrie) Commit() (common.Hash, error) { + if st.db == nil { + return common.Hash{}, ErrCommitDisabled + } + st.hash() + if len(st.val) != 32 { + // If the node's RLP isn't 32 bytes long, the node will not + // be hashed (and committed), and instead contain the rlp-encoding of the + // node. For the top level node, we need to force the hashing+commit. + ret := make([]byte, 32) + h := newHasher(false) + defer returnHasherToPool(h) + h.sha.Reset() + h.sha.Write(st.val) + h.sha.Read(ret) + st.db.Put(ret, st.val) + return common.BytesToHash(ret), nil + } + return common.BytesToHash(st.val), nil +} diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go new file mode 100644 index 000000000000..d4488b4029c4 --- /dev/null +++ b/trie/stacktrie_test.go @@ -0,0 +1,291 @@ +package trie + +import ( + "bytes" + "fmt" + "math/big" + mrand "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb/memorydb" +) + +func TestSizeBug(t *testing.T) { + st := NewStackTrie(nil) + nt, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + + leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + + nt.TryUpdate(leaf, value) + st.TryUpdate(leaf, value) + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func TestEmptyBug(t *testing.T) { + st := NewStackTrie(nil) + nt, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + + //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + kvs := []struct { + K string + V string + }{ + {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "9496f4ec2bf9dab484cac6be589e8417d84781be08"}, + {K: "40edb63a35fcf86c08022722aa3287cdd36440d671b4918131b2514795fefa9c", V: "01"}, + {K: "b10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", V: "947a30f7736e48d6599356464ba4c150d8da0302ff"}, + {K: "c2575a0e9e593c00f959f8c92f12db2869c3395a3b0502d05e2516446f71f85b", V: "02"}, + } + + for _, kv := range kvs { + nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + } + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func TestValLength56(t *testing.T) { + st := NewStackTrie(nil) + nt, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + + //leaf := common.FromHex("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563") + //value := common.FromHex("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3") + kvs := []struct { + K string + V string + }{ + {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"}, + } + + for _, kv := range kvs { + nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + } + + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +func genTxs(num uint64) (types.Transactions, error) { + key, err := crypto.HexToECDSA("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") + if err != nil { + return nil, err + } + var addr = crypto.PubkeyToAddress(key.PublicKey) + newTx := func(i uint64) (*types.Transaction, error) { + signer := types.NewEIP155Signer(big.NewInt(18)) + tx, err := types.SignTx(types.NewTransaction(i, addr, new(big.Int), 0, new(big.Int).SetUint64(10000000), nil), signer, key) + return tx, err + } + var txs types.Transactions + for i := uint64(0); i < num; i++ { + tx, err := newTx(i) + if err != nil { + return nil, err + } + txs = append(txs, tx) + } + return txs, nil +} + +func TestDeriveSha(t *testing.T) { + txs, err := genTxs(0) + if err != nil { + t.Fatal(err) + } + for len(txs) < 1000 { + exp := types.DeriveSha(txs, newEmpty()) + got := types.DeriveSha(txs, NewStackTrie(nil)) + if !bytes.Equal(got[:], exp[:]) { + t.Fatalf("%d txs: got %x exp %x", len(txs), got, exp) + } + newTxs, err := genTxs(uint64(len(txs) + 1)) + if err != nil { + t.Fatal(err) + } + txs = append(txs, newTxs...) + } +} + +func BenchmarkDeriveSha200(b *testing.B) { + txs, err := genTxs(200) + if err != nil { + b.Fatal(err) + } + var exp common.Hash + var got common.Hash + b.Run("std_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + exp = types.DeriveSha(txs, newEmpty()) + } + }) + + b.Run("stack_trie", func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + got = types.DeriveSha(txs, NewStackTrie(nil)) + } + }) + if got != exp { + b.Errorf("got %x exp %x", got, exp) + } +} + +type dummyDerivableList struct { + len int + seed int +} + +func newDummy(seed int) *dummyDerivableList { + d := &dummyDerivableList{} + src := mrand.NewSource(int64(seed)) + // don't use lists longer than 4K items + d.len = int(src.Int63() & 0x0FFF) + d.seed = seed + return d +} + +func (d *dummyDerivableList) Len() int { + return d.len +} + +func (d *dummyDerivableList) GetRlp(i int) []byte { + src := mrand.NewSource(int64(d.seed + i)) + // max item size 256, at least 1 byte per item + size := 1 + src.Int63()&0x00FF + data := make([]byte, size) + _, err := mrand.New(src).Read(data) + if err != nil { + panic(err) + } + return data +} + +func printList(l types.DerivableList) { + fmt.Printf("list length: %d\n", l.Len()) + fmt.Printf("{\n") + for i := 0; i < l.Len(); i++ { + v := l.GetRlp(i) + fmt.Printf("\"0x%x\",\n", v) + } + fmt.Printf("},\n") +} + +func TestFuzzDeriveSha(t *testing.T) { + // increase this for longer runs -- it's set to quite low for travis + rndSeed := mrand.Int() + for i := 0; i < 10; i++ { + seed := rndSeed + i + exp := types.DeriveSha(newDummy(i), newEmpty()) + got := types.DeriveSha(newDummy(i), NewStackTrie(nil)) + if !bytes.Equal(got[:], exp[:]) { + printList(newDummy(seed)) + t.Fatalf("seed %d: got %x exp %x", seed, got, exp) + } + } +} + +type flatList struct { + rlpvals []string +} + +func newFlatList(rlpvals []string) *flatList { + return &flatList{rlpvals} +} +func (f *flatList) Len() int { + return len(f.rlpvals) +} +func (f *flatList) GetRlp(i int) []byte { + return hexutil.MustDecode(f.rlpvals[i]) +} + +// TestDerivableList contains testcases found via fuzzing +func TestDerivableList(t *testing.T) { + type tcase []string + tcs := []tcase{ + { + "0xc041", + }, + { + "0xf04cf757812428b0763112efb33b6f4fad7deb445e", + "0xf04cf757812428b0763112efb33b6f4fad7deb445e", + }, + { + "0xca410605310cdc3bb8d4977ae4f0143df54a724ed873457e2272f39d66e0460e971d9d", + "0x6cd850eca0a7ac46bb1748d7b9cb88aa3bd21c57d852c28198ad8fa422c4595032e88a4494b4778b36b944fe47a52b8c5cd312910139dfcb4147ab8e972cc456bcb063f25dd78f54c4d34679e03142c42c662af52947d45bdb6e555751334ace76a5080ab5a0256a1d259855dfc5c0b8023b25befbb13fd3684f9f755cbd3d63544c78ee2001452dd54633a7593ade0b183891a0a4e9c7844e1254005fbe592b1b89149a502c24b6e1dca44c158aebedf01beae9c30cabe16a", + "0x14abd5c47c0be87b0454596baad2", + "0xca410605310cdc3bb8d4977ae4f0143df54a724ed873457e2272f39d66e0460e971d9d", + }, + } + for i, tc := range tcs[1:] { + exp := types.DeriveSha(newFlatList(tc), newEmpty()) + got := types.DeriveSha(newFlatList(tc), NewStackTrie(nil)) + if !bytes.Equal(got[:], exp[:]) { + t.Fatalf("case %d: got %x exp %x", i, got, exp) + } + } +} + +// TestUpdateSmallNodes tests a case where the leaves are small (both key and value), +// which causes a lot of node-within-node. This case was found via fuzzing. +func TestUpdateSmallNodes(t *testing.T) { + st := NewStackTrie(nil) + nt, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + kvs := []struct { + K string + V string + }{ + {"63303030", "3041"}, // stacktrie.Update + {"65", "3000"}, // stacktrie.Update + } + for _, kv := range kvs { + nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + } + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} + +// TestUpdateVariableKeys contains a case which stacktrie fails: when keys of different +// sizes are used, and the second one has the same prefix as the first, then the +// stacktrie fails, since it's unable to 'expand' on an already added leaf. +// For all practical purposes, this is fine, since keys are fixed-size length +// in account and storage tries. +// +// The test is marked as 'skipped', and exists just to have the behaviour documented. +// This case was found via fuzzing. +func TestUpdateVariableKeys(t *testing.T) { + t.SkipNow() + st := NewStackTrie(nil) + nt, _ := New(common.Hash{}, NewDatabase(memorydb.New())) + kvs := []struct { + K string + V string + }{ + {"0x33303534636532393561313031676174", "303030"}, + {"0x3330353463653239356131303167617430", "313131"}, + } + for _, kv := range kvs { + nt.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + st.TryUpdate(common.FromHex(kv.K), common.FromHex(kv.V)) + } + if nt.Hash() != st.Hash() { + t.Fatalf("error %x != %x", st.Hash(), nt.Hash()) + } +} diff --git a/trie/sync_bloom.go b/trie/sync_bloom.go index 89f61d66d9fd..979f4748f3d0 100644 --- a/trie/sync_bloom.go +++ b/trie/sync_bloom.go @@ -125,14 +125,14 @@ func (b *SyncBloom) init(database ethdb.Iteratee) { it.Release() it = database.NewIterator(nil, key) - log.Info("Initializing fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Initializing state bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start))) swap = time.Now() } } it.Release() // Mark the bloom filter inited and return - log.Info("Initialized fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start))) + log.Info("Initialized state bloom", "items", b.bloom.N(), "errorrate", b.errorRate(), "elapsed", common.PrettyDuration(time.Since(start))) atomic.StoreUint32(&b.inited, 1) } @@ -162,7 +162,7 @@ func (b *SyncBloom) Close() error { b.pend.Wait() // Wipe the bloom, but mark it "uninited" just in case someone attempts an access - log.Info("Deallocated fast sync bloom", "items", b.bloom.N(), "errorrate", b.errorRate()) + log.Info("Deallocated state bloom", "items", b.bloom.N(), "errorrate", b.errorRate()) atomic.StoreUint32(&b.inited, 0) b.bloom = nil diff --git a/trie/sync_test.go b/trie/sync_test.go index 39e0f9575edc..cb3283875d6c 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -377,7 +377,6 @@ func TestIncompleteSync(t *testing.T) { nodes, _, codes := sched.Missing(1) queue := append(append([]common.Hash{}, nodes...), codes...) - for len(queue) > 0 { // Fetch a batch of trie nodes results := make([]SyncResult, len(queue)) @@ -401,10 +400,8 @@ func TestIncompleteSync(t *testing.T) { batch.Write() for _, result := range results { added = append(added, result.Hash) - } - // Check that all known sub-tries in the synced trie are complete - for _, root := range added { - if err := checkTrieConsistency(triedb, root); err != nil { + // Check that all known sub-tries in the synced trie are complete + if err := checkTrieConsistency(triedb, result.Hash); err != nil { t.Fatalf("trie inconsistent: %v", err) } } diff --git a/trie/trie.go b/trie/trie.go index 1e1749a4ff7f..87b72ecf17f8 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -19,13 +19,13 @@ package trie import ( "bytes" + "errors" "fmt" "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/rlp" ) var ( @@ -159,29 +159,26 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { if item == nil { return nil, resolved, nil } - enc, err := rlp.EncodeToBytes(item) - if err != nil { - log.Error("Encoding existing trie node failed", "err", err) - return nil, resolved, err - } - return enc, resolved, err + return item, resolved, err } -func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item node, newnode node, resolved int, err error) { +func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { // If we reached the requested path, return the current node if pos >= len(path) { - // Don't return collapsed hash nodes though - if _, ok := origNode.(hashNode); !ok { - // Short nodes have expanded keys, compact them before returning - item := origNode - if sn, ok := item.(*shortNode); ok { - item = &shortNode{ - Key: hexToCompact(sn.Key), - Val: sn.Val, - } - } - return item, origNode, 0, nil + // Although we most probably have the original node expanded, encoding + // that into consensus form can be nasty (needs to cascade down) and + // time consuming. Instead, just pull the hash up from disk directly. + var hash hashNode + if node, ok := origNode.(hashNode); ok { + hash = node + } else { + hash, _ = origNode.cache() + } + if hash == nil { + return nil, origNode, 0, errors.New("non-consensus node") } + blob, err := t.db.Node(common.BytesToHash(hash)) + return blob, origNode, 1, err } // Path still needs to be traversed, descend into children switch n := (origNode).(type) { @@ -491,7 +488,7 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - hash, cached, _ := t.hashRoot(nil) + hash, cached, _ := t.hashRoot() t.root = cached return common.BytesToHash(hash.(hashNode)) } @@ -505,13 +502,16 @@ func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) { if t.root == nil { return emptyRoot, nil } + // Derive the hash for all dirty nodes first. We hold the assumption + // in the following procedure that all nodes are hashed. rootHash := t.Hash() h := newCommitter() defer returnCommitterToPool(h) + // Do a quick check if we really need to commit, before we spin // up goroutines. This can happen e.g. if we load a trie for reading storage // values, but don't write to it. - if !h.commitNeeded(t.root) { + if _, dirty := t.root.cache(); !dirty { return rootHash, nil } var wg sync.WaitGroup @@ -542,7 +542,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) { } // hashRoot calculates the root hash of the given trie -func (t *Trie) hashRoot(db *Database) (node, node, error) { +func (t *Trie) hashRoot() (node, node, error) { if t.root == nil { return hashNode(emptyRoot.Bytes()), nil, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index 2356b7a74647..ddbdcbbd5b6f 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -19,7 +19,9 @@ package trie import ( "bytes" "encoding/binary" + "errors" "fmt" + "hash" "io/ioutil" "math/big" "math/rand" @@ -31,9 +33,11 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/leveldb" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/rlp" + "golang.org/x/crypto/sha3" ) func init() { @@ -590,21 +594,20 @@ func benchmarkCommitAfterHash(b *testing.B, onleaf LeafCallback) { func TestTinyTrie(t *testing.T) { // Create a realistic account trie to hash - _, accounts := makeAccounts(10000) + _, accounts := makeAccounts(5) trie := newEmpty() trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3]) - if exp, root := common.HexToHash("4fa6efd292cffa2db0083b8bedd23add2798ae73802442f52486e95c3df7111c"), trie.Hash(); exp != root { - t.Fatalf("1: got %x, exp %x", root, exp) + if exp, root := common.HexToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root { + t.Errorf("1: got %x, exp %x", root, exp) } trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4]) - if exp, root := common.HexToHash("cb5fb1213826dad9e604f095f8ceb5258fe6b5c01805ce6ef019a50699d2d479"), trie.Hash(); exp != root { - t.Fatalf("2: got %x, exp %x", root, exp) + if exp, root := common.HexToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root { + t.Errorf("2: got %x, exp %x", root, exp) } trie.Update(common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4]) - if exp, root := common.HexToHash("ed7e06b4010057d8703e7b9a160a6d42cf4021f9020da3c8891030349a646987"), trie.Hash(); exp != root { - t.Fatalf("3: got %x, exp %x", root, exp) + if exp, root := common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root { + t.Errorf("3: got %x, exp %x", root, exp) } - checktr, _ := New(common.Hash{}, trie.db) it := NewIterator(trie.NodeIterator(nil)) for it.Next() { @@ -626,7 +629,7 @@ func TestCommitAfterHash(t *testing.T) { trie.Hash() trie.Commit(nil) root := trie.Hash() - exp := common.HexToHash("e5e9c29bb50446a4081e6d1d748d2892c6101c1e883a1f77cf21d4094b697822") + exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6") if exp != root { t.Errorf("got %x, exp %x", root, exp) } @@ -642,23 +645,257 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { // Create a realistic account trie to hash addresses = make([][20]byte, size) for i := 0; i < len(addresses); i++ { - for j := 0; j < len(addresses[i]); j++ { - addresses[i][j] = byte(random.Intn(256)) - } + data := make([]byte, 20) + random.Read(data) + copy(addresses[i][:], data) } accounts = make([][]byte, len(addresses)) for i := 0; i < len(accounts); i++ { var ( - nonce = uint64(random.Int63()) - balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) - root = emptyRoot - code = crypto.Keccak256(nil) + nonce = uint64(random.Int63()) + root = emptyRoot + code = crypto.Keccak256(nil) ) - accounts[i], _ = rlp.EncodeToBytes(&account{nonce, balance, root, code}) + // The big.Rand function is not deterministic with regards to 64 vs 32 bit systems, + // and will consume different amount of data from the rand source. + //balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + // Therefore, we instead just read via byte buffer + numBytes := random.Uint32() % 33 // [0, 32] bytes + balanceBytes := make([]byte, numBytes) + random.Read(balanceBytes) + balance := new(big.Int).SetBytes(balanceBytes) + data, _ := rlp.EncodeToBytes(&account{nonce, balance, root, code}) + accounts[i] = data } return addresses, accounts } +// spongeDb is a dummy db backend which accumulates writes in a sponge +type spongeDb struct { + sponge hash.Hash + id string + journal []string +} + +func (s *spongeDb) Has(key []byte) (bool, error) { panic("implement me") } +func (s *spongeDb) Get(key []byte) ([]byte, error) { return nil, errors.New("no such elem") } +func (s *spongeDb) Delete(key []byte) error { panic("implement me") } +func (s *spongeDb) NewBatch() ethdb.Batch { return &spongeBatch{s} } +func (s *spongeDb) Stat(property string) (string, error) { panic("implement me") } +func (s *spongeDb) Compact(start []byte, limit []byte) error { panic("implement me") } +func (s *spongeDb) Close() error { return nil } +func (s *spongeDb) Put(key []byte, value []byte) error { + valbrief := value + if len(valbrief) > 8 { + valbrief = valbrief[:8] + } + s.journal = append(s.journal, fmt.Sprintf("%v: PUT([%x...], [%d bytes] %x...)\n", s.id, key[:8], len(value), valbrief)) + s.sponge.Write(key) + s.sponge.Write(value) + return nil +} +func (s *spongeDb) NewIterator(prefix []byte, start []byte) ethdb.Iterator { panic("implement me") } + +// spongeBatch is a dummy batch which immediately writes to the underlying spongedb +type spongeBatch struct { + db *spongeDb +} + +func (b *spongeBatch) Put(key, value []byte) error { + b.db.Put(key, value) + return nil +} +func (b *spongeBatch) Delete(key []byte) error { panic("implement me") } +func (b *spongeBatch) ValueSize() int { return 100 } +func (b *spongeBatch) Write() error { return nil } +func (b *spongeBatch) Reset() {} +func (b *spongeBatch) Replay(w ethdb.KeyValueWriter) error { return nil } + +// TestCommitSequence tests that the trie.Commit operation writes the elements of the trie +// in the expected order, and calls the callbacks in the expected order. +// The test data was based on the 'master' code, and is basically random. It can be used +// to check whether changes to the trie modifies the write order or data in any way. +func TestCommitSequence(t *testing.T) { + for i, tc := range []struct { + count int + expWriteSeqHash []byte + expCallbackSeqHash []byte + }{ + {20, common.FromHex("873c78df73d60e59d4a2bcf3716e8bfe14554549fea2fc147cb54129382a8066"), + common.FromHex("ff00f91ac05df53b82d7f178d77ada54fd0dca64526f537034a5dbe41b17df2a")}, + {200, common.FromHex("ba03d891bb15408c940eea5ee3d54d419595102648d02774a0268d892add9c8e"), + common.FromHex("f3cd509064c8d319bbdd1c68f511850a902ad275e6ed5bea11547e23d492a926")}, + {2000, common.FromHex("f7a184f20df01c94f09537401d11e68d97ad0c00115233107f51b9c287ce60c7"), + common.FromHex("ff795ea898ba1e4cfed4a33b4cf5535a347a02cf931f88d88719faf810f9a1c9")}, + } { + addresses, accounts := makeAccounts(tc.count) + // This spongeDb is used to check the sequence of disk-db-writes + s := &spongeDb{sponge: sha3.NewLegacyKeccak256()} + db := NewDatabase(s) + trie, _ := New(common.Hash{}, db) + // Another sponge is used to check the callback-sequence + callbackSponge := sha3.NewLegacyKeccak256() + // Fill the trie with elements + for i := 0; i < tc.count; i++ { + trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + } + // Flush trie -> database + root, _ := trie.Commit(nil) + // Flush memdb -> disk (sponge) + db.Commit(root, false, func(c common.Hash) { + // And spongify the callback-order + callbackSponge.Write(c[:]) + }) + if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) { + t.Errorf("test %d, disk write sequence wrong:\ngot %x exp %x\n", i, got, exp) + } + if got, exp := callbackSponge.Sum(nil), tc.expCallbackSeqHash; !bytes.Equal(got, exp) { + t.Errorf("test %d, call back sequence wrong:\ngot: %x exp %x\n", i, got, exp) + } + } +} + +// TestCommitSequenceRandomBlobs is identical to TestCommitSequence +// but uses random blobs instead of 'accounts' +func TestCommitSequenceRandomBlobs(t *testing.T) { + for i, tc := range []struct { + count int + expWriteSeqHash []byte + expCallbackSeqHash []byte + }{ + {20, common.FromHex("8e4a01548551d139fa9e833ebc4e66fc1ba40a4b9b7259d80db32cff7b64ebbc"), + common.FromHex("450238d73bc36dc6cc6f926987e5428535e64be403877c4560e238a52749ba24")}, + {200, common.FromHex("6869b4e7b95f3097a19ddb30ff735f922b915314047e041614df06958fc50554"), + common.FromHex("0ace0b03d6cb8c0b82f6289ef5b1a1838306b455a62dafc63cada8e2924f2550")}, + {2000, common.FromHex("444200e6f4e2df49f77752f629a96ccf7445d4698c164f962bbd85a0526ef424"), + common.FromHex("117d30dafaa62a1eed498c3dfd70982b377ba2b46dd3e725ed6120c80829e518")}, + } { + prng := rand.New(rand.NewSource(int64(i))) + // This spongeDb is used to check the sequence of disk-db-writes + s := &spongeDb{sponge: sha3.NewLegacyKeccak256()} + db := NewDatabase(s) + trie, _ := New(common.Hash{}, db) + // Another sponge is used to check the callback-sequence + callbackSponge := sha3.NewLegacyKeccak256() + // Fill the trie with elements + for i := 0; i < tc.count; i++ { + key := make([]byte, 32) + var val []byte + // 50% short elements, 50% large elements + if prng.Intn(2) == 0 { + val = make([]byte, 1+prng.Intn(32)) + } else { + val = make([]byte, 1+prng.Intn(4096)) + } + prng.Read(key) + prng.Read(val) + trie.Update(key, val) + } + // Flush trie -> database + root, _ := trie.Commit(nil) + // Flush memdb -> disk (sponge) + db.Commit(root, false, func(c common.Hash) { + // And spongify the callback-order + callbackSponge.Write(c[:]) + }) + if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) { + t.Fatalf("test %d, disk write sequence wrong:\ngot %x exp %x\n", i, got, exp) + } + if got, exp := callbackSponge.Sum(nil), tc.expCallbackSeqHash; !bytes.Equal(got, exp) { + t.Fatalf("test %d, call back sequence wrong:\ngot: %x exp %x\n", i, got, exp) + } + } +} + +func TestCommitSequenceStackTrie(t *testing.T) { + for count := 1; count < 200; count++ { + prng := rand.New(rand.NewSource(int64(count))) + // This spongeDb is used to check the sequence of disk-db-writes + s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"} + db := NewDatabase(s) + trie, _ := New(common.Hash{}, db) + // Another sponge is used for the stacktrie commits + stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} + stTrie := NewStackTrie(stackTrieSponge) + // Fill the trie with elements + for i := 1; i < count; i++ { + // For the stack trie, we need to do inserts in proper order + key := make([]byte, 32) + binary.BigEndian.PutUint64(key, uint64(i)) + var val []byte + // 50% short elements, 50% large elements + if prng.Intn(2) == 0 { + val = make([]byte, 1+prng.Intn(32)) + } else { + val = make([]byte, 1+prng.Intn(1024)) + } + prng.Read(val) + trie.TryUpdate(key, common.CopyBytes(val)) + stTrie.TryUpdate(key, common.CopyBytes(val)) + } + // Flush trie -> database + root, _ := trie.Commit(nil) + // Flush memdb -> disk (sponge) + db.Commit(root, false, nil) + // And flush stacktrie -> disk + stRoot, err := stTrie.Commit() + if err != nil { + t.Fatalf("Failed to commit stack trie %v", err) + } + if stRoot != root { + t.Fatalf("root wrong, got %x exp %x", stRoot, root) + } + if got, exp := stackTrieSponge.sponge.Sum(nil), s.sponge.Sum(nil); !bytes.Equal(got, exp) { + // Show the journal + t.Logf("Expected:") + for i, v := range s.journal { + t.Logf("op %d: %v", i, v) + } + t.Logf("Stacktrie:") + for i, v := range stackTrieSponge.journal { + t.Logf("op %d: %v", i, v) + } + t.Fatalf("test %d, disk write sequence wrong:\ngot %x exp %x\n", count, got, exp) + } + } +} + +// TestCommitSequenceSmallRoot tests that a trie which is essentially only a +// small (<32 byte) shortnode with an included value is properly committed to a +// database. +// This case might not matter, since in practice, all keys are 32 bytes, which means +// that even a small trie which contains a leaf will have an extension making it +// not fit into 32 bytes, rlp-encoded. However, it's still the correct thing to do. +func TestCommitSequenceSmallRoot(t *testing.T) { + s := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "a"} + db := NewDatabase(s) + trie, _ := New(common.Hash{}, db) + // Another sponge is used for the stacktrie commits + stackTrieSponge := &spongeDb{sponge: sha3.NewLegacyKeccak256(), id: "b"} + stTrie := NewStackTrie(stackTrieSponge) + // Add a single small-element to the trie(s) + key := make([]byte, 5) + key[0] = 1 + trie.TryUpdate(key, []byte{0x1}) + stTrie.TryUpdate(key, []byte{0x1}) + // Flush trie -> database + root, _ := trie.Commit(nil) + // Flush memdb -> disk (sponge) + db.Commit(root, false, nil) + // And flush stacktrie -> disk + stRoot, err := stTrie.Commit() + if err != nil { + t.Fatalf("Failed to commit stack trie %v", err) + } + if stRoot != root { + t.Fatalf("root wrong, got %x exp %x", stRoot, root) + } + fmt.Printf("root: %x\n", stRoot) + if got, exp := stackTrieSponge.sponge.Sum(nil), s.sponge.Sum(nil); !bytes.Equal(got, exp) { + t.Fatalf("test, disk write sequence wrong:\ngot %x exp %x\n", got, exp) + } +} + // BenchmarkCommitAfterHashFixedSize benchmarks the Commit (after Hash) of a fixed number of updates to a trie. // This benchmark is meant to capture the difference on efficiency of small versus large changes. Typically, // storage tries are small (a couple of entries), whereas the full post-block account trie update is large (a couple