Skip to content

Commit

Permalink
Merge pull request #423 from lesismal/async_dialer
Browse files Browse the repository at this point in the history
Async dialer
  • Loading branch information
lesismal authored Apr 29, 2024
2 parents 3059984 + 13653bb commit 9aa509c
Show file tree
Hide file tree
Showing 16 changed files with 613 additions and 49 deletions.
9 changes: 5 additions & 4 deletions conn_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ var (
func (c *Conn) newToWriteBuf(buf []byte) {
c.left += len(buf)

allocator := c.p.g.BodyAllocator
appendBuffer := func() {
t := poolToWrite.New().(*toWrite)
b := c.p.g.BodyAllocator.Malloc(len(buf))
b := allocator.Malloc(len(buf))
copy(b, buf)
t.buf = b
c.writeList = append(c.writeList, t)
Expand All @@ -55,12 +56,12 @@ func (c *Conn) newToWriteBuf(buf []byte) {
appendBuffer()
} else {
if cap(tail.buf) < tailLen+l {
b := c.p.g.BodyAllocator.Malloc(tailLen + l)[:tailLen]
b := allocator.Malloc(tailLen + l)[:tailLen]
copy(b, tail.buf)
c.p.g.BodyAllocator.Free(tail.buf)
allocator.Free(tail.buf)
tail.buf = b
}
tail.buf = append(tail.buf, buf...)
tail.buf = allocator.Append(tail.buf, buf...)
}
}
}
Expand Down
26 changes: 25 additions & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ const (
DefaultUDPReadTimeout = 120 * time.Second
)

const (
NETWORK_TCP = "tcp"
NETWORK_TCP4 = "tcp4"
NETWORK_TCP6 = "tcp6"
NETWORK_UDP = "udp"
NETWORK_UDP4 = "udp4"
NETWORK_UDP6 = "udp6"
NETWORK_UNIX = "unix"
NETWORK_UNIXGRAM = "unixgram"
NETWORK_UNIXPACKET = "unixpacket"
)

var (
// MaxOpenFiles .
MaxOpenFiles = 1024 * 1024 * 2
Expand Down Expand Up @@ -249,7 +261,19 @@ func (g *Engine) AddConn(conn net.Conn) (*Conn, error) {
}

p := g.pollers[c.Hash()%len(g.pollers)]
p.addConn(c)
err = p.addConn(c)
if err != nil {
return nil, err
}
return c, nil
}

func (g *Engine) addDialer(c *Conn) (*Conn, error) {
p := g.pollers[c.Hash()%len(g.pollers)]
err := p.addDialer(c)
if err != nil {
return nil, err
}
return c, nil
}

Expand Down
41 changes: 39 additions & 2 deletions engine_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"runtime"
"strings"
"time"

"github.com/lesismal/nbio/logging"
"github.com/lesismal/nbio/mempool"
Expand All @@ -22,7 +23,7 @@ func (g *Engine) Start() error {
// Create listener pollers.
udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0]
switch g.Network {
case "tcp", "tcp4", "tcp6":
case NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6:
for i := range g.Addrs {
ln, err := newPoller(g, true, i)
if err != nil {
Expand All @@ -34,7 +35,7 @@ func (g *Engine) Start() error {
g.Addrs[i] = ln.listener.Addr().String()
g.listeners = append(g.listeners, ln)
}
case "udp", "udp4", "udp6":
case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6:
for i, addrStr := range g.Addrs {
addr, err := net.ResolveUDPAddr(g.Network, addrStr)
if err != nil {
Expand Down Expand Up @@ -165,3 +166,39 @@ func NewEngine(conf Config) *Engine {

return g
}

// DialAsync connects asynchrony to the address on the named network.
func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error {
return engine.DialAsyncTimeout(network, addr, 0, onConnected)
}

// DialAsync connects asynchrony to the address on the named network with timeout.
func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error {
go func() {
var err error
var conn net.Conn
if timeout > 0 {
conn, err = net.DialTimeout(network, addr, timeout)
} else {
conn, err = net.Dial(network, addr)
}
if err != nil {
onConnected(nil, err)
return
}
nbc, err := NBConn(conn)
if err != nil {
onConnected(nil, err)
return
}
engine.wgConn.Add(1)
nbc, err = engine.addDialer(nbc)
if err == nil {
nbc.SetWriteDeadline(time.Time{})
} else {
engine.wgConn.Done()
}
onConnected(nbc, err)
}()
return nil
}
118 changes: 116 additions & 2 deletions engine_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
package nbio

import (
"errors"
"net"
"runtime"
"strings"
"syscall"
"time"

"github.com/lesismal/nbio/logging"
"github.com/lesismal/nbio/mempool"
Expand All @@ -28,7 +31,7 @@ func (g *Engine) Start() error {
udpListeners := make([]*net.UDPConn, len(g.Addrs))[0:0]

switch g.Network {
case "unix", "tcp", "tcp4", "tcp6":
case NETWORK_UNIX, NETWORK_TCP, NETWORK_TCP4, NETWORK_TCP6:
for i := range g.Addrs {
ln, err := newPoller(g, true, i)
if err != nil {
Expand All @@ -40,7 +43,7 @@ func (g *Engine) Start() error {
g.Addrs[i] = ln.listener.Addr().String()
g.listeners = append(g.listeners, ln)
}
case "udp", "udp4", "udp6":
case NETWORK_UDP, NETWORK_UDP4, NETWORK_UDP6:
for i, addrStr := range g.Addrs {
addr, err := net.ResolveUDPAddr(g.Network, addrStr)
if err != nil {
Expand Down Expand Up @@ -139,6 +142,117 @@ func (g *Engine) Start() error {
return nil
}

// DialAsync connects asynchrony to the address on the named network.
func (engine *Engine) DialAsync(network, addr string, onConnected func(*Conn, error)) error {
return engine.DialAsyncTimeout(network, addr, 0, onConnected)
}

// DialAsync connects asynchrony to the address on the named network with timeout.
func (engine *Engine) DialAsyncTimeout(network, addr string, timeout time.Duration, onConnected func(*Conn, error)) error {
h := func(c *Conn, err error) {
if err == nil {
c.SetWriteDeadline(time.Time{})
}
onConnected(c, err)
}
domain, typ, dialaddr, raddr, connType, err := parseDomainAndType(network, addr)
if err != nil {
return err
}
fd, err := syscall.Socket(domain, typ, 0)
if err != nil {
return err
}
err = syscall.SetNonblock(fd, true)
if err != nil {
syscall.Close(fd)
return err
}
err = syscall.Connect(fd, dialaddr)
inprogress := false
if err != nil {
if errors.Is(err, syscall.EINPROGRESS) {
inprogress = true
} else {
syscall.Close(fd)
return err
}
}
sa, _ := syscall.Getsockname(fd)
c := &Conn{
fd: fd,
rAddr: raddr,
typ: connType,
}
if inprogress {
c.onConnected = h
}
switch vt := sa.(type) {
case *syscall.SockaddrInet4:
switch connType {
case ConnTypeTCP:
c.lAddr = &net.TCPAddr{
IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]},
Port: vt.Port,
}
case ConnTypeUDPClientFromDial:
c.lAddr = &net.TCPAddr{
IP: []byte{vt.Addr[0], vt.Addr[1], vt.Addr[2], vt.Addr[3]},
Port: vt.Port,
}
c.connUDP = &udpConn{
parent: c,
}
}
case *syscall.SockaddrInet6:
var iface *net.Interface
iface, err = net.InterfaceByIndex(int(vt.ZoneId))
if err != nil {
syscall.Close(fd)
return err
}
switch connType {
case ConnTypeTCP:
c.lAddr = &net.TCPAddr{
IP: make([]byte, len(vt.Addr)),
Port: vt.Port,
Zone: iface.Name,
}
case ConnTypeUDPClientFromDial:
c.lAddr = &net.UDPAddr{
IP: make([]byte, len(vt.Addr)),
Port: vt.Port,
Zone: iface.Name,
}
c.connUDP = &udpConn{
parent: c,
}
}
case *syscall.SockaddrUnix:
c.lAddr = &net.UnixAddr{
Net: network,
Name: vt.Name,
}
}

engine.wgConn.Add(1)
_, err = engine.addDialer(c)
if err != nil {
engine.wgConn.Done()
return err
}

if !inprogress {
engine.Async(func() {
h(c, nil)
})
} else if timeout > 0 {
c.setDeadline(&c.wTimer, ErrDialTimeout, time.Now().Add(timeout))
}

return nil
}

// NewEngine creates an Engine and init default configurations.
func NewEngine(conf Config) *Engine {
if conf.Name == "" {
Expand Down
2 changes: 2 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ var (
ErrOverflow = errors.New("write overflow")
errOverflow = ErrOverflow

ErrDialTimeout = errors.New("dial timeout")

ErrUnsupported = errors.New("unsupported operation")
)
Loading

0 comments on commit 9aa509c

Please sign in to comment.