diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 6e6876db6a..202cf14a2d 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -70,7 +70,7 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack if err != nil { return nil, err } - localMhBuf, _ := multihash.EncodeName(localMh, sdpHashToMh(localFingerprints[0].Algorithm)) + localMhBuf, _ := multihash.Encode(localMh, multihash.SHA2_256) localFpMultibase, _ := multibase.Encode(multibase.Base58BTC, localMhBuf) ctx, cancel := context.WithCancel(context.Background()) diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index 5a8105dc16..72803d6fc1 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -1,8 +1,8 @@ package libp2pwebrtc import ( - "bytes" "context" + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -10,7 +10,6 @@ import ( "encoding/hex" "fmt" "net" - "sort" "strings" "sync" @@ -347,7 +346,7 @@ func (t *WebRTCTransport) getCertificateFingerprint() (webrtc.DTLSFingerprint, e return fps[0], nil } -func (t *WebRTCTransport) generateNoisePrologue(pc *webrtc.PeerConnection) ([]byte, error) { +func (t *WebRTCTransport) generateNoisePrologue(pc *webrtc.PeerConnection, inbound bool) ([]byte, error) { raw := pc.SCTP().Transport().GetRemoteCertificate() cert, err := x509.ParseCertificate(raw) if err != nil { @@ -359,12 +358,7 @@ func (t *WebRTCTransport) generateNoisePrologue(pc *webrtc.PeerConnection) ([]by return nil, err } - hashAlgo, err := fingerprint.HashFromString(localFp.Algorithm) - if err != nil { - log.Debugf("could not find hash algo: %s %v", localFp.Algorithm, err) - return nil, err - } - remoteFp, err := fingerprint.Fingerprint(cert, hashAlgo) + remoteFp, err := fingerprint.Fingerprint(cert, crypto.SHA256) if err != nil { return nil, err } @@ -374,39 +368,36 @@ func (t *WebRTCTransport) generateNoisePrologue(pc *webrtc.PeerConnection) ([]by return nil, err } - mhAlgoName := sdpHashToMh(localFp.Algorithm) - if mhAlgoName == "" { - mhAlgoName = localFp.Algorithm - } - local := strings.ReplaceAll(localFp.Value, ":", "") localBytes, err := hex.DecodeString(local) if err != nil { return nil, err } - localEncoded, err := multihash.EncodeName(localBytes, mhAlgoName) + localEncoded, err := multihash.Encode(localBytes, multihash.SHA2_256) if err != nil { log.Debugf("could not encode multihash for local fingerprint") return nil, err } - remoteEncoded, err := multihash.EncodeName(remoteFpBytes, mhAlgoName) + remoteEncoded, err := multihash.Encode(remoteFpBytes, multihash.SHA2_256) if err != nil { log.Debugf("could not encode multihash for remote fingerprint") return nil, err } - b := [][]byte{localEncoded, remoteEncoded} - sort.Slice(b, func(i, j int) bool { - return bytes.Compare(b[i], b[j]) < 0 - }) - result := append([]byte("libp2p-webrtc-noise:"), b[0]...) - result = append(result, b[1]...) + result := []byte("libp2p-webrtc-noise:") + if inbound { + result = append(result, remoteEncoded...) + result = append(result, localEncoded...) + } else { + result = append(result, localEncoded...) + result = append(result, remoteEncoded...) + } return result, nil } func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerConnection, datachannel *dataChannel, peer peer.ID, inbound bool) (secureConn sec.SecureConn, err error) { - prologue, err := t.generateNoisePrologue(pc) + prologue, err := t.generateNoisePrologue(pc, inbound) if err != nil { return nil, errNoise("could not generate prologue", err) } diff --git a/p2p/transport/webrtc/util.go b/p2p/transport/webrtc/util.go index 712b111d12..572ddc1c2b 100644 --- a/p2p/transport/webrtc/util.go +++ b/p2p/transport/webrtc/util.go @@ -10,32 +10,6 @@ import ( "github.com/pion/webrtc/v3" ) -func mhToSdpHash(mh string) string { - switch mh { - case "sha1": - return "sha1" - case "sha2-256": - return "sha-256" - case "md5": - return "md5" - default: - return "" - } -} - -func sdpHashToMh(sdpHash string) string { - switch sdpHash { - case "sha-256": - return "sha2-256" - case "sha1": - return "sha1" - case "md5": - return "md5" - default: - return "" - } -} - func maFingerprintToSdp(fp string) string { result := "" first := true @@ -57,11 +31,7 @@ func fingerprintToSDP(fp *mh.DecodedMultihash) string { return "" } fpDigest := maFingerprintToSdp(hex.EncodeToString(fp.Digest)) - fpAlgo := mhToSdpHash(strings.ToLower(fp.Name)) - if fpAlgo == "" { - fpAlgo = strings.ToLower(fp.Name) - } - return fpAlgo + " " + fpDigest + return "sha-256 " + fpDigest } func decodeRemoteFingerprint(maddr ma.Multiaddr) (*mh.DecodedMultihash, error) { @@ -81,11 +51,7 @@ func encodeDTLSFingerprint(fp webrtc.DTLSFingerprint) (string, error) { if err != nil { return "", err } - algo := sdpHashToMh(strings.ToLower(fp.Algorithm)) - if algo == "" { - algo = fp.Algorithm - } - encoded, err := mh.EncodeName(digest, algo) + encoded, err := mh.Encode(digest, mh.SHA2_256) if err != nil { return "", err }