diff --git a/src/README.md b/src/README.md index b7d5da48b7..08c911a182 100644 --- a/src/README.md +++ b/src/README.md @@ -387,7 +387,7 @@ or equal to `length` of *packet*. — Function **packet.shiftright** *packet*, *length* -Move *packet* payload to the right by *length* bytes, growing *packet* by +Moves *packet* payload to the right by *length* bytes, growing *packet* by *length*. The sum of *length* and `length` of *packet* must be less than or equal to `packet.max_payload`. @@ -399,6 +399,10 @@ Allocate packet and fill it with *length* bytes from *pointer*. Allocate packet and fill it with the contents of *string*. +— Function **packet.clone_to_memory* *pointer* *packet* + +Creates an exact copy of at memory pointed to by *pointer*. *Pointer* must +point to a `packet.packet_t`. ## Memory (core.memory) @@ -469,6 +473,10 @@ If *readonly* is non-nil the shared object is mapped in read-only mode. *Readonly* defaults to nil. Fails if the shared object does not already exist. Returns a pointer to the mapped object. +— Function **shm.exists** *name* + +Returns a true value if shared object by *name* exists. + — Function **shm.unmap** *pointer* Deletes the memory mapping for *pointer*. diff --git a/src/apps/virtio_net/README.md b/src/apps/virtio_net/README.md index 490c216a47..451cf47ac3 100644 --- a/src/apps/virtio_net/README.md +++ b/src/apps/virtio_net/README.md @@ -2,12 +2,10 @@ The `VirtioNet` app implements a subset of the driver part of the [virtio-net](http://docs.oasis-open.org/virtio/virtio/v1.0/csprd04/virtio-v1.0-csprd04.html) -specification. - -With `VirtioNet` SnabbSwitch can be used as a virtual ethernet interface -by *QEMU virtual machines*. When connected via a UNIX socket, packets can -be sent to the virtual machine by transmitting them on the `rx` port and -packets send by the virtual machine will arrive on the `tx` port. +specification. It can connect to a virtio-net device from within a QEMU virtual +machine. Packets can be sent out of the virtual machine by transmitting them on +the `rx` port, and packets sent to the virtual machine will arrive on the `tx` +port. DIAGRAM: VirtioNet +-----------+ diff --git a/src/core/packet.lua b/src/core/packet.lua index 1f8fc579a8..85da7583c8 100644 --- a/src/core/packet.lua +++ b/src/core/packet.lua @@ -71,12 +71,17 @@ function new_packet () return p end +-- Create an exact copy of srcp at memory pointed to by dstp. +function clone_to_memory(dstp, srcp) + ffi.copy(dstp, srcp, srcp.length) + dstp.length = srcp.length + return dstp +end + -- Create an exact copy of a packet. function clone (p) local p2 = allocate() - ffi.copy(p2, p, p.length) - p2.length = p.length - return p2 + return clone_to_memory(p2, p) end -- Append data to the end of a packet. diff --git a/src/core/shm.lua b/src/core/shm.lua index 929eebedec..15b2f3f440 100644 --- a/src/core/shm.lua +++ b/src/core/shm.lua @@ -55,6 +55,12 @@ function open (name, type, readonly) return map(name, type, readonly, false) end +function exists (name) + local path = resolve(name) + local fd = S.open(root..'/'..path, "rdonly") + return fd and fd:close() +end + function resolve (name) local q, p = name:match("^(/*)(.*)") -- split qualifier (/) local result = p @@ -191,6 +197,14 @@ function selftest () unmap(p1) unmap(p2) + print("checking exists..") + assert(not exists(name)) + local p1 = create(name, "struct { int x, y, z; }") + assert(exists(name)) + assert(unlink(name)) + unmap(p1) + assert(not exists(name)) + -- Test that we can open and cleanup many objects print("checking many objects..") local path = 'shm/selftest/manyobj' diff --git a/src/lib/ipsec/esp.lua b/src/lib/ipsec/esp.lua index 5df0fdc6db..d011fe56a5 100644 --- a/src/lib/ipsec/esp.lua +++ b/src/lib/ipsec/esp.lua @@ -37,8 +37,10 @@ local seq_no_t = require("lib.ipsec.seq_no_t") local lib = require("core.lib") local ffi = require("ffi") local C = ffi.C -require("lib.ipsec.track_seq_no_h") +local logger = lib.logger_new({ rate = 32, module = 'esp' }); +require("lib.ipsec.track_seq_no_h") +local window_t = ffi.typeof("uint8_t[?]") local ETHERNET_SIZE = ethernet:sizeof() local IPV6_SIZE = ipv6:sizeof() @@ -123,6 +125,8 @@ function esp_v6_decrypt:new (conf) o.CTEXT_OFFSET = ESP_SIZE + gcm.IV_SIZE o.PLAIN_OVERHEAD = PAYLOAD_OFFSET + ESP_SIZE + gcm.IV_SIZE + gcm.AUTH_SIZE o.window_size = conf.window_size or 128 + assert(o.window_size % 8 == 0, "window_size must be a multiple of 8.") + o.window = ffi.new(window_t, o.window_size / 8) return setmetatable(o, {__index=esp_v6_decrypt}) end @@ -143,10 +147,9 @@ function esp_v6_decrypt:decapsulate (p) local ctext_start = payload + self.CTEXT_OFFSET local ctext_length = length - self.PLAIN_OVERHEAD local seq_low = self.esp:seq_no() - local seq_high = C.track_seq_no(seq_low, self.seq:low(), self.seq:high(), self.window_size) - if gcm:decrypt(ctext_start, seq_low, seq_high, iv_start, ctext_start, ctext_length) then - self.seq:low(seq_low) - self.seq:high(seq_high) + local seq_high = tonumber(C.check_seq_no(seq_low, self.seq.no, self.window, self.window_size)) + if seq_high >= 0 and gcm:decrypt(ctext_start, seq_low, seq_high, iv_start, ctext_start, ctext_length) then + self.seq.no = C.track_seq_no(seq_high, seq_low, self.seq.no, self.window, self.window_size) local esp_tail_start = ctext_start + ctext_length - ESP_TAIL_SIZE self.esp_tail:new_from_mem(esp_tail_start, ESP_TAIL_SIZE) local ptext_length = ctext_length - self.esp_tail:pad_length() - ESP_TAIL_SIZE @@ -156,6 +159,15 @@ function esp_v6_decrypt:decapsulate (p) packet.resize(p, PAYLOAD_OFFSET + ptext_length) return true else + local reason = seq_high == -1 and 'replayed' or 'integrity error' + -- This is the information RFC4303 says we SHOULD log + local info = "SPI=" .. tostring(self.spi) .. ", " .. + "src_addr='" .. tostring(self.ip:ntop(self.ip:src())) .. "', " .. + "dst_addr='" .. tostring(self.ip:ntop(self.ip:dst())) .. "', " .. + "seq_low=" .. tostring(seq_low) .. ", " .. + "flow_id=" .. tostring(self.ip:flow_label()) .. ", " .. + "reason='" .. reason .. "'"; + logger:log("Rejecting packet ("..info..")") return false end end @@ -212,25 +224,66 @@ ABCDEFGHIJKLMNOPQRSTUVWXYZ and C.memcmp(p_min, e_min, p_min.length) == 0, "integrity check failed") -- Check transmitted Sequence Number wrap around - enc.seq:low(0) - enc.seq:high(1) - dec.seq:low(2^32 - dec.window_size) - dec.seq:high(0) - local p3 = packet.clone(p) - enc:encapsulate(p3) - assert(dec:decapsulate(p3), + C.memset(dec.window, 0, dec.window_size / 8); -- clear window + enc.seq.no = 2^32 - 1 -- so next encapsulated will be seq 2^32 + dec.seq.no = 2^32 - 1 -- pretend to have seen 2^32-1 + local px = packet.clone(p) + enc:encapsulate(px) + assert(dec:decapsulate(px), "Transmitted Sequence Number wrap around failed.") - assert(dec.seq:high() == 1 and dec.seq:low() == 1, + assert(dec.seq:high() == 1 and dec.seq:low() == 0, "Lost Sequence Number synchronization.") -- Check Sequence Number exceeding window - enc.seq:low(0) - enc.seq:high(1) - dec.seq:low(dec.window_size+1) - dec.seq:high(1) - local p4 = packet.clone(p) - enc:encapsulate(p4) - assert(not dec:decapsulate(p4), + C.memset(dec.window, 0, dec.window_size / 8); -- clear window + enc.seq.no = 2^32 + dec.seq.no = 2^32 + dec.window_size + 1 + px = packet.clone(p) + enc:encapsulate(px) + assert(not dec:decapsulate(px), "Accepted out of window Sequence Number.") assert(dec.seq:high() == 1 and dec.seq:low() == dec.window_size+1, "Corrupted Sequence Number.") + -- Test anti-replay: From a set of 15 packets, first send all those + -- that have an even sequence number. Then, send all 15. Verify that + -- in the 2nd run, packets with even sequence numbers are rejected while + -- the others are not. + -- Then do the same thing again, but with offset sequence numbers so that + -- we have a 32bit wraparound in the middle. + local offset = 0 -- close to 2^32 in the 2nd iteration + for offset = 0, 2^32-7, 2^32-7 do -- duh + C.memset(dec.window, 0, dec.window_size / 8); -- clear window + dec.seq.no = offset + for i = 1+offset, 15+offset do + if (i % 2 == 0) then + enc.seq.no = i-1 -- so next seq will be i + px = packet.clone(p) + enc:encapsulate(px); + assert(dec:decapsulate(px), "rejected legitimate packet seq=" .. i) + assert(dec.seq.no == i, "Lost sequence number synchronization") + end + end + for i = 1+offset, 15+offset do + enc.seq.no = i-1 + px = packet.clone(p) + enc:encapsulate(px); + if (i % 2 == 0) then + assert(not dec:decapsulate(px), "accepted replayed packet seq=" .. i) + else + assert(dec:decapsulate(px), "rejected legitimate packet seq=" .. i) + end + end + end + -- Check that packets from way in the past/way in the future + -- (further than the biggest allowable window size) are rejected + -- This is where we ultimately want resynchronization (wrt. future packets) + C.memset(dec.window, 0, dec.window_size / 8); -- clear window + dec.seq.no = 2^34 + 42; + enc.seq.no = 2^36 + 42; + px = packet.clone(p) + enc:encapsulate(px); + assert(not dec:decapsulate(px), "accepted packet from way into the future") + enc.seq.no = 2^32 + 42; + px = packet.clone(p) + enc:encapsulate(px); + assert(not dec:decapsulate(px), "accepted packet from way into the past") end diff --git a/src/lib/ipsec/seq_no_t.lua b/src/lib/ipsec/seq_no_t.lua index a1b6f13db3..7f4e46da6c 100644 --- a/src/lib/ipsec/seq_no_t.lua +++ b/src/lib/ipsec/seq_no_t.lua @@ -6,8 +6,6 @@ local ffi = require("ffi") local seq_no_t = ffi.typeof("union { uint64_t no; uint32_t no32[2]; }") local seq_no = {} -function seq_no:full () return self.no end - local low, high if ffi.abi("le") then low = 0; high = 1 elseif ffi.abi("be") then low = 1; high = 0 end diff --git a/src/lib/ipsec/track_seq_no.c b/src/lib/ipsec/track_seq_no.c index 135946442b..78ae1c9ead 100644 --- a/src/lib/ipsec/track_seq_no.c +++ b/src/lib/ipsec/track_seq_no.c @@ -1,14 +1,104 @@ +// See https://tools.ietf.org/html/rfc4303#page-38 + +#include #include -// See https://tools.ietf.org/html/rfc4303#page-38 -// This is a only partial implementation that attempts to keep track of the -// ESN counter, but does not detect replayed packets. -uint32_t track_seq_no (uint32_t seq_no, uint32_t Tl, uint32_t Th, uint32_t W) { +/* LO32/HI32: Get lower/upper 32 bit of a 64 bit value. Should be completely + portable and endian-independent. It shouldn't be difficult for the compiler + to recognize this as division/multiplication by powers of two and hence + replace it with bit shifts, automagically in the right direction. */ +#define LO32(U64) ((uint32_t)((uint64_t)(U64) % 4294967296ull)) // 2**32 +#define HI32(U64) ((uint32_t)((uint64_t)(U64) / 4294967296ull)) // 2**32 + +/* MK64: Likewise, make a 64 bit value from two 32 bit values */ +#define MK64(L, H) ((uint64_t)(((uint32_t)(H)) * 4294967296ull + ((uint32_t)(L)))) + + +/* Set/clear the bit in our window that corresponds to sequence number `seq` */ +static inline void set_bit (bool on, uint64_t seq, + uint8_t* window, uint32_t W) { + uint32_t bitno = seq % W; + uint32_t blockno = bitno / 8; + bitno %= 8; + + /* First clear the bit, then maybe set it. This way we don't have to branch + on `on` */ + window[blockno] &= ~(1u << bitno); + window[blockno] |= (uint8_t)on << bitno; +} + +/* Get the bit in our window that corresponds to sequence number `seq` */ +static inline bool get_bit (uint64_t seq, + uint8_t *window, uint32_t W) { + uint32_t bitno = seq % W; + uint32_t blockno = bitno / 8; + bitno = bitno % 8; + + return window[blockno] & (1u << bitno); +} + +/* Advance the window so that the "head" bit corresponds to sequence + * number `seq`. Clear all bits for the new sequence numbers that are + * now considered in-window. + */ +static void advance_window (uint64_t seq, + uint64_t T, uint8_t* window, uint32_t W) { + uint64_t diff = seq - T; + + /* For advances greater than the window size, don't clear more bits than the + window has */ + if (diff > W) diff = W; + + /* Clear all bits corresponding to the sequence numbers that used to be ahead + of, but are now inside our window since we haven't seen them yet */ + while (diff--) set_bit(0, seq--, window, W); +} + + +/* Determine whether a packet with the sequence number made from + * `seq_hi` and `seq_lo` (where `seq_hi` is inferred from our window + * state) could be a legitimate packet. + * + * "Could", because we can't really tell if the received packet with + * this sequence number is in fact valid (since we haven't yet + * integrity-checked it - the sequence number may be spoofed), but we + * can give an authoritative "no" if we have already seen and accepted + * a packet with this number. + * + * If our answer is NOT "no", the caller will, provided the packet was + * valid, use track_seq_no() for us to mark the sequence number as seen. + */ +int64_t check_seq_no (uint32_t seq_lo, + uint64_t T, uint8_t *window, uint32_t W) { + uint32_t Tl = LO32(T); + uint32_t Th = HI32(T); + uint32_t seq_hi; + uint64_t seq; + if (Tl >= W - 1) { // Case A - if (seq_no >= Tl - W + 1) return Th; - else return Th + 1; + if (seq_lo >= Tl - W + 1) seq_hi = Th; + else seq_hi = Th + 1; } else { // Case B - if (seq_no >= Tl - W + 1) return Th - 1; - else return Th; + if (seq_lo >= Tl - W + 1) seq_hi = Th - 1; + else seq_hi = Th; + } + seq = MK64(seq_lo, seq_hi); + if (seq <= T && get_bit(seq, window, W)) return -1; + else return seq_hi; +} + +/* Signal that the packet received with this sequence number was + * in fact valid -- we record that we have seen it so as to prevent + * future replays of it. + */ +uint64_t track_seq_no (uint32_t seq_hi, uint32_t seq_lo, + uint64_t T, uint8_t *window, uint32_t W) { + uint64_t seq = MK64(seq_lo, seq_hi); + + if (seq > T) { + advance_window(seq, T, window, W); + T = seq; } + set_bit(1, seq, window, W); + return T; } diff --git a/src/lib/ipsec/track_seq_no.h b/src/lib/ipsec/track_seq_no.h index 8d5e5105b2..5f11c4d49e 100644 --- a/src/lib/ipsec/track_seq_no.h +++ b/src/lib/ipsec/track_seq_no.h @@ -1 +1,2 @@ -uint32_t track_seq_no (uint32_t, uint32_t, uint32_t, uint32_t); +uint64_t track_seq_no (uint32_t, uint32_t, uint64_t, uint8_t *, uint32_t); +int64_t check_seq_no (uint32_t, uint64_t, uint8_t *, uint32_t); diff --git a/src/program/snabbmark/snabbmark.lua b/src/program/snabbmark/snabbmark.lua index 5dad61aaff..b968de09c7 100644 --- a/src/program/snabbmark/snabbmark.lua +++ b/src/program/snabbmark/snabbmark.lua @@ -379,6 +379,8 @@ function esp (npackets, packet_size, mode, profile) for i = 1, npackets do plain = packet.clone(encapsulated) dec:decapsulate(plain) + dec.seq.no = 0 + dec.window[0] = 0 packet.free(plain) end local finish = C.get_monotonic_time() diff --git a/src/program/top/top.lua b/src/program/top/top.lua index 250af23892..8c1aec02d3 100644 --- a/src/program/top/top.lua +++ b/src/program/top/top.lua @@ -39,14 +39,20 @@ function select_snabb_instance (pid) if instance == pid then return pid end end print("No such Snabb instance: "..pid) - elseif #instances == 2 then - -- Two means one is us, so we pick the other. - local own_pid = tostring(S.getpid()) - if instances[1] == own_pid then return instances[2] - else return instances[1] end elseif #instances == 1 then print("No Snabb instance found.") - else print("Multple Snabb instances found. Select one.") end - os.exit(1) + else + local own_pid = tostring(S.getpid()) + if #instances == 2 then + -- Two means one is us, so we pick the other. + return instances[1] == own_pid and instances[2] or instances[1] + else + print("Multiple Snabb instances found. Select one:") + for _, instance in ipairs(instances) do + if instance ~= own_pid then print(instance) end + end + end + end + main.exit(1) end function list_shm (pid, object)