Skip to content

Commit

Permalink
Ability to pass custom DNS resolver to TCPDialer (#689)
Browse files Browse the repository at this point in the history
* Ability to pass custom DNS resolver to TCPdialer

* Update tcpdialer.go

Co-Authored-By: Erik Dubbelboer <erik@dubbelboer.com>
  • Loading branch information
enchantner and erikdubbelboer committed Nov 10, 2019
1 parent 70223a1 commit 3fb2eba
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions tcpdialer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fasthttp

import (
"context"
"errors"
"net"
"strconv"
Expand Down Expand Up @@ -129,6 +130,19 @@ type TCPDialer struct {
// Changes made after the first Dial will not affect anything.
Concurrency int

// This may be used to override DNS resolving policy, like this:
// var dialer = &fasthttp.TCPDialer{
// Resolver: &net.Resolver{
// PreferGo: true,
// StrictErrors: false,
// Dial: func (ctx context.Context, network, address string) (net.Conn, error) {
// d := net.Dialer{}
// return d.DialContext(ctx, "udp", "8.8.8.8:53")
// },
// },
// }
Resolver *net.Resolver

tcpAddrsLock sync.Mutex
tcpAddrsMap map[string]*tcpAddrEntry

Expand Down Expand Up @@ -387,7 +401,7 @@ func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uin
d.tcpAddrsLock.Unlock()

if e == nil {
addrs, err := resolveTCPAddrs(addr, dualStack)
addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver)
if err != nil {
d.tcpAddrsLock.Lock()
e = d.tcpAddrsMap[addr]
Expand All @@ -412,7 +426,7 @@ func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uin
return e.addrs, idx, nil
}

func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) {
func resolveTCPAddrs(addr string, dualStack bool, resolver *net.Resolver) ([]net.TCPAddr, error) {
host, portS, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
Expand All @@ -422,20 +436,25 @@ func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) {
return nil, err
}

ips, err := net.LookupIP(host)
if resolver == nil {
resolver = net.DefaultResolver
}

ctx := context.Background()
ipaddrs, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}

n := len(ips)
n := len(ipaddrs)
addrs := make([]net.TCPAddr, 0, n)
for i := 0; i < n; i++ {
ip := ips[i]
if !dualStack && ip.To4() == nil {
ip := ipaddrs[i]
if !dualStack && ip.IP.To4() == nil {
continue
}
addrs = append(addrs, net.TCPAddr{
IP: ip,
IP: ip.IP,
Port: port,
})
}
Expand Down

0 comments on commit 3fb2eba

Please sign in to comment.