diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 4fa9991bd4a6..0e23101ee3d0 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -16,17 +16,25 @@ const enum_array = tls.enum_array; read_seq: u64, write_seq: u64, -/// The number of partially read bytes inside `partially_read_buffer`. -partially_read_len: u15, -/// The number of cleartext bytes from decoding `partially_read_buffer` which -/// have already been transferred via read() calls. This implementation will -/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by -/// the read() API user is not large enough. -partial_cleartext_index: u15, +/// The starting index of cleartext bytes inside `partially_read_buffer`. +partial_cleartext_idx: u15, +/// The ending index of cleartext bytes inside `partially_read_buffer` as well +/// as the starting index of ciphertext bytes. +partial_ciphertext_idx: u15, +/// The ending index of ciphertext bytes inside `partially_read_buffer`. +partial_ciphertext_end: u15, +/// When this is true, the stream may still not be at the end because there +/// may be data in `partially_read_buffer`. +received_close_notify: bool, application_cipher: tls.ApplicationCipher, -eof: bool, /// The size is enough to contain exactly one TLSCiphertext record. -/// Contains encrypted bytes. +/// This buffer is segmented into four parts: +/// 0. unused +/// 1. cleartext +/// 2. ciphertext +/// 3. unused +/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and +/// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, /// `host` is only borrowed during this function call. @@ -597,13 +605,14 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) }, }; var client: Client = .{ - .application_cipher = app_cipher, .read_seq = 0, .write_seq = 0, - .partial_cleartext_index = 0, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(u15, len - end), + .received_close_notify = false, + .application_cipher = app_cipher, .partially_read_buffer = undefined, - .partially_read_len = @intCast(u15, len - end), - .eof = false, }; mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]); return client; @@ -727,19 +736,17 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void { } } +pub fn eof(c: Client) bool { + return c.received_close_notify and c.partial_ciphertext_end == 0; +} + /// Returns the number of bytes read, calling the underlying read function the /// minimal number of times until the buffer has at least `len` bytes filled. /// If the number read is less than `len` it means the stream reached the end. /// Reaching the end of the stream is not an error condition. pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - assert(len <= buffer.len); - if (c.eof) return 0; - var index: usize = 0; - while (index < len) { - index += try c.readAdvanced(stream, buffer[index..]); - if (c.eof) break; - } - return index; + var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; + return readvAtLeast(c, stream, &iovecs, len); } pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { @@ -753,78 +760,180 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, buffer.len); } -/// Returns number of bytes that have been read, populated inside `buffer`. A -/// return value of zero bytes does not mean end of stream. Instead, the `eof` -/// flag is set upon end of stream. The `eof` flag may be set after any call to +/// Returns the number of bytes read. If the number read is less than the space +/// provided it means the stream reached the end. Reaching the end of the +/// stream is not an error condition. +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial reads from the underlying stream layer. +pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { + return readvAtLeast(c, stream, iovecs); +} + +/// Returns the number of bytes read, calling the underlying read function the +/// minimal number of times until the iovecs have at least `len` bytes filled. +/// If the number read is less than `len` it means the stream reached the end. +/// Reaching the end of the stream is not an error condition. +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial reads from the underlying stream layer. +pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize { + if (c.eof()) return 0; + + var off_i: usize = 0; + var vec_i: usize = 0; + while (true) { + var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); + off_i += amt; + if (c.eof() or off_i >= len) return off_i; + while (amt >= iovecs[vec_i].iov_len) { + amt -= iovecs[vec_i].iov_len; + vec_i += 1; + } + iovecs[vec_i].iov_base += amt; + iovecs[vec_i].iov_len -= amt; + } +} + +/// Returns number of bytes that have been read, populated inside `iovecs`. A +/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` +/// for the end of stream. The `eof()` may be true after any call to /// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof` is `false`. -/// See `read` for a higher level function that has the same, familiar API -/// as other read functions, such as `std.fs.File.read`. -/// It is recommended to use a buffer size with length at least -/// `tls.max_ciphertext_len` bytes to avoid redundantly decrypting the same -/// encoded data. -pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { - assert(!c.eof); - const prev_len = c.partially_read_len; - // Ideally, this buffer would never be used. It is needed when `buffer` is too small - // to fit the cleartext, which may be as large as `max_ciphertext_len`. +/// function asserts that `eof()` is `false`. +/// See `readv` for a higher level function that has the same, familiar API as +/// other read functions, such as `std.fs.File.read`. +pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iovec) !usize { + var vp: VecPut = .{ .iovecs = iovecs }; + + // Give away the buffered cleartext we have, if any. + const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; + if (partial_cleartext.len > 0) { + const amt = @intCast(u15, vp.put(partial_cleartext)); + c.partial_cleartext_idx += amt; + if (amt < partial_cleartext.len) { + // We still have cleartext left so we cannot issue another read() call yet. + assert(vp.total == amt); + return amt; + } + if (c.received_close_notify) { + c.partial_ciphertext_end = 0; + assert(vp.total == amt); + return amt; + } + if (c.partial_ciphertext_end == c.partial_ciphertext_idx) { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = 0; + } + } + + assert(!c.received_close_notify); + + // Ideally, this buffer would never be used. It is needed when `iovecs` are + // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // This buffer is typically used, except, as an optimization when a very large - // `buffer` is provided, we use half of it for buffering ciphertext and the - // other half for outputting cleartext. + // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - const half_buffer_len = buffer.len / 2; - const out_in: struct { []u8, []u8 } = if (half_buffer_len >= in_stack_buffer.len) .{ - buffer[0..half_buffer_len], - buffer[half_buffer_len..], - } else .{ - buffer, - &in_stack_buffer, + // How many bytes left in the user's buffer. + const free_size = vp.freeSize(); + // The amount of the user's buffer that we need to repurpose for storing + // ciphertext. The end of the buffer will be used for such purposes. + const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; + // The amount of the user's buffer that will be used to give cleartext. The + // beginning of the buffer will be used for such purposes. + const cleartext_buf_len = free_size - ciphertext_buf_len; + const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; + + var ask_iovecs_buf: [2]std.os.iovec = .{ + .{ + .iov_base = first_iov.ptr, + .iov_len = first_iov.len, + }, + .{ + .iov_base = &in_stack_buffer, + .iov_len = in_stack_buffer.len, + }, }; - const out_buf = out_in[0]; - const in_buf = out_in[1]; - mem.copy(u8, in_buf, c.partially_read_buffer[0..prev_len]); - // Capacity of output buffer, in records, rounded up. - const buf_cap = (out_buf.len +| (max_ciphertext_len - 1)) / max_ciphertext_len; + // Cleartext capacity of output buffer, in records, rounded up. + const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len; const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); - const ask_slice = in_buf[prev_len..][0..@min(ask_len, in_buf.len - prev_len)]; - assert(ask_slice.len > 0); - const frag = frag: { - if (prev_len >= 5) { - const record_size = mem.readIntBig(u16, in_buf[3..][0..2]); - if (prev_len >= 5 + record_size) { - // We can use our buffered data without calling read(). - break :frag in_buf[0..prev_len]; + const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); + const actual_read_len = try stream.readv(ask_iovecs); + if (actual_read_len == 0) { + // This is either a truncation attack, or a bug in the server. + return error.TlsConnectionTruncated; + } + + // There might be more bytes inside `in_stack_buffer` that need to be processed, + // but at least frag0 will have one complete ciphertext record. + const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)]; + var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len]; + // We need to decipher frag0 and frag1 but there may be a ciphertext record + // straddling the boundary. We can handle this with two memcpy() calls to + // assemble the straddling record in between handling the two sides. + var frag = frag0; + var in: usize = 0; + while (true) { + if (in == frag.len) { + // Perfect split. + if (frag.ptr == frag1.ptr) { + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; } + frag = frag1; + in = 0; + continue; } - const actual_read_len = try stream.read(ask_slice); - if (actual_read_len == 0) { - // This is either a truncation attack, or a bug in the server. - return error.TlsConnectionTruncated; - } - break :frag in_buf[0 .. prev_len + actual_read_len]; - }; - var in: usize = 0; - var out: usize = 0; - while (true) { if (in + tls.ciphertext_record_header_len > frag.len) { - return finishRead(c, frag, in, out); + if (frag.ptr == frag1.ptr) + return finishRead(c, frag, in, vp.total); + + const first = frag[in..]; + + if (frag1.len < tls.ciphertext_record_header_len) + return finishRead2(c, first, frag1, vp.total); + + // A record straddles the two fragments. Copy into the now-empty first fragment. + const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); + const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); + const record_len = (record_len_byte_0 << 8) | record_len_byte_1; + if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + + const second_len = record_len + tls.ciphertext_record_header_len - first.len; + if (frag1.len < second_len) + return finishRead2(c, first, frag1, vp.total); + + mem.copy(u8, frag[0..in], first); + mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag1 = frag1[second_len..]; + in = 0; + continue; } - const record_start = in; const ct = @intToEnum(tls.ContentType, frag[in]); in += 1; const legacy_version = mem.readIntBig(u16, frag[in..][0..2]); in += 2; _ = legacy_version; - const record_size = mem.readIntBig(u16, frag[in..][0..2]); + const record_len = mem.readIntBig(u16, frag[in..][0..2]); + if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; - const end = in + record_size; + const end = in + record_len; if (end > frag.len) { - if (record_size > max_ciphertext_len) return error.TlsRecordOverflow; - return finishRead(c, frag, in, out); + if (frag.ptr == frag1.ptr) + return finishRead(c, frag, in, vp.total); + + // A record straddles the two fragments. Copy into the now-empty first fragment. + const first = frag[in..]; + const second_len = record_len + tls.ciphertext_record_header_len - first.len; + if (frag1.len < second_len) + return finishRead2(c, first, frag1, vp.total); + + mem.copy(u8, frag[0..in], first); + mem.copy(u8, frag[first.len..], frag1[0..second_len]); + frag1 = frag1[second_len..]; + in = 0; + continue; } switch (ct) { .alert => { @@ -836,18 +945,16 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_size - P.AEAD.tag_length; + const ciphertext_len = record_len - P.AEAD.tag_length; const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; const auth_tag = frag[in..][0..P.AEAD.tag_length].*; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - // Here we use read_seq and then intentionally don't - // increment it until later when it is certain the same - // ciphertext does not need to be decrypted again. const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq)); const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - const cleartext_buf = if (c.partial_cleartext_index == 0 and out + ciphertext.len <= out_buf.len) - out_buf[out..] + const out_buf = vp.peek(); + const cleartext_buf = if (ciphertext.len <= out_buf.len) + out_buf else &cleartext_stack_buffer; const cleartext = cleartext_buf[0..ciphertext.len]; @@ -856,22 +963,22 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { break :c cleartext; }, }; + c.read_seq += 1; const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { - c.read_seq += 1; - const level = @intToEnum(tls.AlertLevel, out_buf[out]); - const desc = @intToEnum(tls.AlertDescription, out_buf[out + 1]); + const level = @intToEnum(tls.AlertLevel, cleartext[0]); + const desc = @intToEnum(tls.AlertDescription, cleartext[1]); if (desc == .close_notify) { - c.eof = true; - return out; + c.received_close_notify = true; + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; } std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) }); return error.TlsAlert; }, .handshake => { - c.read_seq += 1; var ct_i: usize = 0; while (true) { const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]); @@ -926,42 +1033,37 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { .application_data => { // Determine whether the output buffer or a stack // buffer was used for storing the cleartext. - if (c.partial_cleartext_index == 0 and - out + cleartext.len <= out_buf.len) - { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - out += cleartext.len - 1; - c.read_seq += 1; - } else { + if (cleartext.ptr == &cleartext_stack_buffer) { // Stack buffer was used, so we must copy to the output buffer. - const dest = out_buf[out..]; - const rest = cleartext[c.partial_cleartext_index..]; - const src = rest[0..@min(rest.len, dest.len)]; - mem.copy(u8, dest, src); - out += src.len; - c.partial_cleartext_index = @intCast( - @TypeOf(c.partial_cleartext_index), - c.partial_cleartext_index + src.len, - ); - if (c.partial_cleartext_index >= cleartext.len) { - c.partial_cleartext_index = 0; - c.read_seq += 1; + const msg = cleartext[0 .. cleartext.len - 1]; + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // We have already run out of room in iovecs. Continue + // appending to `partially_read_buffer`. + const dest = c.partially_read_buffer[c.partial_ciphertext_idx..]; + mem.copy(u8, dest, msg); + c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len); } else { - in = record_start; - return finishRead(c, frag, in, out); + const amt = vp.put(msg); + if (amt < msg.len) { + const rest = msg[amt..]; + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len); + mem.copy(u8, &c.partially_read_buffer, rest); + } } + } else { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + vp.next(cleartext.len - 1); } }, else => { - std.debug.print("inner content type: {d}\n", .{inner_ct}); return error.TlsUnexpectedMessage; }, } }, else => { - std.debug.print("unexpected ct: {any}\n", .{ct}); return error.TlsUnexpectedMessage; }, } @@ -971,11 +1073,43 @@ pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize { fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { const saved_buf = frag[in..]; - mem.copy(u8, &c.partially_read_buffer, saved_buf); - c.partially_read_len = @intCast(u15, saved_buf.len); + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // There is cleartext at the beginning already which we need to preserve. + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + saved_buf.len); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], saved_buf); + } else { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), saved_buf.len); + mem.copy(u8, &c.partially_read_buffer, saved_buf); + } return out; } +fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // There is cleartext at the beginning already which we need to preserve. + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), c.partial_ciphertext_idx + first.len + frag1.len); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx..], first); + mem.copy(u8, c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..], frag1); + } else { + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = 0; + c.partial_ciphertext_end = @intCast(@TypeOf(c.partial_ciphertext_end), first.len + frag1.len); + mem.copy(u8, &c.partially_read_buffer, first); + mem.copy(u8, c.partially_read_buffer[first.len..], frag1); + } + return out; +} + +fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { + if (index < s1.len) { + return s1[index]; + } else { + return s2[index - s1.len]; + } +} + fn hostMatchesCommonName(host: []const u8, common_name: []const u8) bool { if (mem.eql(u8, common_name, host)) { return true; // exact match @@ -1015,6 +1149,89 @@ fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { }; } +/// Abstraction for sending multiple byte buffers to a slice of iovecs. +const VecPut = struct { + iovecs: []const std.os.iovec, + idx: usize = 0, + off: usize = 0, + total: usize = 0, + + /// Returns the amount actually put which is always equal to bytes.len + /// unless the vectors ran out of space. + fn put(vp: *VecPut, bytes: []const u8) usize { + var bytes_i: usize = 0; + while (true) { + const v = vp.iovecs[vp.idx]; + const dest = v.iov_base[vp.off..v.iov_len]; + const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; + mem.copy(u8, dest, src); + bytes_i += src.len; + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } + vp.off += src.len; + if (vp.off >= v.iov_len) { + vp.off = 0; + vp.idx += 1; + if (vp.idx >= vp.iovecs.len) { + vp.total += bytes_i; + return bytes_i; + } + } + } + } + + /// Returns the next buffer that consecutive bytes can go into. + fn peek(vp: VecPut) []u8 { + if (vp.idx >= vp.iovecs.len) return &.{}; + const v = vp.iovecs[vp.idx]; + return v.iov_base[vp.off..v.iov_len]; + } + + // After writing to the result of peek(), one can call next() to + // advance the cursor. + fn next(vp: *VecPut, len: usize) void { + vp.total += len; + vp.off += len; + if (vp.off >= vp.iovecs[vp.idx].iov_len) { + vp.off = 0; + vp.idx += 1; + } + } + + fn freeSize(vp: VecPut) usize { + var total: usize = 0; + + total += vp.iovecs[vp.idx].iov_len - vp.off; + + if (vp.idx + 1 >= vp.iovecs.len) + return total; + + for (vp.iovecs[vp.idx + 1 ..]) |v| { + total += v.iov_len; + } + + return total; + } +}; + +/// Limit iovecs to a specific byte size. +fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { + var vec_i: usize = 0; + var bytes_left: usize = len; + while (true) { + if (bytes_left >= iovecs[vec_i].iov_len) { + bytes_left -= iovecs[vec_i].iov_len; + vec_i += 1; + if (vec_i == iovecs.len or bytes_left == 0) return iovecs[0..vec_i]; + continue; + } + iovecs[vec_i].iov_len = bytes_left; + return iovecs[0..vec_i]; + } +} + /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. diff --git a/lib/std/net.zig b/lib/std/net.zig index 0112d5be8cf4..aa51176184e4 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1672,6 +1672,17 @@ pub const Stream = struct { } } + pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { + if (builtin.os.tag == .windows) { + // TODO improve this to use ReadFileScatter + if (iovecs.len == 0) return @as(usize, 0); + const first = iovecs[0]; + return os.windows.ReadFile(s.handle, first.iov_base[0..first.iov_len], null, io.default_mode); + } + + return os.readv(s.handle, iovecs); + } + /// Returns the number of bytes read. If the number read is smaller than /// `buffer.len`, it means the stream reached the end. Reaching the end of /// a stream is not an error condition.