Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anti-replay (RFC4303 Appendix A) #988

Merged
merged 12 commits into from Aug 22, 2016
Merged
93 changes: 73 additions & 20 deletions src/lib/ipsec/esp.lua
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ local seq_no_t = require("lib.ipsec.seq_no_t")
local lib = require("core.lib")
local ffi = require("ffi")
local C = ffi.C
require("lib.ipsec.track_seq_no_h")
local logger = lib.logger_new({ rate = 32, module = 'esp' });

require("lib.ipsec.track_seq_no_h")
local window_t = ffi.typeof("uint8_t[?]")

local ETHERNET_SIZE = ethernet:sizeof()
local IPV6_SIZE = ipv6:sizeof()
Expand Down Expand Up @@ -123,6 +125,8 @@ function esp_v6_decrypt:new (conf)
o.CTEXT_OFFSET = ESP_SIZE + gcm.IV_SIZE
o.PLAIN_OVERHEAD = PAYLOAD_OFFSET + ESP_SIZE + gcm.IV_SIZE + gcm.AUTH_SIZE
o.window_size = conf.window_size or 128
assert(o.window_size % 8 == 0, "window_size must be a multiple of 8.")
o.window = ffi.new(window_t, o.window_size / 8)
return setmetatable(o, {__index=esp_v6_decrypt})
end

Expand All @@ -143,10 +147,9 @@ function esp_v6_decrypt:decapsulate (p)
local ctext_start = payload + self.CTEXT_OFFSET
local ctext_length = length - self.PLAIN_OVERHEAD
local seq_low = self.esp:seq_no()
local seq_high = C.track_seq_no(seq_low, self.seq:low(), self.seq:high(), self.window_size)
if gcm:decrypt(ctext_start, seq_low, seq_high, iv_start, ctext_start, ctext_length) then
self.seq:low(seq_low)
self.seq:high(seq_high)
local seq_high = tonumber(C.check_seq_no(seq_low, self.seq.no, self.window, self.window_size))
if seq_high >= 0 and gcm:decrypt(ctext_start, seq_low, seq_high, iv_start, ctext_start, ctext_length) then
self.seq.no = C.track_seq_no(seq_high, seq_low, self.seq.no, self.window, self.window_size)
local esp_tail_start = ctext_start + ctext_length - ESP_TAIL_SIZE
self.esp_tail:new_from_mem(esp_tail_start, ESP_TAIL_SIZE)
local ptext_length = ctext_length - self.esp_tail:pad_length() - ESP_TAIL_SIZE
Expand All @@ -156,6 +159,15 @@ function esp_v6_decrypt:decapsulate (p)
packet.resize(p, PAYLOAD_OFFSET + ptext_length)
return true
else
local reason = seq_high == -1 and 'replayed' or 'integrity error'
-- This is the information RFC4303 says we SHOULD log
local info = "SPI=" .. tostring(self.spi) .. ", " ..
"src_addr='" .. tostring(self.ip:ntop(self.ip:src())) .. "', " ..
"dst_addr='" .. tostring(self.ip:ntop(self.ip:dst())) .. "', " ..
"seq_low=" .. tostring(seq_low) .. ", " ..
"flow_id=" .. tostring(self.ip:flow_label()) .. ", " ..
"reason='" .. reason .. "'";
logger:log("Rejecting packet ("..info..")")
return false
end
end
Expand Down Expand Up @@ -212,25 +224,66 @@ ABCDEFGHIJKLMNOPQRSTUVWXYZ
and C.memcmp(p_min, e_min, p_min.length) == 0,
"integrity check failed")
-- Check transmitted Sequence Number wrap around
enc.seq:low(0)
enc.seq:high(1)
dec.seq:low(2^32 - dec.window_size)
dec.seq:high(0)
local p3 = packet.clone(p)
enc:encapsulate(p3)
assert(dec:decapsulate(p3),
C.memset(dec.window, 0, dec.window_size / 8); -- clear window
enc.seq.no = 2^32 - 1 -- so next encapsulated will be seq 2^32
dec.seq.no = 2^32 - 1 -- pretend to have seen 2^32-1
local px = packet.clone(p)
enc:encapsulate(px)
assert(dec:decapsulate(px),
"Transmitted Sequence Number wrap around failed.")
assert(dec.seq:high() == 1 and dec.seq:low() == 1,
assert(dec.seq:high() == 1 and dec.seq:low() == 0,
"Lost Sequence Number synchronization.")
-- Check Sequence Number exceeding window
enc.seq:low(0)
enc.seq:high(1)
dec.seq:low(dec.window_size+1)
dec.seq:high(1)
local p4 = packet.clone(p)
enc:encapsulate(p4)
assert(not dec:decapsulate(p4),
C.memset(dec.window, 0, dec.window_size / 8); -- clear window
enc.seq.no = 2^32
dec.seq.no = 2^32 + dec.window_size + 1
px = packet.clone(p)
enc:encapsulate(px)
assert(not dec:decapsulate(px),
"Accepted out of window Sequence Number.")
assert(dec.seq:high() == 1 and dec.seq:low() == dec.window_size+1,
"Corrupted Sequence Number.")
-- Test anti-replay: From a set of 15 packets, first send all those
-- that have an even sequence number. Then, send all 15. Verify that
-- in the 2nd run, packets with even sequence numbers are rejected while
-- the others are not.
-- Then do the same thing again, but with offset sequence numbers so that
-- we have a 32bit wraparound in the middle.
local offset = 0 -- close to 2^32 in the 2nd iteration
for offset = 0, 2^32-7, 2^32-7 do -- duh
C.memset(dec.window, 0, dec.window_size / 8); -- clear window
dec.seq.no = offset
for i = 1+offset, 15+offset do
if (i % 2 == 0) then
enc.seq.no = i-1 -- so next seq will be i
px = packet.clone(p)
enc:encapsulate(px);
assert(dec:decapsulate(px), "rejected legitimate packet seq=" .. i)
assert(dec.seq.no == i, "Lost sequence number synchronization")
end
end
for i = 1+offset, 15+offset do
enc.seq.no = i-1
px = packet.clone(p)
enc:encapsulate(px);
if (i % 2 == 0) then
assert(not dec:decapsulate(px), "accepted replayed packet seq=" .. i)
else
assert(dec:decapsulate(px), "rejected legitimate packet seq=" .. i)
end
end
end
-- Check that packets from way in the past/way in the future
-- (further than the biggest allowable window size) are rejected
-- This is where we ultimately want resynchronization (wrt. future packets)
C.memset(dec.window, 0, dec.window_size / 8); -- clear window
dec.seq.no = 2^34 + 42;
enc.seq.no = 2^36 + 42;
px = packet.clone(p)
enc:encapsulate(px);
assert(not dec:decapsulate(px), "accepted packet from way into the future")
enc.seq.no = 2^32 + 42;
px = packet.clone(p)
enc:encapsulate(px);
assert(not dec:decapsulate(px), "accepted packet from way into the past")
end
2 changes: 0 additions & 2 deletions src/lib/ipsec/seq_no_t.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ local ffi = require("ffi")
local seq_no_t = ffi.typeof("union { uint64_t no; uint32_t no32[2]; }")
local seq_no = {}

function seq_no:full () return self.no end

local low, high
if ffi.abi("le") then low = 0; high = 1
elseif ffi.abi("be") then low = 1; high = 0 end
Expand Down
106 changes: 98 additions & 8 deletions src/lib/ipsec/track_seq_no.c
Original file line number Diff line number Diff line change
@@ -1,14 +1,104 @@
// See https://tools.ietf.org/html/rfc4303#page-38

#include <stdbool.h>
#include <stdint.h>

// See https://tools.ietf.org/html/rfc4303#page-38
// This is a only partial implementation that attempts to keep track of the
// ESN counter, but does not detect replayed packets.
uint32_t track_seq_no (uint32_t seq_no, uint32_t Tl, uint32_t Th, uint32_t W) {
/* LO32/HI32: Get lower/upper 32 bit of a 64 bit value. Should be completely
portable and endian-independent. It shouldn't be difficult for the compiler
to recognize this as division/multiplication by powers of two and hence
replace it with bit shifts, automagically in the right direction. */
#define LO32(U64) ((uint32_t)((uint64_t)(U64) % 4294967296ull)) // 2**32
#define HI32(U64) ((uint32_t)((uint64_t)(U64) / 4294967296ull)) // 2**32

/* MK64: Likewise, make a 64 bit value from two 32 bit values */
#define MK64(L, H) ((uint64_t)(((uint32_t)(H)) * 4294967296ull + ((uint32_t)(L))))


/* Set/clear the bit in our window that corresponds to sequence number `seq` */
static inline void set_bit (bool on, uint64_t seq,
uint8_t* window, uint32_t W) {
uint32_t bitno = seq % W;
uint32_t blockno = bitno / 8;
bitno %= 8;

/* First clear the bit, then maybe set it. This way we don't have to branch
on `on` */
window[blockno] &= ~(1u << bitno);
window[blockno] |= (uint8_t)on << bitno;
}

/* Get the bit in our window that corresponds to sequence number `seq` */
static inline bool get_bit (uint64_t seq,
uint8_t *window, uint32_t W) {
uint32_t bitno = seq % W;
uint32_t blockno = bitno / 8;
bitno = bitno % 8;

return window[blockno] & (1u << bitno);
}

/* Advance the window so that the "head" bit corresponds to sequence
* number `seq`. Clear all bits for the new sequence numbers that are
* now considered in-window.
*/
static void advance_window (uint64_t seq,
uint64_t T, uint8_t* window, uint32_t W) {
uint64_t diff = seq - T;

/* For advances greater than the window size, don't clear more bits than the
window has */
if (diff > W) diff = W;

/* Clear all bits corresponding to the sequence numbers that used to be ahead
of, but are now inside our window since we haven't seen them yet */
while (diff--) set_bit(0, seq--, window, W);
}


/* Determine whether a packet with the sequence number made from
* `seq_hi` and `seq_lo` (where `seq_hi` is inferred from our window
* state) could be a legitimate packet.
*
* "Could", because we can't really tell if the received packet with
* this sequence number is in fact valid (since we haven't yet
* integrity-checked it - the sequence number may be spoofed), but we
* can give an authoritative "no" if we have already seen and accepted
* a packet with this number.
*
* If our answer is NOT "no", the caller will, provided the packet was
* valid, use track_seq_no() for us to mark the sequence number as seen.
*/
int64_t check_seq_no (uint32_t seq_lo,
uint64_t T, uint8_t *window, uint32_t W) {
uint32_t Tl = LO32(T);
uint32_t Th = HI32(T);
uint32_t seq_hi;
uint64_t seq;

if (Tl >= W - 1) { // Case A
if (seq_no >= Tl - W + 1) return Th;
else return Th + 1;
if (seq_lo >= Tl - W + 1) seq_hi = Th;
else seq_hi = Th + 1;
} else { // Case B
if (seq_no >= Tl - W + 1) return Th - 1;
else return Th;
if (seq_lo >= Tl - W + 1) seq_hi = Th - 1;
else seq_hi = Th;
}
seq = MK64(seq_lo, seq_hi);
if (seq <= T && get_bit(seq, window, W)) return -1;
else return seq_hi;
}

/* Signal that the packet received with this sequence number was
* in fact valid -- we record that we have seen it so as to prevent
* future replays of it.
*/
uint64_t track_seq_no (uint32_t seq_hi, uint32_t seq_lo,
uint64_t T, uint8_t *window, uint32_t W) {
uint64_t seq = MK64(seq_lo, seq_hi);

if (seq > T) {
advance_window(seq, T, window, W);
T = seq;
}
set_bit(1, seq, window, W);
return T;
}
3 changes: 2 additions & 1 deletion src/lib/ipsec/track_seq_no.h
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
uint32_t track_seq_no (uint32_t, uint32_t, uint32_t, uint32_t);
uint64_t track_seq_no (uint32_t, uint32_t, uint64_t, uint8_t *, uint32_t);
int64_t check_seq_no (uint32_t, uint64_t, uint8_t *, uint32_t);
2 changes: 2 additions & 0 deletions src/program/snabbmark/snabbmark.lua
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ function esp (npackets, packet_size, mode, profile)
for i = 1, npackets do
plain = packet.clone(encapsulated)
dec:decapsulate(plain)
dec.seq.no = 0
dec.window[0] = 0
packet.free(plain)
end
local finish = C.get_monotonic_time()
Expand Down