Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Read from DNS cache if within TTL #677

Merged
merged 4 commits into from
Apr 8, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions synapse/http/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import collections
import logging
import random
import time


logger = logging.getLogger(__name__)
Expand All @@ -31,7 +32,7 @@


_Server = collections.namedtuple(
"_Server", "priority weight host port"
"_Server", "priority weight host port expires"
)


Expand Down Expand Up @@ -92,7 +93,8 @@ def __init__(self, reactor, service, domain, protocol="tcp",
host=domain,
port=default_port,
priority=0,
weight=0
weight=0,
expires=0,
)
else:
self.default_server = None
Expand Down Expand Up @@ -153,7 +155,13 @@ def connect(self, protocolFactory):


@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)

servers = []

try:
Expand All @@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
continue

payload = answer.payload

host = str(payload.target)
srv_ttl = answer.ttl

try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue

ips = [
answer.payload.dottedQuad()
for answer in answers
if answer.type == dns.A and answer.payload
]

for ip in ips:
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
for answer in answers:
if answer.type == dns.A and answer.payload:
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)

servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + host_ttl,
))

servers.sort()
cache[service_name] = list(servers)
Expand Down
34 changes: 32 additions & 2 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from synapse.http.endpoint import resolve_service

from tests.utils import MockClock


class DnsTestCase(unittest.TestCase):

Expand Down Expand Up @@ -63,14 +65,17 @@ def test_resolve(self):
self.assertEquals(servers[0].host, ip_address)

@defer.inlineCallbacks
def test_from_cache(self):
def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())

service_name = "test_service.examle.com"

entry = Mock(spec_set=["expires"])
entry.expires = 0

cache = {
service_name: [object()]
service_name: [entry]
}

servers = yield resolve_service(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that resolve_service takes a clock, it might be worth using a mock clock here?
Also might be nice to have a test that checks that things expire.

Expand All @@ -82,6 +87,31 @@ def test_from_cache(self):
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])

@defer.inlineCallbacks
def test_from_cache(self):
clock = MockClock()

dns_client_mock = Mock(spec_set=['lookupService'])
dns_client_mock.lookupService = Mock(spec_set=[])

service_name = "test_service.examle.com"

entry = Mock(spec_set=["expires"])
entry.expires = 999999999

cache = {
service_name: [entry]
}

servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
)

self.assertFalse(dns_client_mock.lookupService.called)

self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])

@defer.inlineCallbacks
def test_empty_cache(self):
dns_client_mock = Mock()
Expand Down