Skip to content

Commit

Permalink
Clarify ceiling, add tests, fix doc, fix python3
Browse files Browse the repository at this point in the history
- bad class doc
- explicit ceil
- python3 encoding issue
- improve edns testing
  • Loading branch information
kalou committed Nov 8, 2016
1 parent 349ddca commit 2613b44
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
20 changes: 13 additions & 7 deletions dns/edns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

from __future__ import absolute_import

import math
import struct
import sys

import dns.inet

Expand Down Expand Up @@ -145,9 +147,10 @@ def __init__(self, address, srclen=None, scopelen=0):
@ivar srclen: prefix length, leftmost number of bits of the address
to be used for the lookup. Sent by client, mirrored by server in
responses. If not provided at init, will use /24 for v4 and /56 for v6
@ivar srclen: int
@type srclen: int
@ivar scopelen: prefix length, leftmost number of bits of the address
that the response covers. 0 in queries, set by server.
@type scopelen: int
"""
super(ECSOption, self).__init__(ECS)
af = dns.inet.af_for_address(address)
Expand All @@ -167,13 +170,16 @@ def __init__(self, address, srclen=None, scopelen=0):
self.scopelen = scopelen
self.address = address

self.addrdata = dns.inet.inet_pton(af, address)
addrdata = dns.inet.inet_pton(af, address)
nbytes = int(math.ceil(srclen/8.0))

# Truncate to srclen and pad to the end of the last octet needed
# See RFC section 6
self.addrdata = self.addrdata[:-(-srclen//8)]
last = ord(self.addrdata[-1:]) & (0xff << srclen % 8)
self.addrdata = self.addrdata[:-1] + chr(last).encode('latin1')
self.addrdata = addrdata[:nbytes]
last = chr(ord(self.addrdata[-1:]) & (0xff << srclen % 8))
if sys.version_info >= (3,):
last = last.encode('latin1')
self.addrdata = self.addrdata[:-1] + last

def to_text(self):
return "ECS %s/%s scope/%s" % (self.address, self.srclen,
Expand All @@ -191,7 +197,7 @@ def from_wire(cls, otype, wire, cur, olen):
family, src, scope = struct.unpack('!HBB', wire[cur:cur+4])
cur += 4

addrlen = -(-src//8)
addrlen = int(math.ceil(src/8.0))

if family == 1:
af = dns.inet.AF_INET
Expand All @@ -202,7 +208,7 @@ def from_wire(cls, otype, wire, cur, olen):
else:
raise ValueError('unsupported family')

addr = dns.inet.inet_ntop(af, wire[cur:cur+addrlen] + '\x00' * pad)
addr = dns.inet.inet_ntop(af, wire[cur:cur+addrlen] + b'\x00' * pad)
return cls(addr, src, scope)

def _cmp(self, other):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_option.py → tests/test_edns.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def testGenericOption(self):
data = io.getvalue()
self.assertEqual(data, b'data')

def testECSOption_prefix_length(self):
opt = dns.edns.ECSOption('1.2.255.33', 20)
io = BytesIO()
opt.to_wire(io)
data = io.getvalue()
self.assertEqual(data, b'\x00\x01\x14\x00\x01\x02\xf0')

def testECSOption_from_wire(self):
opt = dns.edns.option_from_wire(8, b'\x00\x01\x14\x00\x01\x02\xf0',
0, 7)
self.assertEqual(opt.otype, dns.edns.ECS)
self.assertEqual(opt.address, b'1.2.240.0')
self.assertEqual(opt.srclen, 20)
self.assertEqual(opt.scopelen, 0)

def testECSOption(self):
opt = dns.edns.ECSOption('1.2.3.4', 24)
io = BytesIO()
Expand Down

0 comments on commit 2613b44

Please sign in to comment.