Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
kevindiu authored and actions-user committed Jul 8, 2020
1 parent 37b66db commit 1572b71
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions internal/net/tcp/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package tcp
import (
"context"
"crypto/tls"
"errors"
"strings"
"time"

Expand All @@ -29,6 +30,7 @@ import (
"github.com/vdaas/vald/internal/safety"
)

// Dialer is an interface to get the dialer instance to connect to an address.
type Dialer interface {
GetDialer() func(ctx context.Context, network, addr string) (net.Conn, error)
StartDialerCache(ctx context.Context)
Expand Down Expand Up @@ -56,6 +58,10 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) {
opt(d)
}

if d.dnsRefreshDuration > d.dnsCacheExpiration {
return nil, errors.New("dnsRefreshDuration > dnsCacheExpiration")
}

d.der = &net.Dialer{
Timeout: d.dialerTimeout,
KeepAlive: d.dialerKeepAlive,
Expand Down Expand Up @@ -83,13 +89,6 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) {
return d, nil
}

if d.dnsRefreshDuration > d.dnsCacheExpiration {
d.dnsRefreshDuration, d.dnsCacheExpiration =
d.dnsCacheExpiration, d.dnsRefreshDuration
d.dnsRefreshDurationStr, d.dnsCacheExpirationStr =
d.dnsCacheExpirationStr, d.dnsRefreshDurationStr
}

if d.cache == nil {
d.cache, err = cache.New(
cache.WithExpireDuration(d.dnsCacheExpirationStr),
Expand Down Expand Up @@ -118,26 +117,24 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) {
return d, nil
}

func (d *dialer) GetDialer() func(ctx context.Context,
network, addr string) (net.Conn, error) {
func (d *dialer) GetDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dialer
}

func (d *dialer) lookup(ctx context.Context,
addr string) (ips map[int]string, err error) {
func (d *dialer) lookup(ctx context.Context, addr string) (ips []string, err error) {
cache, ok := d.cache.Get(addr)
if ok {
return cache.(map[int]string), nil
return cache.([]string), nil
}

r, err := d.der.Resolver.LookupIPAddr(ctx, addr)
if err != nil {
return nil, err
}

ips = make(map[int]string, len(r))
for i, ip := range r {
ips[i] = ip.String()
ips = make([]string, 0, len(r)+2)
for _, ip := range r {
ips = append(ips, ip.String())
}

d.cache.Set(addr, ips)
Expand Down

0 comments on commit 1572b71

Please sign in to comment.