diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index ef234bc5bc..f7d7253c9c 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -7,11 +7,13 @@ import ( "fmt" "io/ioutil" mrand "math/rand" + "net" "time" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" tpt "github.com/libp2p/go-libp2p-core/transport" + filter "github.com/libp2p/go-maddr-filter" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo" @@ -62,12 +64,12 @@ var _ = Describe("Connection", func() { }) It("handshakes on IPv4", func() { - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) @@ -86,12 +88,12 @@ var _ = Describe("Connection", func() { }) It("handshakes on IPv6", func() { - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip6/::1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) @@ -110,12 +112,12 @@ var _ = Describe("Connection", func() { }) It("opens and accepts streams", func() { - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) @@ -139,11 +141,11 @@ var _ = Describe("Connection", func() { It("fails if the peer ID doesn't match", func() { thirdPartyID, _ := createPeer() - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) // dial, but expect the wrong peer ID _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) @@ -161,14 +163,47 @@ var _ = Describe("Connection", func() { Eventually(done).Should(BeClosed()) }) + It("filters addresses", func() { + filters := filter.NewFilters() + ipNet := net.IPNet{ + IP: net.IPv4(127, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 255), + } + filters.AddFilter(ipNet, filter.ActionDeny) + testMA, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234/quic") + Expect(err).ToNot(HaveOccurred()) + Expect(filters.AddrBlocked(testMA)).To(BeTrue()) + + serverTransport, err := NewTransport(serverKey, nil, filters) + Expect(err).ToNot(HaveOccurred()) + ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, nil, nil) + Expect(err).ToNot(HaveOccurred()) + + // make sure that connection attempts fails + quicConfig.HandshakeTimeout = 250 * time.Millisecond + _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + + // now allow the address and make sure the connection goes through + quicConfig.HandshakeTimeout = 2 * time.Second + filters.AddFilter(ipNet, filter.ActionAccept) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + Expect(err).ToNot(HaveOccurred()) + conn.Close() + }) + It("dials to two servers at the same time", func() { serverID2, serverKey2 := createPeer() - serverTransport, err := NewTransport(serverKey, nil) + serverTransport, err := NewTransport(serverKey, nil, nil) Expect(err).ToNot(HaveOccurred()) ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - serverTransport2, err := NewTransport(serverKey2, nil) defer ln1.Close() + serverTransport2, err := NewTransport(serverKey2, nil, nil) Expect(err).ToNot(HaveOccurred()) ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic") defer ln2.Close() @@ -194,7 +229,7 @@ var _ = Describe("Connection", func() { } }() - clientTransport, err := NewTransport(clientKey, nil) + clientTransport, err := NewTransport(clientKey, nil, nil) Expect(err).ToNot(HaveOccurred()) c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) Expect(err).ToNot(HaveOccurred()) diff --git a/p2p/transport/quic/filtered_conn.go b/p2p/transport/quic/filtered_conn.go new file mode 100644 index 0000000000..dc60bb08e2 --- /dev/null +++ b/p2p/transport/quic/filtered_conn.go @@ -0,0 +1,34 @@ +package libp2pquic + +import ( + "net" + + filter "github.com/libp2p/go-maddr-filter" +) + +type filteredConn struct { + net.PacketConn + + filters *filter.Filters +} + +func newFilteredConn(c net.PacketConn, filters *filter.Filters) net.PacketConn { + return &filteredConn{PacketConn: c, filters: filters} +} + +func (c *filteredConn) ReadFrom(b []byte) (n int, addr net.Addr, rerr error) { + for { + n, addr, rerr = c.PacketConn.ReadFrom(b) + // Short Header packet, see https://tools.ietf.org/html/draft-ietf-quic-invariants-07#section-4.2. + if n < 1 || b[0]&0x80 == 0 { + return + } + maddr, err := toQuicMultiaddr(addr) + if err != nil { + panic(err) + } + if !c.filters.AddrBlocked(maddr) { + return + } + } +} diff --git a/p2p/transport/quic/libp2pquic_suite_test.go b/p2p/transport/quic/libp2pquic_suite_test.go index a2e1df400d..d6ae1510fd 100644 --- a/p2p/transport/quic/libp2pquic_suite_test.go +++ b/p2p/transport/quic/libp2pquic_suite_test.go @@ -5,12 +5,13 @@ import ( mrand "math/rand" "runtime/pprof" "strings" + "testing" "time" + "github.com/lucas-clemente/quic-go" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - - "testing" ) func TestLibp2pQuicTransport(t *testing.T) { @@ -22,8 +23,11 @@ var _ = BeforeSuite(func() { mrand.Seed(GinkgoRandomSeed()) }) -var garbageCollectIntervalOrig time.Duration -var maxUnusedDurationOrig time.Duration +var ( + garbageCollectIntervalOrig time.Duration + maxUnusedDurationOrig time.Duration + origQuicConfig *quic.Config +) func isGarbageCollectorRunning() bool { var b bytes.Buffer @@ -37,10 +41,13 @@ var _ = BeforeEach(func() { maxUnusedDurationOrig = maxUnusedDuration garbageCollectInterval = 50 * time.Millisecond maxUnusedDuration = 0 + origQuicConfig = quicConfig + quicConfig = quicConfig.Clone() }) var _ = AfterEach(func() { Eventually(isGarbageCollectorRunning).Should(BeFalse()) garbageCollectInterval = garbageCollectIntervalOrig maxUnusedDuration = maxUnusedDurationOrig + quicConfig = origQuicConfig }) diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index ac7716f353..3388f1ac37 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -23,7 +23,7 @@ var _ = Describe("Listener", func() { Expect(err).ToNot(HaveOccurred()) key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) Expect(err).ToNot(HaveOccurred()) - t, err = NewTransport(key, nil) + t, err = NewTransport(key, nil, nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/p2p/transport/quic/reuse_base.go b/p2p/transport/quic/reuse_base.go index 347a9a6cf7..053b777395 100644 --- a/p2p/transport/quic/reuse_base.go +++ b/p2p/transport/quic/reuse_base.go @@ -4,6 +4,8 @@ import ( "net" "sync" "time" + + filter "github.com/libp2p/go-maddr-filter" ) // Constant. Defined as variables to simplify testing. @@ -20,7 +22,10 @@ type reuseConn struct { unusedSince time.Time } -func newReuseConn(conn net.PacketConn) *reuseConn { +func newReuseConn(conn net.PacketConn, filters *filter.Filters) *reuseConn { + if filters != nil { + conn = newFilteredConn(conn, filters) + } return &reuseConn{PacketConn: conn} } @@ -49,6 +54,8 @@ func (c *reuseConn) ShouldGarbageCollect(now time.Time) bool { type reuseBase struct { mutex sync.Mutex + filters *filter.Filters + garbageCollectorRunning bool unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn @@ -56,8 +63,9 @@ type reuseBase struct { global map[int]*reuseConn } -func newReuseBase() reuseBase { +func newReuseBase(filters *filter.Filters) reuseBase { return reuseBase{ + filters: filters, unicast: make(map[string]map[int]*reuseConn), global: make(map[int]*reuseConn), } @@ -139,7 +147,7 @@ func (r *reuseBase) dialLocked(network string, raddr *net.UDPAddr, ips []net.IP) if err != nil { return nil, err } - rconn := newReuseConn(conn) + rconn := newReuseConn(conn, r.filters) r.global[conn.LocalAddr().(*net.UDPAddr).Port] = rconn return rconn, nil } @@ -151,7 +159,7 @@ func (r *reuseBase) Listen(network string, laddr *net.UDPAddr) (*reuseConn, erro } localAddr := conn.LocalAddr().(*net.UDPAddr) - rconn := newReuseConn(conn) + rconn := newReuseConn(conn, r.filters) rconn.IncreaseCount() r.mutex.Lock() diff --git a/p2p/transport/quic/reuse_linux_test.go b/p2p/transport/quic/reuse_linux_test.go index 8bc401a0c8..6fe77dbbf6 100644 --- a/p2p/transport/quic/reuse_linux_test.go +++ b/p2p/transport/quic/reuse_linux_test.go @@ -14,7 +14,7 @@ var _ = Describe("Reuse (on Linux)", func() { BeforeEach(func() { var err error - reuse, err = newReuse() + reuse, err = newReuse(nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/p2p/transport/quic/reuse_not_win.go b/p2p/transport/quic/reuse_not_win.go index fb36b83e89..57097a3098 100644 --- a/p2p/transport/quic/reuse_not_win.go +++ b/p2p/transport/quic/reuse_not_win.go @@ -5,6 +5,8 @@ package libp2pquic import ( "net" + filter "github.com/libp2p/go-maddr-filter" + "github.com/vishvananda/netlink" ) @@ -14,7 +16,7 @@ type reuse struct { handle *netlink.Handle // Only set on Linux. nil on other systems. } -func newReuse() (*reuse, error) { +func newReuse(filters *filter.Filters) (*reuse, error) { handle, err := netlink.NewHandle(SupportedNlFamilies...) if err == netlink.ErrNotImplemented { handle = nil @@ -22,7 +24,7 @@ func newReuse() (*reuse, error) { return nil, err } return &reuse{ - reuseBase: newReuseBase(), + reuseBase: newReuseBase(filters), handle: handle, }, nil } diff --git a/p2p/transport/quic/reuse_test.go b/p2p/transport/quic/reuse_test.go index e7739156e7..6a24d5d575 100644 --- a/p2p/transport/quic/reuse_test.go +++ b/p2p/transport/quic/reuse_test.go @@ -37,7 +37,7 @@ var _ = Describe("Reuse", func() { BeforeEach(func() { var err error - reuse, err = newReuse() + reuse, err = newReuse(nil) Expect(err).ToNot(HaveOccurred()) }) diff --git a/p2p/transport/quic/reuse_win.go b/p2p/transport/quic/reuse_win.go index 0f57c8e0ea..14ea1babfa 100644 --- a/p2p/transport/quic/reuse_win.go +++ b/p2p/transport/quic/reuse_win.go @@ -2,14 +2,18 @@ package libp2pquic -import "net" +import ( + "net" + + filter "github.com/libp2p/go-maddr-filter" +) type reuse struct { reuseBase } -func newReuse() (*reuse, error) { - return &reuse{reuseBase: newReuseBase()}, nil +func newReuse(filters *filter.Filters) (*reuse, error) { + return &reuse{reuseBase: newReuseBase(filters)}, nil } func (r *reuse) Dial(network string, raddr *net.UDPAddr) (*reuseConn, error) { diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 5e3a8b5522..a60b28f65a 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p-core/pnet" tpt "github.com/libp2p/go-libp2p-core/transport" p2ptls "github.com/libp2p/go-libp2p-tls" + filter "github.com/libp2p/go-maddr-filter" quic "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -36,12 +37,12 @@ type connManager struct { reuseUDP6 *reuse } -func newConnManager() (*connManager, error) { - reuseUDP4, err := newReuse() +func newConnManager(filters *filter.Filters) (*connManager, error) { + reuseUDP4, err := newReuse(filters) if err != nil { return nil, err } - reuseUDP6, err := newReuse() + reuseUDP6, err := newReuse(filters) if err != nil { return nil, err } @@ -89,7 +90,7 @@ type transport struct { var _ tpt.Transport = &transport{} // NewTransport creates a new QUIC transport -func NewTransport(key ic.PrivKey, psk pnet.PSK) (tpt.Transport, error) { +func NewTransport(key ic.PrivKey, psk pnet.PSK, filters *filter.Filters) (tpt.Transport, error) { if len(psk) > 0 { log.Error("QUIC doesn't support private networks yet.") return nil, errors.New("QUIC doesn't support private networks yet") @@ -102,7 +103,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK) (tpt.Transport, error) { if err != nil { return nil, err } - connManager, err := newConnManager() + connManager, err := newConnManager(filters) if err != nil { return nil, err }