diff --git a/lib/compiler/fmt.zig b/lib/compiler/fmt.zig index 2fc04b7935a7..7f58a133a3da 100644 --- a/lib/compiler/fmt.zig +++ b/lib/compiler/fmt.zig @@ -322,8 +322,9 @@ fn fmtPathFile( return; if (check_mode) { - const stdout = std.io.getStdOut().writer(); - try stdout.print("{s}\n", .{file_path}); + const stdout = std.io.getStdOut(); + try stdout.writeAll(file_path); + try stdout.writeAll("\n"); fmt.any_error = true; } else { var af = try dir.atomicFile(sub_path, .{ .mode = stat.mode }); diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index cbd3d427418f..b794d8022792 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -628,6 +628,11 @@ test "basic" { try testing.expectEqual(@as(?u16, null), parsed.port); } +test "subdomain" { + const parsed = try parse("http://a.b.example.com"); + try testing.expectEqualStrings("a.b.example.com", parsed.host orelse return error.UnexpectedNull); +} + test "with port" { const parsed = try parse("http://example:1337/"); try testing.expectEqualStrings("http", parsed.scheme); diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index ff2307e8124d..05779c3a8f5a 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -344,7 +344,7 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?u29) type { @compileError("The Writer interface is only defined for ArrayList(u8) " ++ "but the given type is ArrayList(" ++ @typeName(T) ++ ")") else - std.io.Writer(*Self, Allocator.Error, appendWrite); + std.io.Writer(*Self, Allocator.Error, appendWritev); /// Initializes a Writer which will append to the list. pub fn writer(self: *Self) Writer { @@ -354,9 +354,13 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?u29) type { /// Same as `append` except it returns the number of bytes written, which is always the same /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. - fn appendWrite(self: *Self, m: []const u8) Allocator.Error!usize { - try self.appendSlice(m); - return m.len; + fn appendWritev(self: *Self, iov: []const std.os.iovec_const) Allocator.Error!usize { + var written: usize = 0; + for (iov) |v| { + try self.appendSlice(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } /// Append a value to the list `n` times. @@ -930,7 +934,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ @compileError("The Writer interface is only defined for ArrayList(u8) " ++ "but the given type is ArrayList(" ++ @typeName(T) ++ ")") else - std.io.Writer(WriterContext, Allocator.Error, appendWrite); + std.io.Writer(WriterContext, Allocator.Error, appendWritev); /// Initializes a Writer which will append to the list. pub fn writer(self: *Self, allocator: Allocator) Writer { @@ -941,12 +945,16 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ /// which is always the same as `m.len`. The purpose of this function /// existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. - fn appendWrite(context: WriterContext, m: []const u8) Allocator.Error!usize { - try context.self.appendSlice(context.allocator, m); - return m.len; + fn appendWritev(context: WriterContext, iov: []const std.os.iovec_const) Allocator.Error!usize { + var written: usize = 0; + for (iov) |v| { + try context.self.appendSlice(context.allocator, v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } - pub const FixedWriter = std.io.Writer(*Self, Allocator.Error, appendWriteFixed); + pub const FixedWriter = std.io.Writer(*Self, Allocator.Error, appendWritevFixed); /// Initializes a Writer which will append to the list but will return /// `error.OutOfMemory` rather than increasing capacity. @@ -955,13 +963,18 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ } /// The purpose of this function existing is to match `std.io.Writer` API. - fn appendWriteFixed(self: *Self, m: []const u8) error{OutOfMemory}!usize { - const available_capacity = self.capacity - self.items.len; - if (m.len > available_capacity) - return error.OutOfMemory; - - self.appendSliceAssumeCapacity(m); - return m.len; + fn appendWritevFixed(self: *Self, iov: []const std.os.iovec_const) error{OutOfMemory}!usize { + var written: usize = 0; + for (iov) |v| { + const m = v.iov_base[0..v.iov_len]; + const available_capacity = self.capacity - self.items.len; + if (m.len > available_capacity) + return error.OutOfMemory; + + self.appendSliceAssumeCapacity(m); + written += m.len; + } + return written; } /// Append a value to the list `n` times. diff --git a/lib/std/bounded_array.zig b/lib/std/bounded_array.zig index 9867754dcd0a..ac8cea8576dd 100644 --- a/lib/std/bounded_array.zig +++ b/lib/std/bounded_array.zig @@ -271,7 +271,7 @@ pub fn BoundedArrayAligned( @compileError("The Writer interface is only defined for BoundedArray(u8, ...) " ++ "but the given type is BoundedArray(" ++ @typeName(T) ++ ", ...)") else - std.io.Writer(*Self, error{Overflow}, appendWrite); + std.io.Writer(*Self, error{Overflow}, appendWritev); /// Initializes a writer which will write into the array. pub fn writer(self: *Self) Writer { @@ -280,9 +280,14 @@ pub fn BoundedArrayAligned( /// Same as `appendSlice` except it returns the number of bytes written, which is always the same /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. - fn appendWrite(self: *Self, m: []const u8) error{Overflow}!usize { - try self.appendSlice(m); - return m.len; + fn appendWritev(self: *Self, iov: []const std.os.iovec_const) error{Overflow}!usize { + var written: usize = 0; + for (iov) |v| { + const m = v.iov_base[0..v.iov_len]; + try self.appendSlice(m); + written += m.len; + } + return written; } }; } diff --git a/lib/std/compress.zig b/lib/std/compress.zig index 200489c18ab5..242eb8aff9b9 100644 --- a/lib/std/compress.zig +++ b/lib/std/compress.zig @@ -19,12 +19,18 @@ pub fn HashedReader( hasher: HasherType, pub const Error = ReaderType.Error; - pub const Reader = std.io.Reader(*@This(), Error, read); + pub const Reader = std.io.Reader(*@This(), Error, readv); - pub fn read(self: *@This(), buf: []u8) Error!usize { - const amt = try self.child_reader.read(buf); - self.hasher.update(buf[0..amt]); - return amt; + pub fn readv(self: *@This(), iov: []const std.os.iovec) Error!usize { + const n_read = try self.child_reader.readv(iov); + var hashed_amt: usize = 0; + for (iov) |v| { + const to_hash = @min(n_read - hashed_amt, v.iov_len); + if (to_hash == 0) break; + self.hasher.update(v.iov_base[0..to_hash]); + hashed_amt += to_hash; + } + return n_read; } pub fn reader(self: *@This()) Reader { @@ -49,12 +55,18 @@ pub fn HashedWriter( hasher: HasherType, pub const Error = WriterType.Error; - pub const Writer = std.io.Writer(*@This(), Error, write); + pub const Writer = std.io.Writer(*@This(), Error, writev); - pub fn write(self: *@This(), buf: []const u8) Error!usize { - const amt = try self.child_writer.write(buf); - self.hasher.update(buf[0..amt]); - return amt; + pub fn writev(self: *@This(), iov: []const std.os.iovec_const) Error!usize { + const n_written = try self.child_writer.writev(iov); + var hashed_amt: usize = 0; + for (iov) |v| { + const to_hash = @min(n_written - hashed_amt, v.iov_len); + if (to_hash == 0) break; + self.hasher.update(v.iov_base[0..to_hash]); + hashed_amt += to_hash; + } + return n_written; } pub fn writer(self: *@This()) Writer { diff --git a/lib/std/compress/flate/deflate.zig b/lib/std/compress/flate/deflate.zig index 794ab02247c3..094dc847fc4c 100644 --- a/lib/std/compress/flate/deflate.zig +++ b/lib/std/compress/flate/deflate.zig @@ -354,16 +354,20 @@ fn Deflate(comptime container: Container, comptime WriterType: type, comptime Bl } // Writer interface - - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); pub const Error = BlockWriterType.Error; /// Write `input` of uncompressed data. /// See compress. - pub fn write(self: *Self, input: []const u8) !usize { - var fbs = io.fixedBufferStream(input); - try self.compress(fbs.reader()); - return input.len; + pub fn writev(self: *Self, iov: []const std.os.iovec_const) !usize { + var written: usize = 0; + for (iov) |v| { + const input = v.iov_base[0..v.iov_len]; + var fbs = io.fixedBufferStream(input); + try self.compress(fbs.reader()); + written += input.len; + } + return written; } pub fn writer(self: *Self) Writer { @@ -558,7 +562,7 @@ test "tokenization" { const cww = cw.writer(); var df = try Deflate(container, @TypeOf(cww), TestTokenWriter).init(cww, .{}); - _ = try df.write(c.data); + _ = try df.writer().write(c.data); try df.flush(); // df.token_writer.show(); @@ -579,6 +583,8 @@ const TestTokenWriter = struct { pos: usize = 0, actual: [128]Token = undefined, + pub const Error = error{}; + pub fn init(_: anytype) Self { return .{}; } diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig index cf23961b2132..15752a4c9599 100644 --- a/lib/std/compress/flate/inflate.zig +++ b/lib/std/compress/flate/inflate.zig @@ -325,7 +325,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comp /// Returns decompressed data from internal sliding window buffer. /// Returned buffer can be any length between 0 and `limit` bytes. 0 /// returned bytes means end of stream reached. With limit=0 returns as - /// much data it can. It newer will be more than 65536 bytes, which is + /// much data it can. It never will be more than 65536 bytes, which is /// size of internal buffer. pub fn get(self: *Self, limit: usize) Error![]const u8 { while (true) { @@ -340,16 +340,19 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comp } // Reader interface - - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - pub fn read(self: *Self, buffer: []u8) Error!usize { - const out = try self.get(buffer.len); - @memcpy(buffer[0..out.len], out); - return out.len; + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + var read: usize = 0; + for (iov) |v| { + const out = try self.get(v.iov_len); + @memcpy(v.iov_base[0..out.len], out); + read += out.len; + } + return read; } pub fn reader(self: *Self) Reader { diff --git a/lib/std/compress/lzma.zig b/lib/std/compress/lzma.zig index ff05bc1c8beb..34eddb99f4ff 100644 --- a/lib/std/compress/lzma.zig +++ b/lib/std/compress/lzma.zig @@ -30,7 +30,7 @@ pub fn Decompress(comptime ReaderType: type) type { Allocator.Error || error{ CorruptInput, EndOfStream, Overflow }; - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); allocator: Allocator, in_reader: ReaderType, @@ -63,23 +63,28 @@ pub fn Decompress(comptime ReaderType: type) type { self.* = undefined; } - pub fn read(self: *Self, output: []u8) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const writer = self.to_read.writer(self.allocator); - while (self.to_read.items.len < output.len) { - switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { - .continue_ => {}, - .finished => { - try self.buffer.finish(writer); - break; - }, + var n_read: usize = 0; + for (iov) |v| { + const output = v.iov_base[0..v.iov_len]; + while (self.to_read.items.len < output.len) { + switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { + .continue_ => {}, + .finished => { + try self.buffer.finish(writer); + break; + }, + } } + const input = self.to_read.items; + const n = @min(input.len, output.len); + @memcpy(output[0..n], input[0..n]); + @memcpy(input[0 .. input.len - n], input[n..]); + self.to_read.shrinkRetainingCapacity(input.len - n); + n_read += n; } - const input = self.to_read.items; - const n = @min(input.len, output.len); - @memcpy(output[0..n], input[0..n]); - @memcpy(input[0 .. input.len - n], input[n..]); - self.to_read.shrinkRetainingCapacity(input.len - n); - return n; + return n_read; } }; } diff --git a/lib/std/compress/xz.zig b/lib/std/compress/xz.zig index e844c234ffc8..a0b1c7c33993 100644 --- a/lib/std/compress/xz.zig +++ b/lib/std/compress/xz.zig @@ -34,7 +34,7 @@ pub fn Decompress(comptime ReaderType: type) type { const Self = @This(); pub const Error = ReaderType.Error || block.Decoder(ReaderType).Error; - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); allocator: Allocator, block_decoder: block.Decoder(ReaderType), @@ -71,7 +71,9 @@ pub fn Decompress(comptime ReaderType: type) type { return .{ .context = self }; } - pub fn read(self: *Self, buffer: []u8) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; if (buffer.len == 0) return 0; diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 9092a2d13083..0733250bc22d 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -50,7 +50,7 @@ pub fn Decompressor(comptime ReaderType: type) type { OutOfMemory, }; - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); pub fn init(source: ReaderType, options: DecompressorOptions) Self { return .{ @@ -105,7 +105,9 @@ pub fn Decompressor(comptime ReaderType: type) type { return .{ .context = self }; } - pub fn read(self: *Self, buffer: []u8) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; if (buffer.len == 0) return 0; var size: usize = 0; diff --git a/lib/std/compress/zstandard/readers.zig b/lib/std/compress/zstandard/readers.zig index f95573f77bbf..2235b46bf2de 100644 --- a/lib/std/compress/zstandard/readers.zig +++ b/lib/std/compress/zstandard/readers.zig @@ -4,7 +4,7 @@ pub const ReversedByteReader = struct { remaining_bytes: usize, bytes: []const u8, - const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn); + const Reader = std.io.Reader(*ReversedByteReader, error{}, readvFn); pub fn init(bytes: []const u8) ReversedByteReader { return .{ @@ -17,7 +17,10 @@ pub const ReversedByteReader = struct { return .{ .context = self }; } - fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize { + fn readvFn(ctx: *ReversedByteReader, iov: []const std.os.iovec) !usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; + std.debug.assert(buffer.len > 0); if (ctx.remaining_bytes == 0) return 0; const byte_index = ctx.remaining_bytes - 1; buffer[0] = ctx.bytes[byte_index]; diff --git a/lib/std/crypto/25519/ed25519.zig b/lib/std/crypto/25519/ed25519.zig index d7b51271d29d..025a980a8e1d 100644 --- a/lib/std/crypto/25519/ed25519.zig +++ b/lib/std/crypto/25519/ed25519.zig @@ -21,8 +21,8 @@ pub const Ed25519 = struct { /// Length (in bytes) of optional random bytes, for non-deterministic signatures. pub const noise_length = 32; - const CompressedScalar = Curve.scalar.CompressedScalar; - const Scalar = Curve.scalar.Scalar; + pub const CompressedScalar = Curve.scalar.CompressedScalar; + pub const Scalar = Curve.scalar.Scalar; /// An Ed25519 secret key. pub const SecretKey = struct { @@ -73,7 +73,7 @@ pub const Ed25519 = struct { nonce: CompressedScalar, r_bytes: [Curve.encoded_length]u8, - fn init(scalar: CompressedScalar, nonce: CompressedScalar, public_key: PublicKey) (IdentityElementError || KeyMismatchError || NonCanonicalError || WeakPublicKeyError)!Signer { + pub fn init(scalar: CompressedScalar, nonce: CompressedScalar, public_key: PublicKey) (IdentityElementError || WeakPublicKeyError)!Signer { const r = try Curve.basePoint.mul(nonce); const r_bytes = r.toBytes(); diff --git a/lib/std/crypto/25519/x25519.zig b/lib/std/crypto/25519/x25519.zig index 8bd5101b3756..6d2377e1b58f 100644 --- a/lib/std/crypto/25519/x25519.zig +++ b/lib/std/crypto/25519/x25519.zig @@ -19,15 +19,19 @@ pub const X25519 = struct { pub const public_length = 32; /// Length (in bytes) of the output of the DH function. pub const shared_length = 32; - /// Seed (for key pair creation) length in bytes. - pub const seed_length = 32; + + pub const PublicKey = [public_length]u8; + pub const SecretKey = [secret_length]u8; /// An X25519 key pair. pub const KeyPair = struct { /// Public part. - public_key: [public_length]u8, + public_key: PublicKey, /// Secret part. - secret_key: [secret_length]u8, + secret_key: SecretKey, + + /// Seed (for key pair creation) length in bytes. + pub const seed_length = 32; /// Create a new key pair using an optional seed. pub fn create(seed: ?[seed_length]u8) IdentityElementError!KeyPair { diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index c3ac3e22aa47..bf2c74da1723 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -220,10 +220,6 @@ pub const Parsed = struct { return p.slice(p.pub_key_slice); } - pub fn pubKeySigAlgo(p: Parsed) []const u8 { - return p.slice(p.pub_key_signature_algorithm_slice); - } - pub fn message(p: Parsed) []const u8 { return p.slice(p.message_slice); } @@ -385,6 +381,7 @@ test "Parsed.checkHostName" { pub const ParseError = der.Element.ParseElementError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError; +/// Parse a DER format certificate. pub fn parse(cert: Certificate) ParseError!Parsed { const cert_bytes = cert.buffer; const certificate = try der.Element.parse(cert_bytes, cert.index); @@ -704,7 +701,7 @@ fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) ParseEnu return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId; } -pub const ParseVersionError = error{ UnsupportedCertificateVersion, CertificateFieldHasInvalidLength }; +pub const ParseVersionError = error{ CertificateUnsupportedVersion, CertificateFieldHasInvalidLength }; pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionError!Version { if (@as(u8, @bitCast(version_elem.identifier)) != 0xa0) @@ -723,7 +720,7 @@ pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionEr return .v1; } - return error.UnsupportedCertificateVersion; + return error.CertificateUnsupportedVersion; } fn verifyRsa( @@ -734,11 +731,10 @@ fn verifyRsa( pub_key: []const u8, ) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pk_components = try rsa.PublicKey.parseDer(pub_key); - const exponent = pk_components.exponent; - const modulus = pk_components.modulus; - if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; - if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const public_key = rsa.PublicKey.fromDer(pub_key) catch return error.CertificateSignatureInvalid; + const modulus_len = public_key.n.bits() / 8; + if (sig.len != modulus_len) return error.CertificateSignatureInvalidLength; const hash_der = switch (Hash) { crypto.hash.Sha1 => [_]u8{ @@ -771,18 +767,17 @@ fn verifyRsa( var msg_hashed: [Hash.digest_length]u8 = undefined; Hash.hash(message, &msg_hashed, .{}); - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; - const em: [modulus_len]u8 = + switch (modulus_len) { + inline 128, 256, 512 => |mod_len| { + const ps_len = mod_len - (hash_der.len + msg_hashed.len) - 3; + const em: [mod_len]u8 = [2]u8{ 0, 1 } ++ ([1]u8{0xff} ** ps_len) ++ [1]u8{0} ++ hash_der ++ msg_hashed; - const public_key = rsa.PublicKey.fromBytes(exponent, modulus) catch return error.CertificateSignatureInvalid; - const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key) catch |err| switch (err) { + const em_dec = public_key.encrypt(mod_len, sig[0..mod_len].*) catch |err| switch (err) { error.MessageTooLong => unreachable, }; @@ -949,6 +944,7 @@ test { _ = Bundle; } +/// RFC8017 pub const rsa = struct { const max_modulus_bits = 4096; const Uint = std.crypto.ff.Uint(max_modulus_bits); @@ -964,12 +960,33 @@ pub const rsa = struct { pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type) !void { const mod_bits = public_key.n.bits(); - const em_dec = try encrypt(modulus_len, sig, public_key); + const em_dec = try public_key.encrypt(modulus_len, sig); - EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash) catch unreachable; + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } - fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void { + pub fn sign( + comptime modulus_len: usize, + msg: []const u8, + comptime Hash: type, + private_key: SecretKey, + salt: [Hash.digest_length]u8, + ) ![modulus_len]u8 { + const mod_bits = modulus_len * 8; + + var out: [modulus_len]u8 = undefined; + const em = try EMSA_PSS_ENCODE(&out, msg, mod_bits - 1, Hash, salt); + + return try private_key.decrypt(modulus_len, em[0..modulus_len].*); + } + + fn EMSA_PSS_VERIFY( + msg: []const u8, + em: []const u8, + emBit: usize, + sLen: usize, + comptime Hash: type, + ) !void { // 1. If the length of M is greater than the input limitation for // the hash function (2^61 - 1 octets for SHA-1), output // "inconsistent" and stop. @@ -1103,17 +1120,112 @@ pub const rsa = struct { return out[0..len]; } + + fn EMSA_PSS_ENCODE( + out: []u8, + msg: []const u8, + emBit: usize, + comptime Hash: type, + salt: [Hash.digest_length]u8, + ) ![]u8 { + // 1. If the length of M is greater than the input limitation for + // the hash function (2^61 - 1 octets for SHA-1), output + // "inconsistent" and stop. + // All the cryptographic hash functions in the standard library have a limit of >= 2^61 - 1. + // Even then, this check is only there for paranoia. In the context of TLS certifcates, emBit cannot exceed 4096. + if (emBit >= 1 << 61) return error.InvalidSignature; + // emLen = \c2yyeil(emBits/8) + const emLen = ((emBit - 1) / 8) + 1; + + // 2. Let mHash = Hash(M), an octet string of length hLen. + var mHash: [Hash.digest_length]u8 = undefined; + Hash.hash(msg, &mHash, .{}); + + // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. + if (emLen < Hash.digest_length + salt.len + 2) { + return error.EncodingError; + } + + // 4. Generate a random octet string salt of length sLen; if sLen = + // 0, then salt is the empty string. + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + var m_p: [8 + Hash.digest_length + salt.len]u8 = undefined; + + @memcpy(m_p[0..8], &([_]u8{0} ** 8)); + @memcpy(m_p[8 .. 8 + mHash.len], &mHash); + @memcpy(m_p[(8 + Hash.digest_length)..], &salt); + + // 6. Let H = Hash(M'), an octet string of length hLen. + var hash: [Hash.digest_length]u8 = undefined; + Hash.hash(&m_p, &hash, .{}); + + // 7. Generate an octet string PS consisting of emLen - sLen - hLen + // - 2 zero octets. The length of PS may be 0. + const ps_len = emLen - salt.len - Hash.digest_length - 2; + + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + const mgf_len = emLen - Hash.digest_length - 1; + var db_buf: [512]u8 = undefined; + var db = db_buf[0 .. emLen - Hash.digest_length - 1]; + var i: usize = 0; + while (i < ps_len) : (i += 1) { + db[i] = 0x00; + } + db[i] = 0x01; + i += 1; + @memcpy(db[i..], &salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + var mgf_out_buf: [512]u8 = undefined; + const mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length]; + const dbMask = try MGF1(Hash, mgf_out, &hash, mgf_len); + + // 10. Let maskedDB = DB \xor dbMask. + i = 0; + while (i < db.len) : (i += 1) { + db[i] = db[i] ^ dbMask[i]; + } + + // 11. Set the leftmost 8emLen - emBits bits of the leftmost octet + // in maskedDB to zero. + const zero_bits = emLen * 8 - emBit; + var mask: u8 = 0; + i = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask << 1; + mask += 1; + } + db[0] = db[0] & mask; + + // 12. Let EM = maskedDB || H || 0xbc. + i = 0; + @memcpy(out[0..db.len], db); + i += db.len; + @memcpy(out[i .. i + hash.len], &hash); + i += hash.len; + out[i] = 0xbc; + i += 1; + + // 13. Output EM. + return out[0..i]; + } }; pub const PublicKey = struct { + /// the RSA modulus, a positive integer n: Modulus, + /// public exponent e: Fe, - pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) !PublicKey { + pub fn fromBytes(mod: []const u8, exp: []const u8) !PublicKey { // Reject modulus below 512 bits. // 512-bit RSA was factored in 1999, so this limit barely means anything, // but establish some limit now to ratchet in what we can. - const _n = Modulus.fromBytes(modulus_bytes, .big) catch return error.CertificatePublicKeyInvalid; + const _n = Modulus.fromBytes(mod, .big) catch return error.CertificatePublicKeyInvalid; if (_n.bits() < 512) return error.CertificatePublicKeyInvalid; // Exponent must be odd and greater than 2. @@ -1122,44 +1234,88 @@ pub const rsa = struct { // unlikely that exponents larger than 32 bits are being used for anything // Windows commonly does. // [1] https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-rsapubkey - if (pub_bytes.len > 4) return error.CertificatePublicKeyInvalid; - const _e = Fe.fromBytes(_n, pub_bytes, .big) catch return error.CertificatePublicKeyInvalid; + if (exp.len > 4) return error.CertificatePublicKeyInvalid; + const _e = Fe.fromBytes(_n, exp, .big) catch return error.CertificatePublicKeyInvalid; if (!_e.isOdd()) return error.CertificatePublicKeyInvalid; const e_v = _e.toPrimitive(u32) catch return error.CertificatePublicKeyInvalid; if (e_v < 2) return error.CertificatePublicKeyInvalid; - return .{ - .n = _n, - .e = _e, - }; + return .{ .n = _n, .e = _e }; } - pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { - const pub_key_seq = try der.Element.parse(pub_key, 0); - if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); + // RFC8017 Appendix A.1.1 + pub fn fromDer(bytes: []const u8) !PublicKey { + const seq = try der.Element.parse(bytes, 0); + if (seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.Element.parse(bytes, seq.slice.start); if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); + const exponent_elem = try der.Element.parse(bytes, modulus_elem.slice.end); if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; // Skip over meaningless zeroes in the modulus. - const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; - const modulus_offset = for (modulus_raw, 0..) |byte, i| { - if (byte != 0) break i; - } else modulus_raw.len; - return .{ - .modulus = modulus_raw[modulus_offset..], - .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end], - }; + const modulus_raw = bytes[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = std.mem.indexOfNone(u8, modulus_raw, &[_]u8{0}) orelse modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const exponent = bytes[exponent_elem.slice.start..exponent_elem.slice.end]; + + return try fromBytes(modulus, exponent); + } + + pub fn encrypt(self: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + const m = Fe.fromBytes(self.n, &msg, .big) catch return error.MessageTooLong; + const e = self.n.powPublic(m, self.e) catch unreachable; + var res: [modulus_len]u8 = undefined; + e.toBytes(&res, .big) catch unreachable; + return res; } }; - fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) ![modulus_len]u8 { - const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; - const e = public_key.n.powPublic(m, public_key.e) catch unreachable; - var res: [modulus_len]u8 = undefined; - e.toBytes(&res, .big) catch unreachable; - return res; - } + pub const SecretKey = struct { + public: PublicKey, + /// private exponent + d: Fe, + + pub fn fromBytes(mod: []const u8, public: []const u8, private: []const u8) !SecretKey { + const _public = try PublicKey.fromBytes(mod, public); + + const _d = Fe.fromBytes(_public.n, private, .big) catch return error.CertificatePrivateKeyInvalid; + if (!_d.isOdd()) return error.CertificatePrivateKeyInvalid; + + return .{ .public = _public, .d = _d }; + } + + // RFC8017 Appendix A.1.2 + pub fn fromDer(bytes: []const u8) !SecretKey { + const seq = try der.Element.parse(bytes, 0); + if (seq.identifier.tag != .sequence) return error.PrivateKeyWrongDataType; + + // We're just interested in the first 3 fields which don't vary by version + const version_elem = try der.Element.parse(bytes, seq.slice.start); + if (version_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + + const modulus_elem = try der.Element.parse(bytes, version_elem.slice.end); + if (modulus_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + const pub_exponent_elem = try der.Element.parse(bytes, modulus_elem.slice.end); + if (pub_exponent_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + const priv_exponent_elem = try der.Element.parse(bytes, pub_exponent_elem.slice.end); + if (priv_exponent_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = bytes[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = std.mem.indexOfNone(u8, modulus_raw, &[_]u8{0}) orelse modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const pub_exponent = bytes[pub_exponent_elem.slice.start..pub_exponent_elem.slice.end]; + const priv_exponent = bytes[priv_exponent_elem.slice.start..priv_exponent_elem.slice.end]; + + return try fromBytes(modulus, pub_exponent, priv_exponent); + } + + pub fn decrypt(self: SecretKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + const m = Fe.fromBytes(self.public.n, &msg, .big) catch return error.MessageTooLong; + const e = self.public.n.pow(m, self.d) catch unreachable; + var res: [modulus_len]u8 = undefined; + try e.toBytes(&res, .big); + return res; + } + }; }; const use_vectors = @import("builtin").zig_backend != .stage2_x86_64; diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 70362470c3ef..16ac99bba9cf 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -17,7 +17,7 @@ pub const EcdsaP256Sha256 = Ecdsa(crypto.ecc.P256, crypto.hash.sha2.Sha256); pub const EcdsaP256Sha3_256 = Ecdsa(crypto.ecc.P256, crypto.hash.sha3.Sha3_256); /// ECDSA over P-384 with SHA-384. pub const EcdsaP384Sha384 = Ecdsa(crypto.ecc.P384, crypto.hash.sha2.Sha384); -/// ECDSA over P-384 with SHA3-384. +/// ECDSA over P-256 with SHA3-384. pub const EcdsaP256Sha3_384 = Ecdsa(crypto.ecc.P384, crypto.hash.sha3.Sha3_384); /// ECDSA over Secp256k1 with SHA-256. pub const EcdsaSecp256k1Sha256 = Ecdsa(crypto.ecc.Secp256k1, crypto.hash.sha2.Sha256); @@ -183,7 +183,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { secret_key: SecretKey, noise: ?[noise_length]u8, - fn init(secret_key: SecretKey, noise: ?[noise_length]u8) !Signer { + pub fn init(secret_key: SecretKey, noise: ?[noise_length]u8) Signer { return Signer{ .h = Hash.init(.{}), .secret_key = secret_key, diff --git a/lib/std/crypto/ml_kem.zig b/lib/std/crypto/ml_kem.zig index b2f38b82fe9f..43e7387cc68b 100644 --- a/lib/std/crypto/ml_kem.zig +++ b/lib/std/crypto/ml_kem.zig @@ -229,8 +229,6 @@ fn Kyber(comptime p: Params) type { pub const shared_length = common_shared_key_size; /// Length (in bytes) of a seed for deterministic encapsulation. pub const encaps_seed_length = common_encaps_seed_length; - /// Length (in bytes) of a seed for key generation. - pub const seed_length: usize = inner_seed_length + shared_length; /// Algorithm name. pub const name = p.name; @@ -377,6 +375,9 @@ fn Kyber(comptime p: Params) type { secret_key: SecretKey, public_key: PublicKey, + /// Length (in bytes) of a seed for key generation. + pub const seed_length: usize = inner_seed_length + shared_length; + /// Create a new key pair. /// If seed is null, a random seed will be generated. /// If a seed is provided, the key pair will be determinsitic. diff --git a/lib/std/crypto/pcurves/secp256k1.zig b/lib/std/crypto/pcurves/secp256k1.zig index 945abea931cd..f9cac38ebb26 100644 --- a/lib/std/crypto/pcurves/secp256k1.zig +++ b/lib/std/crypto/pcurves/secp256k1.zig @@ -167,7 +167,7 @@ pub const Secp256k1 = struct { /// Serialize a point using the uncompressed SEC-1 format. pub fn toUncompressedSec1(p: Secp256k1) [65]u8 { var out: [65]u8 = undefined; - out[0] = 4; + out[0] = 4; // uncompressed const xy = p.affineCoordinates(); out[1..33].* = xy.x.toBytes(.big); out[33..65].* = xy.y.toBytes(.big); diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig index 31884c73818a..c5f151771c4f 100644 --- a/lib/std/crypto/sha2.zig +++ b/lib/std/crypto/sha2.zig @@ -392,11 +392,15 @@ fn Sha2x32(comptime params: Sha2Params32) type { } pub const Error = error{}; - pub const Writer = std.io.Writer(*Self, Error, write); + pub const Writer = std.io.Writer(*Self, Error, writev); - fn write(self: *Self, bytes: []const u8) Error!usize { - self.update(bytes); - return bytes.len; + fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + self.update(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn writer(self: *Self) Writer { diff --git a/lib/std/crypto/siphash.zig b/lib/std/crypto/siphash.zig index 5d1ac4f87469..1272a4f98b63 100644 --- a/lib/std/crypto/siphash.zig +++ b/lib/std/crypto/siphash.zig @@ -240,11 +240,15 @@ fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize) } pub const Error = error{}; - pub const Writer = std.io.Writer(*Self, Error, write); + pub const Writer = std.io.Writer(*Self, Error, writev); - fn write(self: *Self, bytes: []const u8) Error!usize { - self.update(bytes); - return bytes.len; + fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + self.update(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn writer(self: *Self) Writer { diff --git a/lib/std/crypto/testdata/cert.der b/lib/std/crypto/testdata/cert.der new file mode 100644 index 000000000000..edce2b88201a Binary files /dev/null and b/lib/std/crypto/testdata/cert.der differ diff --git a/lib/std/crypto/testdata/cert.pem b/lib/std/crypto/testdata/cert.pem new file mode 100644 index 000000000000..45f0184c7b49 --- /dev/null +++ b/lib/std/crypto/testdata/cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDITCCAgmgAwIBAgIIFVqSrcIEj5AwDQYJKoZIhvcNAQELBQAwIjELMAkGA1UE +BhMCVVMxEzARBgNVBAoTCkV4YW1wbGUgQ0EwHhcNMTgxMDA1MDEzODE3WhcNMTkx +MDA1MDEzODE3WjArMQswCQYDVQQGEwJVUzEcMBoGA1UEAxMTZXhhbXBsZS51bGZo +ZWltLm5ldDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMSANga650dr +CJQE7Ke2kQQ/95K8Ge77fXTXqA0AHntLOkrmD+jAcfxz5wJMDbz0vdEdOWu6cEZK +E+lK+D3z4QlZVHvJVftBLaN2UhHh89x3bKpTN27KOuy+w6q3OzHVbLZSnICYvMng +KBjiC/f4oDr9FwRQns55vZ858epp7EeXLoMPtcqV3pWh5gQi1e6+UnlUoee/iob2 +Rm0NnxaVGkz3oEaSWVwTUvJUnlr7Tr/XejeVAUTkwCaHTGU+QH19IwdEAfSE/9CP +eh+gUhDR9PDVznlwKTLiyr5wH9+ta0u3EQH0S61mahETD+Lugp5NAp3JHN1nFtu5 +BhiG7cG6lCECAwEAAaNSMFAwDgYDVR0PAQH/BAQDAgWgMB0GA1UdJQQWMBQGCCsG +AQUFBwMCBggrBgEFBQcDATAfBgNVHSMEGDAWgBSJT95bzGniUs8+owDfsZe4HeHB +RjANBgkqhkiG9w0BAQsFAAOCAQEAWRZFppouN3nk9t0nGrocC/1s11WZtefDblM+ +/zZZCEMkyeelBAedOeDUKYf/4+vdCcHPHZFEVYcLVx3Rm98dJPi7mhH+gP1ZK6A5 +jN4R4mUeYYzlmPqW5Tcu7z0kiv3hdGPrv6u45NGrUCpU7ABk6S94GWYNPyfPIJ5m +f85a4uSsmcfJOBj4slEHIt/tl/MuPpNJ1MZsnqY5bXREYqBrQsbVumiOrDoBe938 +jiz8rSfLadPM3KKAQURl0640jODzSrL7nGGDcTErGRBBZBwjfxGl1lyETwQEhJk4 +cSuVntaFvFxd1kXtGZCUc0ApJty0DjRpoVlB6OLMqEu2CEY2oA== +-----END CERTIFICATE----- diff --git a/lib/std/crypto/testdata/key.der b/lib/std/crypto/testdata/key.der new file mode 100644 index 000000000000..9e4f1334d162 Binary files /dev/null and b/lib/std/crypto/testdata/key.der differ diff --git a/lib/std/crypto/testdata/key.pem b/lib/std/crypto/testdata/key.pem new file mode 100644 index 000000000000..e3f06adc5a99 --- /dev/null +++ b/lib/std/crypto/testdata/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP +6MBx/HPnAkwNvPS90R05a7pwRkoT6Ur4PfPhCVlUe8lV+0Eto3ZSEeHz3HdsqlM3 +bso67L7Dqrc7MdVstlKcgJi8yeAoGOIL9/igOv0XBFCeznm9nznx6mnsR5cugw+1 +ypXelaHmBCLV7r5SeVSh57+KhvZGbQ2fFpUaTPegRpJZXBNS8lSeWvtOv9d6N5UB +ROTAJodMZT5AfX0jB0QB9IT/0I96H6BSENH08NXOeXApMuLKvnAf361rS7cRAfRL +rWZqERMP4u6Cnk0Cnckc3WcW27kGGIbtwbqUIQIDAQABAoIBAGF7OVIdZp8Hejn0 +N3L8HvT8xtUEe9kS6ioM0lGgvX5s035Uo4/T6LhUx0VcdXRH9eLHnLTUyN4V4cra +ZkxVsE3zAvZl60G6E+oDyLMWZOP6Wu4kWlub9597A5atT7BpMIVCdmFVZFLB4SJ3 +AXkC3nplFAYP+Lh1rJxRIrIn2g+pEeBboWbYA++oDNuMQffDZaokTkJ8Bn1JZYh0 +xEXKY8Bi2Egd5NMeZa1UFO6y8tUbZfwgVs6Enq5uOgtfayq79vZwyjj1kd29MBUD +8g8byV053ZKxbUOiOuUts97eb+fN3DIDRTcT2c+lXt/4C54M1FclJAbtYRK/qwsl +pYWKQAECgYEA4ZUbqQnTo1ICvj81ifGrz+H4LKQqe92Hbf/W51D/Umk2kP702W22 +HP4CvrJRtALThJIG9m2TwUjl/WAuZIBrhSAbIvc3Fcoa2HjdRp+sO5U1ueDq7d/S +Z+PxRI8cbLbRpEdIaoR46qr/2uWZ943PHMv9h4VHPYn1w8b94hwD6vkCgYEA3v87 +mFLzyM9ercnEv9zHMRlMZFQhlcUGQZvfb8BuJYl/WogyT6vRrUuM0QXULNEPlrin +mBQTqc1nCYbgkFFsD2VVt1qIyiAJsB9MD1LNV6YuvE7T2KOSadmsA4fa9PUqbr71 +hf3lTTq+LeR09LebO7WgSGYY+5YKVOEGpYMR1GkCgYEAxPVQmk3HKHEhjgRYdaG5 +lp9A9ZE8uruYVJWtiHgzBTxx9TV2iST+fd/We7PsHFTfY3+wbpcMDBXfIVRKDVwH +BMwchXH9+Ztlxx34bYJaegd0SmA0Hw9ugWEHNgoSEmWpM1s9wir5/ELjc7dGsFtz +uzvsl9fpdLSxDYgAAdzeGtkCgYBAzKIgrVox7DBzB8KojhtD5ToRnXD0+H/M6OKQ +srZPKhlb0V/tTtxrIx0UUEFLlKSXA6mPw6XDHfDnD86JoV9pSeUSlrhRI+Ysy6tq +eIE7CwthpPZiaYXORHZ7wCqcK/HcpJjsCs9rFbrV0yE5S3FMdIbTAvgXg44VBB7O +UbwIoQKBgDuY8gSrA5/A747wjjmsdRWK4DMTMEV4eCW1BEP7Tg7Cxd5n3xPJiYhr +nhLGN+mMnVIcv2zEMS0/eNZr1j/0BtEdx+3IC6Eq+ONY0anZ4Irt57/5QeKgKn/L +JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/ +-----END RSA PRIVATE KEY----- diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 7fff68471caa..ca14eaf97db8 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -1,59 +1,23 @@ -//! Plaintext: -//! * type: ContentType -//! * legacy_record_version: u16 = 0x0303, -//! * length: u16, -//! - The length (in bytes) of the following TLSPlaintext.fragment. The -//! length MUST NOT exceed 2^14 bytes. -//! * fragment: opaque -//! - the data being transmitted -//! -//! Ciphertext -//! * ContentType opaque_type = application_data; /* 23 */ -//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ -//! * uint16 length; -//! * opaque encrypted_record[TLSCiphertext.length]; -//! -//! Handshake: -//! * type: HandshakeType -//! * length: u24 -//! * data: opaque -//! -//! ServerHello: -//! * ProtocolVersion legacy_version = 0x0303; -//! * Random random; -//! * opaque legacy_session_id_echo<0..32>; -//! * CipherSuite cipher_suite; -//! * uint8 legacy_compression_method = 0; -//! * Extension extensions<6..2^16-1>; -//! -//! Extension: -//! * ExtensionType extension_type; -//! * opaque extension_data<0..2^16-1>; - const std = @import("../std.zig"); +const builtin = @import("builtin"); +pub const Client = @import("tls/Client.zig"); +pub const Server = @import("tls/Server.zig"); +pub const Stream = @import("tls/Stream.zig"); + const Tls = @This(); const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; +const testing = std.testing; +const native_endian = builtin.cpu.arch.endian(); +pub const ServerOptions = Server.Options; +pub const ClientOptions = Client.Options; +const Allocator = std.mem.Allocator; -pub const Client = @import("tls/Client.zig"); - -pub const record_header_len = 5; -pub const max_cipertext_inner_record_len = 1 << 14; -pub const max_ciphertext_len = max_cipertext_inner_record_len + 256; -pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len; -pub const hello_retry_request_sequence = [32]u8{ - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -}; - -pub const close_notify_alert = [_]u8{ - @intFromEnum(AlertLevel.warning), - @intFromEnum(AlertDescription.close_notify), -}; - -pub const ProtocolVersion = enum(u16) { +pub const Version = enum(u16) { + tls_1_0 = 0x0301, + tls_1_1 = 0x0302, tls_1_2 = 0x0303, tls_1_3 = 0x0304, _, @@ -61,28 +25,194 @@ pub const ProtocolVersion = enum(u16) { pub const ContentType = enum(u8) { invalid = 0, - change_cipher_spec = 20, - alert = 21, - handshake = 22, - application_data = 23, + change_cipher_spec = 0x14, + alert = 0x15, + handshake = 0x16, + application_data = 0x17, + heartbeat = 0x18, _, }; +pub const Plaintext = struct { + type: ContentType, + version: Version = .tls_1_0, + len: u16, + + pub const size = @sizeOf(ContentType) + @sizeOf(Version) + @sizeOf(u16); + pub const max_length = 1 << 14; + + const Self = @This(); + + pub fn init(bytes: [size]u8) Self { + var stream = std.io.fixedBufferStream(&bytes); + var reader = stream.reader(); + const ty = reader.readInt(u8, .big) catch unreachable; + const version = reader.readInt(u16, .big) catch unreachable; + const len = reader.readInt(u16, .big) catch unreachable; + return .{ .type = @enumFromInt(ty), .version = @enumFromInt(version), .len = len }; + } +}; + pub const HandshakeType = enum(u8) { + /// Deprecated. + hello_request = 0, client_hello = 1, server_hello = 2, + /// Deprecated. + hello_verify_request = 3, new_session_ticket = 4, end_of_early_data = 5, + /// Deprecated. + hello_retry_request = 6, encrypted_extensions = 8, certificate = 11, + /// Deprecated. + server_key_exchange = 12, certificate_request = 13, + /// Deprecated. + server_hello_done = 14, certificate_verify = 15, + /// Deprecated. + client_key_exchange = 16, finished = 20, + /// Deprecated. + certificate_url = 21, + /// Deprecated. + certificate_status = 22, + /// Deprecated. + supplemental_data = 23, key_update = 24, message_hash = 254, _, }; +pub const Handshake = union(HandshakeType) { + hello_request: void, + client_hello: ClientHello, + server_hello: ServerHello, + /// Deprecated. + hello_verify_request: void, + new_session_ticket: void, + end_of_early_data: void, + /// Deprecated. + hello_retry_request: void, + encrypted_extensions: []const Extension, + certificate: Certificate, + /// Deprecated. + server_key_exchange: void, + certificate_request: void, + /// Deprecated. + server_hello_done: void, + certificate_verify: CertificateVerify, + /// Deprecated. + client_key_exchange: void, + finished: []const u8, + /// Deprecated. + certificate_url: void, + /// Deprecated. + certificate_status: void, + /// Deprecated. + supplemental_data: void, + key_update: KeyUpdate, + message_hash: void, + + // If `HandshakeCipherT.encode` accepts iovecs for the message this can be moved + // to `Stream.writeFragment` and this type can be deleted. + pub fn write(self: @This(), stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(HandshakeType, self); + switch (self) { + .finished => |verification| { + res += try stream.writeArray(u24, u8, verification); + }, + inline else => |value| { + var len: usize = 0; + const T = @TypeOf(value); + switch (@typeInfo(T)) { + .Void => { + res += try stream.write(u24, @intCast(len)); + }, + .Pointer => |info| { + len += stream.arrayLength(u16, info.child, value); + res += try stream.write(u24, @intCast(len)); + res += try stream.writeArray(u16, info.child, value); + }, + .Struct => { + len += stream.length(T, value); + res += try stream.write(u24, @intCast(len)); + res += try stream.write(T, value); + }, + .Enum => |info| { + len += @bitSizeOf(info.tag_type) / 8; + res += try stream.write(u24, @intCast(len)); + res += try stream.write(T, value); + }, + else => |t| @compileError("implement writing " ++ @tagName(t)), + } + }, + } + return res; + } + + pub const Header = struct { + type: HandshakeType, + len: u24, + }; +}; + +pub const KeyUpdate = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + +/// A DER encoded certificate chain with the first entry being for this domain. +pub const Certificate = struct { + context: []const u8 = "", + entries: []const Entry = &.{}, + + pub const max_context_len = 255; + + pub const Entry = struct { + /// DER encoded + data: []const u8, + extensions: []const Extension = &.{}, + + pub const max_data_len = 1 << 24 - 1; + + pub fn write(self: @This(), stream: *Stream) !usize { + var res: usize = 0; + res += try stream.writeArray(u24, u8, self.data); + res += try stream.writeArray(u16, Extension, self.extensions); + return res; + } + }; + + const Self = @This(); + + pub fn write(self: Self, stream: *Stream) !usize { + var res: usize = 0; + res += try stream.writeArray(u8, u8, self.context); + res += try stream.writeArray(u24, Entry, self.entries); + return res; + } +}; + +pub const CertificateVerify = struct { + algorithm: SignatureScheme, + signature: []const u8, + + pub const max_signature_length = 1 << 16 - 1; + + pub fn write(self: @This(), stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(SignatureScheme, self.algorithm); + res += try stream.writeArray(u16, u8, self.signature); + return res; + } +}; + +// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml pub const ExtensionType = enum(u16) { /// RFC 6066 server_name = 0, @@ -90,8 +220,10 @@ pub const ExtensionType = enum(u16) { max_fragment_length = 1, /// RFC 6066 status_request = 5, - /// RFC 8422, 7919 + /// RFC 8422, 7919. renamed from "elliptic_curves" supported_groups = 10, + /// RFC 8422 S5.1.2 + ec_point_formats = 11, /// RFC 8446 signature_algorithms = 13, /// RFC 5764 @@ -108,6 +240,12 @@ pub const ExtensionType = enum(u16) { server_certificate_type = 20, /// RFC 7685 padding = 21, + /// RFC7366 + encrypt_then_mac = 22, + /// RFC 7627 + extended_master_secret = 23, + /// RFC 5077 + session_ticket = 35, /// RFC 8446 pre_shared_key = 41, /// RFC 8446 @@ -128,130 +266,233 @@ pub const ExtensionType = enum(u16) { signature_algorithms_cert = 50, /// RFC 8446 key_share = 51, - + /// Reserved for private use. + none = 65280, _, }; -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, +/// Matching error set for Alert.Description. +pub const Error = error{ + TlsUnexpectedMessage, + TlsBadRecordMac, + TlsRecordOverflow, + TlsHandshakeFailure, + TlsBadCertificate, + TlsUnsupportedCertificate, + TlsCertificateRevoked, + TlsCertificateExpired, + TlsCertificateUnknown, + TlsIllegalParameter, + TlsUnknownCa, + TlsAccessDenied, + TlsDecodeError, + TlsDecryptError, + TlsProtocolVersion, + TlsInsufficientSecurity, + TlsInternalError, + TlsInappropriateFallback, + TlsMissingExtension, + TlsUnsupportedExtension, + TlsUnrecognizedName, + TlsBadCertificateStatusResponse, + TlsUnknownPskIdentity, + TlsCertificateRequired, + TlsNoApplicationProtocol, + TlsUnknown, }; -pub const AlertDescription = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, +pub const Alert = struct { + /// > In TLS 1.3, the severity is implicit in the type of alert being sent + /// > and the "level" field can safely be ignored. + level: Level, + description: Description, + + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, }; + pub const Description = enum(u8) { + /// Stream is closing. + close_notify = 0, + /// An inappropriate message (e.g., the wrong + /// handshake message, premature Application Data, etc.) was received. + /// This alert should never be observed in communication between + /// proper implementations. + unexpected_message = 10, + /// This alert is returned if a record is received which + /// cannot be deprotected. Because AEAD algorithms combine decryption + /// and verification, and also to avoid side-channel attacks, this + /// alert is used for all deprotection failures. This alert should + /// never be observed in communication between proper implementations, + /// except when messages were corrupted in the network. + bad_record_mac = 20, + /// A TLSCiphertext record was received that had a + /// length more than 2^14 + 256 bytes, or a record decrypted to a + /// TLSPlaintext record with more than 2^14 bytes (or some other + /// negotiated limit). This alert should never be observed in + /// communication between proper implementations, except when messages + /// were corrupted in the network. + record_overflow = 22, + /// Receipt of a "handshake_failure" alert message + /// indicates that the sender was unable to negotiate an acceptable + /// set of security parameters given the options available. + handshake_failure = 40, + /// A certificate was corrupt, contained signatures + /// that did not verify correctly, etc. + bad_certificate = 42, + /// A certificate was of an unsupported type. + unsupported_certificate = 43, + /// A certificate was revoked by its signer. + certificate_revoked = 44, + /// A certificate has expired or is not currently valid. + certificate_expired = 45, + /// Some other (unspecified) issue arose in processing the certificate, + /// rendering it unacceptable. + certificate_unknown = 46, + /// A field in the handshake was incorrect or + /// inconsistent with other fields. This alert is used for errors + /// which conform to the formal protocol syntax but are otherwise + /// incorrect. + illegal_parameter = 47, + /// A valid certificate chain or partial chain was received, + /// but the certificate was not accepted because the CA certificate + /// could not be located or could not be matched with a known trust + /// anchor. + unknown_ca = 48, + /// A valid certificate or PSK was received, but when + /// access control was applied, the sender decided not to proceed with + /// negotiation. + access_denied = 49, + /// A message could not be decoded because some field was + /// out of the specified range or the length of the message was + /// incorrect. This alert is used for errors where the message does + /// not conform to the formal protocol syntax. This alert should + /// never be observed in communication between proper implementations, + /// except when messages were corrupted in the network. + decode_error = 50, + /// A handshake (not record layer) cryptographic + /// operation failed, including being unable to correctly verify a + /// signature or validate a Finished message or a PSK binder. + decrypt_error = 51, + /// The protocol version the peer has attempted to + /// negotiate is recognized but not supported (see Appendix D). + protocol_version = 70, + /// Returned instead of "handshake_failure" when + /// a negotiation has failed specifically because the server requires + /// parameters more secure than those supported by the client. + insufficient_security = 71, + /// An internal error unrelated to the peer or the + /// correctness of the protocol (such as a memory allocation failure) + /// makes it impossible to continue. + internal_error = 80, + /// Sent by a server in response to an invalid + /// connection retry attempt from a client (see [RFC7507]). + inappropriate_fallback = 86, + /// User cancelled handshake. + user_canceled = 90, + /// Sent by endpoints that receive a handshake + /// message not containing an extension that is mandatory to send for + /// the offered TLS version or other negotiated parameters. + missing_extension = 109, + /// Sent by endpoints receiving any handshake + /// message containing an extension known to be prohibited for + /// inclusion in the given handshake message, or including any + /// extensions in a ServerHello or Certificate not first offered in + /// the corresponding ClientHello or CertificateRequest. + unsupported_extension = 110, + /// Sent by servers when no server exists identified + /// by the name provided by the client via the "server_name" extension + /// (see [RFC6066]). + unrecognized_name = 112, + /// Sent by clients when an invalid or + /// unacceptable OCSP response is provided by the server via the + /// "status_request" extension (see [RFC6066]). + bad_certificate_status_response = 113, + /// Sent by servers when PSK key establishment is + /// desired but no acceptable PSK identity is provided by the client. + /// Sending this alert is OPTIONAL; servers MAY instead choose to send + /// a "decrypt_error" alert to merely indicate an invalid PSK + /// identity. + unknown_psk_identity = 115, + /// Sent by servers when a client certificate is + /// desired but none was provided by the client. + certificate_required = 116, + /// Sent by servers when a client + /// "application_layer_protocol_negotiation" extension advertises only + /// protocols that the server does not support (see [RFC7301]). + no_application_protocol = 120, + _, - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, + pub fn toError(alert: @This()) Error { + return switch (alert) { + .close_notify, .user_canceled => unreachable, // not an error + .unexpected_message => Error.TlsUnexpectedMessage, + .bad_record_mac => Error.TlsBadRecordMac, + .record_overflow => Error.TlsRecordOverflow, + .handshake_failure => Error.TlsHandshakeFailure, + .bad_certificate => Error.TlsBadCertificate, + .unsupported_certificate => Error.TlsUnsupportedCertificate, + .certificate_revoked => Error.TlsCertificateRevoked, + .certificate_expired => Error.TlsCertificateExpired, + .certificate_unknown => Error.TlsCertificateUnknown, + .illegal_parameter => Error.TlsIllegalParameter, + .unknown_ca => Error.TlsUnknownCa, + .access_denied => Error.TlsAccessDenied, + .decode_error => Error.TlsDecodeError, + .decrypt_error => Error.TlsDecryptError, + .protocol_version => Error.TlsProtocolVersion, + .insufficient_security => Error.TlsInsufficientSecurity, + .internal_error => Error.TlsInternalError, + .inappropriate_fallback => Error.TlsInappropriateFallback, + .missing_extension => Error.TlsMissingExtension, + .unsupported_extension => Error.TlsUnsupportedExtension, + .unrecognized_name => Error.TlsUnrecognizedName, + .bad_certificate_status_response => Error.TlsBadCertificateStatusResponse, + .unknown_psk_identity => Error.TlsUnknownPskIdentity, + .certificate_required => Error.TlsCertificateRequired, + .no_application_protocol => Error.TlsNoApplicationProtocol, + _ => Error.TlsUnknown, + }; + } + }; - pub fn toError(alert: AlertDescription) Error!void { - return switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, - }; + const Self = @This(); + + pub fn read(stream: *Stream) Self { + const level = try stream.read(Level); + const description = try stream.read(Description); + return .{ .level = level, .description = description }; + } + + pub fn write(self: Self, stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(Level, self.level); + res += try stream.write(Description, self.description); + return res; } }; +/// Scheme for certificate verification +/// +/// Note: This enum is named `SignatureScheme` because there is already a +/// `SignatureAlgorithm` type in TLS 1.2, which this replaces. pub const SignatureScheme = enum(u16) { - // RSASSA-PKCS1-v1_5 algorithms rsa_pkcs1_sha256 = 0x0401, rsa_pkcs1_sha384 = 0x0501, rsa_pkcs1_sha512 = 0x0601, - // ECDSA algorithms ecdsa_secp256r1_sha256 = 0x0403, ecdsa_secp384r1_sha384 = 0x0503, ecdsa_secp521r1_sha512 = 0x0603, - // RSASSA-PSS algorithms with public key OID rsaEncryption rsa_pss_rsae_sha256 = 0x0804, rsa_pss_rsae_sha384 = 0x0805, rsa_pss_rsae_sha512 = 0x0806, - // EdDSA algorithms ed25519 = 0x0807, ed448 = 0x0808, - // RSASSA-PSS algorithms with public key OID RSASSA-PSS rsa_pss_pss_sha256 = 0x0809, rsa_pss_pss_sha384 = 0x080a, rsa_pss_pss_sha512 = 0x080b, @@ -261,104 +502,678 @@ pub const SignatureScheme = enum(u16) { ecdsa_sha1 = 0x0203, _, + + pub fn Ecdsa(comptime self: @This()) type { + return switch (self) { + .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, + else => @compileError("bad scheme"), + }; + } + + pub fn Hash(comptime self: @This()) type { + return switch (self) { + .ecdsa_secp256r1_sha256, .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, + .ecdsa_secp384r1_sha384, .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, + .ecdsa_secp521r1_sha512, .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, + else => @compileError("bad scheme"), + }; + } + + pub fn Eddsa(comptime self: @This()) type { + return switch (self) { + .ed25519 => crypto.sign.Ed25519, + else => @compileError("bad scheme"), + }; + } +}; +pub const supported_signature_schemes = [_]SignatureScheme{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .ed25519, }; +/// Key exchange formats +/// +/// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml pub const NamedGroup = enum(u16) { + // Use reserved value for invalid. + invalid = 0x0000, // Elliptic Curve Groups (ECDHE) secp256r1 = 0x0017, secp384r1 = 0x0018, secp521r1 = 0x0019, x25519 = 0x001D, - x448 = 0x001E, - - // Finite Field Groups (DHE) - ffdhe2048 = 0x0100, - ffdhe3072 = 0x0101, - ffdhe4096 = 0x0102, - ffdhe6144 = 0x0103, - ffdhe8192 = 0x0104, - // Hybrid post-quantum key agreements - x25519_kyber512d00 = 0xFE30, + // Hybrid post-quantum key agreements. Still in draft. x25519_kyber768d00 = 0x6399, _, }; +pub fn NamedGroupT(comptime named_group: NamedGroup) type { + return switch (named_group) { + .secp256r1 => crypto.sign.ecdsa.EcdsaP256Sha256, + .secp384r1 => crypto.sign.ecdsa.EcdsaP384Sha384, + .x25519 => crypto.dh.X25519, + else => |t| @compileError("unsupported named group " ++ @tagName(t)), + }; +} +pub const KeyPair = union(NamedGroup) { + invalid: void, + secp256r1: NamedGroupT(.secp256r1).KeyPair, + secp384r1: NamedGroupT(.secp384r1).KeyPair, + secp521r1: void, + x25519: NamedGroupT(.x25519).KeyPair, + x25519_kyber768d00: void, + + pub fn toKeyShare(self: @This()) KeyShare { + return switch (self) { + .secp256r1 => |k| .{ .secp256r1 = k.public_key }, + .secp384r1 => |k| .{ .secp384r1 = k.public_key }, + .x25519 => |k| .{ .x25519 = k.public_key }, + inline else => |_, t| @unionInit(KeyShare, @tagName(t), {}), + }; + } +}; +/// The public portion of a KeyPair. +pub const KeyShare = union(NamedGroup) { + invalid: void, + secp256r1: NamedGroupT(.secp256r1).PublicKey, + secp384r1: NamedGroupT(.secp384r1).PublicKey, + secp521r1: void, + x25519: NamedGroupT(.x25519).PublicKey, + x25519_kyber768d00: void, + + const Self = @This(); + + pub fn read(stream: *Stream) !Self { + std.debug.assert(!stream.is_client); + + var reader = stream.stream().reader(); + const group = try stream.read(NamedGroup); + const len = try stream.read(u16); + switch (group) { + inline .secp256r1, .secp384r1 => |k| { + const T = NamedGroupT(k).PublicKey; + var buf: [T.uncompressed_sec1_encoded_length]u8 = undefined; + try reader.readNoEof(&buf); + const val = T.fromSec1(&buf) catch return Error.TlsDecryptError; + return @unionInit(Self, @tagName(k), val); + }, + .x25519 => { + var res = Self{ .x25519 = undefined }; + try reader.readNoEof(&res.x25519); + return res; + }, + else => { + try reader.skipBytes(len, .{}); + }, + } + return .{ .invalid = {} }; + } + + pub fn write(self: Self, stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(NamedGroup, self); + const public = switch (self) { + .secp256r1 => |k| &k.toUncompressedSec1(), + .secp384r1 => |k| &k.toUncompressedSec1(), + .x25519 => |k| &k, + else => "", + }; + res += try stream.writeArray(u16, u8, public); + return res; + } +}; +/// In descending order of preference +pub const supported_groups = [_]NamedGroup{ + .secp256r1, + .secp384r1, + .x25519, +}; pub const CipherSuite = enum(u16) { - AES_128_GCM_SHA256 = 0x1301, - AES_256_GCM_SHA384 = 0x1302, - CHACHA20_POLY1305_SHA256 = 0x1303, - AES_128_CCM_SHA256 = 0x1304, - AES_128_CCM_8_SHA256 = 0x1305, - AEGIS_256_SHA512 = 0x1306, - AEGIS_128L_SHA256 = 0x1307, + aes_128_gcm_sha256 = 0x1301, + aes_256_gcm_sha384 = 0x1302, + chacha20_poly1305_sha256 = 0x1303, + aegis_256_sha512 = 0x1306, + aegis_128l_sha256 = 0x1307, _, + + pub fn Hash(comptime self: @This()) type { + return switch (self) { + .aes_128_gcm_sha256 => crypto.hash.sha2.Sha256, + .aes_256_gcm_sha384 => crypto.hash.sha2.Sha384, + .chacha20_poly1305_sha256 => crypto.hash.sha2.Sha256, + .aegis_256_sha512 => crypto.hash.sha2.Sha512, + .aegis_128l_sha256 => crypto.hash.sha2.Sha256, + else => @compileError("unknown suite " ++ @tagName(self)), + }; + } + + pub fn Aead(comptime self: @This()) type { + return switch (self) { + .aes_128_gcm_sha256 => crypto.aead.aes_gcm.Aes128Gcm, + .aes_256_gcm_sha384 => crypto.aead.aes_gcm.Aes256Gcm, + .chacha20_poly1305_sha256 => crypto.aead.chacha_poly.ChaCha20Poly1305, + .aegis_256_sha512 => crypto.aead.aegis.Aegis256, + .aegis_128l_sha256 => crypto.aead.aegis.Aegis128L, + else => @compileError("unknown suite " ++ @tagName(self)), + }; + } }; -pub const CertificateType = enum(u8) { - X509 = 0, - RawPublicKey = 2, +pub const HandshakeCipher = union(CipherSuite) { + aes_128_gcm_sha256: HandshakeCipherT(.aes_128_gcm_sha256), + aes_256_gcm_sha384: HandshakeCipherT(.aes_256_gcm_sha384), + chacha20_poly1305_sha256: HandshakeCipherT(.chacha20_poly1305_sha256), + aegis_256_sha512: HandshakeCipherT(.aegis_256_sha512), + aegis_128l_sha256: HandshakeCipherT(.aegis_128l_sha256), + + const Self = @This(); + + pub fn init( + suite: CipherSuite, + shared_key: []const u8, + hello_hash: []const u8, + logger: KeyLogger, + ) Error!Self { + switch (suite) { + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |tag| { + const T = std.meta.TagPayloadByName(Self, @tagName(tag)); + const cipher = T.init(shared_key, hello_hash, logger); + return @unionInit(Self, @tagName(tag), cipher); + }, + _ => return Error.TlsIllegalParameter, + } + } +}; + +pub const ApplicationCipher = union(CipherSuite) { + aes_128_gcm_sha256: ApplicationCipherT(.aes_128_gcm_sha256), + aes_256_gcm_sha384: ApplicationCipherT(.aes_256_gcm_sha384), + chacha20_poly1305_sha256: ApplicationCipherT(.chacha20_poly1305_sha256), + aegis_256_sha512: ApplicationCipherT(.aegis_256_sha512), + aegis_128l_sha256: ApplicationCipherT(.aegis_128l_sha256), + + const Self = @This(); + + pub fn init( + handshake_cipher: HandshakeCipher, + handshake_hash: []const u8, + logger: KeyLogger, + ) Self { + switch (handshake_cipher) { + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |c, tag| { + const T = std.meta.TagPayloadByName(Self, @tagName(tag)); + const cipher = T.init(c.handshake_secret, handshake_hash, logger); + return @unionInit(Self, @tagName(tag), cipher); + }, + } + } + + pub fn print(self: Self) void { + switch (self) { + inline else => |v| v.print(), + } + } +}; + +/// RFC 8446 S4.1.2 +pub const ClientHello = struct { + /// Legacy field for TLS 1.2 middleboxes + version: Version = .tls_1_2, + random: [32]u8, + /// Legacy session resumption. Max len 32. + session_id: []const u8, + /// In descending order of preference + cipher_suites: []const CipherSuite, + // Legacy and unsecure, requires at least 1 for compat, MUST error if anything else + compression_methods: [1]u8 = .{0}, + // Certain extensions are mandatory for TLS 1.3 + extensions: []const Extension, + + pub const session_id_max_len = 32; + + const Self = @This(); + + pub fn write(self: Self, stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(Version, self.version); + res += try stream.writeAll(&self.random); + res += try stream.writeArray(u8, u8, self.session_id); + res += try stream.writeArray(u16, CipherSuite, self.cipher_suites); + res += try stream.writeArray(u8, u8, &self.compression_methods); + res += try stream.writeArray(u16, Extension, self.extensions); + return res; + } +}; + +pub const ServerHello = struct { + /// Legacy field for TLS 1.2 middleboxes + version: Version = .tls_1_2, + /// Should be an echo of the sent `client_random`. + random: [32]u8, + /// Legacy session resumption + session_id: []const u8, + cipher_suite: CipherSuite, + compression_method: u8 = 0, + /// Certain extensions are mandatory for TLS 1.3 + extensions: []const Extension, + + /// When `random` equals this it means the client should resend the `ClientHello`. + pub const hello_retry_request = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, + }; + + const Self = @This(); + + pub fn write(self: Self, stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(Version, self.version); + res += try stream.writeAll(&self.random); + res += try stream.writeArray(u8, u8, self.session_id); + res += try stream.write(CipherSuite, self.cipher_suite); + res += try stream.write(u8, self.compression_method); + res += try stream.writeArray(u16, Extension, self.extensions); + return res; + } +}; + +pub const EncryptedExtensions = struct { + extensions: []const Extension, + + const Self = @This(); + + pub fn write(self: Self, stream: *Stream) !usize { + return try stream.writeArray(u16, Extension, self.extensions); + } +}; + +pub const Extension = union(ExtensionType) { + // MUST NOT contain more than one name of the same name_type + server_name: []const ServerName, + max_fragment_length: void, + status_request: void, + supported_groups: []const NamedGroup, + ec_point_formats: []const EcPointFormat, + /// For signature_verify messages + signature_algorithms: []const SignatureScheme, + use_srtp: void, + /// https://en.wikipedia.org/wiki/Heartbleed + heartbeat: void, + application_layer_protocol_negotiation: void, + signed_certificate_timestamp: void, + client_certificate_type: void, + server_certificate_type: void, + padding: void, + encrypt_then_mac: void, + extended_master_secret: void, + session_ticket: void, + pre_shared_key: void, + early_data: void, + supported_versions: []const Version, + cookie: void, + psk_key_exchange_modes: []const PskKeyExchangeMode, + certificate_authorities: void, + oid_filters: void, + post_handshake_auth: void, + /// For certificate signatures. + /// > Implementations which have the same policy in both cases MAY omit the + /// > "signature_algorithms_cert" extension. + signature_algorithms_cert: void, + key_share: []const KeyShare, + none: void, + + const Self = @This(); + + pub fn write(self: Self, stream: *Stream) !usize { + const PrefixLen = enum { zero, one, two }; + const prefix_len: PrefixLen = if (stream.is_client) switch (self) { + .supported_versions, .ec_point_formats, .psk_key_exchange_modes => .one, + .server_name, .supported_groups, .signature_algorithms, .key_share => .two, + else => .zero, + } else .zero; + + var res: usize = 0; + res += try stream.write(ExtensionType, self); + + switch (self) { + inline else => |items| { + const T = @TypeOf(items); + switch (@typeInfo(T)) { + .Void => { + res += try stream.write(u16, 0); + }, + .Pointer => |info| { + switch (prefix_len) { + inline else => |t| { + const PrefixT = switch (t) { + .zero => void, + .one => u8, + .two => u16, + }; + const len = stream.arrayLength(PrefixT, info.child, items); + res += try stream.write(u16, @intCast(len)); + res += try stream.writeArray(PrefixT, info.child, items); + }, + } + }, + else => |t| @compileError("unsupported type " ++ @typeName(T) ++ " for member " ++ @tagName(t)), + } + }, + } + return res; + } + + pub const Header = struct { + type: ExtensionType, + len: u16, + + pub fn read(stream: *Stream) @TypeOf(stream.*).ReadError!@This() { + const ty = try stream.read(ExtensionType); + const length = try stream.read(u16); + return .{ .type = ty, .len = length }; + } + }; +}; + +/// RFC 8446 S4.2.9 +pub const PskKeyExchangeMode = enum(u8) { + /// PSK-only key establishment. In this mode, the server + /// MUST NOT supply a "key_share" value. + ke = 1, + /// PSK with (EC)DHE key establishment. In this mode, the + /// client and server MUST supply "key_share" values as described in + /// Section 4.2.8. + dhe_ke = 2, _, }; -pub const KeyUpdateRequest = enum(u8) { - update_not_requested = 0, - update_requested = 1, +/// RFC 8446 S4.1.3 +pub const ServerName = struct { + type: NameType = .host_name, + host_name: []const u8, + + pub const NameType = enum(u8) { host_name = 0, _ }; + + pub fn write(self: @This(), stream: *Stream) !usize { + var res: usize = 0; + res += try stream.write(NameType, self.type); + res += try stream.writeArray(u16, u8, self.host_name); + return res; + } +}; + +pub const EcPointFormat = enum(u8) { + uncompressed = 0, + ansiX962_compressed_prime = 1, + ansiX962_compressed_char2 = 2, _, }; -pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { +/// RFC 5246 S7.1 +pub const ChangeCipherSpec = enum(u8) { change_cipher_spec = 1, _ }; + +/// One of these potential hashes will be selected after receiving the other party's hello. +/// +/// We init them before sending any messages to avoid having to store our first message until the +/// other party's handshake message returns. This message is usually larger than +/// `@sizeOf(MultiHash)` = 560 +/// +/// A nice benefit is decreased latency on hosts where one round trip takes longer than calling +/// `update` with `active == .all`. +pub const MultiHash = struct { + sha256: sha2.Sha256 = sha2.Sha256.init(.{}), + sha384: sha2.Sha384 = sha2.Sha384.init(.{}), + sha512: sha2.Sha512 = sha2.Sha512.init(.{}), + /// Chosen during handshake. + active: enum { all, sha256, sha384, sha512, none } = .all, + + const sha2 = crypto.hash.sha2; + pub const max_digest_len = sha2.Sha512.digest_length; + const Self = @This(); + + pub fn update(self: *Self, bytes: []const u8) void { + switch (self.active) { + .all => { + self.sha256.update(bytes); + self.sha384.update(bytes); + self.sha512.update(bytes); + }, + .sha256 => self.sha256.update(bytes), + .sha384 => self.sha384.update(bytes), + .sha512 => self.sha512.update(bytes), + .none => {}, + } + } + + pub fn setActive(self: *Self, cipher_suite: CipherSuite) void { + self.active = switch (cipher_suite) { + .aes_128_gcm_sha256, .chacha20_poly1305_sha256, .aegis_128l_sha256 => .sha256, + .aes_256_gcm_sha384 => .sha384, + .aegis_256_sha512 => .sha512, + _ => .all, + }; + } + + pub inline fn peek(self: Self) []const u8 { + return &switch (self.active) { + .all, .none => [_]u8{}, + .sha256 => self.sha256.peek(), + .sha384 => self.sha384.peek(), + .sha512 => self.sha512.peek(), + }; + } +}; + +fn HandshakeCipherT(comptime suite: CipherSuite) type { return struct { - pub const AEAD = AeadType; - pub const Hash = HashType; + pub const AEAD = suite.Aead(); + pub const Hash = suite.Hash(); pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + // Later used in ApplicationCipher.init handshake_secret: [Hkdf.prk_length]u8, - master_secret: [Hkdf.prk_length]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, + // For encrypting/decrypting handshake messages + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + // For generating handshake finished messages client_finished_key: [Hmac.key_length]u8, server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, + // Used as a nonce for encrypting/decrypting handshake messages + // iv = initialization vector + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + + // m0aR s3cUr1tY! + read_seq: usize = 0, + write_seq: usize = 0, + + const Self = @This(); + + pub fn init( + shared_key: []const u8, + hello_hash: []const u8, + logger: KeyLogger, + ) Self { + const zeroes = [1]u8{0} ** Hash.digest_length; + const early = Hkdf.extract(&[1]u8{0}, &zeroes); + const empty = emptyHash(Hash); + + const derived = hkdfExpandLabel(Hkdf, early, "derived", &empty, Hash.digest_length); + const handshake = Hkdf.extract(&derived, shared_key); + const client = hkdfExpandLabel(Hkdf, handshake, "c hs traffic", hello_hash, Hash.digest_length); + const server = hkdfExpandLabel(Hkdf, handshake, "s hs traffic", hello_hash, Hash.digest_length); + + // Not being able to log our secrets shouldn't prevent the handshake from continuing. + logger.writeLine("CLIENT_HANDSHAKE_TRAFFIC_SECRET", &client) catch {}; + logger.writeLine("SERVER_HANDSHAKE_TRAFFIC_SECRET", &server) catch {}; + + return .{ + .handshake_secret = handshake, + .client_finished_key = hkdfExpandLabel(Hkdf, client, "finished", "", Hmac.key_length), + .server_finished_key = hkdfExpandLabel(Hkdf, server, "finished", "", Hmac.key_length), + .client_key = hkdfExpandLabel(Hkdf, client, "key", "", AEAD.key_length), + .server_key = hkdfExpandLabel(Hkdf, server, "key", "", AEAD.key_length), + .client_iv = hkdfExpandLabel(Hkdf, client, "iv", "", AEAD.nonce_length), + .server_iv = hkdfExpandLabel(Hkdf, server, "iv", "", AEAD.nonce_length), + }; + } + + pub fn encrypt( + self: *Self, + data: []const u8, + additional: []const u8, + is_client: bool, + out: []u8, + ) [AEAD.tag_length]u8 { + var res: [AEAD.tag_length]u8 = undefined; + const key = if (is_client) self.client_key else self.server_key; + const iv = if (is_client) self.client_iv else self.server_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.write_seq); + AEAD.encrypt(out, &res, data, additional, nonce, key); + self.write_seq += 1; + return res; + } + + pub fn decrypt( + self: *Self, + data: []const u8, + additional: []const u8, + tag: [AEAD.tag_length]u8, + is_client: bool, + out: []u8, + ) Error!void { + const key = if (is_client) self.server_key else self.client_key; + const iv = if (is_client) self.server_iv else self.client_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.read_seq); + AEAD.decrypt(out, data, tag, additional, nonce, key) catch return Error.TlsBadRecordMac; + self.read_seq += 1; + } }; } -pub const HandshakeCipher = union(enum) { - AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), -}; - -pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { +fn ApplicationCipherT(comptime suite: CipherSuite) type { return struct { - pub const AEAD = AeadType; - pub const Hash = HashType; + pub const AEAD = suite.Aead(); + pub const Hash = suite.Hash(); pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + // Used to derive new keys and iv's in key_update messages client_secret: [Hash.digest_length]u8, server_secret: [Hash.digest_length]u8, + // For encrypting/decrypting application data messages client_key: [AEAD.key_length]u8, server_key: [AEAD.key_length]u8, + // Used as a nonce for encrypting/decrypting application data messages + // iv = initialization vector client_iv: [AEAD.nonce_length]u8, server_iv: [AEAD.nonce_length]u8, + + // m0aR s3cUr1tY! + read_seq: usize = 0, + write_seq: usize = 0, + + const Self = @This(); + + pub fn init( + handshake_secret: [Hkdf.prk_length]u8, + handshake_hash: []const u8, + logger: KeyLogger, + ) Self { + const zeroes = [1]u8{0} ** Hash.digest_length; + const empty_hash = emptyHash(Hash); + + const derived = hkdfExpandLabel(Hkdf, handshake_secret, "derived", &empty_hash, Hash.digest_length); + const master = Hkdf.extract(&derived, &zeroes); + const client = hkdfExpandLabel(Hkdf, master, "c ap traffic", handshake_hash, Hash.digest_length); + const server = hkdfExpandLabel(Hkdf, master, "s ap traffic", handshake_hash, Hash.digest_length); + + // Not being able to log our secrets shouldn't prevent the handshake from continuing. + logger.writeLine("CLIENT_TRAFFIC_SECRET_0", &client) catch {}; + logger.writeLine("SERVER_TRAFFIC_SECRET_0", &server) catch {}; + + return .{ + .client_secret = client, + .server_secret = server, + .client_key = hkdfExpandLabel(Hkdf, client, "key", "", AEAD.key_length), + .server_key = hkdfExpandLabel(Hkdf, server, "key", "", AEAD.key_length), + .client_iv = hkdfExpandLabel(Hkdf, client, "iv", "", AEAD.nonce_length), + .server_iv = hkdfExpandLabel(Hkdf, server, "iv", "", AEAD.nonce_length), + }; + } + + pub fn encrypt( + self: *Self, + data: []const u8, + additional: []const u8, + is_client: bool, + out: []u8, + ) [AEAD.tag_length]u8 { + var res: [AEAD.tag_length]u8 = undefined; + const key = if (is_client) self.client_key else self.server_key; + const iv = if (is_client) self.client_iv else self.server_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.write_seq); + AEAD.encrypt(out, &res, data, additional, nonce, key); + self.write_seq += 1; + return res; + } + + pub fn decrypt( + self: *Self, + data: []const u8, + additional: []const u8, + tag: [AEAD.tag_length]u8, + is_client: bool, + out: []u8, + ) !void { + const key = if (is_client) self.server_key else self.client_key; + const iv = if (is_client) self.server_iv else self.client_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.read_seq); + try AEAD.decrypt(out, data, tag, additional, nonce, key); + self.read_seq += 1; + } + + pub fn print(self: Self) void { + inline for (std.meta.fields(Self)) |f| debugPrint(f.name, @field(self, f.name)); + } }; } -/// Encryption parameters for application traffic. -pub const ApplicationCipher = union(enum) { - AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), -}; +fn nonce_for_len(len: comptime_int, iv: [len]u8, seq: usize) [len]u8 { + if (builtin.zig_backend == .stage2_x86_64 and len > comptime std.simd.suggestVectorLength(u8) orelse 1) { + var res = iv; + const operand = std.mem.readInt(u64, res[res.len - 8 ..], .big); + std.mem.writeInt(u64, res[res.len - 8 ..], operand ^ seq, .big); + return res; + } else { + const V = @Vector(len, u8); + const pad = [1]u8{0} ** (len - 8); + const big = switch (native_endian) { + .big => seq, + .little => @byteSwap(seq), + }; + const operand: V = pad ++ @as([8]u8, @bitCast(big)); + return @as(V, iv) ^ operand; + } +} pub fn hkdfExpandLabel( comptime Hkdf: type, @@ -399,163 +1214,283 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } -pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { - return int2(@intFromEnum(et)) ++ array(1, bytes); -} +/// Slice of stack allocated signature content from RFC 8446 S4.4.3 +pub inline fn sigContent(digest: []const u8) []const u8 { + const max_digest_len = MultiHash.max_digest_len; + var buf = [_]u8{0x20} ** 64 ++ "TLS 1.3, server CertificateVerify\x00".* ++ @as([max_digest_len]u8, undefined); + @memcpy(buf[buf.len - max_digest_len ..][0..digest.len], digest); -pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { - comptime assert(bytes.len % elem_size == 0); - return int2(bytes.len) ++ bytes; + return buf[0 .. buf.len - (max_digest_len - digest.len)]; } -pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { - assert(@sizeOf(E) == 2); - var result: [tags.len * 2]u8 = undefined; - for (tags, 0..) |elem, i| { - result[i * 2] = @as(u8, @truncate(@intFromEnum(elem) >> 8)); - result[i * 2 + 1] = @as(u8, @truncate(@intFromEnum(elem))); +/// Default suites used for client and server in descending order of preference. +/// The order is chosen based on what crypto algorithms Zig has available in +/// the standard library and their speed on x86_64-linux. +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast +/// aegis-128l: 15382 MiB/s +/// aegis-256: 9553 MiB/s +/// aes128-gcm: 3721 MiB/s +/// aes256-gcm: 3010 MiB/s +/// chacha20Poly1305: 597 MiB/s +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline +/// aegis-128l: 629 MiB/s +/// chacha20Poly1305: 529 MiB/s +/// aegis-256: 461 MiB/s +/// aes128-gcm: 138 MiB/s +/// aes256-gcm: 120 MiB/s +pub const default_cipher_suites = + if (crypto.core.aes.has_hardware_support) + [_]CipherSuite{ + .aegis_128l_sha256, + .aegis_256_sha512, + .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, } - return array(2, result); -} - -pub inline fn int2(x: u16) [2]u8 { - return .{ - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), +else + [_]CipherSuite{ + .chacha20_poly1305_sha256, + .aegis_128l_sha256, + .aegis_256_sha512, + .aes_128_gcm_sha256, + .aes_256_gcm_sha384, }; -} -pub inline fn int3(x: u24) [3]u8 { - return .{ - @as(u8, @truncate(x >> 16)), - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; -} +// Implements `StreamInterface` with a ring buffer +const TestStream = struct { + buffer: Buffer, -/// An abstraction to ensure that protocol-parsing code does not perform an -/// out-of-bounds read. -pub const Decoder = struct { - buf: []u8, - /// Points to the next byte in buffer that will be decoded. - idx: usize = 0, - /// Up to this point in `buf` we have already checked that `cap` is greater than it. - our_end: usize = 0, - /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we - /// requested with `readAtLeast`. - their_end: usize = 0, - /// Points to the end within buffer that has been filled. Beyond this point - /// in buf is undefined bytes. - cap: usize = 0, - /// Debug helper to prevent illegal calls to read functions. - disable_reads: bool = false, - - pub fn fromTheirSlice(buf: []u8) Decoder { - return .{ - .buf = buf, - .their_end = buf.len, - .cap = buf.len, - .disable_reads = true, - }; + const Buffer = std.RingBuffer; + const Self = @This(); + + pub const ReadError = Buffer.Error; + pub const WriteError = Buffer.Error; + + pub fn init(allocator: std.mem.Allocator) !Self { + return Self{ .buffer = try Buffer.init(allocator, Plaintext.max_length) }; } - /// Use this function to increase `their_end`. - pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { - assert(!d.disable_reads); - const existing_amt = d.cap - d.idx; - d.their_end = d.idx + their_amt; - if (their_amt <= existing_amt) return; - const request_amt = their_amt - existing_amt; - const dest = d.buf[d.cap..]; - if (request_amt > dest.len) return error.TlsRecordOverflow; - const actual_amt = try stream.readAtLeast(dest, request_amt); - if (actual_amt < request_amt) return error.TlsConnectionTruncated; - d.cap += actual_amt; - } - - /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. - /// Use when `our_amt` is calculated by us, not by them. - pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { - assert(!d.disable_reads); - try readAtLeast(d, stream, our_amt); - d.our_end = d.idx + our_amt; - } - - /// Use this function to increase `our_end`. - /// This should always be called with an amount provided by us, not them. - pub fn ensure(d: *Decoder, amt: usize) !void { - d.our_end = @max(d.idx + amt, d.our_end); - if (d.our_end > d.their_end) return error.TlsDecodeError; - } - - /// Use this function to increase `idx`. - pub fn decode(d: *Decoder, comptime T: type) T { - switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => { - skip(d, 1); - return d.buf[d.idx - 1]; - }, - 16 => { - skip(d, 2); - const b0: u16 = d.buf[d.idx - 2]; - const b1: u16 = d.buf[d.idx - 1]; - return (b0 << 8) | b1; - }, - 24 => { - skip(d, 3); - const b0: u24 = d.buf[d.idx - 3]; - const b1: u24 = d.buf[d.idx - 2]; - const b2: u24 = d.buf[d.idx - 1]; - return (b0 << 16) | (b1 << 8) | b2; - }, - else => @compileError("unsupported int type: " ++ @typeName(T)), - }, - .Enum => |info| { - const int = d.decode(info.tag_type); - if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); - }, - else => @compileError("unsupported type: " ++ @typeName(T)), + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + self.buffer.deinit(allocator); + } + + pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { + const first = iov[0]; + try self.buffer.readFirst(first.iov_base[0..first.iov_len], first.iov_len); + return first.iov_len; + } + + pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { + var written: usize = 0; + for (iov) |v| { + try self.buffer.writeSlice(v.iov_base[0..v.iov_len]); + written += v.iov_len; } + return written; } - /// Use this function to increase `idx`. - pub fn array(d: *Decoder, comptime len: usize) *[len]u8 { - skip(d, len); - return d.buf[d.idx - len ..][0..len]; + pub fn peek(self: *Self, out: []u8) ReadError!void { + const read_index = self.buffer.read_index; + _ = try self.read(out); + self.buffer.read_index = read_index; } - /// Use this function to increase `idx`. - pub fn slice(d: *Decoder, len: usize) []u8 { - skip(d, len); - return d.buf[d.idx - len ..][0..len]; + pub fn close(self: *Self) void { + _ = self; } - /// Use this function to increase `idx`. - pub fn skip(d: *Decoder, amt: usize) void { - d.idx += amt; - assert(d.idx <= d.our_end); // insufficient ensured bytes + pub fn expect(self: *Self, expected: []const u8) !void { + var tmp_buf: [Plaintext.max_length]u8 = undefined; + const buf = tmp_buf[0..self.buffer.len()]; + try self.peek(buf); + + try std.testing.expectEqualSlices(u8, expected, buf); } - pub fn eof(d: Decoder) bool { - assert(d.our_end <= d.their_end); - assert(d.idx <= d.our_end); - return d.idx == d.their_end; + const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); + + pub fn stream(self: *Self) GenericStream { + return .{ .context = self }; } +}; - /// Provide the length they claim, and receive a sub-decoder specific to that slice. - /// The parent decoder is advanced to the end. - pub fn sub(d: *Decoder, their_len: usize) !Decoder { - const end = d.idx + their_len; - if (end > d.their_end) return error.TlsDecodeError; - const sub_buf = d.buf[d.idx..end]; - d.idx = end; - d.our_end = end; - return fromTheirSlice(sub_buf); +fn seededClientHandshake(allocator: Allocator, stream: std.io.AnyStream) !Client.Handshake { + const client_random: [32]u8 = ("client_random012" ** 2).*; + const session_id: [32]u8 = ("session_id012345" ** 2).*; + const client_key_seed: [16]u8 = "client_seed01234".*; + const key_pairs = try Client.Handshake.KeyPairs.initAdvanced( + client_key_seed ** 2, + client_key_seed ** 3, + client_key_seed ** 2, + ); + + return Client.Handshake{ + .tls_stream = .{ .inner_stream = stream, .is_client = true }, + .options = .{ + .host = "localhost", + .ca_bundle = null, + .allocator = allocator, + }, + .client_random = client_random, + .session_id = session_id, + .key_pairs = key_pairs, + }; +} + +fn seededServerHandshake(stream: std.io.AnyStream) !Server.Handshake { + const server_random: [32]u8 = ("server_random012" ** 2).*; + const server_keygen_seed: [48]u8 = ("server_seed01234" ** 3).*; + const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; + + const server_cert = @embedFile("./testdata/cert.der"); + const server_key = @embedFile("./testdata/key.der"); + const server_rsa = try crypto.Certificate.rsa.SecretKey.fromDer(server_key); + + // For debugging + const stdout = std.io.getStdOut(); + const key_log = stdout.writer().any(); + + return Server.Handshake{ + .tls_stream = .{ .inner_stream = stream, .is_client = false }, + .options = .{ + .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, + .key_shares = &[_]NamedGroup{.x25519}, + .certificate = .{ .entries = &[_]Certificate.Entry{ + .{ .data = server_cert }, + } }, + .certificate_key = .{ .rsa = server_rsa }, + .key_log = key_log, + }, + .server_random = server_random, + .keygen_seed = server_keygen_seed, + .certificate_verify_salt = server_sig_salt, + }; +} + +test "tls client and server handshake, data, and close_notify" { + const allocator = testing.allocator; + + var inner_stream = try TestStream.init(allocator); + defer inner_stream.deinit(allocator); + const stream = inner_stream.stream(); + + var c_hs = try seededClientHandshake(allocator, stream.any()); + var s_hs = try seededServerHandshake(stream.any()); + + try c_hs.next(); + try testing.expectEqual(.recv_hello, c_hs.command); + + try s_hs.next(); // recv_hello + try testing.expectEqual(.send_hello, s_hs.command); + + try s_hs.next(); // send_hello + try testing.expectEqual(.send_change_cipher_spec, s_hs.command); + + try c_hs.next(); // recv_hello + try testing.expectEqual(.recv_encrypted_extensions, c_hs.command); + { + const s = s_hs.tls_stream.cipher.handshake.aes_256_gcm_sha384; + const c = c_hs.tls_stream.cipher.handshake.aes_256_gcm_sha384; + + try testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); + try testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); + try testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + try testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + } + + try s_hs.next(); // send_change_cipher_spec + try testing.expectEqual(.send_encrypted_extensions, s_hs.command); + try s_hs.next(); // send_encrypted_extensions + try testing.expectEqual(.send_certificate, s_hs.command); + try s_hs.next(); // send_certificate + try testing.expectEqual(.send_certificate_verify, s_hs.command); + try s_hs.next(); // send_certificate_verify + try testing.expectEqual(.send_finished, s_hs.command); + try s_hs.next(); // send_finished + try testing.expectEqual(.recv_finished, s_hs.command); + + try c_hs.next(); // recv_encrypted_extensions + try testing.expectEqual(.recv_certificate_or_finished, c_hs.command); + try c_hs.next(); // recv_certificate_or_finished (certificate) + try testing.expectEqual(.recv_certificate_verify, c_hs.command); + try c_hs.next(); // recv_certificate_verify + try testing.expectEqual(.recv_finished, c_hs.command); + try c_hs.next(); // recv_finished + try testing.expectEqual(.send_change_cipher_spec, c_hs.command); + try c_hs.next(); // send_change_cipher_spec + try testing.expectEqual(.send_finished, c_hs.command); + try c_hs.next(); // send_finished + try testing.expectEqual(.none, c_hs.command); + + try s_hs.next(); // recv_finished + try testing.expectEqual(.none, s_hs.command); + { + const s = s_hs.tls_stream.cipher.application.aes_256_gcm_sha384; + const c = c_hs.tls_stream.cipher.application.aes_256_gcm_sha384; + + try testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + try testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); } - pub fn rest(d: Decoder) []u8 { - return d.buf[d.idx..d.cap]; + var client = try c_hs.handshake(); + var server = try s_hs.handshake(); + + try client.stream().writer().writeAll("ping"); + + var recv_ping: [4]u8 = undefined; + _ = try server.tls_stream.stream().reader().readAll(&recv_ping); + try testing.expectEqualStrings("ping", &recv_ping); + + server.tls_stream.close(); + try testing.expect(server.tls_stream.closed); + + _ = try client.stream().reader().discard(); + try testing.expect(client.tls_stream.closed); +} + +pub fn debugPrint(name: []const u8, slice: anytype) void { + std.debug.print("{s} ", .{name}); + if (@typeInfo(@TypeOf(slice)) == .Int) { + std.debug.print("{d} ", .{slice}); + } else { + for (slice) |c| std.debug.print("{x:0>2} ", .{c}); + } + std.debug.print("\n", .{}); +} + +/// https://www.ietf.org/archive/id/draft-thomson-tls-keylogfile-01.html +pub const KeyLogger = struct { + /// Copy of `options.key_log`. Needed for `key_update` messages. + writer: std.io.AnyWriter = std.io.null_writer.any(), + /// The value received in our `ClientHello` message. + /// Used as a session identifier in messages sent to `key_log`. + client_random: [32]u8 = undefined, + // For logging after `key_update` messages. + server_update_n: u32 = 0, + client_update_n: u32 = 0, + + pub fn writeLine( + self: @This(), + label: []const u8, + secret: []const u8, + ) !void { + var w = self.writer; + + try w.writeAll(label); + try w.writeByte(' '); + for (self.client_random) |b| w.print("{x:0>2}", .{b}) catch {}; + try w.writeByte(' '); + for (secret) |b| w.print("{x:0>2}", .{b}) catch {}; + try w.writeByte('\n'); } }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index f07cfe781031..3efb37e7faf5 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,1468 +1,647 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; -const Client = @This(); -const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; -const Certificate = std.crypto.Certificate; - -const max_ciphertext_len = tls.max_ciphertext_len; -const hkdfExpandLabel = tls.hkdfExpandLabel; -const int2 = tls.int2; -const int3 = tls.int3; -const array = tls.array; -const enum_array = tls.enum_array; - -read_seq: u64, -write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. -partial_ciphertext_end: u15, -/// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. -received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. -allow_truncation_attacks: bool = false, -application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided, indicating end-of-stream. - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } +const Certificate = crypto.Certificate; +const Allocator = std.mem.Allocator; + +tls_stream: tls.Stream, +key_logger: tls.KeyLogger, + +pub const Options = struct { + /// Certificate messages may be up to 2^24-1 bytes long. + /// Certificate verify messages may be up to 2^16-1 bytes long. + /// This is the allocator to use for them. + allocator: Allocator, + /// Trusted certificate authority bundle used to authenticate server certificates. + /// When null, server certificate and certificate_verify messages will be skipped. + ca_bundle: ?Certificate.Bundle, + /// Used to verify cerficate chain and for Server Name Indication. + host: []const u8, + /// List of cipher suites to advertise in order of descending preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + /// Minimum version to support. + min_version: tls.Version = .tls_1_3, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, + /// Writer to log shared secrets for traffic decryption in SSLKEYLOGFILE format. + key_log: std.io.AnyWriter = std.io.null_writer.any(), }; -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, - }; +const Client = @This(); + +/// Executes a TLSv1.3 handshake on `any_stream`. +pub fn init(any_stream: std.io.AnyStream, options: Options) !Client { + const hs = Handshake.init(any_stream, options); + return try hs.handshake(); } -/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which -/// must conform to `StreamInterface`. -/// -/// `host` is only borrowed during this function call. -pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { - const host_len: u16 = @intCast(host.len); - - var random_buffer: [128]u8 = undefined; - crypto.random.bytes(&random_buffer); - const hello_rand = random_buffer[0..32].*; - const legacy_session_id = random_buffer[32..64].*; - const x25519_kp_seed = random_buffer[64..96].*; - const secp256r1_kp_seed = random_buffer[96..128].*; - - const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const kyber768_kp = crypto.kem.kyber_d00.Kyber768.KeyPair.create(null) catch {}; - - const extensions_payload = - tls.extension(.supported_versions, [_]u8{ - 0x02, // byte length of supported versions - 0x03, 0x04, // TLS 1.3 - }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ - .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - .ed25519, - })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ - .x25519_kyber768d00, - .secp256r1, - .x25519, - })) ++ tls.extension( - .key_share, - array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, x25519_kp.public_key) ++ - int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++ - int2(@intFromEnum(tls.NamedGroup.x25519_kyber768d00)) ++ - array(1, x25519_kp.public_key ++ kyber768_kp.public_key.toBytes())), - ) ++ - int2(@intFromEnum(tls.ExtensionType.server_name)) ++ - int2(host_len + 5) ++ // byte length of this extension payload - int2(host_len + 3) ++ // server_name_list byte count - [1]u8{0x00} ++ // name_type - int2(host_len); - - const extensions_header = - int2(@intCast(extensions_payload.len + host_len)) ++ - extensions_payload; - - const legacy_compression_methods = 0x0100; - - const client_hello = - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - hello_rand ++ - [1]u8{32} ++ legacy_session_id ++ - cipher_suites ++ - int2(legacy_compression_methods) ++ - extensions_header; - - const out_handshake = - [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int3(@intCast(client_hello.len + host_len)) ++ - client_hello; - - const plaintext_header = [_]u8{ - @intFromEnum(tls.ContentType.handshake), - 0x03, 0x01, // legacy_record_version - } ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake; - - { - var iovecs = [_]std.os.iovec_const{ - .{ - .iov_base = &plaintext_header, - .iov_len = plaintext_header.len, - }, - .{ - .iov_base = host.ptr, - .iov_len = host.len, - }, - }; - try stream.writevAll(&iovecs); - } +pub const ReadError = anyerror; +pub const WriteError = anyerror; - const client_hello_bytes1 = plaintext_header[5..]; - - var handshake_cipher: tls.HandshakeCipher = undefined; - var handshake_buffer: [8000]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; - { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_record_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - const server_hello_fragment = d.buf[d.idx..][0..record_len]; - var ptd = try d.sub(record_len); - switch (ct) { - .alert => { - try ptd.ensure(2); - const level = ptd.decode(tls.AlertLevel); - const desc = ptd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - try ptd.ensure(4); - const handshake_type = ptd.decode(tls.HandshakeType); - if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; - const length = ptd.decode(u24); - var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); - const legacy_version = hsd.decode(u16); - const random = hsd.array(32); - if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { - // This is a HelloRetryRequest message. This client implementation - // does not expect to get one. - return error.TlsUnexpectedMessage; - } - const legacy_session_id_echo_len = hsd.decode(u8); - if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = hsd.array(32); - if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) - return error.TlsIllegalParameter; - const cipher_suite_tag = hsd.decode(tls.CipherSuite); - hsd.skip(1); // legacy_compression_method - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - var supported_version: u16 = 0; - var shared_key: []const u8 = undefined; - var have_shared_key = false; - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version != 0) return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (have_shared_key) return error.TlsIllegalParameter; - have_shared_key = true; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - switch (named_group) { - .x25519_kyber768d00 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length; - if (key_size != hksl) - return error.TlsIllegalParameter; - const server_ks = extd.array(hksl); - - shared_key = &((crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_ks[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (kyber768_kp.secret_key.decaps( - server_ks[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (key_size != ksl) return error.TlsIllegalParameter; - const server_pub_key = extd.array(ksl); - - shared_key = &(crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_pub_key.*, - ) catch return error.TlsDecryptFailure); - }, - .secp256r1 => { - const server_pub_key = extd.slice(key_size); - - const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; - const pk = PublicKey.fromSec1(server_pub_key) catch { - return error.TlsDecryptFailure; - }; - const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { - return error.TlsDecryptFailure; - }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => {}, - } - } - if (!have_shared_key) return error.TlsIllegalParameter; - - const tls_version = if (supported_version == 0) legacy_version else supported_version; - if (tls_version != @intFromEnum(tls.ProtocolVersion.tls_1_3)) - return error.TlsIllegalParameter; - - switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - => |tag| { - const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ - .handshake_secret = undefined, - .master_secret = undefined, - .client_handshake_key = undefined, - .server_handshake_key = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, - .client_handshake_iv = undefined, - .server_handshake_iv = undefined, - .transcript_hash = P.Hash.init(.{}), - }); - const p = &@field(handshake_cipher, @tagName(tag)); - p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 - p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(server_hello_fragment); - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => return error.TlsUnexpectedMessage, - } - } +/// Reads next application_data message. +pub fn readv(self: *Client, buffers: []const std.os.iovec) ReadError!usize { + var s = &self.tls_stream; - // This is used for two purposes: - // * Detect whether a certificate is the first one presented, in which case - // we need to verify the host name. - // * Flip back and forth between the two cleartext buffers in order to keep - // the previous certificate in memory so that it can be verified by the - // next one. - var cert_index: usize = 0; - var read_seq: u64 = 0; - var prev_cert: Certificate.Parsed = undefined; - // Set to true once a trust chain has been established from the first - // certificate to a root CA. - const HandshakeState = enum { - /// In this state we expect only an encrypted_extensions message. - encrypted_extensions, - /// In this state we expect certificate messages. - certificate, - /// In this state we expect certificate or certificate_verify messages. - /// certificate messages are ignored since the trust chain is already - /// established. - trust_chain_established, - /// In this state, we expect only the finished message. - finished, - }; - var handshake_state: HandshakeState = .encrypted_extensions; - var cleartext_bufs: [2][8000]u8 = undefined; - var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [600]u8 = undefined; - var main_cert_pub_key_len: u16 = undefined; - const now_sec = std.time.timestamp(); - - while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..5]; - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - var record_decoder = try d.sub(record_len); - switch (ct) { - .change_cipher_spec => { - try record_decoder.ensure(1); - if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; - }, - .application_data => { - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - - const cleartext = switch (handshake_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ciphertext_len = record_len - P.AEAD.tag_length; - try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); - const ciphertext = record_decoder.slice(ciphertext_len); - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = record_decoder.array(P.AEAD.tag_length).*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_handshake_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(read_seq))); - break :nonce @as(V, p.server_handshake_iv) ^ operand; - }; - read_seq += 1; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch - return error.TlsBadRecordMac; - break :c cleartext; + while (s.view.len == 0 and !s.eof()) { + const inner_plaintext = try s.readInnerPlaintext(); + switch (inner_plaintext.type) { + .handshake => { + switch (inner_plaintext.handshake_type) { + // A multithreaded client could use these. + .new_session_ticket => { + try s.stream().reader().skipBytes(inner_plaintext.len, .{}); }, - }; - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - if (inner_ct != .handshake) return error.TlsUnexpectedMessage; - - var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); - while (true) { - try ctd.ensure(4); - const handshake_type = ctd.decode(tls.HandshakeType); - const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); - const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; - const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - try hsd.ensure(2); - const total_ext_size = hsd.decode(u16); - var all_extd = try hsd.sub(total_ext_size); - while (!all_extd.eof()) { - try all_extd.ensure(4); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - const extd = try all_extd.sub(ext_size); - _ = extd; - switch (et) { - .server_name => {}, - else => {}, - } - } - }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - try hsd.ensure(1 + 4); - const cert_req_ctx_len = hsd.decode(u8); - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = hsd.decode(u24); - var certs_decoder = try hsd.sub(certs_size); - while (!certs_decoder.eof()) { - try certs_decoder.ensure(3); - const cert_size = certs_decoder.decode(u24); - const certd = try certs_decoder.sub(cert_size); - - const subject_cert: Certificate = .{ - .buffer = certd.buf, - .index = @intCast(certd.idx), - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - try subject.verifyHostName(host); - - // Keep track of the public key for the - // certificate_verify message later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); - main_cert_pub_key_len = @intCast(pub_key.len); - } else { - try prev_cert.verify(subject, now_sec); - } - - if (ca_bundle.verify(subject, now_sec)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; - - try certs_decoder.ensure(2); - const total_ext_size = certs_decoder.decode(u16); - const all_extd = try certs_decoder.sub(total_ext_size); - _ = all_extd; - } - }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, - } - - try hsd.ensure(4); - const scheme = hsd.decode(tls.SignatureScheme); - const sig_len = hsd.decode(u16); - try hsd.ensure(sig_len); - const encoded_sig = hsd.slice(sig_len); - const max_digest_len = 64; - var verify_buffer: [64 + 34 + max_digest_len]u8 = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); - - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = SchemeHash(comptime_scheme); - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = SchemeEddsa(comptime_scheme); - if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); - if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - else => { - return error.TlsBadSignatureScheme; - }, - } - }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @intFromEnum(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { + .key_update => { + switch (s.cipher.application) { + inline else => |*p| { + const P = @TypeOf(p.*); + p.server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); + p.server_key = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "key", "", P.AEAD.key_length); + p.server_iv = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "iv", "", P.AEAD.nonce_length); + p.read_seq = 0; + + var logger = &self.key_logger; + logger.server_update_n += 1; + logger.writer.print("SERVER_TRAFFIC_SECRET_{d}", .{logger.server_update_n}) catch {}; + logger.writeLine("", &p.server_secret) catch {}; + }, + } + const update = try s.read(tls.KeyUpdate); + if (update == .update_requested) { + switch (s.cipher.application) { + inline else => |*p| { const P = @TypeOf(p.*); - const finished_digest = p.transcript_hash.peek(); - p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @intFromEnum(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; - - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - - var finished_msg = [_]u8{ - @intFromEnum(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); - - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - var both_msgs_vec = [_]std.os.iovec_const{.{ - .iov_base = &both_msgs, - .iov_len = both_msgs.len, - }}; - try stream.writevAll(&both_msgs_vec); - - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ - .client_secret = client_secret, - .server_secret = server_secret, - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); + p.client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); + p.client_key = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "key", "", P.AEAD.key_length); + p.client_iv = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); + p.write_seq = 0; + + var logger = &self.key_logger; + logger.client_update_n += 1; + logger.writer.print("CLIENT_TRAFFIC_SECRET_{d}", .{logger.client_update_n}) catch {}; + logger.writeLine("", &p.client_secret) catch {}; }, - }; - const leftover = d.rest(); - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - if (ctd.eof()) break; + } + } + }, + else => return s.writeError(.unexpected_message), } }, - else => { - return error.TlsUnexpectedMessage; - }, + .alert => {}, + .application_data => {}, + else => return s.writeError(.unexpected_message), } } + return try s.readv(buffers); } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); -} +/// Writes application_data message and flushes stream. +pub fn writev(self: *Client, iov: []const std.os.iovec_const) WriteError!usize { + if (self.tls_stream.eof()) return 0; -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); - } + const res = try self.tls_stream.writev(iov); + try self.tls_stream.flush(); + return res; } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } +pub fn close(self: *Client) void { + self.tls_stream.close(); } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.os.iovec_const = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } +pub const GenericStream = std.io.GenericStream(*Client, ReadError, readv, WriteError, writev, close); - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].iov_len) { - const encrypted_amt = iovecs_buf[i].iov_len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; - } - iovecs_buf[i].iov_base += amt; - iovecs_buf[i].iov_len -= amt; - } +pub fn stream(self: *Client) GenericStream { + return .{ .context = self }; } -fn prepareCiphertextRecord( - c: *Client, - iovecs: []std.os.iovec_const, - ciphertext_buf: []u8, - bytes: []const u8, - inner_content_type: tls.ContentType, -) struct { - iovec_end: usize, - ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, -} { - // Due to the trailing inner content type byte in the ciphertext, we need - // an additional buffer for storing the cleartext into before encrypting. - var cleartext_buf: [max_ciphertext_len]u8 = undefined; - var ciphertext_end: usize = 0; - var iovec_end: usize = 0; - var bytes_i: usize = 0; - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; +pub const Handshake = struct { + tls_stream: tls.Stream, + /// Running hash of handshake messages for cryptographic functions + transcript_hash: tls.MultiHash = .{}, + options: Options, + + client_random: [32]u8, + /// Used in TLSv1.2. Always set for middlebox compatibility. + session_id: [32]u8, + /// One of these potential key pairs will be selected in `recv_hello` + /// to establish a shared secret for encryption. + key_pairs: KeyPairs, + + /// Next command to execute + command: Command = .send_hello, + + /// Server certificate to later verify. + /// `cert.certificate.buffer` is allocated + cert: Certificate.Parsed = undefined, + + pub const KeyPairs = struct { + secp256r1: Secp256r1, + secp384r1: Secp384r1, + x25519: X25519, + + const X25519 = tls.NamedGroupT(.x25519).KeyPair; + const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; + const Secp384r1 = tls.NamedGroupT(.secp384r1).KeyPair; + + pub fn init() @This() { + var random_buffer: [ + Secp256r1.seed_length + + Secp384r1.seed_length + + X25519.seed_length + ]u8 = undefined; + while (true) { - const encrypted_content_len: u16 = @intCast(@min( - @min(bytes.len - bytes_i, tls.max_cipertext_inner_record_len), - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), - )); - if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, - .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, - }; - - @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); - bytes_i += encrypted_content_len; - const ciphertext_len = encrypted_content_len + 1; - const cleartext = cleartext_buf[0..ciphertext_len]; - - const record_start = ciphertext_end; - const ad = ciphertext_buf[ciphertext_end..][0..5]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); - ciphertext_end += ad.len; - const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; - ciphertext_end += ciphertext_len; - const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += auth_tag.len; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.client_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); - break :nonce @as(V, p.client_iv) ^ operand; - }; - c.write_seq += 1; // TODO send key_update on overflow - P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .iov_base = record.ptr, - .iov_len = record.len, - }; - iovec_end += 1; - } - }, - } -} + crypto.random.bytes(&random_buffer); -pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; -} + const split1 = Secp256r1.seed_length; + const split2 = split1 + Secp384r1.seed_length; -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); -} + return initAdvanced( + random_buffer[0..split1].*, + random_buffer[split1..split2].*, + random_buffer[split2..].*, + ) catch continue; + } + } -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); -} + pub fn initAdvanced( + secp256r1_seed: [Secp256r1.seed_length]u8, + secp384r1_seed: [Secp384r1.seed_length]u8, + x25519_seed: [X25519.seed_length]u8, + ) !@This() { + return .{ + .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + .secp384r1 = Secp384r1.create(secp384r1_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + .x25519 = X25519.create(x25519_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + }; + } + }; -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} + /// A command to send or receive a single message. + pub const Command = enum { + send_hello, + recv_hello, + recv_encrypted_extensions, + recv_certificate_or_finished, + recv_certificate_verify, + recv_finished, + send_change_cipher_spec, + send_finished, + none, + }; -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} + /// Initializes members. Does NOT send any messages to `any_stream`. + pub fn init(any_stream: std.io.AnyStream, options: Options) Handshake { + const tls_stream = tls.Stream{ .stream = any_stream, .is_client = true }; + var res = Handshake{ + .tls_stream = tls_stream, + .options = options, + .random = undefined, + .session_id = undefined, + .key_pairs = undefined, + }; + res.init_random(); -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].iov_len) { - amt -= iovecs[vec_i].iov_len; - vec_i += 1; - } - iovecs[vec_i].iov_base += amt; - iovecs[vec_i].iov_len -= amt; + return res; } -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; - } + inline fn init_random(self: *Handshake) void { + self.key_pairs = KeyPairs.init(); + crypto.random.bytes(&self.client_random); + crypto.random.bytes(&self.session_id); } - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer space`. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.os.iovec = .{ - .{ - .iov_base = first_iov.ptr, - .iov_len = first_iov.len, - }, - .{ - .iov_base = &in_stack_buffer, - .iov_len = in_stack_buffer.len, - }, - }; + /// Establishes a TLS connection on `tls_stream` and returns a Client. + pub fn handshake(self: *Handshake) !Client { + while (self.command != .none) self.next() catch |err| switch (err) { + error.ConnectionResetByPeer => { + // Prevent reply attacks + self.command = .send_hello; + self.init_random(); + }, + else => return err, + }; - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; - } + return Client{ + .tls_stream = self.tls_stream, + .key_logger = .{ + .writer = self.options.key_log, + .client_random = self.client_random, + }, + }; } - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; + /// Sends or receives exactly ONE handshake message on `tls_stream`. + /// Sets `self.command` to next expected message. + pub fn next(self: *Handshake) !void { + var s = &self.tls_stream; + s.transcript_hash = &self.transcript_hash; - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); + self.command = switch (self.command) { + .send_hello => brk: { + try self.send_hello(); - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; + break :brk .recv_hello; + }, + .recv_hello => brk: { + try s.expectInnerPlaintext(.handshake, .server_hello); + try self.recv_hello(); - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); + break :brk .recv_encrypted_extensions; + }, + .recv_encrypted_extensions => brk: { + try s.expectInnerPlaintext(.handshake, .encrypted_extensions); + try self.recv_encrypted_extensions(); - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - switch (ct) { - .alert => { - if (in + 2 > frag.len) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(frag[in]); - const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + break :brk .recv_certificate_or_finished; }, - .application_data => { - const cleartext = switch (c.application_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); - break :nonce @as(V, p.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch - return error.TlsBadRecordMac; - break :c cleartext; + .recv_certificate_or_finished => brk: { + const digest = s.transcript_hash.?.peek(); + const inner_plaintext = try s.readInnerPlaintext(); + if (inner_plaintext.type != .handshake) return s.writeError(.unexpected_message); + switch (inner_plaintext.handshake_type) { + .certificate => { + self.cert = try self.recv_certificate(); + + break :brk .recv_certificate_verify; }, - }; - - c.read_seq = try std.math.add(u64, c.read_seq, 1); - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - switch (inner_ct) { - .alert => { - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; + .finished => { + if (self.options.ca_bundle != null) + return self.tls_stream.writeError(.certificate_required); - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - const msg = cleartext[0 .. cleartext.len - 1]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], - msg, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); - } else { - const amt = vp.put(msg); - if (amt < msg.len) { - const rest = msg[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); - } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len - 1); - } - }, - else => { - return error.TlsUnexpectedMessage; + try self.recv_finished(digest); + + break :brk .send_finished; }, + else => return self.tls_stream.writeError(.unexpected_message), } }, - else => { - return error.TlsUnexpectedMessage; + .recv_certificate_verify => brk: { + defer self.options.allocator.free(self.cert.certificate.buffer); + + const digest = s.transcript_hash.?.peek(); + try s.expectInnerPlaintext(.handshake, .certificate_verify); + try self.recv_certificate_verify(digest); + + break :brk .recv_finished; }, - } - in = end; - } -} + .recv_finished => brk: { + const digest = s.transcript_hash.?.peek(); + try s.expectInnerPlaintext(.handshake, .finished); + try self.recv_finished(digest); -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return out; -} + break :brk .send_change_cipher_spec; + }, + .send_change_cipher_spec => brk: { + try s.changeCipherSpec(); -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return out; -} + break :brk .send_finished; + }, + .send_finished => brk: { + try self.send_finished(); -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); + break :brk .none; + }, + .none => .none, + }; } -} -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; + pub fn send_hello(self: *Handshake) !void { + const hello = tls.ClientHello{ + .random = self.client_random, + .session_id = &self.session_id, + .cipher_suites = self.options.cipher_suites, + .extensions = &.{ + .{ .server_name = &[_]tls.ServerName{.{ .host_name = self.options.host }} }, + .{ .ec_point_formats = &[_]tls.EcPointFormat{.uncompressed} }, + .{ .supported_groups = &tls.supported_groups }, + .{ .signature_algorithms = &tls.supported_signature_schemes }, + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{ + .{ .secp256r1 = self.key_pairs.secp256r1.public_key }, + .{ .secp384r1 = self.key_pairs.secp384r1.public_key }, + .{ .x25519 = self.key_pairs.x25519.public_key }, + } }, + }, + }; + + _ = try self.tls_stream.write(tls.Handshake, .{ .client_hello = hello }); + try self.tls_stream.flush(); } -} -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); + pub fn recv_hello(self: *Handshake) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + + // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. + _ = try s.read(tls.Version); + var random: [32]u8 = undefined; + try r.readNoEof(&random); + if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { + // We already offered all our supported options and we aren't changing them. + return s.writeError(.unexpected_message); + } -inline fn big(x: anytype) @TypeOf(x) { - return switch (native_endian) { - .big => x, - .little => @byteSwap(x), - }; -} + var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try s.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return s.writeError(.illegal_parameter); + const session_id: []u8 = session_id_buf[0..session_id_len]; + try r.readNoEof(session_id); + if (!mem.eql(u8, session_id, &self.session_id)) + return s.writeError(.illegal_parameter); + + const cipher_suite = try s.read(tls.CipherSuite); + const compression_method = try s.read(u8); + if (compression_method != 0) return s.writeError(.illegal_parameter); + + var supported_version: ?tls.Version = null; + var shared_key: ?[]const u8 = null; + + var iter = try s.extensions(); + while (try iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (supported_version != null) return s.writeError(.illegal_parameter); + supported_version = try s.read(tls.Version); + }, + .key_share => { + if (shared_key != null) return s.writeError(.illegal_parameter); + const named_group = try s.read(tls.NamedGroup); + const key_size = try s.read(u16); + switch (named_group) { + .x25519 => { + const T = tls.NamedGroupT(.x25519); + const expected_len = T.public_length; + if (key_size != expected_len) return s.writeError(.illegal_parameter); + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const mult = crypto.dh.X25519.scalarmult( + self.key_pairs.x25519.secret_key, + server_ks[0..expected_len].*, + ) catch return s.writeError(.illegal_parameter); + shared_key = &mult; + }, + inline .secp256r1, .secp384r1 => |t| { + const T = tls.NamedGroupT(t); + const expected_len = T.PublicKey.uncompressed_sec1_encoded_length; + if (key_size != expected_len) return s.writeError(.illegal_parameter); + + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const pk = T.PublicKey.fromSec1(&server_ks) catch + return s.writeError(.illegal_parameter); + const key_pair = @field(self.key_pairs, @tagName(t)); + const mult = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch + return s.writeError(.illegal_parameter); + shared_key = &mult.affineCoordinates().x.toBytes(.big); + }, + // Server sent us back unknown key. That's weird because we only request known ones, + // but we can keep iterating for another. + else => { + try r.skipBytes(key_size, .{}); + }, + } + }, + else => { + try r.skipBytes(ext.len, .{}); + }, + } + } -fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, - .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, - else => @compileError("bad scheme"), - }; -} + if (supported_version != tls.Version.tls_1_3) return s.writeError(.protocol_version); + if (shared_key == null) return s.writeError(.missing_extension); -fn SchemeHash(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, - else => @compileError("bad scheme"), - }; -} + s.transcript_hash.?.setActive(cipher_suite); + const hello_hash = s.transcript_hash.?.peek(); -fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ed25519 => crypto.sign.Ed25519, - else => @compileError("bad scheme"), - }; -} + const handshake_cipher = tls.HandshakeCipher.init( + cipher_suite, + shared_key.?, + hello_hash, + self.logger(), + ) catch return s.writeError(.illegal_parameter); + s.cipher = .{ .handshake = handshake_cipher }; + } + + pub fn recv_encrypted_extensions(self: *Handshake) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + + var iter = try s.extensions(); + while (try iter.next()) |ext| { + try r.skipBytes(ext.len, .{}); + } + } -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -const VecPut = struct { - iovecs: []const std.os.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.iov_base[vp.off..v.iov_len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.iov_len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; + /// Verifies trust chain if `options.ca_bundle` is specified. + /// + /// Caller owns allocated Certificate.Parsed.certificate. + pub fn recv_certificate(self: *Handshake) !Certificate.Parsed { + var s = &self.tls_stream; + var r = s.stream().reader(); + const allocator = self.options.allocator; + const ca_bundle = self.options.ca_bundle; + const verify = ca_bundle != null; + + var context: [tls.Certificate.max_context_len]u8 = undefined; + const context_len = try s.read(u8); + if (context_len > tls.Certificate.max_context_len) return s.writeError(.decode_error); + try r.readNoEof(context[0..context_len]); + + var first: ?crypto.Certificate.Parsed = null; + errdefer if (first) |f| allocator.free(f.certificate.buffer); + var prev: Certificate.Parsed = undefined; + var verified = false; + const now_sec = std.time.timestamp(); + + var certs_iter = try s.iterator(u24, u24); + while (try certs_iter.next()) |cert_len| { + const is_first = first == null; + + if (verified) { + try r.skipBytes(cert_len, .{}); + } else { + if (cert_len > tls.Certificate.Entry.max_data_len) + return s.writeError(.decode_error); + const buf = allocator.alloc(u8, cert_len) catch + return s.writeError(.internal_error); + defer if (!is_first) allocator.free(buf); + errdefer allocator.free(buf); + try r.readNoEof(buf); + + const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; + const cur = cert.parse() catch return s.writeError(.bad_certificate); + if (first == null) { + if (verify) try cur.verifyHostName(self.options.host); + first = cur; + } else { + if (verify) try prev.verify(cur, now_sec); } + + if (ca_bundle) |b| { + if (b.verify(cur, now_sec)) |_| { + verified = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + error.CertificateExpired => return s.writeError(.certificate_expired), + else => return s.writeError(.bad_certificate), + } + } + + prev = cur; } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } + + var ext_iter = try s.extensions(); + while (try ext_iter.next()) |ext| try r.skipBytes(ext.len, .{}); } - } + if (verify and !verified) return s.writeError(.bad_certificate); - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.iov_base[vp.off..v.iov_len]; + return if (first) |f| f else s.writeError(.bad_certificate); } - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].iov_len) { - vp.off = 0; - vp.idx += 1; + pub fn recv_certificate_verify(self: *Handshake, digest: []const u8) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + const allocator = self.options.allocator; + const cert = self.cert; + + const sig_content = tls.sigContent(digest); + + const scheme = try s.read(tls.SignatureScheme); + const len = try s.read(u16); + if (len > tls.CertificateVerify.max_signature_length) + return s.writeError(.decode_error); + const sig_bytes = allocator.alloc(u8, len) catch + return s.writeError(.internal_error); + defer allocator.free(sig_bytes); + try r.readNoEof(sig_bytes); + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (cert.pub_key_algo != .X9_62_id_ecPublicKey) + return s.writeError(.bad_certificate); + const Ecdsa = comptime_scheme.Ecdsa(); + const sig = Ecdsa.Signature.fromDer(sig_bytes) catch + return s.writeError(.decode_error); + const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch + return s.writeError(.decode_error); + sig.verify(sig_content, key) catch return s.writeError(.bad_certificate); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (cert.pub_key_algo != .rsaEncryption) + return s.writeError(.bad_certificate); + + const Hash = comptime_scheme.Hash(); + const rsa = Certificate.rsa; + const key = rsa.PublicKey.fromDer(cert.pubKey()) catch + return s.writeError(.bad_certificate); + switch (key.n.bits() / 8) { + inline 128, 256, 512 => |modulus_len| { + const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); + rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch + return s.writeError(.decode_error); + }, + else => { + return s.writeError(.bad_certificate); + }, + } + }, + inline .ed25519 => |comptime_scheme| { + if (cert.pub_key_algo != .curveEd25519) + return s.writeError(.bad_certificate); + const Eddsa = comptime_scheme.Eddsa(); + if (sig_content.len != Eddsa.Signature.encoded_length) + return s.writeError(.decode_error); + const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); + if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) + return s.writeError(.decode_error); + const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch + return s.writeError(.bad_certificate); + sig.verify(sig_content, key) catch return s.writeError(.bad_certificate); + }, + else => { + return s.writeError(.bad_certificate); + }, } } - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].iov_len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len; - return total; - } -}; + pub fn recv_finished(self: *Handshake, digest: []const u8) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + const cipher = s.cipher.handshake; -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.iov_len) { - iovec.iov_len = bytes_left; - return iovecs[0 .. vec_i + 1]; + switch (cipher) { + inline else => |p| { + const P = @TypeOf(p); + const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); + + var actual: [expected.len]u8 = undefined; + try r.readNoEof(&actual); + if (!mem.eql(u8, expected, &actual)) return s.writeError(.decode_error); + }, } - bytes_left -= iovec.iov_len; } - return iovecs; -} -/// The priority order here is chosen based on what crypto algorithms Zig has -/// available in the standard library as well as what is faster. Following are -/// a few data points on the relative performance of these algorithms. -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -/// aegis-128l: 15382 MiB/s -/// aegis-256: 9553 MiB/s -/// aes128-gcm: 3721 MiB/s -/// aes256-gcm: 3010 MiB/s -/// chacha20Poly1305: 597 MiB/s -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline -/// aegis-128l: 629 MiB/s -/// chacha20Poly1305: 529 MiB/s -/// aegis-256: 461 MiB/s -/// aes128-gcm: 138 MiB/s -/// aes256-gcm: 120 MiB/s -const cipher_suites = if (crypto.core.aes.has_hardware_support) - enum_array(tls.CipherSuite, &.{ - .AEGIS_128L_SHA256, - .AEGIS_256_SHA512, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - }) -else - enum_array(tls.CipherSuite, &.{ - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - .AEGIS_256_SHA512, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - }); - -test { - _ = StreamInterface; -} + pub fn send_finished(self: *Handshake) !void { + var s = &self.tls_stream; + + const handshake_hash = s.transcript_hash.?.peek(); + + const verify_data = switch (s.cipher.handshake) { + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.client_finished_key; + const transcript_hash = s.transcript_hash.?.peek(); + + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); + }, + else => return s.writeError(.decrypt_error), + }; + s.content_type = .handshake; + _ = try s.write(tls.Handshake, .{ .finished = verify_data }); + try s.flush(); + + const application_cipher = tls.ApplicationCipher.init( + s.cipher.handshake, + handshake_hash, + self.logger(), + ); + s.cipher = .{ .application = application_cipher }; + s.content_type = .application_data; + s.transcript_hash = null; + } + + fn logger(self: *Handshake) tls.KeyLogger { + return tls.KeyLogger{ + .client_random = self.client_random, + .writer = self.options.key_log, + }; + } +}; diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig new file mode 100644 index 000000000000..4935e368071b --- /dev/null +++ b/lib/std/crypto/tls/Server.zig @@ -0,0 +1,559 @@ +const std = @import("../../std.zig"); +const tls = std.crypto.tls; +const net = std.net; +const mem = std.mem; +const crypto = std.crypto; +const io = std.io; +const assert = std.debug.assert; +const Certificate = std.crypto.Certificate; +const Allocator = std.mem.Allocator; + +tls_stream: tls.Stream, +key_logger: tls.KeyLogger, + +pub const Options = struct { + /// List of potential cipher suites in descending order of preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + /// Types of shared keys to accept from client. + key_shares: []const tls.NamedGroup = &tls.supported_groups, + /// Certificate(s) to send in `send_certificate` messages. + /// The first entry will be used for verification. + certificate: tls.Certificate = .{}, + /// Secret key corresponding to `certificate.entries[0]` to use for certificate verification. + certificate_key: CertificateKey = .none, + /// Writer to log shared secrets for traffic decryption in SSLKEYLOGFILE format. + key_log: std.io.AnyWriter = std.io.null_writer.any(), + + pub const CertificateKey = union(enum) { + none: void, + rsa: crypto.Certificate.rsa.SecretKey, + ecdsa256: tls.NamedGroupT(.secp256r1).SecretKey, + ecdsa384: tls.NamedGroupT(.secp384r1).SecretKey, + ed25519: crypto.sign.Ed25519.SecretKey, + }; +}; + +const Server = @This(); + +/// Initiates a TLS handshake and establishes a TLSv1.3 session +pub fn init(any_stream: std.io.AnyStream, options: Options) !Server { + const hs = Handshake.init(any_stream, options); + return try hs.hanshake(); +} + +pub const ReadError = anyerror; +pub const WriteError = anyerror; + +/// Reads next application_data message. +pub fn readv(self: *Server, buffers: []const std.os.iovec) ReadError!usize { + var s = &self.tls_stream; + + if (s.eof()) return 0; + + while (s.view.len == 0) { + const inner_plaintext = try s.readInnerPlaintext(); + switch (inner_plaintext.type) { + .application_data => {}, + .alert => {}, + else => return s.writeError(.unexpected_message), + } + } + return try self.tls_stream.readv(buffers); +} + +pub fn writev(self: *Server, iov: []const std.os.iovec_const) WriteError!usize { + if (self.tls_stream.eof()) return 0; + + const res = try self.tls_stream.writev(iov); + try self.tls_stream.flush(); + return res; +} + +pub fn close(self: *Server) void { + self.tls_stream.close(); +} + +pub const GenericStream = std.io.GenericStream(*Server, ReadError, readv, WriteError, writev, close); + +pub fn stream(self: *Server) GenericStream { + return .{ .context = self }; +} + +pub const Handshake = struct { + tls_stream: tls.Stream, + /// Running hash of handshake messages for cryptographic functions + transcript_hash: tls.MultiHash = .{}, + options: Options, + cert: ?Certificate.Parsed = null, + + server_random: [32]u8, + keygen_seed: [tls.NamedGroupT(.secp384r1).KeyPair.seed_length]u8, + certificate_verify_salt: [tls.MultiHash.max_digest_len]u8, + + /// Next command to execute + command: Command = .recv_hello, + + /// Defined after `recv_hello` + client_hello: ClientHello = undefined, + /// Used to establish a shared secret. Defined after `recv_hello` + key_pair: tls.KeyPair = undefined, + + pub const ClientHello = struct { + random: [32]u8, + session_id_len: u8, + session_id: [32]u8, + cipher_suite: tls.CipherSuite, + key_share: tls.KeyShare, + sig_scheme: tls.SignatureScheme, + }; + + /// A command to send or receive a single message. + pub const Command = enum { + recv_hello, + send_hello, + send_change_cipher_spec, + send_encrypted_extensions, + send_certificate, + send_certificate_verify, + send_finished, + recv_finished, + none, + }; + + /// Initializes members. Does NOT send any messages to `any_stream`. + pub fn init(any_stream: std.io.AnyStream, options: Options) Handshake { + const tls_stream = tls.Stream{ .stream = any_stream, .is_client = false }; + var res = Handshake{ + .tls_stream = tls_stream, + .options = options, + .random = undefined, + .session_id = undefined, + .key_pairs = undefined, + }; + res.init_random(); + + const certificates = options.certificate.entries; + // Verify that the certificate key matches the root certificate. + // This allows failing fast now instead of every client failing signature verification later. + if (certificates.len > 0) { + const cert_buf = Certificate{ .buffer = options.certificate.entries[0].data, .index = 0 }; + res.cert = try cert_buf.parse(); + const expected: std.meta.Tag(Options.CertificateKey) = switch (res.cert.pub_key_algo) { + .rsaEncryption => .rsa, + .X9_62_id_ecPublicKey => |curve| switch (curve) { + .X9_62_prime256v1 => .ecdsa256, + .secp384r1 => .ecdsa384, + else => return error.UnsupportedCertificateSignature, + }, + .curveEd25519 => .ed25519, + }; + if (expected != options.certificate_key) return error.CertificateKeyMismatch; + + // TODO: test private key matches cert public key + // const test_msg = "hello"; + // switch (options.certificate_key) { + // .rsa => |key| { + // switch (key.n.bits() / 8) { + // inline 128, 256, 512 => |modulus_len| { + // const enc = key.public.encrypt(modulus_len, test_msg); + // const dec = key.decrypt(modulus_len, enc) catch return error.CertificateKeyMismatch; + // if (!std.mem.eql(u8, test_msg, dec)) return error.CertificateKeyMismatch; + // }, + // else => return error.CertificateKeyMismatch, + // } + // }, + // inline .ecdsa256, .ecdsa384 => |comptime_scheme| { + // }, + // .ed25519: crypto.sign.Ed25519.SecretKey, + // } + } + + return res; + } + + inline fn init_random(self: *Handshake) void { + crypto.random.bytes(&self.server_random); + crypto.random.bytes(&self.keygen_seed); + crypto.random.bytes(&self.certificate_verify_salt); + } + + /// Executes handshake command and returns next one. + pub fn next(self: *Handshake) !void { + var s = &self.tls_stream; + s.transcript_hash = &self.transcript_hash; + + self.command = switch (self.command) { + .recv_hello => brk: { + self.client_hello = try self.recv_hello(); + + break: brk .send_hello; + }, + .send_hello => brk: { + try self.send_hello(); + + // > if the client sends a non-empty session ID, + // > the server MUST send the change_cipher_spec + if (self.client_hello.session_id_len > 0) break :brk .send_change_cipher_spec; + + break :brk .send_encrypted_extensions; + }, + .send_change_cipher_spec => brk: { + try s.changeCipherSpec(); + + break :brk .send_encrypted_extensions; + }, + .send_encrypted_extensions => brk: { + try self.send_encrypted_extensions(); + + break :brk .send_certificate; + }, + .send_certificate => brk: { + try self.send_certificate(); + + break :brk .send_certificate_verify; + }, + .send_certificate_verify => brk: { + try self.send_certificate_verify(); + break :brk .send_finished; + }, + .send_finished => brk: { + try self.send_finished(); + break :brk .recv_finished; + }, + .recv_finished => brk: { + try self.recv_finished(); + break :brk .none; + }, + .none => .none, + }; + } + + pub fn recv_hello(self: *Handshake) !ClientHello { + var s = &self.tls_stream; + var reader = s.stream().reader(); + + try s.expectInnerPlaintext(.handshake, .client_hello); + + _ = try s.read(tls.Version); + var client_random: [32]u8 = undefined; + try reader.readNoEof(&client_random); + + var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try s.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return s.writeError(.illegal_parameter); + try reader.readNoEof(session_id[0..session_id_len]); + + const cipher_suite: tls.CipherSuite = brk: { + var cipher_suite_iter = try s.iterator(u16, tls.CipherSuite); + var res: ?tls.CipherSuite = null; + while (try cipher_suite_iter.next()) |suite| { + for (self.options.cipher_suites) |cs| { + if (cs == suite and res == null) res = cs; + } + } + if (res == null) return s.writeError(.illegal_parameter); + break :brk res.?; + }; + s.transcript_hash.?.setActive(cipher_suite); + + { + var compression_methods: [2]u8 = undefined; + try reader.readNoEof(&compression_methods); + if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) + return s.writeError(.illegal_parameter); + } + + var tls_version: ?tls.Version = null; + var key_share: ?tls.KeyShare = null; + var ec_point_format: ?tls.EcPointFormat = null; + var sig_scheme: ?tls.SignatureScheme = null; + + var extension_iter = try s.extensions(); + while (try extension_iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (tls_version != null) return s.writeError(.illegal_parameter); + var versions_iter = try s.iterator(u8, tls.Version); + while (try versions_iter.next()) |v| { + if (v == .tls_1_3) tls_version = v; + } + }, + // TODO: use supported_groups instead + .key_share => { + if (key_share != null) return s.writeError(.illegal_parameter); + + var key_share_iter = try s.iterator(u16, tls.KeyShare); + while (try key_share_iter.next()) |ks| { + for (self.options.key_shares) |k| { + if (ks == k and key_share == null) key_share = ks; + } + } + if (key_share == null) return s.writeError(.decode_error); + }, + .ec_point_formats => { + var format_iter = try s.iterator(u8, tls.EcPointFormat); + while (try format_iter.next()) |f| { + if (f == .uncompressed) ec_point_format = .uncompressed; + } + if (ec_point_format == null) return s.writeError(.decode_error); + }, + .signature_algorithms => { + const acceptable = switch (self.options.certificate_key) { + .none => &[_]tls.SignatureScheme{}, // should be all of them + .rsa => &[_]tls.SignatureScheme{ + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha256, + }, + .ecdsa256 => &[_]tls.SignatureScheme{.ecdsa_secp256r1_sha256}, + .ecdsa384 => &[_]tls.SignatureScheme{.ecdsa_secp384r1_sha384}, + .ed25519 => &[_]tls.SignatureScheme{.ed25519}, + }; + var algos_iter = try s.iterator(u16, tls.SignatureScheme); + while (try algos_iter.next()) |algo| { + if (self.options.certificate_key == .none) sig_scheme = algo; + for (acceptable) |a| { + if (algo == a and sig_scheme == null) sig_scheme = algo; + } + } + if (sig_scheme == null) return s.writeError(.decode_error); + }, + else => { + try reader.skipBytes(ext.len, .{}); + }, + } + } + + if (tls_version != .tls_1_3) return s.writeError(.protocol_version); + if (key_share == null) return s.writeError(.missing_extension); + if (ec_point_format == null) return s.writeError(.missing_extension); + if (sig_scheme == null) return s.writeError(.missing_extension); + + self.key_pair = switch (key_share.?) { + inline .secp256r1, + .secp384r1, + .x25519, + => |_, tag| brk: { + const T = tls.NamedGroupT(tag).KeyPair; + const pair = T.create(self.keygen_seed[0..T.seed_length].*) catch unreachable; + break :brk @unionInit(tls.KeyPair, @tagName(tag), pair); + }, + else => return s.writeError(.decode_error), + }; + + return .{ + .random = client_random, + .session_id_len = session_id_len, + .session_id = session_id, + .cipher_suite = cipher_suite, + .key_share = key_share.?, + .sig_scheme = sig_scheme.?, + }; + } + + pub fn send_hello(self: *Handshake) !void { + var s = &self.tls_stream; + const key_pair = self.key_pair; + const client_hello = self.client_hello; + + const hello = tls.ServerHello{ + .random = self.server_random, + .session_id = &client_hello.session_id, + .cipher_suite = client_hello.cipher_suite, + .extensions = &.{ + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{key_pair.toKeyShare()} }, + }, + }; + s.version = .tls_1_2; + _ = try s.write(tls.Handshake, .{ .server_hello = hello }); + try s.flush(); + + const shared_key = switch (client_hello.key_share) { + .x25519 => |ks| brk: { + const shared_point = tls.NamedGroupT(.x25519).scalarmult( + key_pair.x25519.secret_key, + ks, + ) catch return s.writeError(.decrypt_error); + break :brk &shared_point; + }, + inline .secp256r1, .secp384r1 => |ks, tag| brk: { + const key = @field(key_pair, @tagName(tag)); + const mul = ks.p.mulPublic(key.secret_key.bytes, .big) catch + return s.writeError(.decrypt_error); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + else => return s.writeError(.illegal_parameter), + }; + + const hello_hash = s.transcript_hash.?.peek(); + const handshake_cipher = tls.HandshakeCipher.init( + client_hello.cipher_suite, + shared_key, + hello_hash, + self.logger(), + ) catch + return s.writeError(.illegal_parameter); + s.cipher = .{ .handshake = handshake_cipher }; + } + + pub fn send_encrypted_extensions(self: *Handshake) !void { + var s = &self.tls_stream; + _ = try s.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); + try s.flush(); + } + + pub fn send_certificate(self: *Handshake) !void { + var s = &self.tls_stream; + _ = try self.tls_stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + try s.flush(); + } + + pub fn send_certificate_verify(self: *Handshake) !void { + var s = &self.tls_stream; + const salt = self.certificate_verify_salt; + const scheme = self.client_hello.sig_scheme; + + const digest = s.transcript_hash.?.peek(); + const sig_content = tls.sigContent(digest); + + const signature: []const u8 = switch (scheme) { + inline .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384 => |comptime_scheme| brk: { + const Ecdsa = comptime_scheme.Ecdsa(); + const key = switch (comptime_scheme) { + .ecdsa_secp256r1_sha256 => self.options.certificate_key.ecdsa256, + .ecdsa_secp384r1_sha384 => self.options.certificate_key.ecdsa384, + else => unreachable, + }; + + var signer = Ecdsa.Signer.init(key, salt[0..Ecdsa.noise_length].*); + signer.update(sig_content); + const sig = signer.finalize() catch return s.writeError(.internal_error); + break :brk &sig.toBytes(); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| brk: { + const Hash = comptime_scheme.Hash(); + const key = self.options.certificate_key.rsa; + + switch (key.public.n.bits() / 8) { + inline 128, 256, 512 => |modulus_length| { + const sig = Certificate.rsa.PSSSignature.sign( + modulus_length, + sig_content, + Hash, + key, + salt[0..Hash.digest_length].*, + ) catch return s.writeError(.bad_certificate); + break :brk &sig; + }, + else => return s.writeError(.bad_certificate), + } + }, + .ed25519 => brk: { + const Ed25519 = crypto.sign.Ed25519; + const key = self.options.certificate_key.ed25519; + + const pub_key = brk2: { + const cert_buf = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; + const cert = try cert_buf.parse(); + const expected_len = Ed25519.PublicKey.encoded_length; + if (cert.pubKey().len != expected_len) return s.writeError(.bad_certificate); + break :brk2 Ed25519.PublicKey.fromBytes(cert.pubKey()[0..expected_len].*) catch + return s.writeError(.bad_certificate); + }; + const nonce: Ed25519.CompressedScalar = salt[0..Ed25519.noise_length].*; + + const key_pair = Ed25519.KeyPair{ .public_key = pub_key, .secret_key = key }; + const sig = key_pair.sign(sig_content, nonce) catch return s.writeError(.internal_error); + break :brk &sig.toBytes(); + }, + else => { + return s.writeError(.bad_certificate); + }, + }; + + _ = try self.tls_stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ + .algorithm = scheme, + .signature = signature, + } }); + try s.flush(); + } + + pub fn send_finished(self: *Handshake) !void { + var s = &self.tls_stream; + const verify_data = switch (s.cipher.handshake) { + inline else => |v| brk: { + const T = @TypeOf(v); + const secret = v.server_finished_key; + const transcript_hash = s.transcript_hash.?.peek(); + + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); + }, + }; + _ = try s.write(tls.Handshake, .{ .finished = verify_data }); + try s.flush(); + } + + pub fn recv_finished(self: *Handshake) !void { + var s = &self.tls_stream; + var reader = s.stream().reader(); + + const handshake_hash = s.transcript_hash.?.peek(); + + const application_cipher = tls.ApplicationCipher.init( + s.cipher.handshake, + handshake_hash, + self.logger(), + ); + + const expected = switch (s.cipher.handshake) { + inline else => |p| brk: { + const P = @TypeOf(p); + const digest = s.transcript_hash.?.peek(); + break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); + }, + }; + + try s.expectInnerPlaintext(.handshake, .finished); + const actual = s.view; + try reader.skipBytes(s.view.len, .{}); + + if (!mem.eql(u8, expected, actual)) return s.writeError(.decode_error); + + s.content_type = .application_data; + s.handshake_type = null; + s.cipher = .{ .application = application_cipher }; + s.transcript_hash = null; + } + + /// Establishes a TLS connection on `tls_stream` and returns a Client. + pub fn handshake(self: *Handshake) !Server { + while (self.command != .none) self.next() catch |err| switch (err) { + error.ConnectionResetByPeer => { + // Prevent reply attacks + self.command = .send_hello; + self.init_random(); + }, + else => return err, + }; + + return Server{ + .tls_stream = self.tls_stream, + .key_logger = .{ + .writer = self.options.key_log, + .client_random = self.client_hello.random, + }, + }; + } + + fn logger(self: *Handshake) tls.KeyLogger { + return tls.KeyLogger{ + .client_random = self.client_hello.random, + .writer = self.options.key_log, + }; + } +}; + diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig new file mode 100644 index 000000000000..140e1802cfac --- /dev/null +++ b/lib/std/crypto/tls/Stream.zig @@ -0,0 +1,539 @@ +//! Abstraction over TLS record layer (RFC 8446 S5). +//! +//! After writing must `flush` before reading. +//! +//! Handles: +//! * Fragmentation +//! * Encryption and decryption of handshake and application data messages +//! * Reading and writing prefix length arrays +//! * Reading and writing TLS types +//! * Alerts +const std = @import("../../std.zig"); +const tls = std.crypto.tls; + +inner_stream: std.io.AnyStream, +/// Used for both reading and writing. +/// Stores plaintext or briefly ciphertext, but not Plaintext headers. +buffer: [fragment_size]u8 = undefined, +/// Unread or unwritten view of `buffer`. May contain multiple handshakes. +view: []const u8 = "", + +/// When sending this is the record type that will be flushed. +/// When receiving this is the next fragment's expected record type. +content_type: ContentType = .handshake, +/// When sending this is the flushed version. +version: Version = .tls_1_0, +/// When receiving a handshake message this is its expected type. +handshake_type: ?HandshakeType = .client_hello, + +/// Used to encrypt and decrypt messages. +cipher: Cipher = .none, + +/// True when we send or receive a close_notify alert. +closed: bool = false, + +/// True if we're being used as a client. This changes: +/// * Certain shared struct formats (like Extension) +/// * Which cipher members are used for encryption/decryption +is_client: bool, + +/// When > 0 will discard writes. Used to discover prefix lengths. +nocommit: u32 = 0, + +/// Client and server implementations can set this to cause +/// sent and received handshake messages to update the hash. +transcript_hash: ?*MultiHash = null, + +const Self = @This(); +const ContentType = tls.ContentType; +const Version = tls.Version; +const HandshakeType = tls.HandshakeType; +const MultiHash = tls.MultiHash; +const Plaintext = tls.Plaintext; +const HandshakeCipher = tls.HandshakeCipher; +const ApplicationCipher = tls.ApplicationCipher; +const Alert = tls.Alert; +const Extension = tls.Extension; + +const fragment_size = Plaintext.max_length; + +const Cipher = union(enum) { + none: void, + application: ApplicationCipher, + handshake: HandshakeCipher, +}; + +pub const ReadError = anyerror || tls.Error || error{EndOfStream}; +pub const WriteError = anyerror || error{TlsEncodeError}; + +fn ciphertextOverhead(self: Self) usize { + return switch (self.cipher) { + inline .application, .handshake => |c| switch (c) { + inline else => |t| @TypeOf(t).AEAD.tag_length + @sizeOf(ContentType), + }, + else => 0, + }; +} + +fn maxFragmentSize(self: Self) usize { + return self.buffer.len - self.ciphertextOverhead(); +} + +const EncryptionMethod = enum { none, handshake, application }; +fn encryptionMethod(self: Self, content_type: ContentType) EncryptionMethod { + switch (content_type) { + .alert, .change_cipher_spec => {}, + else => { + if (self.cipher == .application) return .application; + if (self.cipher == .handshake) return .handshake; + }, + } + return .none; +} + +pub fn flush(self: *Self) WriteError!void { + if (self.view.len == 0) return; + if (self.transcript_hash) |t| { + if (self.content_type == .handshake) t.update(self.view); + } + + var plaintext = Plaintext{ + .type = self.content_type, + .version = self.version, + .len = @intCast(self.view.len), + }; + + var header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); + var aead: []const u8 = ""; + switch (self.cipher) { + .none => {}, + inline .application, .handshake => |*cipher| { + plaintext.type = .application_data; + plaintext.len += @intCast(self.ciphertextOverhead()); + header = Encoder.encode(Plaintext, plaintext); + switch (cipher.*) { + inline else => |*c| { + std.debug.assert(self.view.ptr == &self.buffer); + self.buffer[self.view.len] = @intFromEnum(self.content_type); + self.view = self.buffer[0 .. self.view.len + 1]; + aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); + }, + } + }, + } + + // TODO: contiguous buffer management + try self.inner_stream.writer().writeAll(&header); + try self.inner_stream.writer().writeAll(self.view); + try self.inner_stream.writer().writeAll(aead); + self.view = self.buffer[0..0]; +} + +/// Flush a change cipher spec message to the underlying stream. +pub fn changeCipherSpec(self: *Self) !void { + self.version = .tls_1_2; + + const plaintext = Plaintext{ + .type = .change_cipher_spec, + .version = self.version, + .len = 1, + }; + const msg = [_]u8{1}; + const header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); + // TODO: contiguous buffer management + try self.inner_stream.writer().writeAll(&header); + try self.inner_stream.writer().writeAll(&msg); +} + +/// Write an alert to stream and call `close_notify` after. Returns Zig error. +pub fn writeError(self: *Self, err: Alert.Description) tls.Error { + const alert = Alert{ .level = .fatal, .description = err }; + + self.view = self.buffer[0..0]; + self.content_type = .alert; + _ = self.write(Alert, alert) catch {}; + self.flush() catch {}; + + self.close(); + @panic("ohnooo"); + // return err.toError(); +} + +pub fn close(self: *Self) void { + const alert = Alert{ .level = .fatal, .description = .close_notify }; + _ = self.write(Alert, alert) catch {}; + self.content_type = .alert; + self.flush() catch {}; + self.closed = true; +} + +/// Write bytes to `stream`, potentially flushing once `self.buffer` is full. +pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; + if (self.nocommit > 0) return bytes.len; + + const available = self.buffer.len - self.view.len; + const to_consume = bytes[0..@min(available, bytes.len)]; + + @memcpy(self.buffer[self.view.len..][0..to_consume.len], to_consume); + self.view = self.buffer[0 .. self.view.len + to_consume.len]; + + if (self.view.len == self.buffer.len) try self.flush(); + + return to_consume.len; +} + +pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T) !usize { + var res: usize = 0; + for (values) |v| res += self.length(T, v); + + if (PrefixT != void) { + if (res > std.math.maxInt(PrefixT)) { + self.close(); + return error.TlsEncodeError; // Prefix length overflow + } + res += try self.write(PrefixT, @intCast(res)); + } + + for (values) |v| _ = try self.write(T, v); + + return res; +} + +/// Returns number of bytes written. Convienent for encoding struct types in tls.zig . +pub fn writeAll(self: *Self, bytes: []const u8) !usize { + try self.stream().writer().writeAll(bytes); + return bytes.len; +} + +pub fn write(self: *Self, comptime T: type, value: T) !usize { + switch (@typeInfo(T)) { + .Int, .Enum => { + const encoded = Encoder.encode(T, value); + try self.stream().writer().writeAll(&encoded); + return encoded.len; + }, + .Struct, .Union => { + return try T.write(value, self); + }, + .Void => return 0, + else => @compileError("cannot write " ++ @typeName(T)), + } +} + +pub fn length(self: *Self, comptime T: type, value: T) usize { + if (T == void) return 0; + self.nocommit += 1; + defer self.nocommit -= 1; + return self.write(T, value) catch unreachable; +} + +pub fn arrayLength( + self: *Self, + comptime PrefixT: type, + comptime T: type, + values: []const T, +) usize { + var res: usize = if (PrefixT == void) 0 else @divExact(@typeInfo(PrefixT).Int.bits, 8); + for (values) |v| res += self.length(T, v); + return res; +} + +/// Reads bytes from `view`, potentially reading more fragments from underlying `stream`. +/// +/// A return value of 0 indicates EOF. +pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { + // > Any data received after a closure alert has been received MUST be ignored. + if (self.eof()) return 0; + + if (self.view.len == 0) try self.expectInnerPlaintext(self.content_type, self.handshake_type); + + var bytes_read: usize = 0; + + for (iov) |b| { + var bytes_read_buffer: usize = 0; + while (bytes_read_buffer != b.iov_len and !self.eof()) { + const to_read = @min(b.iov_len, self.view.len); + if (to_read == 0) return bytes_read; + + @memcpy(b.iov_base[0..to_read], self.view[0..to_read]); + + self.view = self.view[to_read..]; + bytes_read_buffer += to_read; + bytes_read += bytes_read_buffer; + } + } + + return bytes_read; +} + +/// Reads plaintext from `stream` into `buffer` and updates `view`. +/// Skips non-fatal alert and change_cipher_spec messages. +/// Will decrypt according to `encryptionMethod` if receiving application_data message. +pub fn readPlaintext(self: *Self) !Plaintext { + std.debug.assert(self.view.len == 0); // last read should have completed + var plaintext_bytes: [Plaintext.size]u8 = undefined; + var n_read: usize = 0; + + while (true) { + n_read = try self.inner_stream.reader().readAll(&plaintext_bytes); + if (n_read != plaintext_bytes.len) return self.writeError(.decode_error); + + var res = Plaintext.init(plaintext_bytes); + if (res.len > Plaintext.max_length) return self.writeError(.record_overflow); + + self.view = self.buffer[0..res.len]; + n_read = try self.inner_stream.reader().readAll(@constCast(self.view)); + if (n_read != res.len) return self.writeError(.decode_error); + + const encryption_method = self.encryptionMethod(res.type); + if (encryption_method != .none) { + if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); + + switch (self.cipher) { + inline .handshake, .application => |*cipher| switch (cipher.*) { + inline else => |*c| { + const C = @TypeOf(c.*); + const tag_len = C.AEAD.tag_length; + + const ciphertext = self.view[0 .. self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; + const out: []u8 = @constCast(self.view[0..ciphertext.len]); + c.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch + return self.writeError(.bad_record_mac); + + const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); + if (padding_start) |s| { + res.type = @enumFromInt(self.view[s]); + self.view = self.view[0..s]; + } else { + return self.writeError(.decode_error); + } + }, + }, + else => unreachable, + } + } + + switch (res.type) { + .alert => { + const level = try self.read(Alert.Level); + const description = try self.read(Alert.Description); + std.log.debug("TLS alert {} {}", .{ level, description }); + + switch (description) { + .close_notify => { + self.closed = true; + return res; + }, + .certificate_revoked, .certificate_unknown, .certificate_expired, .certificate_required => {}, + else => { + return self.writeError(.unexpected_message); + }, + } + }, + // > An implementation may receive an unencrypted record of type + // > change_cipher_spec consisting of the single byte value 0x01 at any + // > time after the first ClientHello message has been sent or received + // > and before the peer's Finished message has been received and MUST + // > simply drop it without further processing. + .change_cipher_spec => { + if (!std.mem.eql(u8, self.view, &[_]u8{1})) { + return self.writeError(.unexpected_message); + } + }, + else => { + return res; + }, + } + } +} + +pub fn readInnerPlaintext(self: *Self) !InnerPlaintext { + var res: InnerPlaintext = .{ + .type = self.content_type, + .handshake_type = if (self.handshake_type) |h| h else undefined, + .len = 0, + }; + if (self.closed) return res; + + if (self.view.len == 0) { + const plaintext = try self.readPlaintext(); + if (self.closed) return res; + + res.type = plaintext.type; + res.len = plaintext.len; + + self.content_type = res.type; + } + + if (res.type == .handshake) { + if (self.transcript_hash) |t| t.update(self.view[0..4]); + res.handshake_type = try self.read(HandshakeType); + res.len = try self.read(u24); + if (self.transcript_hash) |t| t.update(self.view[0..res.len]); + + self.handshake_type = res.handshake_type; + } + + return res; +} + +pub fn expectInnerPlaintext( + self: *Self, + expected_content: ContentType, + expected_handshake: ?HandshakeType, +) !void { + const inner_plaintext = try self.readInnerPlaintext(); + if (expected_content != inner_plaintext.type) return self.writeError(.unexpected_message); + if (expected_handshake) |expected| { + if (expected != inner_plaintext.handshake_type) return self.writeError(.decode_error); + } +} + +pub fn read(self: *Self, comptime T: type) !T { + comptime std.debug.assert(@sizeOf(T) < fragment_size); + switch (@typeInfo(T)) { + .Int => return self.stream().reader().readInt(T, .big) catch |err| switch (err) { + error.EndOfStream => return self.writeError(.decode_error), + else => |e| return e, + }, + .Enum => |info| { + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + const int = try self.read(info.tag_type); + return @enumFromInt(int); + }, + else => { + return T.read(self) catch |err| switch (err) { + error.TlsUnexpectedMessage => return self.writeError(.unexpected_message), + error.TlsBadRecordMac => return self.writeError(.bad_record_mac), + error.TlsRecordOverflow => return self.writeError(.record_overflow), + error.TlsHandshakeFailure => return self.writeError(.handshake_failure), + error.TlsBadCertificate => return self.writeError(.bad_certificate), + error.TlsUnsupportedCertificate => return self.writeError(.unsupported_certificate), + error.TlsCertificateRevoked => return self.writeError(.certificate_revoked), + error.TlsCertificateExpired => return self.writeError(.certificate_expired), + error.TlsCertificateUnknown => return self.writeError(.certificate_unknown), + error.TlsIllegalParameter => return self.writeError(.illegal_parameter), + error.TlsUnknownCa => return self.writeError(.unknown_ca), + error.TlsAccessDenied => return self.writeError(.access_denied), + error.TlsDecodeError => return self.writeError(.decode_error), + error.TlsDecryptError => return self.writeError(.decrypt_error), + error.TlsProtocolVersion => return self.writeError(.protocol_version), + error.TlsInsufficientSecurity => return self.writeError(.insufficient_security), + error.TlsInternalError => return self.writeError(.internal_error), + error.TlsInappropriateFallback => return self.writeError(.inappropriate_fallback), + error.TlsMissingExtension => return self.writeError(.missing_extension), + error.TlsUnsupportedExtension => return self.writeError(.unsupported_extension), + error.TlsUnrecognizedName => return self.writeError(.unrecognized_name), + error.TlsBadCertificateStatusResponse => return self.writeError(.bad_certificate_status_response), + error.TlsUnknownPskIdentity => return self.writeError(.unknown_psk_identity), + error.TlsCertificateRequired => return self.writeError(.certificate_required), + error.TlsNoApplicationProtocol => return self.writeError(.no_application_protocol), + error.TlsUnknown => |e| { + self.close(); + return e; + }, + else => return self.writeError(.decode_error), + }; + }, + } +} + +fn Iterator(comptime T: type) type { + return struct { + stream: *Self, + end: usize, + + pub fn next(self: *@This()) !?T { + const cur_offset = self.stream.buffer.len - self.stream.view.len; + if (cur_offset > self.end) return null; + return try self.stream.read(T); + } + }; +} + +pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) !Iterator(Tag) { + const offset = self.buffer.len - self.view.len; + const len = try self.read(Len); + return Iterator(Tag){ + .stream = self, + .end = offset + len, + }; +} + +pub fn extensions(self: *Self) !Iterator(Extension.Header) { + return self.iterator(u16, Extension.Header); +} + +pub fn eof(self: Self) bool { + return self.closed and self.view.len == 0; +} + +pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); + +pub fn stream(self: *Self) GenericStream { + return .{ .context = self }; +} + +const Encoder = struct { + fn RetType(comptime T: type) type { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => return [1]u8, + 16 => return [2]u8, + 24 => return [3]u8, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return RetType(info.tag_type); + }, + .Struct => |info| { + var len: usize = 0; + inline for (info.fields) |f| len += @typeInfo(RetType(f.type)).Array.len; + return [len]u8; + }, + else => @compileError("don't know how to encode " ++ @tagName(T)), + } + } + fn encode(comptime T: type, value: T) RetType(T) { + return switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => .{value}, + 16 => .{ + @as(u8, @truncate(value >> 8)), + @as(u8, @truncate(value)), + }, + 24 => .{ + @as(u8, @truncate(value >> 16)), + @as(u8, @truncate(value >> 8)), + @as(u8, @truncate(value)), + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| encode(info.tag_type, @intFromEnum(value)), + .Struct => |info| brk: { + const Ret = RetType(T); + + var offset: usize = 0; + var res: Ret = undefined; + inline for (info.fields) |f| { + const encoded = encode(f.type, @field(value, f.name)); + @memcpy(res[offset..][0..encoded.len], &encoded); + offset += encoded.len; + } + + break :brk res; + }, + else => @compileError("cannot encode type " ++ @typeName(T)), + }; + } +}; + +const InnerPlaintext = struct { + type: ContentType, + handshake_type: HandshakeType, + len: u24, +}; diff --git a/lib/std/fifo.zig b/lib/std/fifo.zig index a26086700258..6fe588bcfeea 100644 --- a/lib/std/fifo.zig +++ b/lib/std/fifo.zig @@ -38,8 +38,8 @@ pub fn LinearFifo( count: usize, const Self = @This(); - pub const Reader = std.io.Reader(*Self, error{}, readFn); - pub const Writer = std.io.Writer(*Self, error{OutOfMemory}, appendWrite); + pub const Reader = std.io.Reader(*Self, error{}, readvFn); + pub const Writer = std.io.Writer(*Self, error{OutOfMemory}, appendWritev); // Type of Self argument for slice operations. // If buffer is inline (Static) then we need to ensure we haven't @@ -232,8 +232,15 @@ pub fn LinearFifo( /// Same as `read` except it returns an error union /// The purpose of this function existing is to match `std.io.Reader` API. - fn readFn(self: *Self, dest: []u8) error{}!usize { - return self.read(dest); + fn readvFn(self: *Self, iov: []const std.os.iovec) error{}!usize { + var n_read: usize = 0; + for (iov) |v| { + const n = self.read(v.iov_base[0..v.iov_len]); + if (n == 0) return n_read; + + n_read += n; + } + return n_read; } pub fn reader(self: *Self) Reader { @@ -321,9 +328,13 @@ pub fn LinearFifo( /// Same as `write` except it returns the number of bytes written, which is always the same /// as `bytes.len`. The purpose of this function existing is to match `std.io.Writer` API. - fn appendWrite(self: *Self, bytes: []const u8) error{OutOfMemory}!usize { - try self.write(bytes); - return bytes.len; + fn appendWritev(self: *Self, iov: []const std.os.iovec_const) error{OutOfMemory}!usize { + var written: usize = 0; + for (iov) |v| { + try self.write(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn writer(self: *Self) Writer { diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 669f1b72e33f..61fa536e36a0 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1441,13 +1441,13 @@ fn writeFileAllSendfile(self: File, in_file: File, args: WriteFileOptions) posix } } -pub const Reader = io.Reader(File, ReadError, read); +pub const Reader = io.Reader(File, ReadError, readv); pub fn reader(file: File) Reader { return .{ .context = file }; } -pub const Writer = io.Writer(File, WriteError, write); +pub const Writer = io.Writer(File, WriteError, writev); pub fn writer(file: File) Writer { return .{ .context = file }; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 339afdb96e91..95f3302be420 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,8 +1,6 @@ -//! HTTP(S) Client implementation. +//! Blocking HTTP(S) client //! //! Connections are opened in a thread-safe manner, but individual Requests are not. -//! -//! TLS support may be disabled via `std.options.http_disable_tls`. const std = @import("../std.zig"); const builtin = @import("builtin"); @@ -17,18 +15,12 @@ const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); const proto = @import("protocol.zig"); - -pub const disable_tls = std.options.http_disable_tls; +const tls = std.crypto.tls; /// Used for all client allocations. Must be thread-safe. allocator: Allocator, -ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, -ca_bundle_mutex: std.Thread.Mutex = .{}, - -/// When this is `true`, the next time this client performs an HTTPS request, -/// it will first rescan the system for root certificates. -next_https_rescan_certs: bool = true, +tls_options: TlsOptions = .{}, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, @@ -42,6 +34,17 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, +/// tls.Client.Options minus ones that we set +pub const TlsOptions = struct { + /// Client takes ownership of this field. If empty, will rescan on init. + /// + /// Trusted certificate authority bundle used to authenticate server certificates. + /// When null, server certificate and certificate_verify messages will be skipped (INSECURE). + ca_bundle: ?std.crypto.Certificate.Bundle = .{}, + /// List of cipher suites to advertise in order of descending preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, +}; + /// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { mutex: std.Thread.Mutex = .{}, @@ -149,8 +152,6 @@ pub const ConnectionPool = struct { pool.mutex.lock(); defer pool.mutex.unlock(); - const next = pool.free.first; - _ = next; while (pool.free_len > new_size) { const popped = pool.free.popFirst() orelse unreachable; pool.free_len -= 1; @@ -190,9 +191,10 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, + /// Underlying socket + socket: net.Socket, + /// TLS client. + tls: tls.Client, /// The protocol that this connection is using. protocol: Protocol, @@ -215,33 +217,21 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + /// Want to be greater than max HTTP headers length. + pub const buffer_size = 4096; const BufferSize = std.math.IntFittingRange(0, buffer_size); pub const Protocol = enum { plain, tls }; - pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } + pub inline fn stream(conn: *Connection) std.io.AnyStream { + return switch (conn.protocol) { + .plain => conn.socket.stream().any(), + .tls => conn.tls.any().any(), }; } pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); - } - - return conn.stream.readv(buffers) catch |err| switch (err) { + return conn.stream().readv(buffers) catch |err| switch (err) { error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, @@ -303,48 +293,36 @@ pub const Connection = struct { return nread; } + pub fn readv(conn: *Connection, iov: []const std.os.iovec) ReadError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; + return try conn.read(buffer); + } + pub const ReadError = error{ TlsFailure, - TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure, EndOfStream, }; - pub const Reader = std.io.Reader(*Connection, ReadError, read); + pub const Reader = std.io.Reader(*Connection, ReadError, readv); pub fn reader(conn: *Connection) Reader { return Reader{ .context = conn }; } - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } - - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - /// Writes the given buffer to the connection. pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { if (conn.write_buf.len - conn.write_end < buffer.len) { try conn.flush(); if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); + conn.stream().writer().writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; return buffer.len; } } @@ -355,6 +333,12 @@ pub const Connection = struct { return buffer.len; } + pub fn writev(conn: *Connection, iov: []const std.os.iovec_const) WriteError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; + return try conn.write(buffer); + } + /// Returns a buffer to be filled with exactly len bytes to write to the connection. pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { if (conn.write_buf.len - conn.write_end < len) try conn.flush(); @@ -366,32 +350,26 @@ pub const Connection = struct { pub fn flush(conn: *Connection) WriteError!void { if (conn.write_end == 0) return; - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + try conn.stream().writer().writeAll(conn.write_buf[0..conn.write_end]); conn.write_end = 0; + + if (conn.protocol == .tls) try conn.tls.stream.flush(); } - pub const WriteError = error{ + pub const WriteError = anyerror || error{ ConnectionResetByPeer, UnexpectedWriteFailure, }; - pub const Writer = std.io.Writer(*Connection, WriteError, write); + pub const Writer = std.io.Writer(*Connection, WriteError, writev); pub fn writer(conn: *Connection) Writer { return Writer{ .context = conn }; } - /// Closes the connection. + /// Closes the connection and deinitializes members. pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; - allocator.destroy(conn.tls_client); - } - - conn.stream.close(); + conn.stream().close(); allocator.free(conn.host); } }; @@ -919,14 +897,16 @@ pub const Request = struct { const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + const TransferReader = std.io.Reader(*Request, TransferReadError, transferReadv); fn transferReader(req: *Request) TransferReader { return .{ .context = req }; } - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + fn transferReadv(req: *Request, iov: []const std.os.iovec) TransferReadError!usize { if (req.response.parser.done) return 0; + const first = iov[0]; + const buf = first.iov_base[0..first.iov_len]; var index: usize = 0; while (index == 0) { @@ -1031,7 +1011,11 @@ pub const Request = struct { // skip the body of the redirect response, this will at least // leave the connection in a known good state. req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + var buf: [0]u8 = undefined; + // we're skipping, no buffer is necessary + const iovecs = [_]std.os.iovec{.{ .iov_base = &buf, .iov_len = 0 }}; + const n_read = try req.transferReadv(&iovecs); + assert(n_read == 0); if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; @@ -1115,20 +1099,20 @@ pub const Request = struct { pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - pub const Reader = std.io.Reader(*Request, ReadError, read); + pub const Reader = std.io.Reader(*Request, ReadError, readv); pub fn reader(req: *Request) Reader { return .{ .context = req }; } /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { + pub fn readv(req: *Request, iov: []const std.os.iovec) ReadError!usize { const out_index = switch (req.response.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + .deflate => |*deflate| deflate.readv(iov) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.readv(iov) catch return error.DecompressionFailure, // https://github.com/ziglang/zig/issues/18937 //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), + else => try req.transferReadv(iov), }; if (out_index > 0) return out_index; @@ -1142,20 +1126,9 @@ pub const Request = struct { return 0; } - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - pub const Writer = std.io.Writer(*Request, WriteError, write); + pub const Writer = std.io.Writer(*Request, WriteError, writev); pub fn writer(req: *Request) Writer { return .{ .context = req }; @@ -1163,21 +1136,28 @@ pub const Request = struct { /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { + pub fn writev(req: *Request, iov: []const std.os.iovec_const) WriteError!usize { + var iov_len: usize = 0; + for (iov) |v| iov_len += v.iov_len; + switch (req.transfer_encoding) { .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); + var w = req.connection.?.writer(); + + if (iov_len > 0) { + try w.print("{x}\r\n", .{iov_len}); + for (iov) |v| try w.writeAll(v.iov_base[0..v.iov_len]); + try w.writeAll("\r\n"); } - return bytes.len; + return iov_len; }, .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; + const cwriter = req.connection.?.writer(); - const amt = try req.connection.?.write(bytes); + if (len.* < iov_len) return error.MessageTooLong; + + const amt = try cwriter.writev(iov); len.* -= amt; return amt; }, @@ -1185,15 +1165,6 @@ pub const Request = struct { } } - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. @@ -1217,6 +1188,17 @@ pub const Proxy = struct { supports_connect: bool, }; +/// Initializes ca_bundle if it's not null and empty. +pub fn init(client: Client) !Client { + var copy = client; + + if (copy.tls_options.ca_bundle) |*bundle| { + if (bundle.bytes.items.len == 0) try bundle.rescan(copy.allocator); + } + + return copy; +} + /// Release all associated resources with the client. /// /// All pending requests must be de-initialized and all active connections released @@ -1226,8 +1208,7 @@ pub fn deinit(client: *Client) void { client.connection_pool.deinit(client.allocator); - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); + if (client.tls_options.ca_bundle) |*bundle| bundle.deinit(client.allocator); client.* = undefined; } @@ -1327,7 +1308,17 @@ pub const basic_authorization = struct { } }; -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +pub const ConnectTcpError = Allocator.Error || tls.Client.ReadError || tls.Client.WriteError || error{ + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, +}; /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. /// @@ -1340,14 +1331,11 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec })) |node| return node; - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; - const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + const socket = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, error.NetworkUnreachable => return error.NetworkUnreachable, error.ConnectionTimedOut => return error.ConnectionTimedOut, @@ -1358,12 +1346,11 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, else => return error.UnexpectedConnectFailure, }; - errdefer stream.close(); - - conn.data = .{ - .stream = stream, - .tls_client = undefined, + errdefer socket.close(); + conn.data = Connection{ + .socket = socket, + .tls = undefined, .protocol = protocol, .host = try client.allocator.dupe(u8, host), .port = port, @@ -1371,15 +1358,15 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.free(conn.data.host); if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.data.tls_client); - - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; + conn.data.tls = try tls.Client.init(conn.data.socket.stream().any(), .{ + .ca_bundle = client.tls_options.ca_bundle, + .cipher_suites = client.tls_options.cipher_suites, + .host = host, + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + .allow_truncation_attacks = true, + .allocator = client.allocator, + }); } client.connection_pool.addUsed(conn); @@ -1408,8 +1395,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti errdefer stream.close(); conn.data = .{ - .stream = stream, - .tls_client = undefined, + .stream = stream.stream(), .protocol = .plain, .host = try client.allocator.dupe(u8, path), @@ -1551,7 +1537,6 @@ pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendE UnsupportedUrlScheme, UriMissingHost, - CertificateBundleLoadFailure, UnsupportedTransferEncoding, }; @@ -1642,18 +1627,6 @@ pub fn open( const host = uri.host orelse return error.UriMissingHost; - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - const conn = options.connection orelse try client.connect(host, port, protocol); var req: Request = .{ @@ -1753,7 +1726,7 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { try req.send(.{ .raw_uri = options.raw_uri }); - if (options.payload) |payload| try req.writeAll(payload); + if (options.payload) |payload| try req.writer().writeAll(payload); try req.finish(); try req.wait(); @@ -1763,7 +1736,7 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { // Take advantage of request internals to discard the response body // and make the connection available for another request. req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. + assert(try req.transferReadv(&.{}) == 0); // No buffer is necessary when skipping. }, .dynamic => |list| { const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 5290241b6e04..8fb150e3adf0 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,6 +1,17 @@ -//! Blocking HTTP server implementation. +//! Blocking HTTP(s) server +//! //! Handles a single connection's lifecycle. +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; + +const Server = @This(); + connection: net.Server.Connection, /// Keeps track of whether the Server is ready to accept a new request on the /// same connection, and makes invalid API usage cause assertion failures @@ -89,7 +100,7 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request { const buf = s.read_buffer[s.read_buffer_len..]; if (buf.len == 0) return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch + const read_n = s.connection.stream().reader().read(buf) catch return error.HttpHeadersUnreadable; if (read_n == 0) { if (s.read_buffer_len > 0) { @@ -391,7 +402,7 @@ pub const Request = struct { request: *Request, content: []const u8, options: RespondOptions, - ) Response.WriteError!void { + ) !void { const max_extra_headers = 25; assert(options.status != .@"continue"); assert(options.extra_headers.len <= max_extra_headers); @@ -418,7 +429,7 @@ pub const Request = struct { h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); + try request.server.connection.stream().writer().writeAll(h.items); return; } h.fixedWriter().print("{s} {d} {s}\r\n", .{ @@ -524,7 +535,7 @@ pub const Request = struct { } } - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + try request.server.connection.stream().writer().writevAll(iovecs[0..iovecs_len]); } pub const RespondStreamingOptions = struct { @@ -604,7 +615,7 @@ pub const Request = struct { }; return .{ - .stream = request.server.connection.stream, + .connection = request.server.connection, .send_buffer = options.send_buffer, .send_buffer_start = 0, .send_buffer_end = h.items.len, @@ -619,12 +630,14 @@ pub const Request = struct { }; } - pub const ReadError = net.Stream.ReadError || error{ + pub const ReadError = anyerror || net.Socket.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize, }; - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + fn read_cl(context: *const anyopaque, iov: []const std.os.iovec) ReadError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; @@ -648,11 +661,13 @@ pub const Request = struct { const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; if (available.len > 0) return available; s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + s.read_buffer_len = head_end + try s.connection.stream().reader().read(s.read_buffer[head_end..]); return s.read_buffer[head_end..s.read_buffer_len]; } - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + fn readv_chunked(context: *const anyopaque, iov: []const std.os.iovec) ReadError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; @@ -710,7 +725,7 @@ pub const Request = struct { const buf = s.read_buffer[s.read_buffer_len..]; if (buf.len == 0) return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); + const read_n = try s.connection.stream().reader().read(buf); s.read_buffer_len += read_n; const bytes = buf[0..read_n]; const end = hp.feed(bytes); @@ -752,7 +767,7 @@ pub const Request = struct { /// request's expect field to `null`. /// /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { + pub fn reader(request: *Request) !std.io.AnyReader { const s = request.server; assert(s.state == .received_head); s.state = .receiving_body; @@ -760,7 +775,7 @@ pub const Request = struct { if (request.head.expect) |expect| { if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + try request.server.connection.stream().writer().writeAll("HTTP/1.1 100 Continue\r\n\r\n"); request.head.expect = null; } else { return error.HttpExpectationFailed; @@ -771,7 +786,7 @@ pub const Request = struct { .chunked => { request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; return .{ - .readFn = read_chunked, + .readvFn = readv_chunked, .context = request, }; }, @@ -780,7 +795,7 @@ pub const Request = struct { .remaining_content_length = request.head.content_length orelse 0, }; return .{ - .readFn = read_cl, + .readvFn = read_cl, .context = request, }; }, @@ -821,7 +836,7 @@ pub const Request = struct { }; pub const Response = struct { - stream: net.Stream, + connection: net.Server.Connection, send_buffer: []u8, /// Index of the first byte in `send_buffer`. /// This is 0 unless a short write happens in `write`. @@ -845,14 +860,12 @@ pub const Response = struct { chunked, }; - pub const WriteError = net.Stream.WriteError; - /// When using content-length, asserts that the amount of data sent matches /// the value sent in the header, then calls `flush`. /// Otherwise, transfer-encoding: chunked is being used, and it writes the /// end-of-stream message, then flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { + pub fn end(r: *Response) !void { switch (r.transfer_encoding) { .content_length => |len| { assert(len == 0); // Trips when end() called before all bytes written. @@ -877,26 +890,18 @@ pub const Response = struct { /// flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + pub fn endChunked(r: *Response, options: EndChunkedOptions) !void { assert(r.transfer_encoding == .chunked); try flush_chunked(r, options.trailers); r.* = undefined; } - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } - - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + fn write_cl(context: *const anyopaque, iov: []const std.os.iovec_const) !usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; + var trash: u64 = std.math.maxInt(u64); const len = switch (r.transfer_encoding) { .content_length => |*len| len, @@ -910,7 +915,7 @@ pub const Response = struct { if (bytes.len + r.send_buffer_end > r.send_buffer.len) { const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ + var iovecs: [2]std.os.iovec_const = .{ .{ .iov_base = r.send_buffer.ptr + r.send_buffer_start, .iov_len = send_buffer_len, @@ -920,7 +925,7 @@ pub const Response = struct { .iov_len = bytes.len, }, }; - const n = try r.stream.writev(&iovecs); + const n = try r.connection.stream().writev(&iovecs); if (n >= send_buffer_len) { // It was enough to reset the buffer. @@ -944,10 +949,13 @@ pub const Response = struct { return bytes.len; } - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + fn write_chunked(context: *const anyopaque, iov: []const std.os.iovec_const) !usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); assert(r.transfer_encoding == .chunked); + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; + if (r.elide_body) return bytes.len; @@ -981,7 +989,7 @@ pub const Response = struct { }; // TODO make this writev instead of writevAll, which involves // complicating the logic of this function. - try r.stream.writevAll(&iovecs); + try r.connection.stream().writer().writevAll(&iovecs); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -995,32 +1003,23 @@ pub const Response = struct { return bytes.len; } - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); - } - } - /// Sends all buffered data to the client. /// This is redundant after calling `end`. /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { + pub fn flush(r: *Response) !void { switch (r.transfer_encoding) { .none, .content_length => return flush_cl(r), .chunked => return flush_chunked(r, null), } } - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + fn flush_cl(r: *Response) !void { + try r.connection.stream().writer().writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); r.send_buffer_start = 0; r.send_buffer_end = 0; } - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) !void { const max_trailers = 25; if (end_trailers) |trailers| assert(trailers.len <= max_trailers); assert(r.transfer_encoding == .chunked); @@ -1028,7 +1027,7 @@ pub const Response = struct { const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; if (r.elide_body) { - try r.stream.writeAll(http_headers); + try r.connection.stream().writer().writeAll(http_headers); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1109,7 +1108,7 @@ pub const Response = struct { iovecs_len += 1; } - try r.stream.writevAll(iovecs[0..iovecs_len]); + try r.connection.stream().writer().writevAll(iovecs[0..iovecs_len]); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1117,7 +1116,7 @@ pub const Response = struct { pub fn writer(r: *Response) std.io.AnyWriter { return .{ - .writeFn = switch (r.transfer_encoding) { + .writevFn = switch (r.transfer_encoding) { .none, .content_length => write_cl, .chunked => write_chunked, }, @@ -1136,13 +1135,3 @@ fn rebase(s: *Server, index: usize) void { } s.read_buffer_len = index + leftover.len; } - -const std = @import("../std.zig"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 78511f435d67..e1f70b82ab93 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -281,7 +281,7 @@ const MockBufferedConnection = struct { pub fn fill(conn: *MockBufferedConnection) ReadError!void { if (conn.end != conn.start) return; - const nread = try conn.conn.read(conn.buf[0..]); + const nread = try conn.conn.reader().read(conn.buf[0..]); if (nread == 0) return error.EndOfStream; conn.start = 0; conn.end = @as(u16, @truncate(nread)); @@ -313,7 +313,7 @@ const MockBufferedConnection = struct { if (left > conn.buf.len) { // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); + return conn.conn.reader().read(buffer[out_index..]); } try conn.fill(); diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index e2aa810d580d..917c08c63e51 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -14,8 +14,8 @@ test "trailers" { var header_buffer: [1024]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); @@ -33,9 +33,9 @@ test "trailers" { var response = request.respondStreaming(.{ .send_buffer = &send_buffer, }); - try response.writeAll("Hello, "); + try response.writer().writeAll("Hello, "); try response.flush(); - try response.writeAll("World!\n"); + try response.writer().writeAll("World!\n"); try response.flush(); try response.endChunked(.{ .trailers = &.{ @@ -96,8 +96,8 @@ test "HTTP server handles a chunked transfer coding request" { const test_server = try createTestServer(struct { fn run(net_server: *std.net.Server) !void { var header_buffer: [8192]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); var request = try server.receiveHead(); @@ -133,9 +133,10 @@ test "HTTP server handles a chunked transfer coding request" { "\r\n"; const gpa = std.testing.allocator; - const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const socket = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = socket.stream(); defer stream.close(); - try stream.writeAll(request_bytes); + try stream.writer().writeAll(request_bytes); const expected_response = "HTTP/1.1 200 OK\r\n" ++ @@ -155,8 +156,8 @@ test "echo content server" { var read_buffer: [1024]u8 = undefined; accept: while (true) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var http_server = http.Server.init(conn, &read_buffer); @@ -237,8 +238,8 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { var header_buffer: [1000]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); @@ -256,7 +257,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { for (0..500) |i| { var buf: [30]u8 = undefined; const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); - try response.writeAll(line); + try response.writer().writeAll(line); total += line.len; } try expectEqual(7390, total); @@ -269,9 +270,11 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { const request_bytes = "GET /foo HTTP/1.1\r\n\r\n"; const gpa = std.testing.allocator; - const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const socket = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = socket.stream(); defer stream.close(); - try stream.writeAll(request_bytes); + try stream.writer().writeAll(request_bytes); + std.debug.print("requested\n", .{}); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -301,8 +304,8 @@ test "receiving arbitrary http headers from the client" { var read_buffer: [666]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &read_buffer); try expectEqual(.ready, server.state); @@ -332,9 +335,10 @@ test "receiving arbitrary http headers from the client" { "aoeu: asdf \r\n" ++ "\r\n"; const gpa = std.testing.allocator; - const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const socket = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = socket.stream(); defer stream.close(); - try stream.writeAll(request_bytes); + try stream.writer().writeAll(request_bytes); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -361,8 +365,8 @@ test "general client/server API coverage" { fn run(net_server: *std.net.Server) anyerror!void { var client_header_buffer: [1024]u8 = undefined; outer: while (global.handle_new_requests) { - var connection = try net_server.accept(); - defer connection.stream.close(); + var connection = try net_server.accept(null); + defer connection.stream().close(); var http_server = http.Server.init(connection, &client_header_buffer); @@ -831,7 +835,7 @@ test "general client/server API coverage" { // connection has been kept alive try expect(client.http_proxy != null or client.connection_pool.free_len == 1); - { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** + { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); defer gpa.free(location); const uri = try std.Uri.parse(location); @@ -877,8 +881,8 @@ test "Server streams both reading and writing" { const test_server = try createTestServer(struct { fn run(net_server: *std.net.Server) anyerror!void { var header_buffer: [1024]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); var request = try server.receiveHead(); @@ -925,8 +929,8 @@ test "Server streams both reading and writing" { try req.send(.{}); try req.wait(); - try req.writeAll("one "); - try req.writeAll("fish"); + try req.writer().writeAll("one "); + try req.writer().writeAll("fish"); try req.finish(); @@ -957,8 +961,8 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .{ .content_length = 14 }; try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + try req.writer().writeAll("Hello, "); + try req.writer().writeAll("World!\n"); try req.finish(); try req.wait(); @@ -991,8 +995,8 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + try req.writer().writeAll("Hello, "); + try req.writer().writeAll("World!\n"); try req.finish(); try req.wait(); @@ -1045,8 +1049,8 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + try req.writer().writeAll("Hello, "); + try req.writer().writeAll("World!\n"); try req.finish(); try req.wait(); @@ -1121,8 +1125,8 @@ test "redirect to different connection" { fn run(net_server: *std.net.Server) anyerror!void { var header_buffer: [888]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); var request = try server.receiveHead(); @@ -1142,8 +1146,8 @@ test "redirect to different connection" { var header_buffer: [999]u8 = undefined; var send_buffer: [100]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); const new_loc = try std.fmt.bufPrint(&send_buffer, "http://127.0.0.1:{d}/ok", .{ global.other_port.?, diff --git a/lib/std/io.zig b/lib/std/io.zig index df220e24898a..9225f40f3ba3 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -11,6 +11,8 @@ const mem = std.mem; const meta = std.meta; const File = std.fs.File; const Allocator = std.mem.Allocator; +const iovec = std.os.iovec; +const iovec_const = std.os.iovec_const; fn getStdOutHandle() os.fd_t { if (builtin.os.tag == .windows) { @@ -78,18 +80,20 @@ pub fn GenericReader( /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - comptime readFn: fn (context: Context, buffer: []u8) ReadError!usize, + comptime readvFn: fn (context: Context, iov: []const iovec) ReadError!usize, ) type { return struct { context: Context, pub const Error = ReadError; - pub const NoEofError = ReadError || error{ - EndOfStream, - }; + pub const NoEofError = ReadError || error{EndOfStream}; + + pub inline fn readv(self: Self, iov: []const iovec) Error!usize { + return readvFn(self.context, iov); + } pub inline fn read(self: Self, buffer: []u8) Error!usize { - return readFn(self.context, buffer); + return @errorCast(self.any().read(buffer)); } pub inline fn readAll(self: Self, buffer: []u8) Error!usize { @@ -283,18 +287,24 @@ pub fn GenericReader( return @errorCast(self.any().readEnum(Enum, endian)); } + /// Reads the stream until the end, ignoring all the data. + /// Returns the number of bytes discarded. + pub inline fn discard(self: Self) anyerror!u64 { + return @errorCast(self.any().discard()); + } + pub inline fn any(self: *const Self) AnyReader { return .{ .context = @ptrCast(&self.context), - .readFn = typeErasedReadFn, + .readvFn = typeErasedReadvFn, }; } const Self = @This(); - fn typeErasedReadFn(context: *const anyopaque, buffer: []u8) anyerror!usize { + fn typeErasedReadvFn(context: *const anyopaque, iov: []const iovec) anyerror!usize { const ptr: *const Context = @alignCast(@ptrCast(context)); - return readFn(ptr.*, buffer); + return readvFn(ptr.*, iov); } }; } @@ -302,7 +312,7 @@ pub fn GenericReader( pub fn GenericWriter( comptime Context: type, comptime WriteError: type, - comptime writeFn: fn (context: Context, bytes: []const u8) WriteError!usize, + comptime writevFn: fn (context: Context, iov: []const iovec_const) WriteError!usize, ) type { return struct { context: Context, @@ -310,8 +320,16 @@ pub fn GenericWriter( const Self = @This(); pub const Error = WriteError; + pub inline fn writev(self: Self, iov: []const iovec_const) Error!usize { + return writevFn(self.context, iov); + } + + pub inline fn writevAll(self: Self, iov: []iovec_const) Error!void { + return @errorCast(self.any().writevAll(iov)); + } + pub inline fn write(self: Self, bytes: []const u8) Error!usize { - return writeFn(self.context, bytes); + return @errorCast(self.any().write(bytes)); } pub inline fn writeAll(self: Self, bytes: []const u8) Error!void { @@ -345,13 +363,60 @@ pub fn GenericWriter( pub inline fn any(self: *const Self) AnyWriter { return .{ .context = @ptrCast(&self.context), - .writeFn = typeErasedWriteFn, + .writevFn = typeErasedWritevFn, + }; + } + + fn typeErasedWritevFn(context: *const anyopaque, iov: []const iovec_const) anyerror!usize { + const ptr: *const Context = @alignCast(@ptrCast(context)); + return writevFn(ptr.*, iov); + } + }; +} + +pub fn GenericStream( + comptime Context: type, + comptime ReadError: type, + /// Returns the number of bytes read. It may be less than buffer.len. + /// If the number of bytes read is 0, it means end of stream. + /// End of stream is not an error condition. + comptime readvFn: fn (context: Context, iov: []const iovec) ReadError!usize, + comptime WriteError: type, + comptime writevFn: fn (context: Context, iov: []const iovec_const) WriteError!usize, + comptime closeFn: fn (context: Context) void, +) type { + return struct { + context: Context, + + const ReaderType = GenericReader(Context, ReadError, readvFn); + const WriterType = GenericWriter(Context, WriteError, writevFn); + + const Self = @This(); + + pub inline fn reader(self: *const Self) ReaderType { + return .{ .context = self.context }; + } + + pub inline fn writer(self: *const Self) WriterType { + return .{ .context = self.context }; + } + + pub inline fn close(self: *const Self) void { + closeFn(self.context); + } + + pub inline fn any(self: *const Self) AnyStream { + return .{ + .context = @ptrCast(&self.context), + .readvFn = self.reader().any().readvFn, + .writevFn = self.writer().any().writevFn, + .closeFn = typeErasedCloseFn, }; } - fn typeErasedWriteFn(context: *const anyopaque, bytes: []const u8) anyerror!usize { + fn typeErasedCloseFn(context: *const anyopaque) void { const ptr: *const Context = @alignCast(@ptrCast(context)); - return writeFn(ptr.*, bytes); + return closeFn(ptr.*); } }; } @@ -365,6 +430,7 @@ pub const Writer = GenericWriter; pub const AnyReader = @import("io/Reader.zig"); pub const AnyWriter = @import("io/Writer.zig"); +pub const AnyStream = @import("io/Stream.zig"); pub const SeekableStream = @import("io/seekable_stream.zig").SeekableStream; @@ -416,10 +482,12 @@ pub const tty = @import("io/tty.zig"); /// A Writer that doesn't write to anything. pub const null_writer = @as(NullWriter, .{ .context = {} }); -const NullWriter = Writer(void, error{}, dummyWrite); -fn dummyWrite(context: void, data: []const u8) error{}!usize { +const NullWriter = Writer(void, error{}, dummyWritev); +fn dummyWritev(context: void, iov: []const std.os.iovec_const) error{}!usize { _ = context; - return data.len; + var written: usize = 0; + for (iov) |v| written += v.iov_len; + return written; } test "null_writer" { @@ -695,6 +763,7 @@ pub fn PollFiles(comptime StreamEnum: type) type { test { _ = AnyReader; _ = AnyWriter; + _ = AnyStream; _ = @import("io/bit_reader.zig"); _ = @import("io/bit_writer.zig"); _ = @import("io/buffered_atomic_file.zig"); diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index a769fe4c0421..92fd16758094 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -1,20 +1,37 @@ +const std = @import("../std.zig"); +const Self = @This(); +const math = std.math; +const assert = std.debug.assert; +const mem = std.mem; +const testing = std.testing; +const native_endian = @import("builtin").target.cpu.arch.endian(); +const iovec = std.os.iovec; + context: *const anyopaque, -readFn: *const fn (context: *const anyopaque, buffer: []u8) anyerror!usize, +readvFn: *const fn (context: *const anyopaque, iov: []const iovec) anyerror!usize, pub const Error = anyerror; +/// Returns the number of bytes read. It may be less than buffer.len. +/// If the number of bytes read is 0, it means end of stream. +/// End of stream is not an error condition. +pub fn readv(self: Self, iov: []const iovec) anyerror!usize { + return self.readvFn(self.context, iov); +} + /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. pub fn read(self: Self, buffer: []u8) anyerror!usize { - return self.readFn(self.context, buffer); + var iov = [_]iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; + return self.readv(&iov); } /// Returns the number of bytes read. If the number read is smaller than `buffer.len`, it /// means the stream reached the end. Reaching the end of a stream is not an error /// condition. pub fn readAll(self: Self, buffer: []u8) anyerror!usize { - return readAtLeast(self, buffer, buffer.len); + return self.readAtLeast(buffer, buffer.len); } /// Returns the number of bytes read, calling the underlying read @@ -372,14 +389,6 @@ pub fn discard(self: Self) anyerror!u64 { } } -const std = @import("../std.zig"); -const Self = @This(); -const math = std.math; -const assert = std.debug.assert; -const mem = std.mem; -const testing = std.testing; -const native_endian = @import("builtin").target.cpu.arch.endian(); - test { _ = @import("Reader/test.zig"); } diff --git a/lib/std/io/Stream.zig b/lib/std/io/Stream.zig new file mode 100644 index 000000000000..cafcda5bb64d --- /dev/null +++ b/lib/std/io/Stream.zig @@ -0,0 +1,37 @@ +const std = @import("../std.zig"); +const assert = std.debug.assert; +const mem = std.mem; +const os = std.os; +const iovec = os.iovec; +const iovec_const = os.iovec_const; + +context: *const anyopaque, +readvFn: *const fn (context: *const anyopaque, iov: []const iovec) anyerror!usize, +writevFn: *const fn (context: *const anyopaque, iov: []const iovec_const) anyerror!usize, +closeFn: *const fn (context: *const anyopaque) void, + +const Self = @This(); +pub const Error = anyerror; + +pub fn writev(self: Self, iov: []const iovec_const) anyerror!usize { + return self.writevFn(self.context, iov); +} + +/// Returns the number of bytes read. It may be less than buffer.len. +/// If the number of bytes read is 0, it means end of stream. +/// End of stream is not an error condition. +pub fn readv(self: Self, iov: []const iovec) anyerror!usize { + return self.readvFn(self.context, iov); +} + +pub fn reader(self: Self) std.io.AnyReader { + return .{ .context = self.context, .readvFn = self.readvFn }; +} + +pub fn writer(self: Self) std.io.AnyWriter { + return .{ .context = self.context, .writevFn = self.writevFn }; +} + +pub fn close(self: Self) void { + return self.closeFn(self.context); +} diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index dfcae48b1eb8..a8015b59ddfa 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -1,15 +1,41 @@ const std = @import("../std.zig"); const assert = std.debug.assert; const mem = std.mem; +const iovec_const = std.os.iovec_const; context: *const anyopaque, -writeFn: *const fn (context: *const anyopaque, bytes: []const u8) anyerror!usize, +writevFn: *const fn (context: *const anyopaque, iov: []const iovec_const) anyerror!usize, const Self = @This(); pub const Error = anyerror; +pub fn writev(self: Self, iov: []const iovec_const) anyerror!usize { + return self.writevFn(self.context, iov); +} + +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial writes from the underlying OS layer. +/// See https://github.com/ziglang/zig/issues/7699 +/// See equivalent function: `std.fs.File.writevAll`. +pub fn writevAll(self: Self, iovecs: []iovec_const) anyerror!void { + if (iovecs.len == 0) return; + + var i: usize = 0; + while (true) { + var amt = try self.writev(iovecs[i..]); + while (amt >= iovecs[i].iov_len) { + amt -= iovecs[i].iov_len; + i += 1; + if (i >= iovecs.len) return; + } + iovecs[i].iov_base += amt; + iovecs[i].iov_len -= amt; + } +} + pub fn write(self: Self, bytes: []const u8) anyerror!usize { - return self.writeFn(self.context, bytes); + var iov = [_]iovec_const{.{ .iov_base = bytes.ptr, .iov_len = bytes.len }}; + return self.writev(&iov); } pub fn writeAll(self: Self, bytes: []const u8) anyerror!void { diff --git a/lib/std/io/buffered_reader.zig b/lib/std/io/buffered_reader.zig index ca132202a7df..1e0b78b7597b 100644 --- a/lib/std/io/buffered_reader.zig +++ b/lib/std/io/buffered_reader.zig @@ -12,11 +12,13 @@ pub fn BufferedReader(comptime buffer_size: usize, comptime ReaderType: type) ty end: usize = 0, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); - pub fn read(self: *Self, dest: []u8) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; var dest_index: usize = 0; while (dest_index < dest.len) { @@ -60,7 +62,7 @@ test "OneByte" { const Error = error{NoError}; const Self = @This(); - const Reader = io.Reader(*Self, Error, read); + const Reader = io.Reader(*Self, Error, readv); fn init(str: []const u8) Self { return Self{ @@ -69,7 +71,9 @@ test "OneByte" { }; } - fn read(self: *Self, dest: []u8) Error!usize { + fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; if (self.str.len <= self.curr or dest.len == 0) return 0; @@ -135,11 +139,11 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out < block @@ -148,13 +152,13 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [3]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "012"); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "301"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "23"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out > block @@ -163,11 +167,11 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [5]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "01230"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "123"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out == 0 @@ -176,7 +180,7 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [0]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, ""); } @@ -186,10 +190,10 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } } diff --git a/lib/std/io/buffered_tee.zig b/lib/std/io/buffered_tee.zig index d5748c3a5276..7fbb64f34575 100644 --- a/lib/std/io/buffered_tee.zig +++ b/lib/std/io/buffered_tee.zig @@ -36,11 +36,13 @@ pub fn BufferedTee( wp: usize = 0, // writer pointer; data is sent to the output up to this position pub const Error = InputReaderType.Error || OutputWriterType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); - pub fn read(self: *Self, dest: []u8) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; var dest_index: usize = 0; while (dest_index < dest.len) { @@ -153,7 +155,7 @@ test "OneByte" { const Error = error{NoError}; const Self = @This(); - const Reader = io.Reader(*Self, Error, read); + const Reader = io.Reader(*Self, Error, readv); fn init(str: []const u8) Self { return Self{ @@ -162,7 +164,9 @@ test "OneByte" { }; } - fn read(self: *Self, dest: []u8) Error!usize { + fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; if (self.str.len <= self.curr or dest.len == 0) return 0; @@ -226,11 +230,11 @@ test "Block" { .output = io.null_writer, }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out < block @@ -240,13 +244,13 @@ test "Block" { .output = io.null_writer, }; var out_buf: [3]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "012"); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "301"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "23"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out > block @@ -256,11 +260,11 @@ test "Block" { .output = io.null_writer, }; var out_buf: [5]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "01230"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "123"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out == 0 @@ -270,7 +274,7 @@ test "Block" { .output = io.null_writer, }; var out_buf: [0]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, ""); } @@ -281,11 +285,11 @@ test "Block" { .output = io.null_writer, }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } } @@ -301,7 +305,7 @@ test "with zero lookahead" { var buf: [16]u8 = undefined; var read_len: usize = 0; for (0..buf.len) |i| { - const n = try bt.read(buf[0..i]); + const n = try bt.reader().read(buf[0..i]); try testing.expectEqual(i, n); read_len += i; try testing.expectEqual(read_len, out.items.len); @@ -321,7 +325,7 @@ test "with lookahead" { var read_len: usize = 0; for (1..buf.len) |i| { - const n = try bt.read(buf[0..i]); + const n = try bt.reader().read(buf[0..i]); try testing.expectEqual(i, n); read_len += i; const out_len = if (read_len < lookahead) 0 else read_len - lookahead; @@ -342,14 +346,14 @@ test "internal state" { var bt = bufferedTee(8, 4, in.reader(), out.writer()); var buf: [16]u8 = undefined; - var n = try bt.read(buf[0..3]); + var n = try bt.reader().read(buf[0..3]); try testing.expectEqual(3, n); try testing.expectEqualSlices(u8, data[0..3], buf[0..n]); try testing.expectEqual(8, bt.tail); try testing.expectEqual(3, bt.rp); try testing.expectEqual(0, out.items.len); - n = try bt.read(buf[0..6]); + n = try bt.reader().read(buf[0..6]); try testing.expectEqual(6, n); try testing.expectEqualSlices(u8, data[3..9], buf[0..n]); try testing.expectEqual(8, bt.tail); @@ -357,7 +361,7 @@ test "internal state" { try testing.expectEqualSlices(u8, data[4..12], &bt.buf); try testing.expectEqual(5, out.items.len); - n = try bt.read(buf[0..9]); + n = try bt.reader().read(buf[0..9]); try testing.expectEqual(9, n); try testing.expectEqualSlices(u8, data[9..18], buf[0..n]); try testing.expectEqual(8, bt.tail); @@ -369,7 +373,7 @@ test "internal state" { try testing.expectEqual(18, out.items.len); bt.putBack(4); - n = try bt.read(buf[0..4]); + n = try bt.reader().read(buf[0..4]); try testing.expectEqual(4, n); try testing.expectEqualSlices(u8, data[14..18], buf[0..n]); diff --git a/lib/std/io/buffered_writer.zig b/lib/std/io/buffered_writer.zig index 906d6cce4926..6daaac8427e0 100644 --- a/lib/std/io/buffered_writer.zig +++ b/lib/std/io/buffered_writer.zig @@ -10,7 +10,7 @@ pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) ty end: usize = 0, pub const Error = WriterType.Error; - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); const Self = @This(); @@ -23,17 +23,22 @@ pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) ty return .{ .context = self }; } - pub fn write(self: *Self, bytes: []const u8) Error!usize { - if (self.end + bytes.len > self.buf.len) { - try self.flush(); - if (bytes.len > self.buf.len) - return self.unbuffered_writer.write(bytes); + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + const bytes = v.iov_base[0..v.iov_len]; + if (self.end + bytes.len > self.buf.len) { + try self.flush(); + if (bytes.len > self.buf.len) + return self.unbuffered_writer.write(bytes); + } + + const new_end = self.end + bytes.len; + @memcpy(self.buf[self.end..new_end], bytes); + self.end = new_end; + written += bytes.len; } - - const new_end = self.end + bytes.len; - @memcpy(self.buf[self.end..new_end], bytes); - self.end = new_end; - return bytes.len; + return written; } }; } diff --git a/lib/std/io/counting_reader.zig b/lib/std/io/counting_reader.zig index 2ff9b8a08fe3..30a99bac9dcf 100644 --- a/lib/std/io/counting_reader.zig +++ b/lib/std/io/counting_reader.zig @@ -9,15 +9,17 @@ pub fn CountingReader(comptime ReaderType: anytype) type { bytes_read: u64 = 0, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*@This(), Error, read); + pub const Reader = io.Reader(*@This(), Error, readv); - pub fn read(self: *@This(), buf: []u8) Error!usize { - const amt = try self.child_reader.read(buf); + const Self = @This(); + + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const amt = try self.child_reader.readv(iov); self.bytes_read += amt; return amt; } - pub fn reader(self: *@This()) Reader { + pub fn reader(self: *Self) Reader { return .{ .context = self }; } }; diff --git a/lib/std/io/counting_writer.zig b/lib/std/io/counting_writer.zig index 9043e1a47c17..a663423c3348 100644 --- a/lib/std/io/counting_writer.zig +++ b/lib/std/io/counting_writer.zig @@ -9,12 +9,12 @@ pub fn CountingWriter(comptime WriterType: type) type { child_stream: WriterType, pub const Error = WriterType.Error; - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); const Self = @This(); - pub fn write(self: *Self, bytes: []const u8) Error!usize { - const amt = try self.child_stream.write(bytes); + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + const amt = try self.child_stream.writev(iov); self.bytes_written += amt; return amt; } diff --git a/lib/std/io/fixed_buffer_stream.zig b/lib/std/io/fixed_buffer_stream.zig index 14e5e5de43eb..6a48ba32df98 100644 --- a/lib/std/io/fixed_buffer_stream.zig +++ b/lib/std/io/fixed_buffer_stream.zig @@ -17,8 +17,8 @@ pub fn FixedBufferStream(comptime Buffer: type) type { pub const SeekError = error{}; pub const GetSeekPosError = error{}; - pub const Reader = io.Reader(*Self, ReadError, read); - pub const Writer = io.Writer(*Self, WriteError, write); + pub const Reader = io.Reader(*Self, ReadError, readv); + pub const Writer = io.Writer(*Self, WriteError, writev); pub const SeekableStream = io.SeekableStream( *Self, @@ -44,31 +44,39 @@ pub fn FixedBufferStream(comptime Buffer: type) type { return .{ .context = self }; } - pub fn read(self: *Self, dest: []u8) ReadError!usize { - const size = @min(dest.len, self.buffer.len - self.pos); - const end = self.pos + size; + pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { + var read: usize = 0; + for (iov) |v| { + const size = @min(v.iov_len, self.buffer.len - self.pos); + const end = self.pos + size; - @memcpy(dest[0..size], self.buffer[self.pos..end]); - self.pos = end; + @memcpy(v.iov_base[0..size], self.buffer[self.pos..end]); + self.pos = end; + read += size; + } - return size; + return read; } /// If the returned number of bytes written is less than requested, the /// buffer is full. Returns `error.NoSpaceLeft` when no bytes would be written. /// Note: `error.NoSpaceLeft` matches the corresponding error from /// `std.fs.File.WriteError`. - pub fn write(self: *Self, bytes: []const u8) WriteError!usize { - if (bytes.len == 0) return 0; - if (self.pos >= self.buffer.len) return error.NoSpaceLeft; - - const n = @min(self.buffer.len - self.pos, bytes.len); - @memcpy(self.buffer[self.pos..][0..n], bytes[0..n]); - self.pos += n; - - if (n == 0) return error.NoSpaceLeft; + pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { + var written: usize = 0; + for (iov) |v| { + if (v.iov_len == 0) continue; + if (self.pos >= self.buffer.len) return error.NoSpaceLeft; + + const n = @min(self.buffer.len - self.pos, v.iov_len); + @memcpy(self.buffer[self.pos..][0..n], v.iov_base[0..n]); + self.pos += n; + + if (n == 0) return error.NoSpaceLeft; + written += n; + } - return n; + return written; } pub fn seekTo(self: *Self, pos: u64) SeekError!void { diff --git a/lib/std/io/limited_reader.zig b/lib/std/io/limited_reader.zig index d7e250388139..7b14557284bf 100644 --- a/lib/std/io/limited_reader.zig +++ b/lib/std/io/limited_reader.zig @@ -9,7 +9,7 @@ pub fn LimitedReader(comptime ReaderType: type) type { bytes_left: u64, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); @@ -20,6 +20,12 @@ pub fn LimitedReader(comptime ReaderType: type) type { return n; } + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const buf = first.iov_base[0..first.iov_len]; + return try self.read(buf); + } + pub fn reader(self: *Self) Reader { return .{ .context = self }; } diff --git a/lib/std/io/multi_writer.zig b/lib/std/io/multi_writer.zig index 9cd4600e634c..3ae7c98931c2 100644 --- a/lib/std/io/multi_writer.zig +++ b/lib/std/io/multi_writer.zig @@ -15,16 +15,16 @@ pub fn MultiWriter(comptime Writers: type) type { streams: Writers, pub const Error = ErrSet; - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); pub fn writer(self: *Self) Writer { return .{ .context = self }; } - pub fn write(self: *Self, bytes: []const u8) Error!usize { - inline for (self.streams) |stream| - try stream.writeAll(bytes); - return bytes.len; + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + inline for (self.streams) |stream| written = try stream.writev(iov); + return written; } }; } diff --git a/lib/std/io/stream_source.zig b/lib/std/io/stream_source.zig index 6e06af8204e0..5234908f0a7c 100644 --- a/lib/std/io/stream_source.zig +++ b/lib/std/io/stream_source.zig @@ -26,8 +26,8 @@ pub const StreamSource = union(enum) { pub const SeekError = io.FixedBufferStream([]u8).SeekError || (if (has_file) std.fs.File.SeekError else error{}); pub const GetSeekPosError = io.FixedBufferStream([]u8).GetSeekPosError || (if (has_file) std.fs.File.GetSeekPosError else error{}); - pub const Reader = io.Reader(*StreamSource, ReadError, read); - pub const Writer = io.Writer(*StreamSource, WriteError, write); + pub const Reader = io.Reader(*StreamSource, ReadError, readv); + pub const Writer = io.Writer(*StreamSource, WriteError, writev); pub const SeekableStream = io.SeekableStream( *StreamSource, SeekError, @@ -38,19 +38,19 @@ pub const StreamSource = union(enum) { getEndPos, ); - pub fn read(self: *StreamSource, dest: []u8) ReadError!usize { + pub fn readv(self: *StreamSource, iov: []const std.os.iovec) ReadError!usize { switch (self.*) { - .buffer => |*x| return x.read(dest), - .const_buffer => |*x| return x.read(dest), - .file => |x| if (!has_file) unreachable else return x.read(dest), + .buffer => |*x| return x.readv(iov), + .const_buffer => |*x| return x.readv(iov), + .file => |x| if (!has_file) unreachable else return x.readv(iov), } } - pub fn write(self: *StreamSource, bytes: []const u8) WriteError!usize { + pub fn writev(self: *StreamSource, iov: []const std.os.iovec_const) WriteError!usize { switch (self.*) { - .buffer => |*x| return x.write(bytes), + .buffer => |*x| return x.writev(iov), .const_buffer => return error.AccessDenied, - .file => |x| if (!has_file) unreachable else return x.write(bytes), + .file => |x| if (!has_file) unreachable else return x.writev(iov), } } diff --git a/lib/std/json/stringify_test.zig b/lib/std/json/stringify_test.zig index c87e400a8445..a00f6c114669 100644 --- a/lib/std/json/stringify_test.zig +++ b/lib/std/json/stringify_test.zig @@ -298,7 +298,7 @@ test "stringify tuple" { fn testStringify(expected: []const u8, value: anytype, options: StringifyOptions) !void { const ValidationWriter = struct { const Self = @This(); - pub const Writer = std.io.Writer(*Self, Error, write); + pub const Writer = std.io.Writer(*Self, Error, writev); pub const Error = error{ TooMuchData, DifferentData, @@ -314,7 +314,9 @@ fn testStringify(expected: []const u8, value: anytype, options: StringifyOptions return .{ .context = self }; } - fn write(self: *Self, bytes: []const u8) Error!usize { + fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; if (self.expected_remaining.len < bytes.len) { std.debug.print( \\====== expected this output: ========= diff --git a/lib/std/net.zig b/lib/std/net.zig index e68adc4207ff..23f832ba062e 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -10,6 +10,7 @@ const posix = std.posix; const fs = std.fs; const io = std.io; const native_endian = builtin.target.cpu.arch.endian(); +const tls = std.crypto.tls; // Windows 10 added support for unix sockets in build 17063, redstone 4 is the // first release to support them. @@ -708,26 +709,26 @@ pub const Ip6Address = extern struct { } }; -pub fn connectUnixSocket(path: []const u8) !Stream { +pub fn connectUnixSocket(path: []const u8) !Socket { const opt_non_block = 0; const sockfd = try os.socket( os.AF.UNIX, os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block, 0, ); - errdefer Stream.close(.{ .handle = sockfd }); + errdefer Socket.close(.{ .handle = sockfd }); var addr = try std.net.Address.initUnix(path); try os.connect(sockfd, &addr.any, addr.getOsSockLen()); - return Stream{ .handle = sockfd }; + return Socket{ .handle = sockfd }; } fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { if (builtin.target.os.tag == .linux) { var ifr: os.ifreq = undefined; const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0); - defer Stream.close(.{ .handle = sockfd }); + defer Socket.close(.{ .handle = sockfd }); @memcpy(ifr.ifrn.name[0..name.len], name); ifr.ifrn.name[name.len] = 0; @@ -772,7 +773,7 @@ pub const AddressList = struct { pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; /// All memory allocated with `allocator` will be freed before this function returns. -pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { +pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Socket { const list = try getAddressList(allocator, name, port); defer list.deinit(); @@ -791,16 +792,16 @@ pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) T pub const TcpConnectToAddressError = std.os.SocketError || std.os.ConnectError; -pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { +pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Socket { const nonblock = 0; const sock_flags = os.SOCK.STREAM | nonblock | (if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC); const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP); - errdefer Stream.close(.{ .handle = sockfd }); + errdefer Socket.close(.{ .handle = sockfd }); try os.connect(sockfd, &address.any, address.getOsSockLen()); - return Stream{ .handle = sockfd }; + return Socket{ .handle = sockfd }; } const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || std.os.SetSockOptError || error{ @@ -1126,7 +1127,7 @@ fn linuxLookupName( var prefixlen: i32 = 0; const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC; if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: { - defer Stream.close(.{ .handle = fd }); + defer Socket.close(.{ .handle = fd }); os.connect(fd, da, dalen) catch break :syscalls; key |= DAS_USABLE; os.getsockname(fd, sa, &salen) catch break :syscalls; @@ -1611,7 +1612,7 @@ fn resMSendRc( }, else => |e| return e, }; - defer Stream.close(.{ .handle = fd }); + defer Socket.close(.{ .handle = fd }); // Past this point, there are no errors. Each individual query will // yield either no reply (indicated by zero length) or an answer @@ -1786,12 +1787,12 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) } } -pub const Stream = struct { +pub const Socket = struct { /// Underlying platform-defined type which may or may not be /// interchangeable with a file system file descriptor. handle: posix.socket_t, - pub fn close(s: Stream) void { + pub fn close(s: Socket) void { switch (builtin.os.tag) { .windows => std.os.windows.closesocket(s.handle) catch unreachable, else => posix.close(s.handle), @@ -1800,27 +1801,13 @@ pub const Stream = struct { pub const ReadError = os.ReadError; pub const WriteError = os.WriteError; + pub const GenericStream = io.GenericStream(Socket, ReadError, readv, WriteError, writev, close); - pub const Reader = io.Reader(Stream, ReadError, read); - pub const Writer = io.Writer(Stream, WriteError, write); - - pub fn reader(self: Stream) Reader { + pub fn stream(self: Socket) GenericStream { return .{ .context = self }; } - pub fn writer(self: Stream) Writer { - return .{ .context = self }; - } - - pub fn read(self: Stream, buffer: []u8) ReadError!usize { - if (builtin.os.tag == .windows) { - return os.windows.ReadFile(self.handle, buffer, null); - } - - return os.read(self.handle, buffer); - } - - pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { + pub fn readv(s: Socket, iovecs: []const os.iovec) ReadError!usize { if (builtin.os.tag == .windows) { // TODO improve this to use ReadFileScatter if (iovecs.len == 0) return @as(usize, 0); @@ -1831,81 +1818,31 @@ pub const Stream = struct { return os.readv(s.handle, iovecs); } - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of - /// a stream is not an error condition. - pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { - return readAtLeast(s, buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error - /// condition. - pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try s.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's - /// file system thread instead of non-blocking. It needs to be reworked to properly - /// use non-blocking I/O. - pub fn write(self: Stream, buffer: []const u8) WriteError!usize { - if (builtin.os.tag == .windows) { - return os.windows.WriteFile(self.handle, buffer, null); - } - - return os.write(self.handle, buffer); - } - - pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - /// See https://github.com/ziglang/zig/issues/7699 /// See equivalent function: `std.fs.File.writev`. - pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { + pub fn writev(self: Socket, iovecs: []const os.iovec_const) WriteError!usize { return os.writev(self.handle, iovecs); } - - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial writes from the underlying OS layer. - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writevAll`. - pub fn writevAll(self: Stream, iovecs: []os.iovec_const) WriteError!void { - if (iovecs.len == 0) return; - - var i: usize = 0; - while (true) { - var amt = try self.writev(iovecs[i..]); - while (amt >= iovecs[i].iov_len) { - amt -= iovecs[i].iov_len; - i += 1; - if (i >= iovecs.len) return; - } - iovecs[i].iov_base += amt; - iovecs[i].iov_len -= amt; - } - } }; pub const Server = struct { listen_address: Address, - stream: std.net.Stream, + stream: Socket, pub const Connection = struct { - stream: std.net.Stream, address: Address, + protocol: Protocol, + socket: Socket, + tls: tls.Server, + + pub const Protocol = enum { plain, tls }; + + pub inline fn stream(conn: *Connection) std.io.AnyStream { + return switch (conn.protocol) { + .plain => conn.socket.stream().any(), + .tls => conn.tls.any().any(), + }; + } }; pub fn deinit(s: *Server) void { @@ -1913,17 +1850,24 @@ pub const Server = struct { s.* = undefined; } - pub const AcceptError = posix.AcceptError; - - /// Blocks until a client connects to the server. The returned `Connection` has - /// an open stream. - pub fn accept(s: *Server) AcceptError!Connection { + /// Blocks until a client connects to the server. + /// If tls_options are supplied, will await a client handshake. + /// The returned `Connection` has an open stream. + pub fn accept(s: *Server, tls_options: ?tls.Server.Options) !Connection { var accepted_addr: Address = undefined; var addr_len: posix.socklen_t = @sizeOf(Address); const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); + const socket = Socket{ .handle = fd }; + const protocol: Connection.Protocol = if (tls_options == null) .plain else .tls; + const _tls: tls.Server = if (tls_options) |options| + try tls.Server.init(socket.stream().any(), options) + else + undefined; return .{ - .stream = .{ .handle = fd }, .address = accepted_addr, + .protocol = protocol, + .socket = socket, + .tls = _tls, }; } }; @@ -1931,6 +1875,6 @@ pub const Server = struct { test { _ = @import("net/test.zig"); _ = Server; - _ = Stream; + _ = Socket; _ = Address; } diff --git a/lib/std/net/test.zig b/lib/std/net/test.zig index 3e316c545643..1f98adc38de3 100644 --- a/lib/std/net/test.zig +++ b/lib/std/net/test.zig @@ -189,17 +189,17 @@ test "listen on a port, send bytes, receive bytes" { const socket = try net.tcpConnectToAddress(server_address); defer socket.close(); - _ = try socket.writer().writeAll("Hello world!"); + _ = try socket.stream().writer().writeAll("Hello world!"); } }; const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address}); defer t.join(); - var client = try server.accept(); - defer client.stream.close(); + var client = try server.accept(null); + defer client.stream().close(); var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); + const n = try client.stream().reader().read(&buf); try testing.expectEqual(@as(usize, 12), n); try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); @@ -247,7 +247,7 @@ fn testClient(addr: net.Address) anyerror!void { fn testServer(server: *net.Server) anyerror!void { if (builtin.os.tag == .wasi) return error.SkipZigTest; - var client = try server.accept(); + var client = try server.accept(null); const stream = client.stream.writer(); try stream.print("hello from server\n", .{}); @@ -280,17 +280,17 @@ test "listen on a unix socket, send bytes, receive bytes" { const socket = try net.connectUnixSocket(path); defer socket.close(); - _ = try socket.writer().writeAll("Hello world!"); + _ = try socket.stream().writer().writeAll("Hello world!"); } }; const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path}); defer t.join(); - var client = try server.accept(); - defer client.stream.close(); + var client = try server.accept(null); + defer client.stream().close(); var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); + const n = try client.stream().reader().read(&buf); try testing.expectEqual(@as(usize, 12), n); try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); @@ -317,13 +317,13 @@ test "non-blocking tcp server" { var server = localhost.listen(.{ .force_nonblocking = true }); defer server.deinit(); - const accept_err = server.accept(); + const accept_err = server.accept(null); try testing.expectError(error.WouldBlock, accept_err); const socket_file = try net.tcpConnectToAddress(server.listen_address); defer socket_file.close(); - var client = try server.accept(); + var client = try server.accept(null); defer client.stream.close(); const stream = client.stream.writer(); try stream.print("hello from server\n", .{}); diff --git a/lib/std/os/test.zig b/lib/std/os/test.zig index d9f05d94ef93..b93a5d6d9439 100644 --- a/lib/std/os/test.zig +++ b/lib/std/os/test.zig @@ -819,7 +819,7 @@ test "shutdown socket" { error.SocketNotConnected => {}, else => |e| return e, }; - std.net.Stream.close(.{ .handle = sock }); + std.net.Socket.close(.{ .handle = sock }); } test "sigaction" { diff --git a/lib/std/std.zig b/lib/std/std.zig index 557b320c244e..27920085177a 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -151,13 +151,6 @@ pub const Options = struct { /// it like any other error. keep_sigpipe: bool = false, - /// By default, std.http.Client will support HTTPS connections. Set this option to `true` to - /// disable TLS support. - /// - /// This will likely reduce the size of the binary, but it will also make it impossible to - /// make a HTTPS connection. - http_disable_tls: bool = false, - side_channels_mitigations: crypto.SideChannelsMitigations = crypto.default_side_channels_mitigations, }; diff --git a/lib/std/tar.zig b/lib/std/tar.zig index 13da27ca846d..70bdfaad1cb3 100644 --- a/lib/std/tar.zig +++ b/lib/std/tar.zig @@ -269,7 +269,7 @@ pub const FileKind = enum { file, }; -/// Iteartor over entries in the tar file represented by reader. +/// Iterator over tar entries pub fn Iterator(comptime ReaderType: type) type { return struct { reader: ReaderType, @@ -295,17 +295,22 @@ pub fn Iterator(comptime ReaderType: type) type { unread_bytes: *u64, parent_reader: ReaderType, - pub const Reader = std.io.Reader(File, ReaderType.Error, File.read); + pub const Reader = std.io.Reader(File, ReaderType.Error, readv); pub fn reader(self: File) Reader { return .{ .context = self }; } - pub fn read(self: File, dest: []u8) ReaderType.Error!usize { - const buf = dest[0..@min(dest.len, self.unread_bytes.*)]; - const n = try self.parent_reader.read(buf); - self.unread_bytes.* -= n; - return n; + pub fn readv(self: File, iov: []const std.os.iovec) ReaderType.Error!usize { + var n_read: usize = 0; + for (iov) |v| { + const dest = v.iov_base[0..v.iov_len]; + const buf = dest[0..@min(dest.len, self.unread_bytes.*)]; + const n = try self.parent_reader.read(buf); + self.unread_bytes.* -= n; + n_read += n; + } + return n_read; } // Writes file content to writer. diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index c6a6f3ce710d..98b2b2fc2d3e 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -3329,7 +3329,7 @@ fn AutoIndentingStream(comptime UnderlyingWriter: type) type { return struct { const Self = @This(); pub const WriteError = UnderlyingWriter.Error; - pub const Writer = std.io.Writer(*Self, WriteError, write); + pub const Writer = std.io.Writer(*Self, WriteError, writev); underlying_writer: UnderlyingWriter, @@ -3355,12 +3355,15 @@ fn AutoIndentingStream(comptime UnderlyingWriter: type) type { return .{ .context = self }; } - pub fn write(self: *Self, bytes: []const u8) WriteError!usize { - if (bytes.len == 0) - return @as(usize, 0); + pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { + var n_written: usize = 0; + for (iov) |v| { + if (v.iov_len == 0) return n_written; - try self.applyIndent(); - return self.writeNoIndent(bytes); + try self.applyIndent(); + n_written += try self.writeNoIndent(v.iov_base[0..v.iov_len]); + } + return n_written; } // Change the indent delta without changing the final indentation level diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 93a6868603c6..70ba9164a924 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -807,16 +807,16 @@ const Resource = union(enum) { fn reader(resource: *Resource) std.io.AnyReader { return .{ .context = resource, - .readFn = read, + .readvFn = readv, }; } - fn read(context: *const anyopaque, buffer: []u8) anyerror!usize { + fn readv(context: *const anyopaque, iov: []const std.os.iovec) anyerror!usize { const resource: *Resource = @constCast(@ptrCast(@alignCast(context))); switch (resource.*) { - .file => |*f| return f.read(buffer), - .http_request => |*r| return r.read(buffer), - .git => |*g| return g.fetch_stream.read(buffer), + .file => |*f| return f.readv(iov), + .http_request => |*r| return r.readv(iov), + .git => |*g| return g.fetch_stream.reader().readv(iov), .dir => unreachable, } } @@ -1105,14 +1105,14 @@ fn unpackResource( .tar => try unpackTarball(f, tmp_directory.handle, resource.reader()), .@"tar.gz" => { const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var br = std.io.bufferedReaderSize(4096, reader); var dcp = std.compress.gzip.decompressor(br.reader()); try unpackTarball(f, tmp_directory.handle, dcp.reader()); }, .@"tar.xz" => { const gpa = f.arena.child_allocator; const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var br = std.io.bufferedReaderSize(4096, reader); var dcp = std.compress.xz.decompress(gpa, br.reader()) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to decompress tarball: {s}", @@ -1126,7 +1126,7 @@ fn unpackResource( const window_size = std.compress.zstd.DecompressorOptions.default_window_buffer_len; const window_buffer = try f.arena.allocator().create([window_size]u8); const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var br = std.io.bufferedReaderSize(4096, reader); var dcp = std.compress.zstd.decompressor(br.reader(), .{ .window_buffer = window_buffer, }); diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index 36652bd88c55..7e97a35c6c65 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -10,6 +10,8 @@ const testing = std.testing; const Allocator = mem.Allocator; const Sha1 = std.crypto.hash.Sha1; const assert = std.debug.assert; +const hashedWriter = std.compress.hashedWriter; +const hashedReader = std.compress.hashedReader; pub const oid_length = Sha1.digest_length; pub const fmt_oid_length = 2 * oid_length; @@ -667,7 +669,7 @@ pub const Session = struct { errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; try request.send(.{}); - try request.writeAll(body.items); + try request.writer().writeAll(body.items); try request.finish(); try request.wait(); @@ -772,7 +774,7 @@ pub const Session = struct { errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; try request.send(.{}); - try request.writeAll(body.items); + try request.writer().writeAll(body.items); try request.finish(); try request.wait(); @@ -819,7 +821,7 @@ pub const Session = struct { ProtocolError, UnexpectedPacket, }; - pub const Reader = std.io.Reader(*FetchStream, ReadError, read); + pub const Reader = std.io.Reader(*FetchStream, ReadError, readv); const StreamCode = enum(u8) { pack_data = 1, @@ -857,6 +859,12 @@ pub const Session = struct { return size; } }; + + pub fn readv(stream: *FetchStream, iov: []const std.os.iovec) !usize { + const first = iov[0]; + const buf = first.iov_base[0..first.iov_len]; + return try stream.read(buf); + } }; const PackHeader = struct { @@ -1113,7 +1121,7 @@ fn indexPackFirstPass( ) ![Sha1.digest_length]u8 { var pack_buffered_reader = std.io.bufferedReader(pack.reader()); var pack_counting_reader = std.io.countingReader(pack_buffered_reader.reader()); - var pack_hashed_reader = std.compress.hashedReader(pack_counting_reader.reader(), Sha1.init(.{})); + var pack_hashed_reader = hashedReader(pack_counting_reader.reader(), Sha1.init(.{})); const pack_reader = pack_hashed_reader.reader(); const pack_header = try PackHeader.read(pack_reader); @@ -1121,7 +1129,7 @@ fn indexPackFirstPass( var current_entry: u32 = 0; while (current_entry < pack_header.total_objects) : (current_entry += 1) { const entry_offset = pack_counting_reader.bytes_read; - var entry_crc32_reader = std.compress.hashedReader(pack_reader, std.hash.Crc32.init()); + var entry_crc32_reader = hashedReader(pack_reader, std.hash.Crc32.init()); const entry_header = try EntryHeader.read(entry_crc32_reader.reader()); switch (entry_header) { .commit, .tree, .blob, .tag => |object| { @@ -1325,36 +1333,6 @@ fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !vo } } -fn HashedWriter( - comptime WriterType: anytype, - comptime HasherType: anytype, -) type { - return struct { - child_writer: WriterType, - hasher: HasherType, - - const Error = WriterType.Error; - const Writer = std.io.Writer(*@This(), Error, write); - - fn write(hashed_writer: *@This(), buf: []const u8) Error!usize { - const amt = try hashed_writer.child_writer.write(buf); - hashed_writer.hasher.update(buf); - return amt; - } - - fn writer(hashed_writer: *@This()) Writer { - return .{ .context = hashed_writer }; - } - }; -} - -fn hashedWriter( - writer: anytype, - hasher: anytype, -) HashedWriter(@TypeOf(writer), @TypeOf(hasher)) { - return .{ .child_writer = writer, .hasher = hasher }; -} - test "packfile indexing and checkout" { // To verify the contents of this packfile without using the code in this // file: diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 10795529f976..66a37a46e03c 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -7500,12 +7500,12 @@ const ArrayListWriter = ErrorOnlyGenericWriter(std.ArrayList(u8).Writer.Error); fn arrayListWriter(list: *std.ArrayList(u8)) ArrayListWriter { return .{ .context = .{ .context = list, - .writeFn = struct { - fn write(context: *const anyopaque, bytes: []const u8) anyerror!usize { + .writevFn = struct { + fn writev(context: *const anyopaque, iov: []const std.os.iovec_const) anyerror!usize { const l: *std.ArrayList(u8) = @alignCast(@constCast(@ptrCast(context))); - return l.writer().write(bytes); + return l.writer().writev(iov); } - }.write, + }.writev, } }; } @@ -7524,25 +7524,28 @@ fn IndentWriter(comptime UnderlyingWriter: type) type { pub fn writer(self: *Self) Writer { return .{ .context = .{ .context = self, - .writeFn = writeAny, + .writevFn = writevAny, } }; } - pub fn write(self: *Self, bytes: []const u8) Error!usize { - if (bytes.len == 0) return @as(usize, 0); - - const current_indent = self.indent_count * Self.indent_delta; - if (self.current_line_empty and current_indent > 0) { - try self.underlying_writer.writeByteNTimes(' ', current_indent); + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + const bytes = v.iov_base[0..v.iov_len]; + const current_indent = self.indent_count * Self.indent_delta; + if (self.current_line_empty and current_indent > 0) { + try self.underlying_writer.writeByteNTimes(' ', current_indent); + } + self.current_line_empty = false; + written += try self.writeNoIndent(bytes); } - self.current_line_empty = false; - return self.writeNoIndent(bytes); + return written; } - fn writeAny(context: *const anyopaque, bytes: []const u8) anyerror!usize { + fn writevAny(context: *const anyopaque, iov: []const std.os.iovec_const) anyerror!usize { const self: *Self = @alignCast(@constCast(@ptrCast(context))); - return self.write(bytes); + return self.writev(iov); } pub fn insertNewline(self: *Self) Error!void { @@ -7576,10 +7579,10 @@ fn IndentWriter(comptime UnderlyingWriter: type) type { /// maintaining ease of error handling. fn ErrorOnlyGenericWriter(comptime Error: type) type { return std.io.GenericWriter(std.io.AnyWriter, Error, struct { - fn write(context: std.io.AnyWriter, bytes: []const u8) Error!usize { - return @errorCast(context.write(bytes)); + fn writev(context: std.io.AnyWriter, iov: []const std.os.iovec_const) Error!usize { + return @errorCast(context.writev(iov)); } - }.write); + }.writev); } fn toCIntBits(zig_bits: u32) ?u32 { diff --git a/src/main.zig b/src/main.zig index db76f7605ca6..04a4efe1e528 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3373,13 +3373,13 @@ fn buildOutputType( }); defer server.deinit(); - const conn = try server.accept(); - defer conn.stream.close(); + var conn = try server.accept(null); + defer conn.stream().close(); try serve( comp, - .{ .handle = conn.stream.handle }, - .{ .handle = conn.stream.handle }, + .{ .handle = conn.socket.handle }, + .{ .handle = conn.socket.handle }, test_exec_args.items, self_exe_path, arg_mode,