Skip to content

Commit

Permalink
More consistent handling of big endian values (part 3)
Browse files Browse the repository at this point in the history
  • Loading branch information
mafik committed Dec 8, 2023
1 parent a888d1b commit 644709c
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 91 deletions.
19 changes: 0 additions & 19 deletions src/big_endian.cc

This file was deleted.

17 changes: 10 additions & 7 deletions src/big_endian.hh
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include "int.hh"
#include "vec.hh"

#include <bit>

Expand All @@ -15,21 +14,25 @@ template <> constexpr maf::U24 byteswap(maf::U24 x) noexcept {

namespace maf {

template <typename T> void AppendBigEndian(Vec<> &s, T x);

template <> void AppendBigEndian(Vec<> &s, U16 x);
template <> void AppendBigEndian(Vec<> &s, U24 x);
template <> void AppendBigEndian(Vec<> &s, U32 x);

// A type that can be operated just like any other integral type, but its memory
// representation is big-endian.
template <typename T> struct Big {
T big_endian;

Big() = default;
constexpr Big(T host_value) : big_endian(std::byteswap(host_value)) {}

constexpr static Big<T> FromBig(T big_endian) {
Big<T> ret;
ret.big_endian = big_endian;
return ret;
}

T Get() const { return std::byteswap(big_endian); }
void Set(T host_value) { big_endian = std::byteswap(host_value); }
operator T() const { return Get(); }

auto operator<=>(const Big<T> &other) const { return Get() <=> other.Get(); }
} __attribute__((packed));

static_assert(Big<U16>(0x1122).big_endian == 0x2211);
Expand Down
33 changes: 16 additions & 17 deletions src/dhcp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <sys/socket.h>

#include "arp.hh"
#include "big_endian.hh"
#include "config.hh"
#include "etc.hh"
#include "expirable.hh"
Expand All @@ -24,9 +25,9 @@ using chrono::steady_clock;
namespace dhcp {

const IP kBroadcastIP(255, 255, 255, 255);
const uint16_t kServerPort = 67;
const uint16_t kClientPort = 68;
const uint32_t kMagicCookie = 0x63825363;
const U16 kServerPort = 67;
const U16 kClientPort = 68;
const U32 kMagicCookie = 0x63825363;

namespace options {

Expand Down Expand Up @@ -379,9 +380,9 @@ struct __attribute__((__packed__)) RequestedIPAddress : Base {
};

struct __attribute__((__packed__)) IPAddressLeaseTime : Base {
const uint32_t seconds;
IPAddressLeaseTime(uint32_t seconds)
: Base(OptionCode::IPAddressLeaseTime, 4), seconds(htonl(seconds)) {}
const Big<U32> seconds;
IPAddressLeaseTime(U32 seconds)
: Base(OptionCode::IPAddressLeaseTime, 4), seconds(seconds) {}
Str ToStr() const {
return "IPAddressLeaseTime(" + ::ToStr(ntohl(seconds)) + ")";
}
Expand Down Expand Up @@ -489,18 +490,16 @@ struct __attribute__((__packed__)) ParameterRequestList {

// RFC 2132, section 9.10
struct __attribute__((__packed__)) MaximumDHCPMessageSize {
const uint8_t code = 57;
const uint8_t length = 2;
const uint16_t value = htons(1500);
Str ToStr() const {
return "MaximumDHCPMessageSize(" + ::ToStr(ntohs(value)) + ")";
}
const U8 code = 57;
const U8 length = 2;
const Big<U16> value = 1500;
Str ToStr() const { return "MaximumDHCPMessageSize(" + ::ToStr(value) + ")"; }
};

struct __attribute__((__packed__)) VendorClassIdentifier {
const uint8_t code = 60;
const uint8_t length;
const uint8_t value[0];
const U8 code = 60;
const U8 length;
const U8 value[0];
Str ToStr() const {
return "VendorClassIdentifier(" + Str((const char *)value, length) + ")";
}
Expand Down Expand Up @@ -582,7 +581,7 @@ struct __attribute__((__packed__)) Header {
};
uint8_t server_name[64] = {};
uint8_t boot_filename[128] = {};
uint32_t magic_cookie = htonl(kMagicCookie);
Big<U32> magic_cookie = kMagicCookie;

Str ToStr() const {
string s = "dhcp::Header {\n";
Expand Down Expand Up @@ -875,7 +874,7 @@ void Server::HandleRequest(string_view buf, IP source_ip, uint16_t port) {
ERROR << log_error;
return;
}
if (ntohl(packet.magic_cookie) != kMagicCookie) {
if (packet.magic_cookie != kMagicCookie) {
ERROR << "DHCP server received a packet with an invalid magic cookie: "
<< ValToHex(packet.magic_cookie);
return;
Expand Down
22 changes: 12 additions & 10 deletions src/dns_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <optional>
#include <unordered_set>

#include "big_endian.hh"
#include "chrono.hh"
#include "config.hh"
#include "dns_utils.hh"
Expand All @@ -26,9 +27,10 @@ static constexpr U16 kClientPort = 338;

U16 AllocateRequestId() {
// Randomize initial request ID
static U16 request_id = random<uint16_t>();
static Big<U16> request_id = random<U16>();
// Subsequent request IDs are incremented by 1
return request_id = htons(ntohs(request_id) + 1);
++request_id.big_endian;
return request_id;
}

struct Entry {
Expand Down Expand Up @@ -225,10 +227,10 @@ void LookupBase::Start(Str domain, U16 type) {
.reply = true,
.response_code = cached->response_code,
.recursion_available = true,
.question_count = htons(1),
.answer_count = htons(cached->answers.size()),
.authority_count = htons(cached->authority.size()),
.additional_count = htons(cached->additional.size()),
.question_count = 1,
.answer_count = cached->answers.size(),
.authority_count = cached->authority.size(),
.additional_count = cached->additional.size(),
},
.questions = {question},
.answers = cached->answers,
Expand Down Expand Up @@ -380,8 +382,8 @@ Client client;
PendingEntry::PendingEntry(Question question, U16 id, LookupBase *lookup)
: Expirable(kPendingTTL), Entry(question), id(id), in_progress({lookup}) {
string buffer;
Header{.id = id, .recursion_desired = true, .question_count = htons(1)}
.write_to(buffer);
Header{.id = id, .recursion_desired = true, .question_count = 1}.write_to(
buffer);
question.write_to(buffer);
IP upstream_ip =
etc::resolv[(++server_i) % etc::resolv.size()]; // Round-robin
Expand All @@ -400,8 +402,8 @@ void InjectAuthoritativeEntry(const Str &domain, IP ip) {
.reply = true,
.response_code = ResponseCode::NO_ERROR,
.recursion_available = true,
.question_count = htons(1),
.answer_count = htons(1)},
.question_count = 1,
.answer_count = 1},
.questions = {Question{.domain_name = domain}},
.answers = {Record{Question{.domain_name = domain},
std::chrono::steady_clock::time_point::max(),
Expand Down
16 changes: 8 additions & 8 deletions src/dns_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ struct Server : UDPListener {
.response_code = ResponseCode::NO_ERROR,
.reserved = 0,
.recursion_available = msg.header.recursion_available,
.question_count = htons(0),
.answer_count = htons(0),
.authority_count = htons(0),
.additional_count = htons(0),
.question_count = 0,
.answer_count = 0,
.authority_count = 0,
.additional_count = 0,
};
}

Expand Down Expand Up @@ -179,10 +179,10 @@ void ProxyLookup::OnAnswer(const Message &msg) {
.response_code = msg.header.response_code,
.reserved = 0,
.recursion_available = true,
.question_count = htons(1),
.answer_count = htons(msg.answers.size()),
.authority_count = htons(msg.authority.size()),
.additional_count = htons(msg.additional.size()),
.question_count = 1,
.answer_count = msg.answers.size(),
.authority_count = msg.authority.size(),
.additional_count = msg.additional.size(),
};
response_header.write_to(buffer);
msg.questions.front().write_to(buffer);
Expand Down
24 changes: 12 additions & 12 deletions src/dns_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ Str Header::ToStr() const {
r += " recursion_desired: " + ::ToStr(recursion_desired) + "\n";
r += " recursion_available: " + ::ToStr(recursion_available) + "\n";
r += " response_code: " + string(dns::ToStr(response_code)) + "\n";
r += " question_count: " + ::ToStr(ntohs(question_count)) + "\n";
r += " answer_count: " + ::ToStr(ntohs(answer_count)) + "\n";
r += " authority_count: " + ::ToStr(ntohs(authority_count)) + "\n";
r += " additional_count: " + ::ToStr(ntohs(additional_count)) + "\n";
r += " question_count: " + ::ToStr(question_count) + "\n";
r += " answer_count: " + ::ToStr(answer_count) + "\n";
r += " authority_count: " + ::ToStr(authority_count) + "\n";
r += " additional_count: " + ::ToStr(additional_count) + "\n";
r += "}";
return r;
}
Expand Down Expand Up @@ -378,7 +378,7 @@ void Message::Parse(const char *ptr, size_t len, string &err) {

size_t offset = sizeof(Header);

for (int i = 0; i < ntohs(header.question_count); ++i) {
for (int i = 0; i < header.question_count.Get(); ++i) {
if (auto q_size = questions.emplace_back().LoadFrom(ptr, len, offset)) {
offset += q_size;
} else {
Expand All @@ -387,7 +387,7 @@ void Message::Parse(const char *ptr, size_t len, string &err) {
}
}

auto LoadRecordList = [&](vector<Record> &v, uint16_t n) {
auto LoadRecordList = [&](Vec<Record> &v, U16 n) {
for (int i = 0; i < n; ++i) {
Record &r = v.emplace_back();
if (auto r_size = r.LoadFrom(ptr, len, offset)) {
Expand All @@ -403,18 +403,18 @@ void Message::Parse(const char *ptr, size_t len, string &err) {
}
};

LoadRecordList(answers, ntohs(header.answer_count));
LoadRecordList(answers, header.answer_count);
if (!err.empty())
return;
LoadRecordList(authority, ntohs(header.authority_count));
LoadRecordList(authority, header.authority_count);
if (!err.empty())
return;
LoadRecordList(additional, ntohs(header.additional_count));
LoadRecordList(additional, header.additional_count);
if (!err.empty())
return;
}
string Message::ToStr() const {
string r = "dns::Message {\n";
Str Message::ToStr() const {
Str r = "dns::Message {\n";
r += IndentString(header.ToStr()) + "\n";
for (auto &q : questions) {
r += " " + q.ToStr() + "\n";
Expand All @@ -431,7 +431,7 @@ string Message::ToStr() const {
r += "}";
return r;
}
void Message::ForEachRecord(function<void(const Record &)> f) const {
void Message::ForEachRecord(Fn<void(const Record &)> f) const {
for (const Record &r : answers) {
f(r);
}
Expand Down
9 changes: 5 additions & 4 deletions src/dns_utils.hh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <chrono>

#include "big_endian.hh"
#include "fn.hh"
#include "int.hh"
#include "optional.hh"
Expand Down Expand Up @@ -94,10 +95,10 @@ struct __attribute__((__packed__)) Header {
U8 reserved : 3;
bool recursion_available : 1;

U16 question_count; // big endian
U16 answer_count; // big endian
U16 authority_count; // big endian
U16 additional_count; // big endian
Big<U16> question_count;
Big<U16> answer_count;
Big<U16> authority_count;
Big<U16> additional_count;
void write_to(Str &buffer);
Str ToStr() const;
};
Expand Down
6 changes: 3 additions & 3 deletions src/ip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ IP IP::NetmaskFromInterface(std::string_view interface_name, Status &status) {
}

IP IP::NetmaskFromPrefixLength(int prefix_length) {
uint32_t mask = 0;
U32 mask = 0;
for (int i = 0; i < prefix_length; i++) {
mask |= 1 << (31 - i);
mask = std::byteswap(std::rotr(std::byteswap(mask), 1) | 0x80000000);
}
return IP(htonl(mask));
return IP(mask);
}

Str ToStr(IP ip) {
Expand Down
11 changes: 8 additions & 3 deletions src/ip.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <arpa/inet.h>
#include <bit>

#include "big_endian.hh"
#include "int.hh"
#include "status.hh"
#include "str.hh"
Expand All @@ -11,6 +12,7 @@ namespace maf {

union __attribute__((__packed__)) IP {
U32 addr; // network byte order
Big<U32> addr_big_endian;
U8 bytes[4];
U16 halves[2];
IP() : addr(0) {}
Expand All @@ -22,16 +24,19 @@ union __attribute__((__packed__)) IP {
Status &status);
static IP NetmaskFromPrefixLength(int prefix_length);
auto operator<=>(const IP &other) const {
return (int32_t)ntohl(addr) <=> (int32_t)ntohl(other.addr);
return addr_big_endian <=> other.addr_big_endian;
}
bool operator==(const IP &other) const { return addr == other.addr; }
bool operator!=(const IP &other) const { return addr != other.addr; }
IP operator&(const IP &other) const { return IP(addr & other.addr); }
IP operator|(const IP &other) const { return IP(addr | other.addr); }
IP operator~() const { return IP(~addr); }
IP operator+(int n) const { return IP(htonl(ntohl(addr) + n)); }
IP operator+(int n) const {
Big<U32> sum = addr_big_endian.Get() + n;
return IP(sum.big_endian);
}
IP &operator++() {
addr = htonl(ntohl(addr) + 1);
addr_big_endian.Set(addr_big_endian.Get() + 1);
return *this;
}
bool TryParse(const char *cp) {
Expand Down
Loading

0 comments on commit 644709c

Please sign in to comment.