diff --git a/wire/common.go b/wire/common.go index 404c72f3d8..59f254c5c3 100644 --- a/wire/common.go +++ b/wire/common.go @@ -308,6 +308,14 @@ func readElement(r io.Reader, element interface{}) error { } return nil + // Mixed message + case *[MixMsgSize]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + case *[32]byte: _, err := io.ReadFull(r, e[:]) if err != nil { @@ -322,6 +330,38 @@ func readElement(r io.Reader, element interface{}) error { } return nil + // Mix identity + case *[33]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + // Mix signature + case *[64]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + // sntrup4591651 ciphertext + case *[1047]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + + // sntrup4591651 public key + case *[1218]byte: + _, err := io.ReadFull(r, e[:]) + if err != nil { + return err + } + return nil + case *ServiceFlag: rv, err := binarySerializer.Uint64(r, littleEndian) if err != nil { @@ -377,6 +417,20 @@ func writeElement(w io.Writer, element interface{}) error { // Attempt to write the element based on the concrete type via fast // type assertions first. switch e := element.(type) { + case uint8: + err := binarySerializer.PutUint8(w, e) + if err != nil { + return err + } + return nil + + case uint16: + err := binarySerializer.PutUint16(w, littleEndian, e) + if err != nil { + return err + } + return nil + case int32: err := binarySerializer.PutUint32(w, littleEndian, uint32(e)) if err != nil { @@ -441,6 +495,21 @@ func writeElement(w io.Writer, element interface{}) error { } return nil + // Mixed message + case *[MixMsgSize]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + + case *[32]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + case *chainhash.Hash: _, err := w.Write(e[:]) if err != nil { @@ -448,6 +517,38 @@ func writeElement(w io.Writer, element interface{}) error { } return nil + // Mix identity + case *[33]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + + // Mix signature + case *[64]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + + // sntrup4591761 ciphertext + case *[1047]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + + // sntrup4591761 public key + case *[1218]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + case ServiceFlag: err := binarySerializer.PutUint64(w, littleEndian, uint64(e)) if err != nil { diff --git a/wire/common_test.go b/wire/common_test.go index 4901c67492..e02bc69098 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -907,3 +907,59 @@ func TestRandomUint64Errors(t *testing.T) { t.Errorf("Nonce is not 0 [%v]", nonce) } } + +// repeat returns the byte slice containing count elements of the byte b. +func repeat(b byte, count int) []byte { + s := make([]byte, count) + for i := range s { + s[i] = b + } + return s +} + +// rhash returns a chainhash.Hash with all bytes set to b. +func rhash(b byte) chainhash.Hash { + var h chainhash.Hash + for i := range h { + h[i] = b + } + return h +} + +// varBytesLen returns the size required to encode l bytes as a varint +// followed by the bytes themselves. +func varBytesLen(l uint32) uint32 { + return uint32(VarIntSerializeSize(uint64(l))) + l +} + +// expectedSerializationCompare compares serialized bytes to the expected +// sequence of bytes. When got and expected are not equal, the test t will be +// errored with descriptive messages of how the two encodings are different. +// Returns true if the serialization are equal, and false if the test +// errors. +func expectedSerializationEqual(t *testing.T, got, expected []byte) bool { + if bytes.Equal(got, expected) { + return true + } + + t.Errorf("encoded message differs from expected serialization") + minLen := len(expected) + if len(got) < minLen { + minLen = len(got) + } + for i := 0; i < minLen; i++ { + if b := got[i]; b != expected[i] { + t.Errorf("message differs at index %d (got 0x%x, expected 0x%x)", + i, b, expected[i]) + } + } + if len(got) > len(expected) { + t.Errorf("serialized message contains extra bytes [%x]", + got[len(expected):]) + } + if len(expected) > len(got) { + t.Errorf("serialization prematurely ends at index %d, missing bytes [%x]", + len(got), expected[len(got):]) + } + return false +} diff --git a/wire/error.go b/wire/error.go index d21279220d..81a104f255 100644 --- a/wire/error.go +++ b/wire/error.go @@ -133,6 +133,22 @@ const ( // ErrTooManyTSpends is returned when the number of tspend hashes // exceeds the maximum allowed. ErrTooManyTSpends + + // ErrTooManyManyMixPairReqs is returned when the number of mix pair + // request message hashes exceeds the maximum allowed. + ErrTooManyManyMixPairReqs + + // ErrMixPairReqScriptClassTooLong is returned when a mixing script + // class type string is longer than allowed by the protocol. + ErrMixPairReqScriptClassTooLong + + // ErrTooManyMixPairReqUTXOs is returned when a MixPairReq message + // contains more UTXOs than allowed by the protocol. + ErrTooManyMixPairReqUTXOs + + // ErrTooManyPrevMixMsgs is returned when too many previous messages of + // a mix run are referenced by a message. + ErrTooManyPrevMixMsgs ) // Map of ErrorCode values back to their constant names for pretty printing. @@ -168,6 +184,10 @@ var errorCodeStrings = map[ErrorCode]string{ ErrTooManyInitStateTypes: "ErrTooManyInitStateTypes", ErrInitStateTypeTooLong: "ErrInitStateTypeTooLong", ErrTooManyTSpends: "ErrTooManyTSpends", + ErrTooManyManyMixPairReqs: "ErrTooManyManyMixPairReqs", + ErrMixPairReqScriptClassTooLong: "ErrMixPairReqScriptClassTooLong", + ErrTooManyMixPairReqUTXOs: "ErrTooManyMixPairReqUTXOs", + ErrTooManyPrevMixMsgs: "ErrTooManyPrevMixMsgs", } // String returns the ErrorCode as a human-readable name. diff --git a/wire/error_test.go b/wire/error_test.go index 155738518f..4c162f88dd 100644 --- a/wire/error_test.go +++ b/wire/error_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2017 The btcsuite developers -// Copyright (c) 2015-2020 The Decred developers +// Copyright (c) 2015-2023 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -50,6 +50,10 @@ func TestMessageErrorCodeStringer(t *testing.T) { {ErrTooManyInitStateTypes, "ErrTooManyInitStateTypes"}, {ErrInitStateTypeTooLong, "ErrInitStateTypeTooLong"}, {ErrTooManyTSpends, "ErrTooManyTSpends"}, + {ErrTooManyManyMixPairReqs, "ErrTooManyManyMixPairReqs"}, + {ErrMixPairReqScriptClassTooLong, "ErrMixPairReqScriptClassTooLong"}, + {ErrTooManyMixPairReqUTXOs, "ErrTooManyMixPairReqUTXOs"}, + {ErrTooManyPrevMixMsgs, "ErrTooManyPrevMixMsgs"}, {0xffff, "Unknown ErrorCode (65535)"}, } diff --git a/wire/go.mod b/wire/go.mod index e2ca97f58f..7f6801e5df 100644 --- a/wire/go.mod +++ b/wire/go.mod @@ -5,10 +5,8 @@ go 1.17 require ( github.com/davecgh/go-spew v1.1.1 github.com/decred/dcrd/chaincfg/chainhash v1.0.4 + github.com/decred/dcrd/crypto/blake256 v1.0.1 lukechampine.com/blake3 v1.2.1 ) -require ( - github.com/decred/dcrd/crypto/blake256 v1.0.1 // indirect - github.com/klauspost/cpuid/v2 v2.0.9 // indirect -) +require github.com/klauspost/cpuid/v2 v2.0.9 // indirect diff --git a/wire/invvect.go b/wire/invvect.go index ca4b1aae14..d6443dd196 100644 --- a/wire/invvect.go +++ b/wire/invvect.go @@ -30,6 +30,7 @@ const ( InvTypeTx InvType = 1 InvTypeBlock InvType = 2 InvTypeFilteredBlock InvType = 3 + InvTypeMix InvType = 4 ) // Map of service flags back to their constant names for pretty printing. @@ -38,6 +39,7 @@ var ivStrings = map[InvType]string{ InvTypeTx: "MSG_TX", InvTypeBlock: "MSG_BLOCK", InvTypeFilteredBlock: "MSG_FILTERED_BLOCK", + InvTypeMix: "MSG_MIX", } // String returns the InvType in human-readable form. diff --git a/wire/invvect_test.go b/wire/invvect_test.go index 95158f9a6e..fb7b417d3d 100644 --- a/wire/invvect_test.go +++ b/wire/invvect_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2016 The Decred developers +// Copyright (c) 2015-2023 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -23,6 +23,7 @@ func TestInvTypeStringer(t *testing.T) { {InvTypeError, "ERROR"}, {InvTypeTx, "MSG_TX"}, {InvTypeBlock, "MSG_BLOCK"}, + {InvTypeMix, "MSG_MIX"}, {0xffffffff, "Unknown InvType (4294967295)"}, } diff --git a/wire/message.go b/wire/message.go index f0aa4ae96d..a0adfeddba 100644 --- a/wire/message.go +++ b/wire/message.go @@ -52,6 +52,13 @@ const ( CmdCFilterV2 = "cfilterv2" CmdGetInitState = "getinitstate" CmdInitState = "initstate" + CmdMixPairReq = "mixpairreq" + CmdMixKeyExchange = "mixkeyxchg" + CmdMixCiphertexts = "mixcphrtxt" + CmdMixSlotReserve = "mixslotres" + CmdMixDCNet = "mixdcnet" + CmdMixConfirm = "mixconfirm" + CmdMixSecrets = "mixsecrets" ) const ( @@ -188,6 +195,27 @@ func makeEmptyMessage(command string) (Message, error) { case CmdInitState: msg = &MsgInitState{} + case CmdMixPairReq: + msg = &MsgMixPairReq{} + + case CmdMixKeyExchange: + msg = &MsgMixKeyExchange{} + + case CmdMixCiphertexts: + msg = &MsgMixCiphertexts{} + + case CmdMixSlotReserve: + msg = &MsgMixSlotReserve{} + + case CmdMixDCNet: + msg = &MsgMixDCNet{} + + case CmdMixConfirm: + msg = &MsgMixConfirm{} + + case CmdMixSecrets: + msg = &MsgMixSecrets{} + default: str := fmt.Sprintf("unhandled command [%s]", command) return nil, messageError(op, ErrUnknownCmd, str) diff --git a/wire/message_test.go b/wire/message_test.go index b29f14d418..fee3629630 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2021 The Decred developers +// Copyright (c) 2015-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -80,6 +80,16 @@ func TestMessage(t *testing.T) { msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block") msgGetInitState := NewMsgGetInitState() msgInitState := NewMsgInitState() + msgMixPR, err := NewMsgMixPairReq([33]byte{}, 1, 1, "", 1, 1, 1, 1, []MixPairReqUTXO{}, NewTxOut(0, []byte{})) + if err != nil { + t.Errorf("NewMsgMixPairReq: %v", err) + } + msgMixKE := NewMsgMixKeyExchange([33]byte{}, [32]byte{}, 1, [33]byte{}, [1218]byte{}, [32]byte{}, []chainhash.Hash{}) + msgMixCT := NewMsgMixCiphertexts([33]byte{}, [32]byte{}, 1, [][1047]byte{}, []chainhash.Hash{}) + msgMixSR := NewMsgMixSlotReserve([33]byte{}, [32]byte{}, 1, [][][]byte{{{}}}, []chainhash.Hash{}) + msgMixDC := NewMsgMixDCNet([33]byte{}, [32]byte{}, 1, []MixVect{make(MixVect, 1)}, []chainhash.Hash{}) + msgMixCM := NewMsgMixConfirm([33]byte{}, [32]byte{}, 1, NewMsgTx(), []chainhash.Hash{}) + msgMixRS := NewMsgMixSecrets([33]byte{}, [32]byte{}, 1, [32]byte{}, [][]byte{}, MixVect{}) tests := []struct { in Message // Value to encode @@ -111,7 +121,14 @@ func TestMessage(t *testing.T) { {msgCFHeaders, msgCFHeaders, pver, MainNet, 58}, {msgCFTypes, msgCFTypes, pver, MainNet, 26}, {msgGetInitState, msgGetInitState, pver, MainNet, 25}, - {msgInitState, msgInitState, pver, MainNet, 27}, + {msgInitState, msgInitState, pver, MainNet, 28}, + {msgMixPR, msgMixPR, pver, MainNet, 165}, + {msgMixKE, msgMixKE, pver, MainNet, 1441}, + {msgMixCT, msgMixCT, pver, MainNet, 158}, + {msgMixSR, msgMixSR, pver, MainNet, 161}, + {msgMixDC, msgMixDC, pver, MainNet, 181}, + {msgMixCM, msgMixCM, pver, MainNet, 173}, + {msgMixRS, msgMixRS, pver, MainNet, 192}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/mixvect.go b/wire/mixvect.go new file mode 100644 index 0000000000..c2b983c22f --- /dev/null +++ b/wire/mixvect.go @@ -0,0 +1,29 @@ +// Copyright (c) 2023 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "strings" +) + +const MixMsgSize = 20 + +// MixVect is vector of padded or unpadded DC-net messages. +type MixVect [][MixMsgSize]byte + +func (v MixVect) String() string { + b := new(strings.Builder) + b.Grow(2 + len(v)*(2*MixMsgSize+1)) + b.WriteString("[") + for i := range v { + if i != 0 { + b.WriteString(" ") + } + fmt.Fprintf(b, "%x", v[i][:]) + } + b.WriteString("]") + return b.String() +} diff --git a/wire/msggetinitstate.go b/wire/msggetinitstate.go index 350b6fe4a9..c787c77248 100644 --- a/wire/msggetinitstate.go +++ b/wire/msggetinitstate.go @@ -29,6 +29,10 @@ const ( // InitStateTSpends is the init state type used to request tpends for // voting. InitStateTSpends = "tspends" + + // InitStateMixPairReqs is the init state type used to request mixing pair + // request messages. + InitStateMixPairReqs = "mixpairreqs" ) // MsgGetInitState implements the Message interface and represents a diff --git a/wire/msginitstate.go b/wire/msginitstate.go index 1ec60ec336..05a71fdfcb 100644 --- a/wire/msginitstate.go +++ b/wire/msginitstate.go @@ -22,6 +22,10 @@ const MaxISVotesAtHeadPerMsg = 40 // 8 * 5 // message. const MaxISTSpendsAtHeadPerMsg = 7 +// MaxISMixPairReqsPerMsg is the maximum number of pair request mix messages +// at head per message. +const MaxISMixPairReqsPerMsg = 100 + // MsgInitState implements the Message interface and represents an initial // state message. It is used to receive ephemeral startup information from a // remote peer, such as blocks that can be mined upon, votes for such blocks @@ -30,9 +34,10 @@ const MaxISTSpendsAtHeadPerMsg = 7 // The content of such a message depends upon what the local peer requested // during a previous GetInitState msg. type MsgInitState struct { - BlockHashes []chainhash.Hash - VoteHashes []chainhash.Hash - TSpendHashes []chainhash.Hash + BlockHashes []chainhash.Hash + VoteHashes []chainhash.Hash + TSpendHashes []chainhash.Hash + MixPairReqHashes []chainhash.Hash } // AddBlockHash adds a new block hash to the message. Up to @@ -77,6 +82,20 @@ func (msg *MsgInitState) AddTSpendHash(hash *chainhash.Hash) error { return nil } +// AddMixPairReqHash adds a new mixing pair request message hash to the message. +// Up to MaxISMixPRsPerMsg may be added before this function errors out. +func (msg *MsgInitState) AddMixPairReqHash(hash *chainhash.Hash) error { + const op = "MsgInitState.AddMixPairReqHash" + if len(msg.MixPairReqHashes)+1 > MaxISMixPairReqsPerMsg { + msg := fmt.Sprintf("too many mixpairreq hashes for message [max %v]", + MaxISMixPairReqsPerMsg) + return messageError(op, ErrTooManyManyMixPairReqs, msg) + } + + msg.MixPairReqHashes = append(msg.MixPairReqHashes, *hash) + return nil +} + // BtcDecode decodes r using the protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgInitState) BtcDecode(r io.Reader, pver uint32) error { @@ -144,6 +163,29 @@ func (msg *MsgInitState) BtcDecode(r io.Reader, pver uint32) error { } } + // Read num mixpairreq hashes (when enabled by protocol). + if pver < MixVersion { + return nil + } + + count, err = ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxISMixPairReqsPerMsg { + msg := fmt.Sprintf("too many mixpairreq hashes for message "+ + "[count %v, max %v]", count, MaxISMixPairReqsPerMsg) + return messageError(op, ErrTooManyManyMixPairReqs, msg) + } + + msg.MixPairReqHashes = make([]chainhash.Hash, count) + for i := uint64(0); i < count; i++ { + err := readElement(r, &msg.MixPairReqHashes[i]) + if err != nil { + return err + } + } + return nil } @@ -217,6 +259,30 @@ func (msg *MsgInitState) BtcEncode(w io.Writer, pver uint32) error { } } + // Write mixpairreq hashes when enabled by protocol. + if pver < MixVersion { + return nil + } + + count = len(msg.MixPairReqHashes) + if count > MaxISMixPairReqsPerMsg { + msg := fmt.Sprintf("too many mixpairreq hashes for message "+ + "[count %v, max %v]", count, MaxISMixPairReqsPerMsg) + return messageError(op, ErrTooManyManyMixPairReqs, msg) + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + + for i := range msg.MixPairReqHashes { + err = writeElement(w, &msg.MixPairReqHashes[i]) + if err != nil { + return err + } + } + return nil } @@ -233,21 +299,29 @@ func (msg *MsgInitState) MaxPayloadLength(pver uint32) uint32 { return 0 } - return uint32(VarIntSerializeSize(MaxISBlocksAtHeadPerMsg)) + + max := uint32(VarIntSerializeSize(MaxISBlocksAtHeadPerMsg)) + (MaxISBlocksAtHeadPerMsg * chainhash.HashSize) + uint32(VarIntSerializeSize(MaxISVotesAtHeadPerMsg)) + (MaxISVotesAtHeadPerMsg * chainhash.HashSize) + uint32(VarIntSerializeSize(MaxISTSpendsAtHeadPerMsg)) + (MaxISTSpendsAtHeadPerMsg * chainhash.HashSize) + + if pver >= MixVersion { + max += uint32(VarIntSerializeSize(MaxISMixPairReqsPerMsg)) + + (MaxISMixPairReqsPerMsg * chainhash.HashSize) + } + + return max } // NewMsgInitState returns a new Decred initstate message that conforms to the // Message interface using the defaults for the fields. func NewMsgInitState() *MsgInitState { return &MsgInitState{ - BlockHashes: make([]chainhash.Hash, 0, MaxISBlocksAtHeadPerMsg), - VoteHashes: make([]chainhash.Hash, 0, MaxISVotesAtHeadPerMsg), - TSpendHashes: make([]chainhash.Hash, 0, MaxISTSpendsAtHeadPerMsg), + BlockHashes: make([]chainhash.Hash, 0, MaxISBlocksAtHeadPerMsg), + VoteHashes: make([]chainhash.Hash, 0, MaxISVotesAtHeadPerMsg), + TSpendHashes: make([]chainhash.Hash, 0, MaxISTSpendsAtHeadPerMsg), + MixPairReqHashes: make([]chainhash.Hash, 0, MaxISMixPairReqsPerMsg), } } @@ -282,9 +356,17 @@ func NewMsgInitStateFilled(blockHashes []chainhash.Hash, voteHashes []chainhash. return nil, messageError(op, ErrTooManyTSpends, msg) } + // Not preallocated. The purpose of this function is to create a + // MsgInitState from already allocated data. Assume that if the + // caller also has mixpairreq hashes to include, they have also been + // preallocated, so there is no reason to allocate the memory here + // too. + var mixPairReqHashes []chainhash.Hash + return &MsgInitState{ - BlockHashes: blockHashes, - VoteHashes: voteHashes, - TSpendHashes: tspendHashes, + BlockHashes: blockHashes, + VoteHashes: voteHashes, + TSpendHashes: tspendHashes, + MixPairReqHashes: mixPairReqHashes, }, nil } diff --git a/wire/msginitstate_test.go b/wire/msginitstate_test.go index 520e7bd5ce..73814cd3db 100644 --- a/wire/msginitstate_test.go +++ b/wire/msginitstate_test.go @@ -30,7 +30,7 @@ func TestInitState(t *testing.T) { // Ensure max payload returns the expected value for latest protocol // version. a var int and n * hashes for each of block, vote and tspend // hashes. - wantPayload := uint32((1 + 32*8) + (1 + 32*40) + (1 + 32*7)) + wantPayload := uint32((1 + 32*8) + (1 + 32*40) + (1 + 32*7) + (1 + 32*100)) maxPayload := msg.MaxPayloadLength(pver) if maxPayload != wantPayload { t.Errorf("MaxPayloadLength: wrong max payload length for "+ @@ -119,6 +119,7 @@ func TestInitStateWire(t *testing.T) { 0x00, // Varint for number of blocks 0x00, // Varint for number of votes 0x00, // Varint for number of tspends + 0x00, // Varint for number of mixpairreqs } fakeBlock1, _ := chainhash.NewHashFromStr("4433221144332211443322114" + @@ -137,6 +138,12 @@ func TestInitStateWire(t *testing.T) { "999999999999999999991199999999999999919") fakeTSpend3, _ := chainhash.NewHashFromStr("aaaaaaaaaaaa9200aaaaaa" + "aaaaaaaaaaaaaaaaaaa9a9a9aaaaaaaaaaaaaaa") + fakeMixPairReq1, _ := chainhash.NewHashFromStr("bbbbbbbbbbbb9200bbbbbb" + + "bbbbbbbbbbbbbbbbbbb9b9b9bbbbbbbbbbbbbbb") + fakeMixPairReq2, _ := chainhash.NewHashFromStr("cccccccccccc9200cccccc" + + "ccccccccccccccccccc9c9c9ccccccccccccccc") + fakeMixPairReq3, _ := chainhash.NewHashFromStr("dddddddddddd9200dddddd" + + "ddddddddddddddddddd9d9d9ddddddddddddddd") // MsgInitState message with multiple values for each hash. multiData := NewMsgInitState() @@ -148,6 +155,9 @@ func TestInitStateWire(t *testing.T) { multiData.AddTSpendHash(fakeTSpend1) multiData.AddTSpendHash(fakeTSpend2) multiData.AddTSpendHash(fakeTSpend3) + multiData.AddMixPairReqHash(fakeMixPairReq1) + multiData.AddMixPairReqHash(fakeMixPairReq2) + multiData.AddMixPairReqHash(fakeMixPairReq3) multiDataEncoded := []byte{ 0x02, // Varint for number of block hashes @@ -185,6 +195,19 @@ func TestInitStateWire(t *testing.T) { 0x9a, 0x9a, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0x0a, 0x20, 0xa9, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0x0a, 0x00, + 0x03, // Varint for number of mixpr hashes + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x9b, // Fake mixpairreq 1 + 0x9b, 0x9b, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x0b, 0x20, + 0xb9, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x0b, 0x00, + 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0x9c, // Fake mixpairreq 2 + 0x9c, 0x9c, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, + 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0x0c, 0x20, + 0xc9, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0x0c, 0x00, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0x9d, // Fake mixpairreq 3 + 0x9d, 0x9d, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0x0d, 0x20, + 0xd9, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0x0d, 0x00, } tests := []struct { @@ -255,6 +278,12 @@ func TestInitStateWireErrors(t *testing.T) { "999999999999999999991199999999999999919") fakeTSpend3, _ := chainhash.NewHashFromStr("aaaaaaaaaaaa9200aaaaaa" + "aaaaaaaaaaaaaaaaaaa9a9a9aaaaaaaaaaaaaaa") + fakeMixPairReq1, _ := chainhash.NewHashFromStr("bbbbbbbbbbbb9200bbbbbb" + + "bbbbbbbbbbbbbbbbbbb9b9b9bbbbbbbbbbbbbbb") + fakeMixPairReq2, _ := chainhash.NewHashFromStr("cccccccccccc9200cccccc" + + "ccccccccccccccccccc9c9c9ccccccccccccccc") + fakeMixPairReq3, _ := chainhash.NewHashFromStr("dddddddddddd9200dddddd" + + "ddddddddddddddddddd9d9d9ddddddddddddddd") // MsgInitState message with multiple values for each hash. baseMsg := NewMsgInitState() @@ -266,6 +295,9 @@ func TestInitStateWireErrors(t *testing.T) { baseMsg.AddTSpendHash(fakeTSpend1) baseMsg.AddTSpendHash(fakeTSpend2) baseMsg.AddTSpendHash(fakeTSpend3) + baseMsg.AddMixPairReqHash(fakeMixPairReq1) + baseMsg.AddMixPairReqHash(fakeMixPairReq2) + baseMsg.AddMixPairReqHash(fakeMixPairReq3) baseMsgEncoded := []byte{ 0x02, // Varint for number of block hashes @@ -303,6 +335,19 @@ func TestInitStateWireErrors(t *testing.T) { 0x9a, 0x9a, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0x0a, 0x20, 0xa9, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0x0a, 0x00, + 0x03, // Varint for number of mixpairreq hashes + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x9b, // Fake mixpairreq 1 + 0x9b, 0x9b, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, + 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x0b, 0x20, + 0xb9, 0xbb, 0xbb, 0xbb, 0xbb, 0xbb, 0x0b, 0x00, + 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0x9c, // Fake mixpairreq 2 + 0x9c, 0x9c, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, + 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0x0c, 0x20, + 0xc9, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0x0c, 0x00, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0x9d, // Fake mixpairreq 3 + 0x9d, 0x9d, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, + 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0x0d, 0x20, + 0xd9, 0xdd, 0xdd, 0xdd, 0xdd, 0xdd, 0x0d, 0x00, } // Message that forces an error by having more than the max allowed diff --git a/wire/msgmixciphertexts.go b/wire/msgmixciphertexts.go new file mode 100644 index 0000000000..277dab4ede --- /dev/null +++ b/wire/msgmixciphertexts.go @@ -0,0 +1,239 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgMixCiphertexts is used by mixing peers to share SNTRUP4591761 +// ciphertexts with other peers who have published their public keys. It +// implements the Message interface. +type MsgMixCiphertexts struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + Ciphertexts [][1047]byte + SeenKeyExchanges []chainhash.Hash + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixCiphertexts) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixCiphertexts.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run) + if err != nil { + return err + } + + // Count is of both Ciphertexts and seen SeenKeyExchanges. + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + ciphertexts := make([][1047]byte, count) + for i := range ciphertexts { + err := readElement(r, &ciphertexts[i]) + if err != nil { + return err + } + } + msg.Ciphertexts = ciphertexts + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenKeyExchanges = seen + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixCiphertexts) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixCiphertexts.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixCiphertexts) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixCiphertexts) WriteHash(h hash.Hash) { + if h.Size() != chainhash.HashSize { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + h.Sum(msg.hash[:0]) +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixCiphertexts) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + count := len(msg.Ciphertexts) + if !hashing && count != len(msg.SeenKeyExchanges) { + msg := fmt.Sprintf("differing counts of ciphertexts (%d) "+ + "and seen key exchange messages (%d)", count, + len(msg.SeenKeyExchanges)) + return messageError(op, ErrInvalidMsg, msg) + } + if !hashing && count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.Ciphertexts { + err = writeElement(w, &msg.Ciphertexts[i]) + if err != nil { + return err + } + } + for i := range msg.SeenKeyExchanges { + err = writeElement(w, &msg.SeenKeyExchanges[i]) + if err != nil { + return err + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixCiphertexts) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixCiphertexts+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixCiphertexts) Command() string { + return CmdMixCiphertexts +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixCiphertexts) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 552584 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixCiphertexts) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixCiphertexts) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns the previous key exchange messages seen by the peer. +func (msg *MsgMixCiphertexts) PrevMsgs() []chainhash.Hash { + return msg.SeenKeyExchanges +} + +// Sid returns the session ID. +func (msg *MsgMixCiphertexts) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixCiphertexts) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixCiphertexts returns a new mixcphrtxt message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixCiphertexts(identity [33]byte, sid [32]byte, run uint32, + ciphertexts [][1047]byte, seenKeyExchanges []chainhash.Hash) *MsgMixCiphertexts { + + return &MsgMixCiphertexts{ + Identity: identity, + SessionID: sid, + Run: run, + Ciphertexts: ciphertexts, + SeenKeyExchanges: seenKeyExchanges, + } +} diff --git a/wire/msgmixciphertexts_test.go b/wire/msgmixciphertexts_test.go new file mode 100644 index 0000000000..f3193c19b1 --- /dev/null +++ b/wire/msgmixciphertexts_test.go @@ -0,0 +1,185 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +func newTestMixCiphertexts() *MsgMixCiphertexts { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + cts := make([][1047]byte, 4) + for b := byte(0x84); b < 0x88; b++ { + copy(cts[b-0x84][:], repeat(b, 1047)) + } + + seenKEs := make([]chainhash.Hash, 4) + for b := byte(0x88); b < 0x8C; b++ { + copy(seenKEs[b-0x88][:], repeat(b, 32)) + } + + ct := NewMsgMixCiphertexts(id, sid, run, cts, seenKEs) + ct.Signature = sig + + return ct +} + +func TestMsgMixCiphertextsWire(t *testing.T) { + pver := MixVersion + + ct := newTestMixCiphertexts() + + buf := new(bytes.Buffer) + err := ct.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + // Varint count of ciphertexts and seen KEs + expected = append(expected, 0x04) + // Four ciphextexts (repeating 1047 bytes of 0x84, 0x85, 0x86, 0x87) + expected = append(expected, repeat(0x84, 1047)...) + expected = append(expected, repeat(0x85, 1047)...) + expected = append(expected, repeat(0x86, 1047)...) + expected = append(expected, repeat(0x87, 1047)...) + // Four seen KEs (repeating 32 bytes of 0x88, 0x89, 0x8a, 0x8b) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, repeat(0x8a, 32)...) + expected = append(expected, repeat(0x8b, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedCT := new(MsgMixCiphertexts) + err = decodedCT.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(ct, decodedCT) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedCT), spew.Sdump(ct)) + } +} + +func TestMsgMixCiphertextsCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixCiphertexts() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixCiphertexts) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixCiphertextsMaxPayloadLength tests the results returned by +// [MsgMixCiphertexts.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixCiphertextsMaxPayloadLength(t *testing.T) { + var ct *MsgMixCiphertexts + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := ct.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + uint32(VarIntSerializeSize(MaxMixPeers)) + // Ciphextext and KE hash count + MaxMixPeers*1047 + // Ciphextexts + 32*MaxMixPeers // Key exchange hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := ct.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixconfirm.go b/wire/msgmixconfirm.go new file mode 100644 index 0000000000..f6054bb821 --- /dev/null +++ b/wire/msgmixconfirm.go @@ -0,0 +1,234 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgMixConfirm contains a partially-signed mix transaction, with signatures +// contributed from the peer identity. When all CM messages are received, +// signatures can be merged and the transaction may be published, ending a +// successful mix session. +// +// It implements the Message interface. +type MsgMixConfirm struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + Mix MsgTx + SeenDCNets []chainhash.Hash + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixConfirm) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixConfirm.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run) + if err != nil { + return err + } + + err = msg.Mix.BtcDecode(r, pver) + if err != nil { + return err + } + + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenDCNets = seen + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixConfirm) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixConfirm.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixConfirm) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixConfirm) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixConfirm) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + count := len(msg.SeenDCNets) + if !hashing && count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run) + if err != nil { + return err + } + + err = msg.Mix.BtcEncode(w, pver) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.SeenDCNets { + err = writeElement(w, &msg.SeenDCNets[i]) + if err != nil { + return err + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixConfirm) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixConfirm+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixConfirm) Command() string { + return CmdMixConfirm +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixConfirm) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 1016520 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixConfirm) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixConfirm) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns the previous DC messages seen by the peer. +func (msg *MsgMixConfirm) PrevMsgs() []chainhash.Hash { + return msg.SeenDCNets +} + +// Sid returns the session ID. +func (msg *MsgMixConfirm) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixConfirm) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixConfirm returns a new mixconfirm message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixConfirm(identity [33]byte, sid [32]byte, run uint32, + mix *MsgTx, seenDCNets []chainhash.Hash) *MsgMixConfirm { + + if mix == nil { + mix = NewMsgTx() + } + + return &MsgMixConfirm{ + Identity: identity, + SessionID: sid, + Run: run, + Mix: *mix, + SeenDCNets: seenDCNets, + } +} diff --git a/wire/msgmixconfirm_test.go b/wire/msgmixconfirm_test.go new file mode 100644 index 0000000000..09be5b4b0a --- /dev/null +++ b/wire/msgmixconfirm_test.go @@ -0,0 +1,184 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +func newTestMixConfirm() *MsgMixConfirm { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + mix := NewMsgTx() + + seenDCs := make([]chainhash.Hash, 4) + for b := byte(0x84); b < 0x88; b++ { + copy(seenDCs[b-0x84][:], repeat(b, 32)) + } + + cm := NewMsgMixConfirm(id, sid, run, mix, seenDCs) + cm.Signature = sig + + return cm +} + +func TestMsgMixConfirmWire(t *testing.T) { + pver := MixVersion + + cm := newTestMixConfirm() + + buf := new(bytes.Buffer) + err := cm.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + expected = append(expected, []byte{ // Mix transaction + 0x01, 0x00, 0x00, 0x00, // Version + 0x00, // Varint for number of input transactions + 0x00, // Varint for number of output transactions + 0x00, 0x00, 0x00, 0x00, // Lock time + 0x00, 0x00, 0x00, 0x00, // Expiry + 0x00, // Varint for number of input signatures + }...) + // Four seen DCs (repeating 32 bytes of 0x84, 0x85, 0x86, 0x87) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x84, 32)...) + expected = append(expected, repeat(0x85, 32)...) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, repeat(0x87, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedCM := new(MsgMixConfirm) + err = decodedCM.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(cm, decodedCM) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedCM), spew.Sdump(cm)) + } +} + +func TestMsgMixConfirmCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixConfirm() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixConfirm) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixConfirmMaxPayloadLength tests the results returned by +// [MsgMixConfirm.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixConfirmMaxPayloadLength(t *testing.T) { + var cm *MsgMixConfirm + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := cm.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + MaxBlockPayloadV3 + // Maximum transaction size + uint32(VarIntSerializeSize(MaxMixPeers)) + // DC-net count + 32*MaxMixPeers // DC-net hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := cm.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixdcnet.go b/wire/msgmixdcnet.go new file mode 100644 index 0000000000..c6fc4db98c --- /dev/null +++ b/wire/msgmixdcnet.go @@ -0,0 +1,310 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgMixDCNet is the DC-net broadcast. It implements the Message interface. +type MsgMixDCNet struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + DCNet []MixVect + SeenSlotReserves []chainhash.Hash + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixDCNet) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixDCNet.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run) + if err != nil { + return err + } + + var dcnet []MixVect + err = readMixVects(op, r, pver, &dcnet) + if err != nil { + return err + } + msg.DCNet = dcnet + + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenSlotReserves = seen + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixDCNet) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixDCNet.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixDCNet) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixDCNet) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixDCNet) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + mcount := len(msg.DCNet) + if !hashing && mcount == 0 { + msg := fmt.Sprintf("too few mixed messages [%v]", mcount) + return messageError(op, ErrInvalidMsg, msg) + } + if !hashing && mcount > MaxMixMcount { + msg := fmt.Sprintf("too many total mixed messages [%v]", mcount) + return messageError(op, ErrInvalidMsg, msg) + } + srcount := len(msg.SeenSlotReserves) + if !hashing && srcount > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + srcount, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run) + if err != nil { + return err + } + + err = writeMixVects(w, pver, msg.DCNet) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(srcount)) + if err != nil { + return err + } + for i := range msg.SeenSlotReserves { + err = writeElement(w, &msg.SeenSlotReserves[i]) + if err != nil { + return err + } + } + + return nil +} + +func writeMixVects(w io.Writer, pver uint32, vecs []MixVect) error { + // Write dimensions + err := WriteVarInt(w, pver, uint64(len(vecs))) + if err != nil { + return err + } + if len(vecs) == 0 { + return nil + } + err = WriteVarInt(w, pver, uint64(len(vecs[0]))) + if err != nil { + return err + } + err = WriteVarInt(w, pver, MixMsgSize) + if err != nil { + return err + } + + // Write messages + for i := range vecs { + for j := range vecs[i] { + err = writeElement(w, &vecs[i][j]) + if err != nil { + return err + } + } + } + + return nil +} + +func readMixVects(op string, r io.Reader, pver uint32, vecs *[]MixVect) error { + // Read dimensions + x, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if x == 0 { + return nil + } + y, err := ReadVarInt(r, pver) + if err != nil { + return err + } + msize, err := ReadVarInt(r, pver) + if err != nil { + return err + } + + if x > MaxMixMcount || y > MaxMixMcount { + msg := "DC-net mix vector dimensions are too large for maximum message count" + return messageError(op, ErrInvalidMsg, msg) + } + if msize != MixMsgSize { + msg := fmt.Sprintf("mixed message length must be %d [got: %d]", + MixMsgSize, msize) + return messageError(op, ErrInvalidMsg, msg) + } + + // Read messages + *vecs = make([]MixVect, x) + for i := uint64(0); i < x; i++ { + (*vecs)[i] = make(MixVect, y) + for j := uint64(0); j < y; j++ { + err = readElement(r, &(*vecs)[i][j]) + if err != nil { + return err + } + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixDCNet) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixDCNet+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixDCNet) Command() string { + return CmdMixDCNet +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixDCNet) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 20988047 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixDCNet) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixDCNet) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns the previous SR messages seen by the peer. +func (msg *MsgMixDCNet) PrevMsgs() []chainhash.Hash { + return msg.SeenSlotReserves +} + +// Sid returns the session ID. +func (msg *MsgMixDCNet) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixDCNet) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixDCNet returns a new mixdcnet message that conforms to the Message +// interface using the passed parameters and defaults for the remaining +// fields. +func NewMsgMixDCNet(identity [33]byte, sid [32]byte, run uint32, + dcnet []MixVect, seenSlotReserves []chainhash.Hash) *MsgMixDCNet { + + return &MsgMixDCNet{ + Identity: identity, + SessionID: sid, + Run: run, + DCNet: dcnet, + SeenSlotReserves: seenSlotReserves, + } +} diff --git a/wire/msgmixdcnet_test.go b/wire/msgmixdcnet_test.go new file mode 100644 index 0000000000..1963517be2 --- /dev/null +++ b/wire/msgmixdcnet_test.go @@ -0,0 +1,212 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +func newTestMixDCNet() *MsgMixDCNet { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + const mcount = 4 + const kpcount = 4 + dcnet := make([]MixVect, mcount) + // will add 4x4 field numbers of incrementing repeating byte values to + // dcnet, ranging from 0x84 through 0x93 + b := byte(0x84) + for i := 0; i < mcount; i++ { + dcnet[i] = make(MixVect, kpcount) + for j := 0; j < kpcount; j++ { + copy(dcnet[i][j][:], repeat(b, 32)) + b++ + } + } + + seenSRs := make([]chainhash.Hash, 4) + for b := byte(0x94); b < 0x98; b++ { + copy(seenSRs[b-0x94][:], repeat(b, 32)) + } + + dc := NewMsgMixDCNet(id, sid, run, dcnet, seenSRs) + dc.Signature = sig + + return dc +} + +func TestMsgMixDCNetWire(t *testing.T) { + pver := MixVersion + + dc := newTestMixDCNet() + + buf := new(bytes.Buffer) + err := dc.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + // DC-net dimensions (4x4, message size of 20) + expected = append(expected, 0x04) + expected = append(expected, 0x04) + expected = append(expected, 0x14) // msize + // 16 padded messages + expected = append(expected, repeat(0x84, 20)...) + expected = append(expected, repeat(0x85, 20)...) + expected = append(expected, repeat(0x86, 20)...) + expected = append(expected, repeat(0x87, 20)...) + expected = append(expected, repeat(0x88, 20)...) + expected = append(expected, repeat(0x89, 20)...) + expected = append(expected, repeat(0x8a, 20)...) + expected = append(expected, repeat(0x8b, 20)...) + expected = append(expected, repeat(0x8c, 20)...) + expected = append(expected, repeat(0x8d, 20)...) + expected = append(expected, repeat(0x8e, 20)...) + expected = append(expected, repeat(0x8f, 20)...) + expected = append(expected, repeat(0x90, 20)...) + expected = append(expected, repeat(0x91, 20)...) + expected = append(expected, repeat(0x92, 20)...) + expected = append(expected, repeat(0x93, 20)...) + // Four seen DCs (repeating 32 bytes of 0x94, 0x95, 0x96, 0x97) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x94, 32)...) + expected = append(expected, repeat(0x95, 32)...) + expected = append(expected, repeat(0x96, 32)...) + expected = append(expected, repeat(0x97, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedDC := new(MsgMixDCNet) + err = decodedDC.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(dc, decodedDC) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedDC), spew.Sdump(dc)) + } +} + +func TestMsgMixDCNetCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixDCNet() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixDCNet) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixDCNetMaxPayloadLength tests the results returned by +// [MsgMixDCNet.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixDCNetMaxPayloadLength(t *testing.T) { + var dc *MsgMixDCNet + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := dc.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + uint32(VarIntSerializeSize(MaxMixMcount)) + // Message count (our mcount) + uint32(VarIntSerializeSize(MaxMixMcount)) + // Message count (total) + uint32(VarIntSerializeSize(MixMsgSize)) + // Message size + MaxMixMcount*MaxMixMcount*MixMsgSize + // Padded DC-net values + uint32(VarIntSerializeSize(MaxMixPeers)) + // Slot reserve count + 32*MaxMixPeers // Slot reserve hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := dc.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixkeyexchange.go b/wire/msgmixkeyexchange.go new file mode 100644 index 0000000000..e7bfa4ab8d --- /dev/null +++ b/wire/msgmixkeyexchange.go @@ -0,0 +1,234 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +const ( + // MaxMixPeers is the maximum number of peers allowed together in a + // single mix. This restricts the maximum dimensions of the slot + // reservation and XOR DC-net matrices and the maximum number of + // previous messages that may be referenced by mix messages. + // This value is an high estimate of what a large mix may look like, + // based on statistics from the centralized mixing server. + MaxMixPeers = 512 +) + +// MsgMixKeyExchange implements the Message interface and represents a mixing key +// exchange message. It includes a commitment to secrets (private keys and +// discarded mixed addresses) in case they must be revealed for blame +// assignment. +type MsgMixKeyExchange struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + ECDH [33]byte // Secp256k1 public key + PQPK [1218]byte // Sntrup4591761 public key + Commitment [32]byte + SeenPRs []chainhash.Hash + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixKeyExchange) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixKeyExchange.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run, &msg.ECDH, &msg.PQPK, &msg.Commitment) + if err != nil { + return err + } + + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenPRs = seen + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixKeyExchange) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixKeyExchange.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixKeyExchange) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixKeyExchange) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixKeyExchange) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + // Limit to max previous messages hashes. + count := len(msg.SeenPRs) + if !hashing && count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run, + &msg.ECDH, &msg.PQPK, &msg.Commitment) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.SeenPRs { + err := writeElement(w, &msg.SeenPRs[i]) + if err != nil { + return err + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixKeyExchange) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixKeyExchange+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixKeyExchange) Command() string { + return CmdMixKeyExchange +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixKeyExchange) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 17803 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixKeyExchange) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixKeyExchange) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns the previous PR messages seen by the peer. +func (msg *MsgMixKeyExchange) PrevMsgs() []chainhash.Hash { + return msg.SeenPRs +} + +// Sid returns the session ID. +func (msg *MsgMixKeyExchange) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixKeyExchange) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixKeyExchange returns a new mixkeyxchg message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixKeyExchange(identity [33]byte, sid [32]byte, run uint32, + ecdh [33]byte, pqpk [1218]byte, commitment [32]byte, seenPRs []chainhash.Hash) *MsgMixKeyExchange { + + return &MsgMixKeyExchange{ + Identity: identity, + SessionID: sid, + Run: run, + ECDH: ecdh, + PQPK: pqpk, + Commitment: commitment, + SeenPRs: seenPRs, + } +} diff --git a/wire/msgmixkeyexchange_test.go b/wire/msgmixkeyexchange_test.go new file mode 100644 index 0000000000..e5f83a32f5 --- /dev/null +++ b/wire/msgmixkeyexchange_test.go @@ -0,0 +1,186 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +func newTestMixKeyExchange() *MsgMixKeyExchange { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + ecdh := *(*[33]byte)(repeat(0x84, 33)) + pqpk := *(*[1218]byte)(repeat(0x85, 1218)) + commitment := *(*[32]byte)(repeat(0x86, 32)) + + seenPRs := make([]chainhash.Hash, 4) + for b := byte(0x87); b < 0x8B; b++ { + copy(seenPRs[b-0x87][:], repeat(b, 32)) + } + + ke := NewMsgMixKeyExchange(id, sid, run, ecdh, pqpk, commitment, seenPRs) + ke.Signature = sig + + return ke +} + +func TestMsgMixKeyExchangeWire(t *testing.T) { + pver := MixVersion + + ke := newTestMixKeyExchange() + + buf := new(bytes.Buffer) + err := ke.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + expected = append(expected, repeat(0x84, 33)...) // ECDH public key + expected = append(expected, repeat(0x85, 1218)...) // PQ public key + expected = append(expected, repeat(0x86, 32)...) // Secrets commitment + // Four seen PRs (repeating 32 bytes of 0x87, 0x88, 0x89, 0x8a) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x87, 32)...) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, repeat(0x8a, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedKE := new(MsgMixKeyExchange) + err = decodedKE.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(ke, decodedKE) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedKE), spew.Sdump(ke)) + } else { + t.Logf("bytes: %x", buf.Bytes()) + t.Logf("spew: %s", spew.Sdump(decodedKE)) + } +} + +func TestMsgMixKeyExchangeCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixKeyExchange() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixKeyExchange) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixKeyExchangeMaxPayloadLength tests the results returned by +// [MsgMixKeyExchange.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixKeyExchangeMaxPayloadLength(t *testing.T) { + var ke *MsgMixKeyExchange + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := ke.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + 33 + // ECDH public key + 1218 + // sntrup4591761 public key + 32 + // Secrets commitment + uint32(VarIntSerializeSize(MaxMixPeers)) + // Pair request count + 32*MaxMixPeers // Pair request hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := ke.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixpairreq.go b/wire/msgmixpairreq.go new file mode 100644 index 0000000000..4b6c74e98f --- /dev/null +++ b/wire/msgmixpairreq.go @@ -0,0 +1,455 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +const ( + // MaxMixPairReqScriptClassLen is the maximum length allowable for a + // MsgMixPairReq script class. + MaxMixPairReqScriptClassLen = 32 + + // MaxMixPairReqUTXOs is the maximum number of unspent transaction + // outputs that may be contributed in a single mixpairreq message. + // This value is an high estimate of what a large mix may look like, + // based on statistics from the centralized mixing server. + MaxMixPairReqUTXOs = 512 + + // MaxMixPairReqUTXOScriptLen is the maximum length allowed for the + // unhashed P2SH script of a UTXO ownership proof. + MaxMixPairReqUTXOScriptLen = 16384 // txscript.MaxScriptSize + + // MaxMixPairReqUTXOPubKeyLen is the maximum length allowed for the + // pubkey of a UTXO ownership proof. + MaxMixPairReqUTXOPubKeyLen = 33 + + // MaxMixPairReqUTXOSignatureLen is the maximum length allowed for the + // signature of a UTXO ownership proof. + MaxMixPairReqUTXOSignatureLen = 64 +) + +// MixPairReqUTXO describes an unspent transaction output to be spent in a +// mix. It includes a proof that the output is able to be spent, by +// containing a signature and the necessary data needed to prove key +// possession. +type MixPairReqUTXO struct { + OutPoint OutPoint + Script []byte // Only used for P2SH + PubKey []byte + Signature []byte +} + +// MsgMixPairReq implements the Message interface and represents a mixing pair +// request message. It describes a type of coinjoin to be created, unmixed +// data being contributed to the coinjoin, and proof of ability to sign the +// resulting coinjoin. +type MsgMixPairReq struct { + Signature [64]byte + Identity [33]byte + Expiry uint32 + MixAmount int64 + ScriptClass string + TxVersion uint16 + LockTime uint32 + MessageCount uint32 + InputValue int64 + UTXOs []MixPairReqUTXO + Change *TxOut + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// Pairing returns a description of the coinjoin transaction being created. +// Different mixpairreq messages are compatible to perform a mix together if +// their pairing descriptions are identical. +func (msg *MsgMixPairReq) Pairing() ([]byte, error) { + bufLen := 8 + // Mix amount + VarIntSerializeSize(uint64(len(msg.ScriptClass))) + // Script class + len(msg.ScriptClass) + + 2 + // Tx version + 4 // Locktime + w := bytes.NewBuffer(make([]byte, 0, bufLen)) + + err := writeElement(w, msg.MixAmount) + if err != nil { + return nil, err + } + + err = WriteVarString(w, MixVersion, msg.ScriptClass) + if err != nil { + return nil, err + } + + err = writeElements(w, msg.TxVersion, msg.LockTime) + if err != nil { + return nil, err + } + + return w.Bytes(), nil +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixPairReq) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixPairReq.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.Expiry, + &msg.MixAmount) + if err != nil { + return err + } + + if msg.MixAmount < 0 { + msg := "mixing pair request contains negative mixed amount" + return messageError(op, ErrInvalidMsg, msg) + } + + sc, err := ReadAsciiVarString(r, pver, MaxMixPairReqScriptClassLen) + if err != nil { + return err + } + msg.ScriptClass = sc + + err = readElements(r, &msg.TxVersion, &msg.LockTime, + &msg.MessageCount, &msg.InputValue) + if err != nil { + return err + } + + if msg.InputValue < 0 { + msg := "mixing pair request contains negative input value" + return messageError(op, ErrInvalidMsg, msg) + } + + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPairReqUTXOs { + msg := fmt.Sprintf("too many UTXOs in message [count %v, max %v]", + count, MaxMixPairReqUTXOs) + return messageError(op, ErrTooManyMixPairReqUTXOs, msg) + } + utxos := make([]MixPairReqUTXO, count) + for i := range utxos { + utxo := &utxos[i] + + err := ReadOutPoint(r, pver, msg.TxVersion, &utxo.OutPoint) + if err != nil { + return err + } + + script, err := ReadVarBytes(r, pver, MaxMixPairReqUTXOScriptLen, + "MixPairReqUTXO.Script") + if err != nil { + return err + } + utxo.Script = script + + pubkey, err := ReadVarBytes(r, pver, MaxMixPairReqUTXOPubKeyLen, + "MixPairReqUTXO.PubKey") + if err != nil { + return err + } + utxo.PubKey = pubkey + + sig, err := ReadVarBytes(r, pver, MaxMixPairReqUTXOSignatureLen, + "MixPairReqUTXO.Signature") + if err != nil { + return err + } + utxo.Signature = sig + } + msg.UTXOs = utxos + + var hasChange uint8 + err = readElement(r, &hasChange) + if err != nil { + return err + } + switch hasChange { + case 0: + case 1: + change := new(TxOut) + err := readTxOut(r, pver, msg.TxVersion, change) + if err != nil { + return err + } + msg.Change = change + default: + msg := "invalid change TxOut encoding" + return messageError(op, ErrInvalidMsg, msg) + } + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixPairReq) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixPairReq.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixPairReq) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixPairReq) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixPairReq) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + // Require script class to be strict ASCII and not exceed the maximum + // length. + lenScriptClass := len(msg.ScriptClass) + if lenScriptClass > MaxMixPairReqScriptClassLen { + msg := fmt.Sprintf("script class length is too long "+ + "[len %d, max %d]", lenScriptClass, + MaxMixPairReqScriptClassLen) + return messageError(op, ErrMixPairReqScriptClassTooLong, msg) + } + if !isStrictAscii(msg.ScriptClass) { + msg := "script class string is not strict ASCII" + return messageError(op, ErrMalformedStrictString, msg) + } + + // Limit to max UTXOs per message. + count := len(msg.UTXOs) + if !hashing && count > MaxMixPairReqUTXOs { + msg := fmt.Sprintf("too many UTXOs in message [%v]", count) + return messageError(op, ErrTooManyMixPairReqUTXOs, msg) + } + + err := writeElements(w, &msg.Identity, msg.Expiry, msg.MixAmount) + if err != nil { + return err + } + + err = WriteVarString(w, pver, msg.ScriptClass) + if err != nil { + return err + } + + err = writeElements(w, msg.TxVersion, msg.LockTime, msg.MessageCount, + msg.InputValue) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.UTXOs { + utxo := &msg.UTXOs[i] + + err := WriteOutPoint(w, pver, msg.TxVersion, &utxo.OutPoint) + if err != nil { + return err + } + + if l := len(utxo.Script); !hashing && l > MaxMixPairReqUTXOScriptLen { + msg := fmt.Sprintf("UTXO script is too long [len %v, max %v]", + l, MaxMixPairReqUTXOScriptLen) + return messageError(op, ErrVarBytesTooLong, msg) + } + err = WriteVarBytes(w, pver, utxo.Script) + if err != nil { + return err + } + + if l := len(utxo.PubKey); !hashing && l > MaxMixPairReqUTXOPubKeyLen { + msg := fmt.Sprintf("UTXO public key is too long [len %v, max %v]", + l, MaxMixPairReqUTXOPubKeyLen) + return messageError(op, ErrVarBytesTooLong, msg) + } + err = WriteVarBytes(w, pver, utxo.PubKey) + if err != nil { + return err + } + + if l := len(utxo.Signature); !hashing && l > MaxMixPairReqUTXOSignatureLen { + msg := fmt.Sprintf("UTXO signature is too long [len %v, max %v]", + l, MaxMixPairReqUTXOSignatureLen) + return messageError(op, ErrVarBytesTooLong, msg) + } + err = WriteVarBytes(w, pver, utxo.Signature) + if err != nil { + return err + } + } + + var hasChange uint8 + if msg.Change != nil { + hasChange = 1 + } + err = writeElement(w, hasChange) + if err != nil { + return err + } + if msg.Change != nil { + err = writeTxOut(w, pver, msg.TxVersion, msg.Change) + if err != nil { + return err + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixPairReq) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixPairReq+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixPairReq) Command() string { + return CmdMixPairReq +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixPairReq) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 8476336 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixPairReq) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixPairReq) Sig() []byte { + return msg.Signature[:] +} + +// Expires returns the block height at which the message expires. +func (msg *MsgMixPairReq) Expires() uint32 { + return msg.Expiry +} + +// PrevMsgs always returns nil. Pair request messages are the initial message. +func (msg *MsgMixPairReq) PrevMsgs() []chainhash.Hash { + return nil +} + +// Sid always returns nil. Pair request messages do not belong to a session. +func (msg *MsgMixPairReq) Sid() []byte { + return nil +} + +// GetRun always returns 0. Pair request messages do not belong to a session. +func (msg *MsgMixPairReq) GetRun() uint32 { + return 0 +} + +// NewMsgMixPairReq returns a new mixpairreq message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixPairReq(identity [33]byte, expiry uint32, mixAmount int64, + scriptClass string, txVersion uint16, lockTime, messageCount uint32, + inputValue int64, utxos []MixPairReqUTXO, change *TxOut) (*MsgMixPairReq, error) { + + const op = "NewMsgMixPairReq" + lenScriptClass := len(scriptClass) + if lenScriptClass > MaxMixPairReqScriptClassLen { + msg := fmt.Sprintf("script class length is too long "+ + "[len %d, max %d]", lenScriptClass, + MaxMixPairReqScriptClassLen) + return nil, messageError(op, ErrMixPairReqScriptClassTooLong, msg) + } + + if !isStrictAscii(scriptClass) { + msg := "script class string is not strict ASCII" + return nil, messageError(op, ErrMalformedStrictString, msg) + } + + if len(utxos) > MaxMixPairReqUTXOs { + msg := fmt.Sprintf("too many input UTXOs [len %d, max %d]", + len(utxos), MaxMixPairReqUTXOs) + return nil, messageError(op, ErrTooManyMixPairReqUTXOs, msg) + } + + msg := &MsgMixPairReq{ + Identity: identity, + Expiry: expiry, + MixAmount: mixAmount, + ScriptClass: scriptClass, + TxVersion: txVersion, + LockTime: lockTime, + MessageCount: messageCount, + InputValue: inputValue, + UTXOs: utxos, + Change: change, + } + return msg, nil +} diff --git a/wire/msgmixpairreq_test.go b/wire/msgmixpairreq_test.go new file mode 100644 index 0000000000..36ac28f7cf --- /dev/null +++ b/wire/msgmixpairreq_test.go @@ -0,0 +1,326 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +type mixPairReqArgs struct { + identity [33]byte + signature [64]byte + expiry uint32 + mixAmount int64 + scriptClass string + txVersion uint16 + lockTime, messageCount uint32 + inputValue int64 + utxos []MixPairReqUTXO + change *TxOut +} + +func (a *mixPairReqArgs) msg() (*MsgMixPairReq, error) { + return NewMsgMixPairReq(a.identity, a.expiry, a.mixAmount, a.scriptClass, + a.txVersion, a.lockTime, a.messageCount, a.inputValue, a.utxos, a.change) +} + +func newMixPairReqArgs() *mixPairReqArgs { + // Use easily-distinguishable fields. + + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + + const expiry = uint32(0x82828282) + const mixAmount = int64(0x0383838383838383) + const sc = "P2PKH-secp256k1-v0" + const txVersion = uint16(0x8484) + const lockTime = uint32(0x85858585) + const messageCount = uint32(0x86868686) + const inputValue = int64(0x0787878787878787) + + utxos := []MixPairReqUTXO{ + { + OutPoint: OutPoint{ + Hash: rhash(0x88), + Index: 0x89898989, + Tree: 0x0A, + }, + Script: []byte{}, + PubKey: repeat(0x8B, 33), + Signature: repeat(0x8C, 64), + }, + { + OutPoint: OutPoint{ + Hash: rhash(0x8D), + Index: 0x8E8E8E8E, + Tree: 0x0F, + }, + Script: repeat(0x90, 25), + PubKey: repeat(0x91, 33), + Signature: repeat(0x92, 64), + }, + } + + const changeValue = int64(0x1393939393939393) + pkScript := repeat(0x94, 25) + change := NewTxOut(changeValue, pkScript) + + return &mixPairReqArgs{ + identity: id, + signature: sig, + expiry: expiry, + mixAmount: mixAmount, + scriptClass: sc, + txVersion: txVersion, + lockTime: lockTime, + messageCount: messageCount, + inputValue: inputValue, + utxos: utxos, + change: change, + } +} + +func TestMsgMixPairReqWire(t *testing.T) { + t.Parallel() + + pver := MixVersion + + a := newMixPairReqArgs() + pr, err := a.msg() + if err != nil { + t.Fatal(err) + } + pr.Signature = a.signature + + buf := new(bytes.Buffer) + err = pr.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 4)...) // Expiry + expected = append(expected, 0x83, 0x83, 0x83, 0x83, // Amount + 0x83, 0x83, 0x83, 0x03) + expected = append(expected, byte(len("P2PKH-secp256k1-v0"))) // Script class + expected = append(expected, []byte("P2PKH-secp256k1-v0")...) + expected = append(expected, 0x84, 0x84) // Tx version + expected = append(expected, repeat(0x85, 4)...) // Locktime + expected = append(expected, repeat(0x86, 4)...) // Message count + expected = append(expected, 0x87, 0x87, 0x87, 0x87, // Input value + 0x87, 0x87, 0x87, 0x07) + expected = append(expected, 0x02) // UTXO count + // First UTXO 8888888888888888888888888888888888888888888888888888888888888888:0x89898989 + expected = append(expected, repeat(0x88, 32)...) // Hash + expected = append(expected, repeat(0x89, 4)...) // Index + expected = append(expected, 0x0a) // Tree + expected = append(expected, 0x00) // Zero-length P2SH redeem script + expected = append(expected, 0x21) // 33-byte pubkey + expected = append(expected, repeat(0x8b, 33)...) + expected = append(expected, 0x40) // 64-byte signature + expected = append(expected, repeat(0x8c, 64)...) + // Second UTXO 8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d:0x8e8e8e8e + expected = append(expected, repeat(0x8d, 32)...) // Hash + expected = append(expected, repeat(0x8e, 4)...) // Index + expected = append(expected, 0x0f) // Tree + expected = append(expected, 0x19) // 25-byte P2SH redeem script + expected = append(expected, repeat(0x90, 25)...) + expected = append(expected, 0x21) // 33-byte pubkey + expected = append(expected, repeat(0x91, 33)...) + expected = append(expected, 0x40) // 64-byte signature + expected = append(expected, repeat(0x92, 64)...) + // Change output + expected = append(expected, 0x01) // Has change = true + expected = append(expected, []byte{ + 0x93, 0x93, 0x93, 0x93, 0x93, 0x93, 0x93, 0x13, // Amount + 0x00, 0x00, // Version + 0x19, // 25-byte Pkscript + 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, + 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, + 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, + 0x94, + }...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedPR := new(MsgMixPairReq) + err = decodedPR.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(pr, decodedPR) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedPR), spew.Sdump(pr)) + } +} + +func TestNewMixPairReqErrs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + modArgs func(*mixPairReqArgs) + err error + }{{ + name: "LongScriptClass", + modArgs: func(a *mixPairReqArgs) { + a.scriptClass = "scriptclassthatexceedsmaximumlength" + }, + err: ErrMixPairReqScriptClassTooLong, + }, { + name: "NonAsciiScriptClass", + modArgs: func(a *mixPairReqArgs) { + a.scriptClass = string([]byte{128}) + }, + err: ErrMalformedStrictString, + }, { + name: "TooManyUTXOs", + modArgs: func(a *mixPairReqArgs) { + a.utxos = make([]MixPairReqUTXO, MaxMixPairReqUTXOs+1) + }, + err: ErrTooManyMixPairReqUTXOs, + }} + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + a := newMixPairReqArgs() + tc.modArgs(a) + _, err := a.msg() + if !errors.Is(err, tc.err) { + t.Errorf("expected error %v; got %v", tc.err, err) + } + }) + } +} + +func TestMsgMixPairReqCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + a := newMixPairReqArgs() + msg, err := a.msg() + if err != nil { + t.Fatalf("%v", err) + } + + buf := new(bytes.Buffer) + err = msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixPairReq) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixPairReqMaxPayloadLength tests the results returned by +// [MsgMixPairReq.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixPairReqMaxPayloadLength(t *testing.T) { + var pr *MsgMixPairReq + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := pr.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var maxUTXOLen uint32 = 32 + // Hash + 4 + // Index + 1 + // Tree + varBytesLen(MaxMixPairReqUTXOScriptLen) + // P2SH redeem script + varBytesLen(33) + // Pubkey + varBytesLen(64) // Signature + var maxTxOutLen uint32 = 8 + // Value + 2 + // Version + varBytesLen(16384) // PkScript (txscript.MaxScriptLen) + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 4 + // Expiry + 8 + // Amount + varBytesLen(MaxMixPairReqScriptClassLen) + // Script class + 2 + // Tx version + 4 + // Locktime + 4 + // Message count + 8 + // Input value + uint32(VarIntSerializeSize(MaxMixPairReqUTXOs)) + // UTXO count + MaxMixPairReqUTXOs*maxUTXOLen + // UTXOs + maxTxOutLen // Change output + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := pr.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixsecrets.go b/wire/msgmixsecrets.go new file mode 100644 index 0000000000..53d0b06e2d --- /dev/null +++ b/wire/msgmixsecrets.go @@ -0,0 +1,291 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +// MsgMixSecrets reveals secrets of a failed mix run. After secrets are +// exposed, peers can determine which peers (if any) misbehaved and remove +// them from the next run in the session. +// +// It implements the Message interface. +type MsgMixSecrets struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + Seed [32]byte + SlotReserveMsgs [][]byte + DCNetMsgs MixVect + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +func writeMixVect(op string, w io.Writer, pver uint32, vec MixVect) error { + if len(vec) > MaxMixMcount { + msg := "DC-net mix vector dimensions are too large for maximum message count" + return messageError(op, ErrInvalidMsg, msg) + } + + // Write dimensions + err := WriteVarInt(w, pver, uint64(len(vec))) + if err != nil { + return err + } + err = WriteVarInt(w, pver, MixMsgSize) + if err != nil { + return err + } + + // Write messages + for i := range vec { + err = writeElement(w, &vec[i]) + if err != nil { + return err + } + } + + return nil +} + +func readMixVect(op string, r io.Reader, pver uint32, vec *MixVect) error { + // Read dimensions + n, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if n == 0 { + *vec = MixVect{} + return nil + } + msize, err := ReadVarInt(r, pver) + if err != nil { + return err + } + + if n > MaxMixMcount { + msg := "DC-net mix vector dimensions are too large for maximum message count" + return messageError(op, ErrInvalidMsg, msg) + } + if msize != MixMsgSize { + msg := fmt.Sprintf("mixed message length must be %d [got: %d]", + MixMsgSize, msize) + return messageError(op, ErrInvalidMsg, msg) + } + + // Read messages + *vec = make(MixVect, n) + for i := uint64(0); i < n; i++ { + err = readElement(r, &(*vec)[i]) + if err != nil { + return err + } + } + + return nil +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixSecrets) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixSecrets.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run, &msg.Seed) + if err != nil { + return err + } + + numSRs, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if numSRs > MaxMixMcount { + msg := fmt.Sprintf("too many total mixed messages [count %v, max %v]", + numSRs, MaxMixMcount) + return messageError(op, ErrInvalidMsg, msg) + } + msg.SlotReserveMsgs = make([][]byte, numSRs) + for i := uint64(0); i < numSRs; i++ { + sr, err := ReadVarBytes(r, pver, MaxMixFieldValLen, + "slot reservation mixed message") + if err != nil { + return err + } + msg.SlotReserveMsgs[i] = sr + } + + var dcnetMsgs MixVect + err = readMixVect(op, r, pver, &dcnetMsgs) + if err != nil { + return err + } + msg.DCNetMsgs = dcnetMsgs + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixSecrets) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixSecrets.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixSecrets) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixSecrets) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// This method never errors for invalid message construction. +func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run, &msg.Seed) + if err != nil { + return err + } + + err = WriteVarInt(w, pver, uint64(len(msg.SlotReserveMsgs))) + if err != nil { + return err + } + for _, sr := range msg.SlotReserveMsgs { + err := WriteVarBytes(w, pver, sr) + if err != nil { + return err + } + } + + err = writeMixVect(op, w, pver, msg.DCNetMsgs) + if err != nil { + return err + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixSecrets) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixSecrets+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixSecrets) Command() string { + return CmdMixSecrets +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixSecrets) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation + return 54444 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixSecrets) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixSecrets) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns nil. Previous messages are not needed to perform blame +// assignment, because of the assumption that all previous messages must have +// been received for a blame stage to be necessary. Additionally, a +// commitment to the secrets message is included in the key exchange, and +// future message hashes are not available at that time. +func (msg *MsgMixSecrets) PrevMsgs() []chainhash.Hash { + return nil +} + +// Sid returns the session ID. +func (msg *MsgMixSecrets) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixSecrets) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixSecrets returns a new mixsecrets message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixSecrets(identity [33]byte, sid [32]byte, run uint32, + seed [32]byte, slotReserveMsgs [][]byte, dcNetMsgs MixVect) *MsgMixSecrets { + + return &MsgMixSecrets{ + Identity: identity, + SessionID: sid, + Run: run, + Seed: seed, + SlotReserveMsgs: slotReserveMsgs, + DCNetMsgs: dcNetMsgs, + } +} diff --git a/wire/msgmixsecrets_test.go b/wire/msgmixsecrets_test.go new file mode 100644 index 0000000000..2905d94eb4 --- /dev/null +++ b/wire/msgmixsecrets_test.go @@ -0,0 +1,194 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func newTestMixSecrets() *MsgMixSecrets { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + seed := *(*[32]byte)(repeat(0x84, 32)) + + sr := make([][]byte, 4) + for b := byte(0x85); b < 0x89; b++ { + sr[b-0x85] = repeat(b, 32) + } + + m := make(MixVect, 4) + for b := byte(0x89); b < 0x8D; b++ { + copy(m[b-0x89][:], repeat(b, 20)) + } + + rs := NewMsgMixSecrets(id, sid, run, seed, sr, m) + rs.Signature = sig + + return rs +} + +func TestMsgMixSecretsWire(t *testing.T) { + pver := MixVersion + + rs := newTestMixSecrets() + + buf := new(bytes.Buffer) + err := rs.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + expected = append(expected, repeat(0x84, 32)...) // Seed + // Four slot reservation mixed messages (repeating 32 bytes of 0x85, 0x86, 0x87, 0x88) + expected = append(expected, 0x04) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x85, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x87, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x88, 32)...) + // Four slot reservation mixed messages (repeating 20 bytes of 0x89, 0x8a, 0x8b, 0x8c) + expected = append(expected, 0x04, 0x14) + expected = append(expected, repeat(0x89, 20)...) + expected = append(expected, repeat(0x8a, 20)...) + expected = append(expected, repeat(0x8b, 20)...) + expected = append(expected, repeat(0x8c, 20)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedRS := new(MsgMixSecrets) + err = decodedRS.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(rs, decodedRS) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedRS), spew.Sdump(rs)) + } +} + +func TestMsgMixSecretsCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixSecrets() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixSecrets) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixSecretsMaxPayloadLength tests the results returned by +// [MsgMixSecrets.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixSecretsMaxPayloadLength(t *testing.T) { + var rs *MsgMixSecrets + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := rs.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + 32 + // Seed + uint32(VarIntSerializeSize(MaxMixMcount)) + // SR message count + MaxMixMcount*varBytesLen(MaxMixFieldValLen) + // Unpadded SR values + uint32(VarIntSerializeSize(MaxMixMcount)) + // DC-net message count + uint32(VarIntSerializeSize(MixMsgSize)) + // DC-net message size + MaxMixMcount*MixMsgSize // DC-net messages + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := rs.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixslotreserve.go b/wire/msgmixslotreserve.go new file mode 100644 index 0000000000..3c69fbf572 --- /dev/null +++ b/wire/msgmixslotreserve.go @@ -0,0 +1,314 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "fmt" + "hash" + "io" + + "github.com/decred/dcrd/chaincfg/chainhash" +) + +const ( + // MaxMixMcount is the maximum number of mixed messages that are allowed + // in a single mix. This restricts the total allowed size of the slot + // reservation and XOR DC-net matrices. + // This value is an high estimate of what a large mix may look like, + // based on statistics from the centralized mixing server. + MaxMixMcount = 1024 + + // MaxMixFieldValLen is the maximum number of bytes allowed to represent + // a value in the slot reservation mix bounded by the field prime. + MaxMixFieldValLen = 32 +) + +// MsgMixSlotReserve is the slot reservation broadcast. It implements the Message +// interface. +type MsgMixSlotReserve struct { + Signature [64]byte + Identity [33]byte + SessionID [32]byte + Run uint32 + DCMix [][][]byte // mcount-by-peers matrix of field numbers + SeenCiphertexts []chainhash.Hash + + // hash records the hash of the message. It is a member of the + // message for convenience and performance, but is never automatically + // set during creation or deserialization. + hash chainhash.Hash +} + +// BtcDecode decodes r using the Decred protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgMixSlotReserve) BtcDecode(r io.Reader, pver uint32) error { + const op = "MsgMixSlotReserve.BtcDecode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := readElements(r, &msg.Signature, &msg.Identity, &msg.SessionID, + &msg.Run) + if err != nil { + return err + } + + // Read the DCMix + mcount, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if mcount == 0 { + msg := fmt.Sprintf("too few mixed messages [%v]", mcount) + return messageError(op, ErrInvalidMsg, msg) + } + if mcount > MaxMixMcount { + msg := fmt.Sprintf("too many total mixed messages [%v]", mcount) + return messageError(op, ErrInvalidMsg, msg) + } + kpcount, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if kpcount == 0 { + msg := fmt.Sprintf("too few mixing peers [%v]", kpcount) + return messageError(op, ErrInvalidMsg, msg) + } + if kpcount > MaxMixPeers { + msg := fmt.Sprintf("too many mixing peers [count %v, max %v]", + kpcount, MaxMixPeers) + return messageError(op, ErrInvalidMsg, msg) + } + dcmix := make([][][]byte, mcount) + for i := range dcmix { + dcmix[i] = make([][]byte, kpcount) + for j := range dcmix[i] { + v, err := ReadVarBytes(r, pver, MaxMixFieldValLen, + "slot reservation field value") + if err != nil { + return err + } + dcmix[i][j] = v + } + } + msg.DCMix = dcmix + + count, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + seen := make([]chainhash.Hash, count) + for i := range seen { + err := readElement(r, &seen[i]) + if err != nil { + return err + } + } + msg.SeenCiphertexts = seen + + return nil +} + +// BtcEncode encodes the receiver to w using the Decred protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgMixSlotReserve) BtcEncode(w io.Writer, pver uint32) error { + const op = "MsgMixSlotReserve.BtcEncode" + if pver < MixVersion { + msg := fmt.Sprintf("%s message invalid for protocol version %d", + msg.Command(), pver) + return messageError(op, ErrMsgInvalidForPVer, msg) + } + + err := writeElement(w, &msg.Signature) + if err != nil { + return err + } + + err = msg.writeMessageNoSignature(op, w, pver) + if err != nil { + return err + } + + return nil +} + +// Hash returns the message hash calculated by WriteHash. +// +// Hash returns an invalid or zero hash if WriteHash has not been called yet. +// +// This method is not safe while concurrently calling WriteHash. +func (msg *MsgMixSlotReserve) Hash() chainhash.Hash { + return msg.hash +} + +// WriteHash serializes the message to a hasher and records the sum in the +// message's Hash field. +// +// The hasher's Size() must equal chainhash.HashSize, or this method will +// panic. This method is designed to work only with hashers returned by +// blake256.New. +func (msg *MsgMixSlotReserve) WriteHash(h hash.Hash) { + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + sum := h.Sum(msg.hash[:0]) + if len(sum) != len(msg.hash) { + s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) + panic(s) + } +} + +// writeMessageNoSignature serializes all elements of the message except for +// the signature. This allows code reuse between message serialization, and +// signing and verifying these message contents. +// +// If w implements hash.Hash, no errors will be returned for invalid message +// construction. +func (msg *MsgMixSlotReserve) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { + _, hashing := w.(hash.Hash) + + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Run) + if err != nil { + return err + } + + // Write the DCMix + mcount := len(msg.DCMix) + if !hashing && mcount == 0 { + msg := fmt.Sprintf("too few mixed messages [%v]", mcount) + return messageError(op, ErrInvalidMsg, msg) + } + if !hashing && mcount > MaxMixMcount { + msg := fmt.Sprintf("too many total mixed messages [%v]", mcount) + return messageError(op, ErrInvalidMsg, msg) + } + kpcount := len(msg.DCMix[0]) + if !hashing && kpcount == 0 { + msg := fmt.Sprintf("too few mixing peers [%v]", kpcount) + return messageError(op, ErrInvalidMsg, msg) + } + if !hashing && kpcount > MaxMixPeers { + msg := fmt.Sprintf("too many mixing peers [%v]", kpcount) + return messageError(op, ErrInvalidMsg, msg) + } + err = WriteVarInt(w, pver, uint64(mcount)) + if err != nil { + return err + } + err = WriteVarInt(w, pver, uint64(kpcount)) + if err != nil { + return err + } + for i := range msg.DCMix { + if !hashing && len(msg.DCMix[i]) != kpcount { + msg := "invalid matrix dimensions" + return messageError(op, ErrInvalidMsg, msg) + } + for j := range msg.DCMix[i] { + v := msg.DCMix[i][j] + if !hashing && len(v) > MaxMixFieldValLen { + msg := "value exceeds bytes necessary to represent number in field" + return messageError(op, ErrInvalidMsg, msg) + } + err := WriteVarBytes(w, pver, v) + if err != nil { + return err + } + } + } + + count := len(msg.SeenCiphertexts) + if !hashing && count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + + err = WriteVarInt(w, pver, uint64(count)) + if err != nil { + return err + } + for i := range msg.SeenCiphertexts { + err = writeElement(w, &msg.SeenCiphertexts[i]) + if err != nil { + return err + } + } + + return nil +} + +// WriteSignedData writes a tag identifying the message data, followed by all +// message fields excluding the signature. This is the data committed to when +// the message is signed. +func (msg *MsgMixSlotReserve) WriteSignedData(h hash.Hash) { + WriteVarString(h, MixVersion, CmdMixSlotReserve+"-sig") + msg.writeMessageNoSignature("", h, MixVersion) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgMixSlotReserve) Command() string { + return CmdMixSlotReserve +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgMixSlotReserve) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 17318030 +} + +// Pub returns the message sender's public key identity. +func (msg *MsgMixSlotReserve) Pub() []byte { + return msg.Identity[:] +} + +// Sig returns the message signature. +func (msg *MsgMixSlotReserve) Sig() []byte { + return msg.Signature[:] +} + +// PrevMsgs returns the previous CT messages seen by the peer. +func (msg *MsgMixSlotReserve) PrevMsgs() []chainhash.Hash { + return msg.SeenCiphertexts +} + +// Sid returns the session ID. +func (msg *MsgMixSlotReserve) Sid() []byte { + return msg.SessionID[:] +} + +// GetRun returns the run number. +func (msg *MsgMixSlotReserve) GetRun() uint32 { + return msg.Run +} + +// NewMsgMixSlotReserve returns a new mixslotres message that conforms to the +// Message interface using the passed parameters and defaults for the +// remaining fields. +func NewMsgMixSlotReserve(identity [33]byte, sid [32]byte, run uint32, + dcmix [][][]byte, seenCTs []chainhash.Hash) *MsgMixSlotReserve { + + return &MsgMixSlotReserve{ + Identity: identity, + SessionID: sid, + Run: run, + DCMix: dcmix, + SeenCiphertexts: seenCTs, + } +} diff --git a/wire/msgmixslotreserve_test.go b/wire/msgmixslotreserve_test.go new file mode 100644 index 0000000000..070f7df62a --- /dev/null +++ b/wire/msgmixslotreserve_test.go @@ -0,0 +1,223 @@ +// Copyright (c) 2023-2024 The Decred developers +// Use of this source code is governed by an ISC +// license that can be found in the LICENSE file. + +package wire + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/decred/dcrd/chaincfg/chainhash" +) + +func newTestMixSlotReserve() *MsgMixSlotReserve { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) + + const run = uint32(0x83838383) + + const mcount = 4 + const kpcount = 4 + dcmix := make([][][]byte, mcount) + // will add 4x4 field numbers of incrementing repeating byte values to + // dcmix, ranging from 0x84 through 0x93 + b := byte(0x84) + for i := 0; i < mcount; i++ { + dcmix[i] = make([][]byte, kpcount) + for j := 0; j < kpcount; j++ { + dcmix[i][j] = repeat(b, 32) + b++ + } + } + + seenCTs := make([]chainhash.Hash, 4) + for b := byte(0x94); b < 0x98; b++ { + copy(seenCTs[b-0x94][:], repeat(b, 32)) + } + + sr := NewMsgMixSlotReserve(id, sid, run, dcmix, seenCTs) + sr.Signature = sig + + return sr +} +func TestMsgMixSlotReserveWire(t *testing.T) { + pver := MixVersion + + sr := newTestMixSlotReserve() + + buf := new(bytes.Buffer) + err := sr.BtcEncode(buf, pver) + if err != nil { + t.Fatal(err) + } + + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Run + // 4x4 slot reservation mixed messages (repeating 32 bytes from 0x84 through 0x93) + expected = append(expected, 0x04, 0x04) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x84, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x85, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x87, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8a, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8b, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8c, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8d, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8e, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8f, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x90, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x91, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x92, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x93, 32)...) + // Four seen CTs (repeating 32 bytes of 0x94, 0x95, 0x96, 0x97) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x94, 32)...) + expected = append(expected, repeat(0x95, 32)...) + expected = append(expected, repeat(0x96, 32)...) + expected = append(expected, repeat(0x97, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + + decodedSR := new(MsgMixSlotReserve) + err = decodedSR.BtcDecode(bytes.NewReader(buf.Bytes()), pver) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(sr, decodedSR) { + t.Errorf("BtcDecode got: %s want: %s", + spew.Sdump(decodedSR), spew.Sdump(sr)) + } +} + +func TestMsgMixSlotReserveCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixSlotReserve() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixSlotReserve) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixSlotReserveMaxPayloadLength tests the results returned by +// [MsgMixSlotReserve.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixSlotReserveMaxPayloadLength(t *testing.T) { + var sr *MsgMixSlotReserve + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := sr.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Run + uint32(VarIntSerializeSize(MaxMixMcount)) + // Message count + uint32(VarIntSerializeSize(MaxMixPeers)) + // Peer count + MaxMixMcount*MaxMixPeers*varBytesLen(MaxMixFieldValLen) + // Padded SR values + uint32(VarIntSerializeSize(MaxMixPeers)) + // Ciphertext count + 32*MaxMixPeers // Ciphertext hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := sr.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/protocol.go b/wire/protocol.go index a49ddf1e99..c2585d4380 100644 --- a/wire/protocol.go +++ b/wire/protocol.go @@ -17,7 +17,7 @@ const ( InitialProcotolVersion uint32 = 1 // ProtocolVersion is the latest protocol version this package supports. - ProtocolVersion uint32 = 9 + ProtocolVersion uint32 = 10 // NodeBloomVersion is the protocol version which added the SFNodeBloom // service flag (unused). @@ -51,6 +51,9 @@ const ( // RemoveRejectVersion is the protocol version which removes support for the // reject message. RemoveRejectVersion uint32 = 9 + + // MixVersion is the protocol version which adds peer-to-peer mixing. + MixVersion uint32 = 10 ) // ServiceFlag identifies services supported by a Decred peer.