Skip to content

Commit

Permalink
Improved unpacking robustness (not nearly finished)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bouke committed Apr 5, 2017
1 parent 741bd70 commit 52a1c41
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 57 deletions.
121 changes: 78 additions & 43 deletions Sources/DNS/Bytes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,50 @@ enum EncodeError: Swift.Error {
case unicodeEncodingNotSupported
}

enum DecodeError: Swift.Error {
case invalidMessageSize
case invalidOperationCode
case invalidReturnCode
case invalidLabelSize
case invalidLabelOffset
case unicodeDecodingError
case unicodeEncodingNotSupported
case invalidIntegerSize
case invalidResourceRecordType
}

func unpackName(_ data: Data, _ position: inout Data.Index) -> String {
func unpackName(_ data: Data, _ position: inout Data.Index) throws -> String {
var components = [String]()
while true {
let step = data[position]
if step & 0xc0 == 0xc0 {
var pointer = data.index(data.startIndex, offsetBy: Int(UInt16(bytes: data[position..<position+2]) ^ 0xc000))
components += unpackName(data, &pointer).components(separatedBy: ".").filter({ $0 != "" })
guard position + 2 <= data.endIndex else {
throw DecodeError.invalidLabelOffset
}
let offset = Int(UInt16(bytes: data[position..<position+2]) ^ 0xc000)
guard var pointer = data.index(data.startIndex, offsetBy: offset, limitedBy: data.endIndex) else {
throw DecodeError.invalidLabelOffset
}
components += try unpackName(data, &pointer).components(separatedBy: ".").filter({ $0 != "" })
position += 2
break
}

let start = data.index(position, offsetBy: 1)
let end = data.index(start, offsetBy: Int(step))
guard let start = data.index(position, offsetBy: 1, limitedBy: data.endIndex),
let end = data.index(start, offsetBy: Int(step), limitedBy: data.endIndex) else
{
throw DecodeError.invalidLabelSize
}
if step > 0 {
for byte in data[start..<end] {
precondition((0x20..<0xff).contains(byte))
guard (0x20..<0xff).contains(byte) else {
throw DecodeError.unicodeEncodingNotSupported
}
}
components.append(String(bytes: data[start..<end], encoding: .utf8)!)
guard let component = String(bytes: data[start..<end], encoding: .utf8) else {
throw DecodeError.unicodeDecodingError
}
components.append(component)
} else {
position = end
break
Expand Down Expand Up @@ -58,20 +83,20 @@ func packName(_ name: String, onto buffer: inout Data, labels: inout Labels) thr
}
}

func unpack<T: Integer>(_ data: Data, _ position: inout Data.Index) -> T {
func unpack<T: Integer>(_ data: Data, _ position: inout Data.Index) throws -> T {
let size = MemoryLayout<T>.size
defer { position += size }
return T(bytes: data[position..<position+size])
}

typealias RecordCommonFields = (name: String, type: UInt16, unique: Bool, internetClass: UInt16, ttl: UInt32)

func unpackRecordCommonFields(_ data: Data, _ position: inout Data.Index) -> RecordCommonFields {
return (unpackName(data, &position),
unpack(data, &position),
data[position] & 0x80 == 0x80,
unpack(data, &position),
unpack(data, &position))
func unpackRecordCommonFields(_ data: Data, _ position: inout Data.Index) throws -> RecordCommonFields {
return try (unpackName(data, &position),
unpack(data, &position),
data[position] & 0x80 == 0x80,
unpack(data, &position),
unpack(data, &position))
}

func packRecordCommonFields(_ common: RecordCommonFields, onto buffer: inout Data, labels: inout Labels) throws {
Expand All @@ -82,15 +107,15 @@ func packRecordCommonFields(_ common: RecordCommonFields, onto buffer: inout Dat
}


func unpackRecord(_ data: Data, _ position: inout Data.Index) -> ResourceRecord {
let common = unpackRecordCommonFields(data, &position)
func unpackRecord(_ data: Data, _ position: inout Data.Index) throws -> ResourceRecord {
let common = try unpackRecordCommonFields(data, &position)
switch ResourceRecordType(rawValue: common.type) {
case .host?: return HostRecord<IPv4>(unpack: data, position: &position, common: common)
case .host6?: return HostRecord<IPv6>(unpack: data, position: &position, common: common)
case .service?: return ServiceRecord(unpack: data, position: &position, common: common)
case .text?: return TextRecord(unpack: data, position: &position, common: common)
case .pointer?: return PointerRecord(unpack: data, position: &position, common: common)
default: return Record(unpack: data, position: &position, common: common)
case .host?: return try HostRecord<IPv4>(unpack: data, position: &position, common: common)
case .host6?: return try HostRecord<IPv6>(unpack: data, position: &position, common: common)
case .service?: return try ServiceRecord(unpack: data, position: &position, common: common)
case .text?: return try TextRecord(unpack: data, position: &position, common: common)
case .pointer?: return try PointerRecord(unpack: data, position: &position, common: common)
default: return try Record(unpack: data, position: &position, common: common)
}
}

Expand Down Expand Up @@ -142,23 +167,33 @@ extension Message {
}

public init(unpack bytes: Data) throws {
guard bytes.count >= 12 else {
throw DecodeError.invalidMessageSize
}

let flags = UInt16(bytes: bytes[2..<4])
guard let operationCode = OperationCode(rawValue: UInt8(flags >> 11 & 0x7)) else {
throw DecodeError.invalidOperationCode
}
guard let returnCode = ReturnCode(rawValue: UInt8(flags & 0x7)) else {
throw DecodeError.invalidReturnCode
}

header = Header(id: UInt16(bytes: bytes[0..<2]),
response: flags >> 15 & 1 == 1,
operationCode: OperationCode(rawValue: UInt8(flags >> 11 & 0x7))!,
operationCode: operationCode,
authoritativeAnswer: flags >> 10 & 0x1 == 0x1,
truncation: flags >> 9 & 0x1 == 0x1,
recursionDesired: flags >> 8 & 0x1 == 0x1,
recursionAvailable: flags >> 7 & 0x1 == 0x1,
returnCode: ReturnCode(rawValue: UInt8(flags & 0x7))!)
returnCode: returnCode)

var position = bytes.index(bytes.startIndex, offsetBy: 12)

questions = (0..<UInt16(bytes: bytes[4..<6])).map { _ in Question(unpack: bytes, position: &position) }
answers = (0..<UInt16(bytes: bytes[6..<8])).map { _ in unpackRecord(bytes, &position) }
authorities = (0..<UInt16(bytes: bytes[8..<10])).map { _ in unpackRecord(bytes, &position) }
additional = (0..<UInt16(bytes: bytes[10..<12])).map { _ in unpackRecord(bytes, &position) }
questions = try (0..<UInt16(bytes: bytes[4..<6])).map { _ in try Question(unpack: bytes, position: &position) }
answers = try (0..<UInt16(bytes: bytes[6..<8])).map { _ in try unpackRecord(bytes, &position) }
authorities = try (0..<UInt16(bytes: bytes[8..<10])).map { _ in try unpackRecord(bytes, &position) }
additional = try (0..<UInt16(bytes: bytes[10..<12])).map { _ in try unpackRecord(bytes, &position) }
}

func tcp() throws -> Data {
Expand All @@ -168,9 +203,9 @@ extension Message {
}

extension Record: ResourceRecord {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) throws {
(name, type, unique, internetClass, ttl) = common
let size = Int(unpack(data, &position) as UInt16)
let size = Int(try unpack(data, &position) as UInt16)
self.data = Data(data[position..<position+size])
position += size
}
Expand All @@ -182,9 +217,9 @@ extension Record: ResourceRecord {
}

extension HostRecord: ResourceRecord {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) throws {
(name, _, unique, internetClass, ttl) = common
let size = Int(unpack(data, &position) as UInt16)
let size = Int(try unpack(data, &position) as UInt16)
ip = IPType(networkBytes: Data(data[position..<position+size]))!
position += size
}
Expand All @@ -208,14 +243,14 @@ extension HostRecord: ResourceRecord {
}

extension ServiceRecord: ResourceRecord {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) throws {
(name, _, unique, internetClass, ttl) = common
let length = unpack(data, &position) as UInt16
let length = try unpack(data, &position) as UInt16
let expectedPosition = position + Data.Index(length)
priority = unpack(data, &position)
weight = unpack(data, &position)
port = unpack(data, &position)
server = unpackName(data, &position)
priority = try unpack(data, &position)
weight = try unpack(data, &position)
port = try unpack(data, &position)
server = try unpackName(data, &position)
precondition(position == expectedPosition, "Unexpected length")
}

Expand All @@ -235,14 +270,14 @@ extension ServiceRecord: ResourceRecord {
}

extension TextRecord: ResourceRecord {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) throws {
(name, _, unique, internetClass, ttl) = common
let endIndex = Int(unpack(data, &position) as UInt16) + position
let endIndex = Int(try UInt16(data: data, position: &position)) + position

var attrs = [String: String]()
var other = [String]()
while position < endIndex {
let size = Int(unpack(data, &position) as UInt8)
let size = Int(try UInt8(data: data, position: &position))
guard size > 0 else { break }
var attr = String(bytes: data[position..<position+size], encoding: .utf8)!.characters.split(separator: "=", maxSplits: 1, omittingEmptySubsequences: false).map { String($0) }
if attr.count == 2 {
Expand All @@ -268,10 +303,10 @@ extension TextRecord: ResourceRecord {
}

extension PointerRecord: ResourceRecord {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) {
init(unpack data: Data, position: inout Data.Index, common: RecordCommonFields) throws {
(name, _, unique, internetClass, ttl) = common
position += 2
destination = unpackName(data, &position)
destination = try unpackName(data, &position)
}

public func pack(onto buffer: inout Data, labels: inout Labels) throws {
Expand Down
13 changes: 13 additions & 0 deletions Sources/DNS/Integer+Data.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ extension Integer {
init<S: Sequence>(bytes: S) where S.Iterator.Element == UInt8 {
self.init(bytes: Array(bytes))
}

init(data: Data, position: inout Data.Index) throws {
let start = position
guard data.formIndex(&position, offsetBy: MemoryLayout<Self>.size, limitedBy: data.endIndex) else {
throw DecodeError.invalidIntegerSize
}
let bytes = Array(data[start..<position].reversed())
self = bytes.withUnsafeBufferPointer() {
$0.baseAddress!.withMemoryRebound(to: Self.self, capacity: 1) {
return $0.pointee
}
}
}
}

extension UnsignedInteger {
Expand Down
14 changes: 8 additions & 6 deletions Sources/DNS/Types.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ public struct Question {
self.internetClass = internetClass
}

init(unpack data: Data, position: inout Data.Index) {
name = unpackName(data, &position)
type = ResourceRecordType(rawValue: UInt16(bytes: data[position..<position+2]))!
unique = data[position+2] & 0x80 == 0x80
internetClass = UInt16(bytes: data[position+2..<position+4]) & 0x7fff
position += 4
init(unpack data: Data, position: inout Data.Index) throws {
name = try unpackName(data, &position)
guard let recordType = ResourceRecordType(rawValue: try UInt16(data: data, position: &position)) else {
throw DecodeError.invalidResourceRecordType
}
type = recordType
unique = data[position] & 0x80 == 0x80
internetClass = try UInt16(data: data, position: &position) & 0x7fff
}
}

Expand Down
16 changes: 8 additions & 8 deletions Tests/DNSTests/DNSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class DNSTests: XCTestCase {
XCTAssertEqual(packed0.hex, packed1.hex)

var position = packed0.startIndex
let rcf = unpackRecordCommonFields(packed0, &position)
let pointer0copy = PointerRecord(unpack: packed0, position: &position, common: rcf)
let rcf = try! unpackRecordCommonFields(packed0, &position)
let pointer0copy = try! PointerRecord(unpack: packed0, position: &position, common: rcf)

XCTAssertEqual(pointer0, pointer0copy)
}
Expand All @@ -40,7 +40,7 @@ class DNSTests: XCTestCase {
let message0 = Message(header: Header(id: 4529, response: true, operationCode: .query, authoritativeAnswer: false, truncation: false, recursionDesired: false, recursionAvailable: false, returnCode: .NXDOMAIN))
let packed0 = try! message0.pack()
XCTAssertEqual(packed0.hex, "11b180030000000000000000")
let message1 = Message(unpack: packed0)
let message1 = try! Message(unpack: packed0)
let packed1 = try! message1.pack()
XCTAssertEqual(packed0.hex, packed1.hex)
}
Expand All @@ -49,7 +49,7 @@ class DNSTests: XCTestCase {
let message0 = Message(header: Header(id: 18765, response: true, operationCode: .query, authoritativeAnswer: true, truncation: true, recursionDesired: true, recursionAvailable: true, returnCode: .NOERROR))
let packed0 = try! message0.pack()
XCTAssertEqual(packed0.hex, "494d87800000000000000000")
let message1 = Message(unpack: packed0)
let message1 = try! Message(unpack: packed0)
let packed1 = try! message1.pack()
XCTAssertEqual(packed0.hex, packed1.hex)
}
Expand All @@ -58,7 +58,7 @@ class DNSTests: XCTestCase {
let message0 = Message(header: Header(response: false),
questions: [Question(name: "_airplay._tcp._local", type: .pointer)])
let packed0 = try! message0.pack()
let message1 = Message(unpack: packed0)
let message1 = try! Message(unpack: packed0)
let packed1 = try! message1.pack()
XCTAssertEqual(packed0.hex, packed1.hex)
}
Expand All @@ -70,7 +70,7 @@ class DNSTests: XCTestCase {
questions: [Question(name: service, type: .pointer)],
answers: [PointerRecord(name: service, ttl: 120, destination: name)])
let packed0 = try! message0.pack()
let message1 = Message(unpack: packed0)
let message1 = try! Message(unpack: packed0)
let packed1 = try! message1.pack()
XCTAssertEqual(packed0.hex, packed1.hex)
}
Expand All @@ -86,7 +86,7 @@ class DNSTests: XCTestCase {
additional: [HostRecord<IPv4>(name: server, ttl: 120, ip: IPv4("10.0.1.2")!),
TextRecord(name: service, ttl: 120, attributes: ["hello": "world"])])
let packed0 = try! message0.pack()
let message1 = Message(unpack: packed0)
let message1 = try! Message(unpack: packed0)
let packed1 = try! message1.pack()
XCTAssertEqual(packed0.hex, packed1.hex)
}
Expand All @@ -95,7 +95,7 @@ class DNSTests: XCTestCase {
// This is part of a record. The name can be found by following two pointers indicated by 0xc000 (mask).
let data = Data(hex: "000084000000000200000006075a6974686f656b0c5f6465766963652d696e666f045f746370056c6f63616c000010000100001194000d0c6d6f64656c3d4a3432644150085f616972706c6179c021000c000100001194000a075a6974686f656bc044")!
var position = 89
let name = unpackName(data, &position)
let name = try! unpackName(data, &position)
XCTAssertEqual(name, "Zithoek._airplay._tcp.local.")
XCTAssertEqual(position, 99)
}
Expand Down

0 comments on commit 52a1c41

Please sign in to comment.