Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove UDP connection caching in embedded DNS server #1352

Merged
merged 1 commit into from
Jul 25, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,6 @@ func (ep *endpoint) sbJoin(sb *sandbox, options ...EndpointOption) error {
}
}

if sb.resolver != nil {
sb.resolver.FlushExtServers()
}
}

if !sb.needDefaultGW() {
Expand Down Expand Up @@ -619,10 +616,6 @@ func (ep *endpoint) Leave(sbox Sandbox, options ...EndpointOption) error {
sb.joinLeaveStart()
defer sb.joinLeaveEnd()

if sb.resolver != nil {
sb.resolver.FlushExtServers()
}

return ep.sbLeave(sb, false, options...)
}

Expand Down
243 changes: 22 additions & 221 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ type Resolver interface {
// SetExtServers configures the external nameservers the resolver
// should use to forward queries
SetExtServers([]string)
// FlushExtServers clears the cached UDP connections to external
// nameservers
FlushExtServers()
// ResolverOptions returns resolv.conf options that should be set
ResolverOptions() []string
}
Expand All @@ -48,35 +45,12 @@ const (
defaultRespSize = 512
maxConcurrent = 100
logInterval = 2 * time.Second
maxDNSID = 65536
)

type clientConn struct {
dnsID uint16
respWriter dns.ResponseWriter
}

type extDNSEntry struct {
ipStr string
extConn net.Conn
extOnce sync.Once
}

type sboxQuery struct {
sboxID string
dnsID uint16
ipStr string
}

type clientConnGC struct {
toDelete bool
client clientConn
}

var (
queryGCMutex sync.Mutex
queryGC map[sboxQuery]*clientConnGC
)

// resolver implements the Resolver interface
type resolver struct {
sb *sandbox
Expand All @@ -89,34 +63,17 @@ type resolver struct {
count int32
tStamp time.Time
queryLock sync.Mutex
client map[uint16]clientConn
}

func init() {
rand.Seed(time.Now().Unix())
queryGC = make(map[sboxQuery]*clientConnGC)
go func() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
queryGCMutex.Lock()
for query, conn := range queryGC {
if !conn.toDelete {
conn.toDelete = true
continue
}
delete(queryGC, query)
}
queryGCMutex.Unlock()
}
}()
}

// NewResolver creates a new instance of the Resolver
func NewResolver(sb *sandbox) Resolver {
return &resolver{
sb: sb,
err: fmt.Errorf("setup not done yet"),
client: make(map[uint16]clientConn),
sb: sb,
err: fmt.Errorf("setup not done yet"),
}
}

Expand Down Expand Up @@ -173,20 +130,7 @@ func (r *resolver) Start() error {
return nil
}

func (r *resolver) FlushExtServers() {
for i := 0; i < maxExtDNS; i++ {
if r.extDNSList[i].extConn != nil {
r.extDNSList[i].extConn.Close()
}

r.extDNSList[i].extConn = nil
r.extDNSList[i].extOnce = sync.Once{}
}
}

func (r *resolver) Stop() {
r.FlushExtServers()

if r.server != nil {
r.server.Shutdown()
}
Expand Down Expand Up @@ -355,7 +299,6 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
extConn net.Conn
resp *dns.Msg
err error
writer dns.ResponseWriter
)

if query == nil || len(query.Question) == 0 {
Expand Down Expand Up @@ -397,10 +340,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
if resp.Len() > maxSize {
truncateResp(resp, maxSize, proto == "tcp")
}
writer = w
} else {
queryID := query.Id
extQueryLoop:
for i := 0; i < maxExtDNS; i++ {
extDNS := &r.extDNSList[i]
if extDNS.ipStr == "" {
Expand All @@ -411,30 +351,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
}

// For udp clients connection is persisted to reuse for further queries.
// Accessing extDNS.extConn be a race here between go rouines. Hence the
// connection setup is done in a Once block and fetch the extConn again
extConn = extDNS.extConn
if extConn == nil || proto == "tcp" {
if proto == "udp" {
extDNS.extOnce.Do(func() {
r.sb.execFunc(extConnect)
extDNS.extConn = extConn
})
extConn = extDNS.extConn
} else {
r.sb.execFunc(extConnect)
}
if err != nil {
log.Debugf("Connect failed, %s", err)
continue
}
}
// If two go routines are executing in parralel one will
// block on the Once.Do and in case of error connecting
// to the external server it will end up with a nil err
// but extConn also being nil.
if extConn == nil {
r.sb.execFunc(extConnect)
if err != nil {
log.Debugf("Connect failed, %s", err)
continue
}
log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
Expand All @@ -443,10 +362,10 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
// Timeout has to be set for every IO operation.
extConn.SetDeadline(time.Now().Add(extIOTimeout))
co := &dns.Conn{Conn: extConn}
defer co.Close()

// forwardQueryStart stores required context to mux multiple client queries over
// one connection; and limits the number of outstanding concurrent queries.
if r.forwardQueryStart(w, query, queryID) == false {
// limits the number of outstanding concurrent queries.
if r.forwardQueryStart() == false {
old := r.tStamp
r.tStamp = time.Now()
if r.tStamp.Sub(old) > logInterval {
Expand All @@ -455,74 +374,38 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
continue
}

defer func() {
if proto == "tcp" {
co.Close()
}
}()
err = co.WriteMsg(query)
if err != nil {
r.forwardQueryEnd(w, query)
r.forwardQueryEnd()
log.Debugf("Send to DNS server failed, %s", err)
continue
}
for {
// If a reply comes after a read timeout it will remain in the socket buffer
// and will be read after sending next query. To ignore such stale replies
// save the query context in a GC queue when read timesout. On the next reply
// if the context is present in the GC queue its a old reply. Ignore it and
// read again
resp, err = co.ReadMsg()
if err != nil {
// Truncated DNS replies should be sent to the client so that the
// client can retry over TCP
if err == dns.ErrTruncated && resp != nil {
break
}
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
r.addQueryToGC(w, query)
}
r.forwardQueryEnd(w, query)
log.Debugf("Read from DNS server failed, %s", err)
continue extQueryLoop
}

if !r.checkRespInGC(w, resp) {
break
}
}
// Retrieves the context for the forwarded query and returns the client connection
// to send the reply to
writer = r.forwardQueryEnd(w, resp)
if writer == nil {
resp, err = co.ReadMsg()
// Truncated DNS replies should be sent to the client so that the
// client can retry over TCP
if err != nil && err != dns.ErrTruncated {
r.forwardQueryEnd()
log.Debugf("Read from DNS server failed, %s", err)
continue
}

r.forwardQueryEnd()

resp.Compress = true
break
}
if resp == nil || writer == nil {
if resp == nil {
return
}
}

if writer == nil {
return
}
if err = writer.WriteMsg(resp); err != nil {
if err = w.WriteMsg(resp); err != nil {
log.Errorf("error writing resolver resp, %s", err)
}
}

func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID uint16) bool {
proto := w.LocalAddr().Network()
dnsID := uint16(rand.Intn(maxDNSID))

cc := clientConn{
dnsID: queryID,
respWriter: w,
}

func (r *resolver) forwardQueryStart() bool {
r.queryLock.Lock()
defer r.queryLock.Unlock()

Expand All @@ -531,74 +414,10 @@ func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID
}
r.count++

switch proto {
case "tcp":
break
case "udp":
for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
_, ok = r.client[dnsID]
}
log.Debugf("client dns id %v, changed id %v", queryID, dnsID)
r.client[dnsID] = cc
msg.Id = dnsID
default:
log.Errorf("Invalid protocol..")
return false
}

return true
}

func (r *resolver) addQueryToGC(w dns.ResponseWriter, msg *dns.Msg) {
if w.LocalAddr().Network() != "udp" {
return
}

r.queryLock.Lock()
cc, ok := r.client[msg.Id]
r.queryLock.Unlock()
if !ok {
return
}

query := sboxQuery{
sboxID: r.sb.ID(),
dnsID: msg.Id,
}
clientGC := &clientConnGC{
client: cc,
}
queryGCMutex.Lock()
queryGC[query] = clientGC
queryGCMutex.Unlock()
}

func (r *resolver) checkRespInGC(w dns.ResponseWriter, msg *dns.Msg) bool {
if w.LocalAddr().Network() != "udp" {
return false
}

query := sboxQuery{
sboxID: r.sb.ID(),
dnsID: msg.Id,
}

queryGCMutex.Lock()
defer queryGCMutex.Unlock()
if _, ok := queryGC[query]; ok {
delete(queryGC, query)
return true
}
return false
}

func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
var (
cc clientConn
ok bool
)
proto := w.LocalAddr().Network()

func (r *resolver) forwardQueryEnd() {
r.queryLock.Lock()
defer r.queryLock.Unlock()

Expand All @@ -607,22 +426,4 @@ func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.Respo
} else {
r.count--
}

switch proto {
case "tcp":
break
case "udp":
if cc, ok = r.client[msg.Id]; ok == false {
log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
return nil
}
log.Debugf("dns msg id %v, client id %v", msg.Id, cc.dnsID)
delete(r.client, msg.Id)
msg.Id = cc.dnsID
w = cc.respWriter
default:
log.Errorf("Invalid protocol")
return nil
}
return w
}