Skip to content

Commit

Permalink
add optional simulated addresses to pipeconn and inmemorylistener (#1449
Browse files Browse the repository at this point in the history
)

* add optional simulated addresses to pipeconn and inmemorylistener

* add mutexes to addresses of pipeConn and InmemoryListener
  • Loading branch information
tobikris authored Dec 7, 2022
1 parent dbf457e commit 951f5a1
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 6 deletions.
49 changes: 43 additions & 6 deletions fasthttputil/inmemory_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed:
// It may be used either for fast in-process client<->server communications
// without network stack overhead or for client<->server tests.
type InmemoryListener struct {
lock sync.Mutex
closed bool
conns chan acceptConn
lock sync.Mutex
closed bool
conns chan acceptConn
listenerAddr net.Addr
addrLock sync.RWMutex
}

type acceptConn struct {
Expand All @@ -31,6 +33,14 @@ func NewInmemoryListener() *InmemoryListener {
}
}

// SetLocalAddr sets the (simulated) local address for the listener.
func (ln *InmemoryListener) SetLocalAddr(localAddr net.Addr) {
ln.addrLock.Lock()
defer ln.addrLock.Unlock()

ln.listenerAddr = localAddr
}

// Accept implements net.Listener's Accept.
//
// It is safe calling Accept from concurrently running goroutines.
Expand Down Expand Up @@ -60,12 +70,26 @@ func (ln *InmemoryListener) Close() error {
return err
}

type inmemoryAddr int

func (inmemoryAddr) Network() string {
return "inmemory"
}

func (inmemoryAddr) String() string {
return "InmemoryListener"
}

// Addr implements net.Listener's Addr.
func (ln *InmemoryListener) Addr() net.Addr {
return &net.UnixAddr{
Name: "InmemoryListener",
Net: "memory",
ln.addrLock.RLock()
defer ln.addrLock.RUnlock()

if ln.listenerAddr != nil {
return ln.listenerAddr
}

return inmemoryAddr(0)
}

// Dial creates new client<->server connection.
Expand All @@ -74,7 +98,20 @@ func (ln *InmemoryListener) Addr() net.Addr {
//
// It is safe calling Dial from concurrently running goroutines.
func (ln *InmemoryListener) Dial() (net.Conn, error) {
return ln.DialWithLocalAddr(nil)
}

// DialWithLocalAddr creates new client<->server connection.
// Just like a real Dial it only returns once the server
// has accepted the connection. The local address of the
// client connection can be set with local.
//
// It is safe calling Dial from concurrently running goroutines.
func (ln *InmemoryListener) DialWithLocalAddr(local net.Addr) (net.Conn, error) {
pc := NewPipeConns()

pc.SetAddresses(local, ln.Addr(), ln.Addr(), local)

cConn := pc.Conn1()
sConn := pc.Conn2()
ln.lock.Lock()
Expand Down
93 changes: 93 additions & 0 deletions fasthttputil/inmemory_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,96 @@ func TestInmemoryListenerHTTPConcurrent(t *testing.T) {
wg.Wait()
})
}

func acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
panic(err)
}

conn.Close()
}
}

func TestInmemoryListenerAddrDefault(t *testing.T) {
t.Parallel()

ln := NewInmemoryListener()

verifyAddr(t, ln.Addr(), inmemoryAddr(0))

go func() {
c, err := ln.Dial()
if err != nil {
panic(err)
}

c.Close()
}()

lc, err := ln.Accept()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

verifyAddr(t, lc.LocalAddr(), inmemoryAddr(0))
verifyAddr(t, lc.RemoteAddr(), pipeAddr(0))

go acceptLoop(ln)

c, err := ln.Dial()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

verifyAddr(t, c.LocalAddr(), pipeAddr(0))
verifyAddr(t, c.RemoteAddr(), inmemoryAddr(0))
}

func verifyAddr(t *testing.T, got, expected net.Addr) {
if got != expected {
t.Fatalf("unexpected addr: %v. Expecting %v", got, expected)
}
}

func TestInmemoryListenerAddrCustom(t *testing.T) {
t.Parallel()

ln := NewInmemoryListener()

listenerAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}

ln.SetLocalAddr(listenerAddr)

verifyAddr(t, ln.Addr(), listenerAddr)

go func() {
c, err := ln.Dial()
if err != nil {
panic(err)
}

c.Close()
}()

lc, err := ln.Accept()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

verifyAddr(t, lc.LocalAddr(), listenerAddr)
verifyAddr(t, lc.RemoteAddr(), pipeAddr(0))

go acceptLoop(ln)

clientAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 2), Port: 65432}

c, err := ln.DialWithLocalAddr(clientAddr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

verifyAddr(t, c.LocalAddr(), clientAddr)
verifyAddr(t, c.RemoteAddr(), listenerAddr)
}
33 changes: 33 additions & 0 deletions fasthttputil/pipeconns.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ type PipeConns struct {
stopChLock sync.Mutex
}

// SetAddresses sets the local and remote addresses for the connection.
func (pc *PipeConns) SetAddresses(localAddr1, remoteAddr1, localAddr2, remoteAddr2 net.Addr) {
pc.c1.addrLock.Lock()
defer pc.c1.addrLock.Unlock()

pc.c2.addrLock.Lock()
defer pc.c2.addrLock.Unlock()

pc.c1.localAddr = localAddr1
pc.c1.remoteAddr = remoteAddr1

pc.c2.localAddr = localAddr2
pc.c2.remoteAddr = remoteAddr2
}

// Conn1 returns the first end of bi-directional pipe.
//
// Data written to Conn1 may be read from Conn2.
Expand Down Expand Up @@ -92,6 +107,10 @@ type pipeConn struct {
writeDeadlineCh <-chan time.Time

readDeadlineChLock sync.Mutex

localAddr net.Addr
remoteAddr net.Addr
addrLock sync.RWMutex
}

func (c *pipeConn) Write(p []byte) (int, error) {
Expand Down Expand Up @@ -224,10 +243,24 @@ func (c *pipeConn) Close() error {
}

func (c *pipeConn) LocalAddr() net.Addr {
c.addrLock.RLock()
defer c.addrLock.RUnlock()

if c.localAddr != nil {
return c.localAddr
}

return pipeAddr(0)
}

func (c *pipeConn) RemoteAddr() net.Addr {
c.addrLock.RLock()
defer c.addrLock.RUnlock()

if c.remoteAddr != nil {
return c.remoteAddr
}

return pipeAddr(0)
}

Expand Down
48 changes: 48 additions & 0 deletions fasthttputil/pipeconns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,51 @@ func testConcurrency(t *testing.T, concurrency int, f func(*testing.T)) {
}
}
}

func TestPipeConnsAddrDefault(t *testing.T) {
t.Parallel()

pc := NewPipeConns()
c1 := pc.Conn1()

if c1.LocalAddr() != pipeAddr(0) {
t.Fatalf("unexpected local address: %v", c1.LocalAddr())
}

if c1.RemoteAddr() != pipeAddr(0) {
t.Fatalf("unexpected remote address: %v", c1.RemoteAddr())
}
}

func TestPipeConnsAddrCustom(t *testing.T) {
t.Parallel()

pc := NewPipeConns()

addr1 := &net.TCPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
addr2 := &net.TCPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 5678}
addr3 := &net.TCPAddr{IP: net.IPv4(9, 10, 11, 12), Port: 9012}
addr4 := &net.TCPAddr{IP: net.IPv4(13, 14, 15, 16), Port: 3456}

pc.SetAddresses(addr1, addr2, addr3, addr4)

c1 := pc.Conn1()

if c1.LocalAddr() != addr1 {
t.Fatalf("unexpected local address: %v", c1.LocalAddr())
}

if c1.RemoteAddr() != addr2 {
t.Fatalf("unexpected remote address: %v", c1.RemoteAddr())
}

c2 := pc.Conn1()

if c2.LocalAddr() != addr1 {
t.Fatalf("unexpected local address: %v", c2.LocalAddr())
}

if c2.RemoteAddr() != addr2 {
t.Fatalf("unexpected remote address: %v", c2.RemoteAddr())
}
}

0 comments on commit 951f5a1

Please sign in to comment.