diff --git a/internal/net/tcp/dialer.go b/internal/net/tcp/dialer.go index 150e84b490..3e87143b12 100644 --- a/internal/net/tcp/dialer.go +++ b/internal/net/tcp/dialer.go @@ -20,6 +20,7 @@ package tcp import ( "context" "crypto/tls" + "errors" "strings" "time" @@ -29,6 +30,7 @@ import ( "github.com/vdaas/vald/internal/safety" ) +// Dialer is an interface to get the dialer instance to connect to an address. type Dialer interface { GetDialer() func(ctx context.Context, network, addr string) (net.Conn, error) StartDialerCache(ctx context.Context) @@ -56,6 +58,10 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) { opt(d) } + if d.dnsRefreshDuration > d.dnsCacheExpiration { + return nil, errors.New("dnsRefreshDuration > dnsCacheExpiration") + } + d.der = &net.Dialer{ Timeout: d.dialerTimeout, KeepAlive: d.dialerKeepAlive, @@ -83,13 +89,6 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) { return d, nil } - if d.dnsRefreshDuration > d.dnsCacheExpiration { - d.dnsRefreshDuration, d.dnsCacheExpiration = - d.dnsCacheExpiration, d.dnsRefreshDuration - d.dnsRefreshDurationStr, d.dnsCacheExpirationStr = - d.dnsCacheExpirationStr, d.dnsRefreshDurationStr - } - if d.cache == nil { d.cache, err = cache.New( cache.WithExpireDuration(d.dnsCacheExpirationStr), @@ -118,16 +117,14 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) { return d, nil } -func (d *dialer) GetDialer() func(ctx context.Context, - network, addr string) (net.Conn, error) { +func (d *dialer) GetDialer() func(ctx context.Context, network, addr string) (net.Conn, error) { return d.dialer } -func (d *dialer) lookup(ctx context.Context, - addr string) (ips map[int]string, err error) { +func (d *dialer) lookup(ctx context.Context, addr string) (ips []string, err error) { cache, ok := d.cache.Get(addr) if ok { - return cache.(map[int]string), nil + return cache.([]string), nil } r, err := d.der.Resolver.LookupIPAddr(ctx, addr) @@ -135,9 +132,9 @@ func (d *dialer) lookup(ctx context.Context, return nil, err } - ips = make(map[int]string, len(r)) - for i, ip := range r { - ips[i] = ip.String() + ips = make([]string, 0, len(r)+2) + for _, ip := range r { + ips = append(ips, ip.String()) } d.cache.Set(addr, ips)