From 82a5cc421b97ca5fb7b31db67755e8bdbf718378 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Thu, 7 Mar 2024 10:49:30 -0500 Subject: [PATCH 01/17] send+recv until server cert, dirty decoding --- TODO | 3 + lib/std/Uri.zig | 5 + lib/std/crypto/25519/x25519.zig | 12 +- lib/std/crypto/kyber_d00.zig | 5 +- lib/std/crypto/pcurves/secp256k1.zig | 2 +- lib/std/crypto/tls.zig | 2111 ++++++++++++++++++++++---- lib/std/crypto/tls/Client.zig | 1969 ++++++++---------------- lib/std/crypto/tls/Server.zig | 325 ++++ lib/std/http/Client.zig | 24 +- 9 files changed, 2777 insertions(+), 1679 deletions(-) create mode 100644 TODO create mode 100644 lib/std/crypto/tls/Server.zig diff --git a/TODO b/TODO new file mode 100644 index 000000000000..a5b21b7cfc26 --- /dev/null +++ b/TODO @@ -0,0 +1,3 @@ +Refactoring TLS client to handle fragments + send messages with structs +Adding TLS server for testing +Using https://tls13.xargs.org/ for unit tests diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index cbd3d427418f..b794d8022792 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -628,6 +628,11 @@ test "basic" { try testing.expectEqual(@as(?u16, null), parsed.port); } +test "subdomain" { + const parsed = try parse("http://a.b.example.com"); + try testing.expectEqualStrings("a.b.example.com", parsed.host orelse return error.UnexpectedNull); +} + test "with port" { const parsed = try parse("http://example:1337/"); try testing.expectEqualStrings("http", parsed.scheme); diff --git a/lib/std/crypto/25519/x25519.zig b/lib/std/crypto/25519/x25519.zig index 8bd5101b3756..6d2377e1b58f 100644 --- a/lib/std/crypto/25519/x25519.zig +++ b/lib/std/crypto/25519/x25519.zig @@ -19,15 +19,19 @@ pub const X25519 = struct { pub const public_length = 32; /// Length (in bytes) of the output of the DH function. pub const shared_length = 32; - /// Seed (for key pair creation) length in bytes. - pub const seed_length = 32; + + pub const PublicKey = [public_length]u8; + pub const SecretKey = [secret_length]u8; /// An X25519 key pair. pub const KeyPair = struct { /// Public part. - public_key: [public_length]u8, + public_key: PublicKey, /// Secret part. - secret_key: [secret_length]u8, + secret_key: SecretKey, + + /// Seed (for key pair creation) length in bytes. + pub const seed_length = 32; /// Create a new key pair using an optional seed. pub fn create(seed: ?[seed_length]u8) IdentityElementError!KeyPair { diff --git a/lib/std/crypto/kyber_d00.zig b/lib/std/crypto/kyber_d00.zig index 00246bbf402a..6c7c4ea2c701 100644 --- a/lib/std/crypto/kyber_d00.zig +++ b/lib/std/crypto/kyber_d00.zig @@ -186,8 +186,6 @@ fn Kyber(comptime p: Params) type { pub const shared_length = common_shared_key_size; /// Length (in bytes) of a seed for deterministic encapsulation. pub const encaps_seed_length = common_encaps_seed_length; - /// Length (in bytes) of a seed for key generation. - pub const seed_length: usize = inner_seed_length + shared_length; /// Algorithm name. pub const name = p.name; @@ -335,6 +333,9 @@ fn Kyber(comptime p: Params) type { secret_key: SecretKey, public_key: PublicKey, + /// Length (in bytes) of a seed for key generation. + pub const seed_length: usize = inner_seed_length + shared_length; + /// Create a new key pair. /// If seed is null, a random seed will be generated. /// If a seed is provided, the key pair will be determinsitic. diff --git a/lib/std/crypto/pcurves/secp256k1.zig b/lib/std/crypto/pcurves/secp256k1.zig index 945abea931cd..f9cac38ebb26 100644 --- a/lib/std/crypto/pcurves/secp256k1.zig +++ b/lib/std/crypto/pcurves/secp256k1.zig @@ -167,7 +167,7 @@ pub const Secp256k1 = struct { /// Serialize a point using the uncompressed SEC-1 format. pub fn toUncompressedSec1(p: Secp256k1) [65]u8 { var out: [65]u8 = undefined; - out[0] = 4; + out[0] = 4; // uncompressed const xy = p.affineCoordinates(); out[1..33].* = xy.x.toBytes(.big); out[33..65].* = xy.y.toBytes(.big); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 7fff68471caa..0d42fc3fa954 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -1,59 +1,20 @@ -//! Plaintext: -//! * type: ContentType -//! * legacy_record_version: u16 = 0x0303, -//! * length: u16, -//! - The length (in bytes) of the following TLSPlaintext.fragment. The -//! length MUST NOT exceed 2^14 bytes. -//! * fragment: opaque -//! - the data being transmitted -//! -//! Ciphertext -//! * ContentType opaque_type = application_data; /* 23 */ -//! * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */ -//! * uint16 length; -//! * opaque encrypted_record[TLSCiphertext.length]; -//! -//! Handshake: -//! * type: HandshakeType -//! * length: u24 -//! * data: opaque -//! -//! ServerHello: -//! * ProtocolVersion legacy_version = 0x0303; -//! * Random random; -//! * opaque legacy_session_id_echo<0..32>; -//! * CipherSuite cipher_suite; -//! * uint8 legacy_compression_method = 0; -//! * Extension extensions<6..2^16-1>; -//! -//! Extension: -//! * ExtensionType extension_type; -//! * opaque extension_data<0..2^16-1>; - const std = @import("../std.zig"); +const builtin = @import("builtin"); +const client_mod = @import("tls/Client.zig"); +const server_mod = @import("tls/Server.zig"); const Tls = @This(); const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; +const native_endian = builtin.cpu.arch.endian(); -pub const Client = @import("tls/Client.zig"); - -pub const record_header_len = 5; -pub const max_cipertext_inner_record_len = 1 << 14; -pub const max_ciphertext_len = max_cipertext_inner_record_len + 256; -pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len; -pub const hello_retry_request_sequence = [32]u8{ - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -}; - -pub const close_notify_alert = [_]u8{ - @intFromEnum(AlertLevel.warning), - @intFromEnum(AlertDescription.close_notify), -}; +pub const Client = client_mod.Client; +pub const Server = server_mod.Server; -pub const ProtocolVersion = enum(u16) { +pub const Version = enum(u16) { + tls_1_0 = 0x0301, + tls_1_1 = 0x0302, tls_1_2 = 0x0303, tls_1_3 = 0x0304, _, @@ -61,28 +22,144 @@ pub const ProtocolVersion = enum(u16) { pub const ContentType = enum(u8) { invalid = 0, - change_cipher_spec = 20, - alert = 21, - handshake = 22, - application_data = 23, + change_cipher_spec = 0x14, + alert = 0x15, + handshake = 0x16, + application_data = 0x17, + heartbeat = 0x18, _, }; +/// Also called a Record. +pub const Plaintext = struct { + type: ContentType, + version: Version = .tls_1_0, + /// > The length MUST NOT exceed 2^14 bytes. + /// > An endpoint that receives a record that exceeds this length MUST terminate the connection + /// > with a "record_overflow" alert. + length: u16, + // `length` bytes follow which may contain a Message or a partial Message. + + pub const max_length = 1 << 14; +}; + +pub const Handshake = struct { + type: HandshakeType, + length: u24, + // `length` bytes follow +}; + +fn fieldsLen(comptime T: type) comptime_int { + var res: comptime_int = 0; + inline for (std.meta.fields(T)) |f| res += @sizeOf(f.type); + return res; +} + pub const HandshakeType = enum(u8) { + // hello_request = 0, client_hello = 1, server_hello = 2, + // hello_verify_request = 3, new_session_ticket = 4, end_of_early_data = 5, + // hello_retry_request = 3, encrypted_extensions = 8, certificate = 11, + // server_key_exchange = 12, certificate_request = 13, + // server_hello_done = 14, certificate_verify = 15, + // client_key_exchange = 16, finished = 20, + // certificate_url = 21, + // certificate_status = 22, + // supplemental_data = 23, key_update = 24, message_hash = 254, _, }; +pub const NewSessionTicket = struct { + ticket_lifetime: u32, + ticket_age_add: u32, + /// max len 255 + ticket_nonce: []const u8, + /// Should have at least one + ticket: []const u8, + extensions: []const Extension, +}; + +pub const CertificateRequest = struct { + /// Max len 255 + context: []const u8, + /// At least 2 + extensions: []const Extension, +}; + +pub const Certificate = struct { + /// Max len 255 + context: []const u8 = "", + entries: []const Entry, + + pub const Entry = struct { + /// Either ASN1_subjectPublicKeyInfo or cert_data based on CertificateType + data: []const u8, + extensions: []const Extension = &.{}, + + pub fn len(self: @This(), is_client: bool) usize { + return 3 + + self.data.len + + @sizeOf(u16) + slice_len(Extension, self.extensions, is_client); + } + + pub fn write(self: @This(), stream: anytype) !void { + try stream.write(u24, @intCast(self.data.len)); + try stream.writeAll(self.data); + try stream.writeArray(2, Extension, self.extensions); + } + }; + + const Self = @This(); + + pub fn write(self: Self, stream: anytype) !void { + std.debug.assert(!stream.is_client); + try stream.write(HandshakeType, .certificate); + + const entries_len = slice_len(Entry, self.entries, stream.is_client); + std.debug.print("entries_len {d}\n", .{ entries_len }); + + const length: usize = @sizeOf(u8) + 3 + entries_len; + try stream.write(u24, @intCast(length)); + + // TODO: handle this being sent in response to certificate request + std.debug.assert(self.context.len == 0); + try stream.write(u8, 0); + + try stream.write(u24, @intCast(entries_len)); + for (self.entries) |e| try stream.write(Entry, e); + } +}; + +pub const CertificateVerify = struct { + algorithm: SignatureScheme, + signature: []const u8, +}; + +pub const Finished = struct { + verify_data: []const u8, +}; + +pub const KeyUpdate = struct { + request: Request, + + pub const Request = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, + }; +}; + +// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml pub const ExtensionType = enum(u16) { /// RFC 6066 server_name = 0, @@ -90,8 +167,10 @@ pub const ExtensionType = enum(u16) { max_fragment_length = 1, /// RFC 6066 status_request = 5, - /// RFC 8422, 7919 + /// RFC 8422, 7919. renamed from "elliptic_curves" supported_groups = 10, + /// RFC 8422 S5.1.2 + ec_point_formats = 11, /// RFC 8446 signature_algorithms = 13, /// RFC 5764 @@ -108,6 +187,12 @@ pub const ExtensionType = enum(u16) { server_certificate_type = 20, /// RFC 7685 padding = 21, + /// RFC7366 + encrypt_then_mac = 22, + /// RFC 7627 + extended_master_secret = 23, + /// RFC 5077 + session_ticket = 35, /// RFC 8446 pre_shared_key = 41, /// RFC 8446 @@ -128,109 +213,127 @@ pub const ExtensionType = enum(u16) { signature_algorithms_cert = 50, /// RFC 8446 key_share = 51, - + /// Reserved for private use. + none = 65280, _, }; -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, -}; +pub const Alert = packed struct { + /// > In TLS 1.3, the severity is implicit in the type of alert being sent + /// > and the "level" field can safely be ignored. + level: Level, + description: Description, -pub const AlertDescription = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, }; + pub const Description = enum(u8) { + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, - pub fn toError(alert: AlertDescription) Error!void { - return switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, - }; + pub fn toError(alert: @This()) Error { + return switch (alert) { + .close_notify => {}, // not an error + .unexpected_message => error.TlsAlertUnexpectedMessage, + .bad_record_mac => error.TlsAlertBadRecordMac, + .record_overflow => error.TlsAlertRecordOverflow, + .handshake_failure => error.TlsAlertHandshakeFailure, + .bad_certificate => error.TlsAlertBadCertificate, + .unsupported_certificate => error.TlsAlertUnsupportedCertificate, + .certificate_revoked => error.TlsAlertCertificateRevoked, + .certificate_expired => error.TlsAlertCertificateExpired, + .certificate_unknown => error.TlsAlertCertificateUnknown, + .illegal_parameter => error.TlsAlertIllegalParameter, + .unknown_ca => error.TlsAlertUnknownCa, + .access_denied => error.TlsAlertAccessDenied, + .decode_error => error.TlsAlertDecodeError, + .decrypt_error => error.TlsAlertDecryptError, + .protocol_version => error.TlsAlertProtocolVersion, + .insufficient_security => error.TlsAlertInsufficientSecurity, + .internal_error => error.TlsAlertInternalError, + .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => error.TlsAlertMissingExtension, + .unsupported_extension => error.TlsAlertUnsupportedExtension, + .unrecognized_name => error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, + .certificate_required => error.TlsAlertCertificateRequired, + .no_application_protocol => error.TlsAlertNoApplicationProtocol, + _ => error.TlsAlertUnknown, + }; + } + }; + + pub const len = @sizeOf(Level) + @sizeOf(Description); + + pub fn write(self: @This(), stream: anytype) !void { + try stream.write(Level, self.level); + try stream.write(Description, self.description); } }; +/// Scheme for certificate verification +/// +/// Note: This enum is named `SignatureScheme` because there is already a +/// `SignatureAlgorithm` type in TLS 1.2, which this replaces. pub const SignatureScheme = enum(u16) { // RSASSA-PKCS1-v1_5 algorithms rsa_pkcs1_sha256 = 0x0401, @@ -261,8 +364,41 @@ pub const SignatureScheme = enum(u16) { ecdsa_sha1 = 0x0203, _, + + fn Ecdsa(comptime self: @This()) type { + return switch (self) { + .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, + else => @compileError("bad scheme"), + }; + } + + fn Hash(comptime self: @This()) type { + return switch (self) { + .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, + .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, + .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, + else => @compileError("bad scheme"), + }; + } + + fn Eddsa(comptime self: @This()) type { + return switch (self) { + .ed25519 => crypto.sign.Ed25519, + else => @compileError("bad scheme"), + }; + } +}; +pub const supported_signature_schemes = [_]SignatureScheme{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .ed25519, }; +/// Key exchange formats pub const NamedGroup = enum(u16) { // Elliptic Curve Groups (ECDHE) secp256r1 = 0x0017, @@ -278,40 +414,1092 @@ pub const NamedGroup = enum(u16) { ffdhe6144 = 0x0103, ffdhe8192 = 0x0104, - // Hybrid post-quantum key agreements - x25519_kyber512d00 = 0xFE30, + // Hybrid post-quantum key agreements. Still in draft. x25519_kyber768d00 = 0x6399, _, }; +pub fn NamedGroupT(comptime named_group: NamedGroup) type { + return switch (named_group) { + .secp256r1 => crypto.sign.ecdsa.EcdsaP256Sha256, + .secp384r1 => crypto.sign.ecdsa.EcdsaP384Sha384, + .x25519 => crypto.dh.X25519, + .x25519_kyber768d00 => X25519Kyber768Draft, + else => |t| @compileError("unsupported named group " ++ @tagName(t)), + }; +} +// Hybrid share, see https://www.ietf.org/archive/id/draft-ietf-tls-hybrid-design-05.html +pub const X25519Kyber768Draft = struct { + pub const X25519 = NamedGroupT(.x25519); + pub const Kyber768 = crypto.kem.kyber_d00.Kyber768; + pub const KeyPair = struct { + x25519: X25519.KeyPair, + kyber768d00: Kyber768.KeyPair, + }; + pub const PublicKey = struct { + x25519: X25519.PublicKey, + kyber768d00: Kyber768.PublicKey, + + pub const bytes_length = X25519.public_length + Kyber768.PublicKey.bytes_length; + + pub fn toBytes(self: @This()) [bytes_length]u8 { + return self.x25519 ++ self.kyber768d00.toBytes(); + } + + pub fn ciphertext(self: @This()) [X25519.public_length + Kyber768.ciphertext_length]u8 { + return self.x25519 ++ self.kyber768d00.encaps(null).ciphertext; + } + }; +}; +pub const KeyPair = union(NamedGroup) { + secp256r1: NamedGroupT(.secp256r1).KeyPair, + secp384r1: NamedGroupT(.secp384r1).KeyPair, + secp521r1: void, + x25519: NamedGroupT(.x25519).KeyPair, + x448: void, + + ffdhe2048: void, + ffdhe3072: void, + ffdhe4096: void, + ffdhe6144: void, + ffdhe8192: void, + + x25519_kyber768d00: NamedGroupT(.x25519_kyber768d00).KeyPair, + + pub fn toKeyShare(self: @This()) KeyShare { + return switch (self) { + .x25519_kyber768d00 => |k| .{ .x25519_kyber768d00 = X25519Kyber768Draft.PublicKey{ + .x25519 = k.x25519.public_key, + .kyber768d00 = k.kyber768d00.public_key, + } }, + .secp256r1 => |k| .{ .secp256r1 = k.public_key }, + .secp384r1 => |k| .{ .secp384r1 = k.public_key }, + .x25519 => |k| .{ .x25519 = k.public_key }, + inline else => |_, t| @unionInit(KeyShare, @tagName(t), {}), + }; + } +}; +/// The public key portion of a KeyPair. Saves bytes. +pub const KeyShare = union(NamedGroup) { + secp256r1: NamedGroupT(.secp256r1).PublicKey, + secp384r1: NamedGroupT(.secp384r1).PublicKey, + secp521r1: void, + x25519: NamedGroupT(.x25519).PublicKey, + x448: void, + + ffdhe2048: void, + ffdhe3072: void, + ffdhe4096: void, + ffdhe6144: void, + ffdhe8192: void, + + x25519_kyber768d00: NamedGroupT(.x25519_kyber768d00).PublicKey, + + pub const max_len = NamedGroupT(.x25519_kyber768d00).PublicKey.bytes_length; + + const Self = @This(); + + pub fn write(self: Self, stream: anytype) !void { + try stream.write(NamedGroup, self); + const public = switch (self) { + .x25519_kyber768d00 => |k| if (stream.is_client) &k.toBytes() else &k.ciphertext(), + .secp256r1 => |k| &k.toUncompressedSec1(), + .secp384r1 => |k| &k.toUncompressedSec1(), + .x25519 => |k| &k, + else => "", + }; + try stream.writeArray(2, u8, public); + } + + pub fn keyLen(self: Self, is_client: bool) usize { + return switch (self) { + .x25519_kyber768d00 => |k| if (is_client) @TypeOf(k).bytes_length else X25519Kyber768Draft.Kyber768.ciphertext_length, + .secp256r1 => |k| @TypeOf(k).uncompressed_sec1_encoded_length, + .secp384r1 => |k| @TypeOf(k).uncompressed_sec1_encoded_length, + .x25519 => |k| k.len, + else => 0, + }; + } + + pub fn len(self: Self, is_client: bool) usize { + return @sizeOf(NamedGroup) + @sizeOf(u16) + self.keyLen(is_client); + } + + pub const Header = struct { + group: NamedGroup, + len: u16, + + pub fn read(stream: anytype) !@This() { + const group = try stream.read(NamedGroup); + const length = try stream.read(u16); + return .{ .group = group, .len = length }; + } + }; +}; +/// In descending order of preference +pub const supported_groups = [_]NamedGroup{ + .x25519_kyber768d00, + .secp256r1, + .secp384r1, + .x25519, +}; pub const CipherSuite = enum(u16) { - AES_128_GCM_SHA256 = 0x1301, - AES_256_GCM_SHA384 = 0x1302, - CHACHA20_POLY1305_SHA256 = 0x1303, - AES_128_CCM_SHA256 = 0x1304, - AES_128_CCM_8_SHA256 = 0x1305, - AEGIS_256_SHA512 = 0x1306, - AEGIS_128L_SHA256 = 0x1307, + aes_128_gcm_sha256 = 0x1301, + aes_256_gcm_sha384 = 0x1302, + chacha20_poly1305_sha256 = 0x1303, + aegis_256_sha512 = 0x1306, + aegis_128l_sha256 = 0x1307, + empty_renegotiation_info_scsv = 0x00ff, _, + + pub fn Hash(comptime self: @This()) type { + return switch (self) { + .aes_128_gcm_sha256 => crypto.hash.sha2.Sha256, + .aes_256_gcm_sha384 => crypto.hash.sha2.Sha384, + .chacha20_poly1305_sha256 => crypto.hash.sha2.Sha256, + .aegis_256_sha512 => crypto.hash.sha2.Sha512, + .aegis_128l_sha256 => crypto.hash.sha2.Sha256, + else => @compileError("unknown suite " ++ @tagName(self)), + }; + } + + pub fn Aead(comptime self: @This()) type { + return switch (self) { + .aes_128_gcm_sha256 => crypto.aead.aes_gcm.Aes128Gcm, + .aes_256_gcm_sha384 => crypto.aead.aes_gcm.Aes256Gcm, + .chacha20_poly1305_sha256 => crypto.aead.chacha_poly.ChaCha20Poly1305, + .aegis_256_sha512 => crypto.aead.aegis.Aegis256, + .aegis_128l_sha256 => crypto.aead.aegis.Aegis128L, + else => @compileError("unknown suite " ++ @tagName(self)), + }; + } +}; + +pub const HandshakeCipher = union(CipherSuite) { + aes_128_gcm_sha256: HandshakeCipherT(.aes_128_gcm_sha256), + aes_256_gcm_sha384: HandshakeCipherT(.aes_256_gcm_sha384), + chacha20_poly1305_sha256: HandshakeCipherT(.chacha20_poly1305_sha256), + aegis_256_sha512: HandshakeCipherT(.aegis_256_sha512), + aegis_128l_sha256: HandshakeCipherT(.aegis_128l_sha256), + empty_renegotiation_info_scsv: void, + + const Self = @This(); + + pub fn init(suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8) Self { + debugPrint("hello_hash", hello_hash); + debugPrint("shared_key", shared_key); + switch (suite) { + else => unreachable, + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |tag| { + var res = @unionInit(Self, @tagName(tag), .{ + .handshake_secret = undefined, + .master_secret = undefined, + .client_handshake_key = undefined, + .server_handshake_key = undefined, + .client_finished_key = undefined, + .server_finished_key = undefined, + .client_handshake_iv = undefined, + .server_handshake_iv = undefined, + }); + const P = std.meta.TagPayloadByName(Self, @tagName(tag)); + const p = &@field(res, @tagName(tag)); + + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = emptyHash(P.Hash); + + const derived_secret = hkdfExpandLabel( + P.Hkdf, + early_secret, + "derived", + &empty_hash, + P.Hash.digest_length, + ); + p.handshake_secret = P.Hkdf.extract(&derived_secret, shared_key); + const ap_derived_secret = hkdfExpandLabel( + P.Hkdf, + p.handshake_secret, + "derived", + &empty_hash, + P.Hash.digest_length, + ); + p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel( + P.Hkdf, + p.handshake_secret, + "c hs traffic", + hello_hash, + P.Hash.digest_length, + ); + const server_secret = hkdfExpandLabel( + P.Hkdf, + p.handshake_secret, + "s hs traffic", + hello_hash, + P.Hash.digest_length, + ); + p.client_finished_key = hkdfExpandLabel( + P.Hkdf, + client_secret, + "finished", + "", + P.Hmac.key_length, + ); + p.server_finished_key = hkdfExpandLabel( + P.Hkdf, + server_secret, + "finished", + "", + P.Hmac.key_length, + ); + p.client_handshake_key = hkdfExpandLabel( + P.Hkdf, + client_secret, + "key", + "", + P.AEAD.key_length, + ); + p.server_handshake_key = hkdfExpandLabel( + P.Hkdf, + server_secret, + "key", + "", + P.AEAD.key_length, + ); + p.client_handshake_iv = hkdfExpandLabel( + P.Hkdf, + client_secret, + "iv", + "", + P.AEAD.nonce_length, + ); + p.server_handshake_iv = hkdfExpandLabel( + P.Hkdf, + server_secret, + "iv", + "", + P.AEAD.nonce_length, + ); + + return res; + }, + } + } + + pub fn print(self: Self) void { + switch (self) { + .empty_renegotiation_info_scsv => {}, + inline else => |v| v.print(), + } + } +}; + +pub const ApplicationCipher = union(CipherSuite) { + aes_128_gcm_sha256: ApplicationCipherT(.aes_128_gcm_sha256), + aes_256_gcm_sha384: ApplicationCipherT(.aes_256_gcm_sha384), + chacha20_poly1305_sha256: ApplicationCipherT(.chacha20_poly1305_sha256), + aegis_256_sha512: ApplicationCipherT(.aegis_256_sha512), + aegis_128l_sha256: ApplicationCipherT(.aegis_128l_sha256), + empty_renegotiation_info_scsv: void, +}; + +/// RFC 8446 S4.1.2 +pub const ClientHello = struct { + /// Legacy field for TLS 1.2 middleboxes + version: Version = .tls_1_2, + random: [32]u8, + /// Legacy session resumption + session_id: []const u8, + /// In descending order of preference + cipher_suites: []const CipherSuite, + // Legacy and unsecure, requires at least 1 for compat, MUST error if anything else + compression_methods: [1]u8 = .{0}, + // Certain extensions are mandatory for TLS 1.3 + extensions: []const Extension, + + const Self = @This(); + + pub fn write(self: Self, stream: anytype) !void { + try stream.write(HandshakeType, .client_hello); + + var length = @sizeOf(Version) + + self.random.len + + @sizeOf(u8) + self.session_id.len + + @sizeOf(u16) + self.cipher_suites.len * @sizeOf(CipherSuite) + + @sizeOf(u8) + self.compression_methods.len; + length += @sizeOf(u16); + for (self.extensions) |e| length += e.len(true); + try stream.write(u24, @intCast(length)); + + try stream.write(Version, self.version); + try stream.writeAll(&self.random); + try stream.writeArray(1, u8, self.session_id); + try stream.writeArray(2, CipherSuite, self.cipher_suites); + try stream.writeArray(1, u8, &self.compression_methods); + try stream.writeArray(2, Extension, self.extensions); + } +}; + +pub const ServerHello = struct { + /// Legacy field for TLS 1.2 middleboxes + version: Version = .tls_1_2, + /// Should be an echo of the sent `client_random`. + random: [32]u8, + /// Legacy session resumption + session_id: []const u8, + cipher_suite: CipherSuite, + compression_method: u8 = 0, + /// Certain extensions are mandatory for TLS 1.3 + extensions: []const Extension, + + /// When `random` equals this it means the client should resend the `ClientHello`. + pub const hello_retry_request = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, + }; + + const Self = @This(); + + pub fn len(self: Self, _: bool) usize { + var res = @sizeOf(Version) + + self.random.len + + @sizeOf(u8) + self.session_id.len + + @sizeOf(CipherSuite) + + @sizeOf(u8); + res += @sizeOf(u16); + for (self.extensions) |e| res += e.len(false); + return res; + } + + pub fn write(self: Self, stream: anytype) !void { + try stream.write(HandshakeType, .server_hello); + const length = self.len(false); + try stream.write(u24, @intCast(length)); + try stream.write(Version, self.version); + try stream.writeAll(&self.random); + try stream.writeArray(1, u8, self.session_id); + try stream.write(CipherSuite, self.cipher_suite); + try stream.write(u8, self.compression_method); + try stream.writeArray(2, Extension, self.extensions); + } +}; + +pub const EncryptedExtensions = struct { + extensions: []const Extension, + + const Self = @This(); + + pub fn len(self: Self, _: bool) usize { + var res: usize = @sizeOf(HandshakeType) + @sizeOf(u24); + for (self.extensions) |e| res += e.len(false); + return res; + } + + pub fn write(self: Self, stream: anytype) !void { + try stream.write(HandshakeType, .encrypted_extensions); + const ext_len = @sizeOf(u16) + slice_len(Extension, self.extensions, false); + try stream.write(u24, @intCast(ext_len)); + try stream.writeArray(2, Extension, self.extensions); + } +}; + +pub const Extension = union(ExtensionType) { + // MUST NOT contain more than one name of the same name_type + server_name: []const ServerName, + max_fragment_length: void, + status_request: void, + supported_groups: []const NamedGroup, + ec_point_formats: []const EcPointFormat, + signature_algorithms: []const SignatureScheme, + use_srtp: void, + /// https://en.wikipedia.org/wiki/Heartbleed + heartbeat: void, + application_layer_protocol_negotiation: void, + signed_certificate_timestamp: void, + client_certificate_type: void, + server_certificate_type: void, + padding: void, + encrypt_then_mac: void, + extended_master_secret: void, + session_ticket: void, + pre_shared_key: void, + early_data: void, + supported_versions: []const Version, + cookie: void, + psk_key_exchange_modes: []const PskKeyExchangeMode, + certificate_authorities: void, + oid_filters: void, + post_handshake_auth: void, + signature_algorithms_cert: void, + key_share: []const KeyShare, + none: void, + + const Self = @This(); + + pub fn len(self: Self, is_client: bool) usize { + var res: usize = @sizeOf(ExtensionType) + @sizeOf(u16); + switch (self) { + inline else => |items| { + const T = @TypeOf(items); + if (T == void) return res; + if (is_client) { + res += switch (self) { + .supported_versions, .ec_point_formats, .psk_key_exchange_modes => 1, + .server_name, .supported_groups, .signature_algorithms, .key_share => 2, + else => 0, + }; + } + res += slice_len(@typeInfo(T).Pointer.child, items, is_client); + }, + } + return res; + } + + pub fn write(self: Self, stream: anytype) !void { + const prefix_len = stream.is_client; + switch (self) { + inline else => |items, tag| { + const T = @TypeOf(items); + try stream.write(ExtensionType, tag); + const length = self.len(prefix_len) - @sizeOf(ExtensionType) - @sizeOf(u16); + try stream.write(u16, @intCast(length)); + if (T == void) return; + }, + } + switch (self) { + inline .supported_versions, + .ec_point_formats, + .psk_key_exchange_modes, + => |items| { + const T = @typeInfo(@TypeOf(items)).Pointer.child; + if (prefix_len) { + try stream.writeArray(1, T, items); + } else { + try stream.writeArray(0, T, items); + } + }, + inline .server_name, + .supported_groups, + .signature_algorithms, + .key_share, + => |items| { + const T = @typeInfo(@TypeOf(items)).Pointer.child; + if (prefix_len) { + try stream.writeArray(2, T, items); + } else { + try stream.writeArray(0, T, items); + } + }, + inline else => |_, tag| { + @panic("unsupported extension " ++ @tagName(tag)); + }, + } + } + + pub const Header = struct { + type: ExtensionType, + len: u16, + + pub fn read(stream: anytype) !@This() { + const ty = try stream.read(ExtensionType); + const length = try stream.read(u16); + return .{ .type = ty, .len = length }; + } + }; }; -pub const CertificateType = enum(u8) { - X509 = 0, - RawPublicKey = 2, +/// RFC 8446 S4.2.9 +pub const PskKeyExchangeMode = enum(u8) { + /// PSK-only key establishment. In this mode, the server + /// MUST NOT supply a "key_share" value. + ke = 1, + /// PSK with (EC)DHE key establishment. In this mode, the + /// client and server MUST supply "key_share" values as described in + /// Section 4.2.8. + dhe_ke = 2, _, }; -pub const KeyUpdateRequest = enum(u8) { - update_not_requested = 0, - update_requested = 1, +pub const UncompressedPointRepresentation = struct { + form: u8 = 4, // uncompressed + x: []const u8, + y: []const u8, +}; + +/// RFC 8446 S4.1.3 +pub const ServerName = struct { + type: NameType = .host_name, + host_name: []const u8, + + pub const NameType = enum(u8) { host_name = 0, _ }; + + pub fn len(self: @This(), _: bool) usize { + return @sizeOf(NameType) + 2 + self.host_name.len; + } + + pub fn write(self: @This(), stream: anytype) !void { + try stream.write(NameType, self.type); + try stream.writeArray(2, u8, self.host_name); + } +}; + +pub const EcPointFormat = enum(u8) { + uncompressed = 0, + ansiX962_compressed_prime = 1, + ansiX962_compressed_char2 = 2, _, }; -pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { +/// RFC 5246 S7.1 +pub const ChangeCipherSpec = enum(u8) { change_cipher_spec = 1, _ }; + +/// This is an example of the type that is needed by the read and write +/// functions. It can have any fields but it must at least have these +/// functions. +/// +/// Note that `std.net.Stream` conforms to this interface. +/// +/// This declaration serves as documentation only. +pub const StreamInterface = struct { + /// Can be any error set. + pub const ReadError = error{}; + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of + /// a stream is not an error condition. + pub fn readAll(this: @This(), buffer: []u8) ReadError!usize { + _ = .{ this, buffer }; + @panic("unimplemented"); + } + + /// Can be any error set. + pub const WriteError = error{}; + + /// Returns the number of bytes written, which may be less than the buffer space provided. + pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { + _ = .{ this, iovecs }; + @panic("unimplemented"); + } + + /// The `iovecs` parameter is mutable in case this function needs to mutate + /// the fields in order to handle partial writes from the underlying layer. + pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!void { + // This can be implemented in terms of writev, or specialized if desired. + _ = .{ this, iovecs }; + @panic("unimplemented"); + } +}; + +/// Abstraction over TLS record layer that fragments all messages. +/// It also encrypts and decrypts .application_data messages. +/// This makes it suitable for both clients and servers. +/// +/// StreamType MUST satisfy `StreamInterface`. +/// StreamType MUST satisfy: +/// * fn hash(self: @This(), bytes: []const u8) void +/// * fn peek(self: @This()) [_]u8 +/// Cannot read and write at the same time. +pub fn Stream( + comptime fragment_size: usize, + comptime StreamType: type, + comptime TranscriptHash: type, +) type { + // TODO: Support RFC 6066 MaxFragmentLength and give fragment_size option to Client+Server. + if (fragment_size > std.math.maxInt(u16)) @compileError("choose a smaller fragment_size"); + return struct { - pub const AEAD = AeadType; - pub const Hash = HashType; + stream: *StreamType, + /// > For concreteness, the transcript hash is always taken from the + /// > following sequence of handshake messages, starting at the first + /// > ClientHello and including only those messages that were sent: + /// > ClientHello, HelloRetryRequest, ClientHello, ServerHello, + /// > EncryptedExtensions, server CertificateRequest, server Certificate, + /// > server CertificateVerify, server Finished, EndOfEarlyData, client + /// > Certificate, client CertificateVerify, client Finished. + transcript_hash: TranscriptHash, + /// Used for both reading and writing. Cannot be doing both at the same time. Must be twice + /// fragment size to handle `readAll(fragment_size)`. In practice this is only approachable for + /// the SNI hostname which may be up to 8KB. + buffer: [fragment_size * 2]u8 = undefined, + /// Unflushed part of `buffer`. + view: []const u8 = "", + + /// When sending this is the record type that will be sent. + /// If a cipher is in use it will be encrypted in `inner_content_type`. + content_type: ContentType = .handshake, + + /// When receiving fragments this is the next expected fragment type. + handshake_type: ?HandshakeType = null, + + /// Used to encrypt and decrypt .application_data messages until application_cipher is not null. + handshake_cipher: ?HandshakeCipher = null, + /// Used to encrypt and decrypt .application_data messages. + application_cipher: ?ApplicationCipher = null, + + /// True when we send or receive `close_notify` + closed: bool = false, + + /// Version to send out in record headers + version: Version = .tls_1_0, + + /// True if we're being used as a client. Certain shared struct types serialize differently + /// based on this. + is_client: bool, + + const Self = @This(); + + pub const ReadError = StreamType.ReadError || error{ + EndOfStream, + InsufficientEntropy, + DiskQuota, + LockViolation, + NotOpenForWriting, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsDecryptFailure, + TlsRecordOverflow, + TlsBadRecordMac, + CertificateFieldHasInvalidLength, + CertificateHostMismatch, + CertificatePublicKeyInvalid, + CertificateExpired, + CertificateFieldHasWrongDataType, + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateSignatureAlgorithmMismatch, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureNamedCurveUnsupported, + CertificateSignatureUnsupportedBitCount, + TlsCertificateNotVerified, + TlsBadSignatureScheme, + TlsBadRsaSignatureBitCount, + InvalidEncoding, + IdentityElement, + SignatureVerificationFailed, + TlsDecryptError, + TlsConnectionTruncated, + TlsDecodeError, + UnsupportedCertificateVersion, + CertificateTimeInvalid, + CertificateHasUnrecognizedObjectId, + CertificateHasInvalidBitString, + MessageTooLong, + NegativeIntoUnsigned, + TargetTooSmall, + BufferTooSmall, + InvalidSignature, + NotSquare, + NonCanonical, + WeakPublicKey, + }; + pub const WriteError = StreamType.WriteError; + // pub const Reader = std.io.Reader(*Self, Error, read); + // pub const Writer = std.io.Writer(*Self, WriteError, write); + + fn ciphertextOverhead(self: Self) usize { + if (self.application_cipher) |a| { + switch (a) { + .empty_renegotiation_info_scsv => {}, + inline else => |c| return @TypeOf(c).AEAD.tag_length + @sizeOf(ContentType), + } + } + if (self.handshake_cipher) |a| { + switch (a) { + .empty_renegotiation_info_scsv => {}, + inline else => |c| return @TypeOf(c).AEAD.tag_length + @sizeOf(ContentType), + } + } + return 0; + } + + fn maxFragmentSize(self: Self) usize { + return fragment_size - self.ciphertextOverhead(); + } + + pub fn flush(self: *Self) WriteError!void { + const aead_overhead = self.ciphertextOverhead(); + const plaintext = Plaintext{ + .type = if (aead_overhead > 0) .application_data else self.content_type, + .version = self.version, + .length = @intCast(self.view.len + aead_overhead), + }; + const header = Encoder.encode(Plaintext, plaintext); + + var aead: []const u8 = ""; + if (self.application_cipher) |*a| { + switch (a.*) { + .empty_renegotiation_info_scsv => {}, + inline else => |*c| { + std.debug.assert(self.view.ptr == &self.buffer); + self.buffer[self.view.len] = @intFromEnum(self.content_type); + self.view = self.buffer[0..self.view.len + 1]; + aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); + }, + } + } else if (self.handshake_cipher) |*a| { + switch (a.*) { + .empty_renegotiation_info_scsv => {}, + inline else => |*c| { + std.debug.assert(self.view.ptr == &self.buffer); + self.buffer[self.view.len] = @intFromEnum(self.content_type); + self.view = self.buffer[0..self.view.len + 1]; + aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); + }, + } + } else { + switch (plaintext.type) { + .change_cipher_spec, .alert => {}, + else => self.transcript_hash.update(self.view), + } + } + + var iovecs = [_]std.os.iovec_const{ + .{ .iov_base = &header, .iov_len = header.len }, + .{ .iov_base = self.view.ptr, .iov_len = self.view.len }, + .{ .iov_base = aead.ptr, .iov_len = aead.len }, + }; + try self.stream.writevAll(&iovecs); + self.view = self.buffer[0..0]; + } + + pub fn writeAll(self: *Self, bytes: []const u8) WriteError!void { + if (bytes.len > self.maxFragmentSize()) { + @panic("cannot flush bytes because have to append 1 byte of record type if encrypted"); + // try self.flush(); + // self.view = bytes; + // try self.flush(); + } else { + if (self.view.len + bytes.len >= self.maxFragmentSize()) { + // TODO: optimization copy bytes before flushing. + try self.flush(); + } + @memcpy(self.buffer[self.view.len..][0..bytes.len], bytes); + self.view = self.buffer[0 .. self.view.len + bytes.len]; + } + } + + pub fn write(self: *Self, comptime T: type, value: T) WriteError!void { + switch (@typeInfo(T)) { + .Int, .Enum => try self.writeAll(&Encoder.encode(T, value)), + else => { + if (@hasDecl(T, "write")) { + try value.write(self); + } else { + @compileError("expected fn write(stream: anytype) on type " ++ @typeName(T)); + } + }, + } + } + + pub fn writeArray(self: *Self, len_bytes: comptime_int, comptime T: type, values: []const T) !void { + if (len_bytes != 0) { + const len = slice_len(T, values, self.is_client); + switch (len_bytes) { + 1 => try self.write(u8, @intCast(len)), + 2 => try self.write(u16, @intCast(len)), + 3 => try self.write(u24, @intCast(len)), + else => @compileError("unsupported prefix len"), + } + } + for (values) |v| try self.write(T, v); + } + + /// Returns slice that is valid until next `readAll` call. + pub fn readAll(self: *Self, len: usize) ReadError![]const u8 { + if (len >= self.maxFragmentSize()) { + // Only workaround is to use an allocator for self.buffer + return error.TlsRecordOverflow; + } else { + if (len <= self.view.len) { + defer self.view = self.view[len..]; + return self.view[0..len]; + } else { + // We need another fragment. + // Copy last (hopefully small) portion of buffer to start. It may alias. + std.mem.copyForwards(u8, &self.buffer, self.view); + self.view = self.buffer[0..self.view.len]; + try self.readFragment(self.handshake_type); + return try self.readAll(len); + } + } + } + + /// Read plaintext fragment from `self.stream` into `self.buffer`. Checks message has correct + /// `content_type`. + pub fn readFragment(self: *Self, handshake_type: ?HandshakeType) ReadError!void { + self.handshake_type = handshake_type; + var plaintext_header: [fieldsLen(Plaintext)]u8 = undefined; + var n_read: usize = 0; + + var ty: ContentType = .invalid; + var length: u16 = 0; + + while (true) { + n_read = try self.stream.readAll(&plaintext_header); + if (n_read != plaintext_header.len) return error.TlsConnectionTruncated; + self.view = &plaintext_header; + ty = try self.read(ContentType); + _ = try self.read(Version); + length = try self.read(u16); + + switch (ty) { + .alert => { + const level = try self.read(Alert.Level); + const description = try self.read(Alert.Description); + std.log.debug("TLS alert {} {}", .{ level, description }); + + return error.TlsUnexpectedMessage; + }, + // An implementation may receive an unencrypted record of type + // change_cipher_spec consisting of the single byte value 0x01 at any + // time after the first ClientHello message has been sent or received + // and before the peer's Finished message has been received and MUST + // simply drop it without further processing. + .change_cipher_spec => { + if (self.application_cipher != null) return error.TlsUnexpectedMessage; + var next_byte: [1]u8 = undefined; + n_read = try self.stream.readAll(&next_byte); + if (length != 1 or n_read != 1 or next_byte[0] != 1) return error.TlsIllegalParameter; + }, + else => break, + } + } + if (ty != self.content_type) return error.TlsDecodeError; + + if (length > self.maxFragmentSize()) return error.TlsRecordOverflow; + if (self.view.len > self.maxFragmentSize()) return error.TlsDecodeError; // Should have read more before calling readFragment again. + + const dest = self.buffer[self.view.len..][0..length]; + n_read = try self.stream.readAll(dest); + if (n_read != length) return error.TlsConnectionTruncated; + + self.view = self.buffer[0 .. self.view.len + length]; + + if (ty == .application_data and self.handshake_cipher != null) { + switch (self.handshake_cipher.?) { + .empty_renegotiation_info_scsv => {}, + inline else => |*p| { + const P = @TypeOf(p.*); + const tag_len = P.AEAD.tag_length; + + const ciphertext = self.view[0..self.view.len - tag_len]; + debugPrint("ciphertext", ciphertext); + const tag = self.view[self.view.len - tag_len..][0..tag_len].*; + debugPrint("tag", tag); + + try p.decrypt( + ciphertext, + &plaintext_header, + tag, + self.is_client, + self.buffer[0..ciphertext.len], + ); + self.view = self.buffer[0..ciphertext.len]; + }, + } + + self.transcript_hash.update(self.view); + + const handshake_ty = try self.read(HandshakeType); + if (handshake_type == null or handshake_ty != handshake_type) return error.TlsDecodeError; + const handshake_len = try self.read(u24); + std.debug.print("handshake_len {d}\n", .{ handshake_len }); + } else { + self.transcript_hash.update(self.view); + if (handshake_type) |expected| { + const actual = try self.read(HandshakeType); + if (actual != expected) return error.TlsDecodeError; + } + } + } + + pub fn read(self: *Self, comptime T: type) ReadError!T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + const byte = try self.readAll(1); + return byte[0]; + }, + 16 => { + const bytes = try self.readAll(2); + const b0: u16 = bytes[0]; + const b1: u16 = bytes[1]; + return (b0 << 8) | b1; + }, + 24 => { + const bytes = try self.readAll(3); + const b0: u24 = bytes[0]; + const b1: u24 = bytes[1]; + const b2: u24 = bytes[2]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + const int = try self.read(info.tag_type); + return @enumFromInt(int); + }, + else => { + if (@hasDecl(T, "read")) { + return try T.read(self); + } else { + @compileError("expected fn read(stream: anytype): @This() on type " ++ @typeName(T)); + } + }, + } + } + + /// Read a u8 prefixed array. Valid until next `read`. + pub fn readSmallArray(self: *Self, comptime T: type) ReadError![]align(1) const T { + if (std.math.maxInt(u8) > self.maxFragmentSize()) @panic("increase fragment_size"); + const len = try self.read(u8); + const old_view = self.view; + var bytes = try self.readAll(len); + if (@sizeOf(T) > 1) { + self.view = old_view; + for (0..len / @sizeOf(T)) |i| { + const val_bytes = @constCast(bytes[i * @sizeOf(T) ..][0..@sizeOf(T)]); + var val = try self.read(T); + @memcpy(val_bytes, std.mem.asBytes(&val)); + } + } + return std.mem.bytesAsSlice(T, bytes); + } + + fn Iterator(comptime T: type) type { + return struct { + stream: *Self, + expected_len: usize, + start: usize, + + pub fn next(self: *@This()) !?T { + const cur = @intFromPtr(self.stream.view.ptr) - @intFromPtr(&self.stream.buffer); + const len = cur - self.start; + if (len > self.expected_len) return error.TlsUnexpectedMessage; // overread + if (len == self.expected_len) return null; + + return try self.stream.read(T); + } + }; + } + + pub fn iterator(self: *Self, comptime Tag: type) !Iterator(Tag) { + const expected_len = try self.read(u16); + const start = @intFromPtr(self.view.ptr) - @intFromPtr(&self.buffer); + return Iterator(Tag){ + .stream = self, + .expected_len = expected_len, + .start = start, + }; + } + + pub fn extensions(self: *Self) !Iterator(Extension.Header) { + return self.iterator(Extension.Header); + } + + pub fn eof(self: Self) bool { + return self.closed and self.view.len == 0; + } + }; +} + +fn slice_len(comptime T: type, values: []const T, is_client: bool) usize { + switch (@typeInfo(T)) { + .Int, .Enum => return @sizeOf(T) * values.len, + else => { + var len: usize = 0; + for (values) |v| len += v.len(is_client); + return len; + }, + } +} + +const Encoder = struct { + fn RetType(comptime T: type) type { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => return [1]u8, + 16 => return [2]u8, + 24 => return [3]u8, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return RetType(info.tag_type); + }, + .Struct => |info| { + var len: usize = 0; + inline for (info.fields) |f| len += @typeInfo(RetType(f.type)).Array.len; + return [len]u8; + }, + else => @compileError("don't know how to encode " ++ @tagName(T)), + } + } + fn encode(comptime T: type, value: T) RetType(T) { + return switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => .{value}, + 16 => .{ + @as(u8, @truncate(value >> 8)), + @as(u8, @truncate(value)), + }, + 24 => .{ + @as(u8, @truncate(value >> 16)), + @as(u8, @truncate(value >> 8)), + @as(u8, @truncate(value)), + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| encode(info.tag_type, @intFromEnum(value)), + .Struct => |info| brk: { + const Ret = RetType(T); + + var offset: usize = 0; + var res: Ret = undefined; + inline for (info.fields) |f| { + const encoded = encode(f.type, @field(value, f.name)); + @memcpy(res[offset..][0..encoded.len], &encoded); + offset += encoded.len; + } + + break :brk res; + }, + else => @compileError("cannot encode type " ++ @typeName(T)), + }; + } +}; + +fn nonce_for_len(len: comptime_int, iv: [len]u8, seq: usize) [len]u8 { + if (builtin.zig_backend == .stage2_x86_64 and len > comptime std.simd.suggestVectorLength(u8) orelse 1) { + var res = iv; + const operand = std.mem.readInt(u64, res[res.len - 8 ..], .big); + std.mem.writeInt(u64, res[res.len - 8 ..], operand ^ seq, .big); + return res; + } else { + const V = @Vector(len, u8); + const pad = [1]u8{0} ** (len - 8); + const big = switch (native_endian) { + .big => seq, + .little => @byteSwap(seq), + }; + const operand: V = pad ++ @as([8]u8, @bitCast(big)); + return @as(V, iv) ^ operand; + } +} + +fn HandshakeCipherT(comptime suite: CipherSuite) type { + return struct { + pub const AEAD = suite.Aead(); + pub const Hash = suite.Hash(); pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); @@ -323,22 +1511,51 @@ pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { server_finished_key: [Hmac.key_length]u8, client_handshake_iv: [AEAD.nonce_length]u8, server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, + seq: usize = 0, + + const Self = @This(); + + pub fn nonce(self: Self) [AEAD.nonce_length]u8 { + return nonce_for_len(AEAD.nonce_length, self.server_handshake_iv, self.seq); + } + + fn encrypt( + self: *Self, + data: []const u8, + additional: []const u8, + is_client: bool, + out: []u8, + ) [AEAD.tag_length]u8 { + var res: [AEAD.tag_length]u8 = undefined; + const key = if (is_client) self.client_handshake_key else self.server_handshake_key; + AEAD.encrypt(out, &res, data, additional, self.nonce(), key); + self.seq += 1; + return res; + } + + fn decrypt( + self: *Self, + data: []const u8, + additional: []const u8, + tag: [AEAD.tag_length]u8, + is_client: bool, + out: []u8, + ) !void { + const key = if (is_client) self.server_handshake_key else self.client_handshake_key; + AEAD.decrypt(out, data, tag, additional, self.nonce(), key) catch return error.TlsBadRecordMac; + self.seq += 1; + } + + pub fn print(self: Self) void { + inline for (std.meta.fields(Self)) |f| debugPrint(f.name, @field(self, f.name)); + } }; } -pub const HandshakeCipher = union(enum) { - AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), -}; - -pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { +fn ApplicationCipherT(comptime suite: CipherSuite) type { return struct { - pub const AEAD = AeadType; - pub const Hash = HashType; + pub const AEAD = suite.Aead(); + pub const Hash = suite.Hash(); pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); @@ -348,18 +1565,35 @@ pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type server_key: [AEAD.key_length]u8, client_iv: [AEAD.nonce_length]u8, server_iv: [AEAD.nonce_length]u8, + seq: usize = 0, + + const Self = @This(); + + pub fn client_nonce(self: Self) [AEAD.nonce_length]u8 { + return nonce_for_len(AEAD.nonce_length, self.client_iv, self.seq); + } + + pub fn server_nonce(self: Self) [AEAD.nonce_length]u8 { + return nonce_for_len(AEAD.nonce_length, self.server_iv, self.seq); + } + + fn encrypt( + self: *Self, + data: []const u8, + additional: []const u8, + is_client: bool, + out: []u8, + ) [AEAD.tag_length]u8 { + var res: [AEAD.tag_length]u8 = undefined; + const nonce = if (is_client) self.client_nonce() else self.server_nonce(); + const key = if (is_client) self.client_key else self.server_key; + AEAD.encrypt(out, &res, data, additional, nonce, key); + self.seq += 1; + return res; + } }; } -/// Encryption parameters for application traffic. -pub const ApplicationCipher = union(enum) { - AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), -}; - pub fn hkdfExpandLabel( comptime Hkdf: type, key: [Hkdf.prk_length]u8, @@ -399,163 +1633,404 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } -pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { - return int2(@intFromEnum(et)) ++ array(1, bytes); -} +// Implements `StreamInterface` with a ring buffer +const TestStream = struct { + buffer: Buffer, -pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { - comptime assert(bytes.len % elem_size == 0); - return int2(bytes.len) ++ bytes; -} + const Buffer = std.RingBuffer; + pub const ReadError = Buffer.Error; + pub const WriteError = Buffer.Error; + const Self = @This(); -pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { - assert(@sizeOf(E) == 2); - var result: [tags.len * 2]u8 = undefined; - for (tags, 0..) |elem, i| { - result[i * 2] = @as(u8, @truncate(@intFromEnum(elem) >> 8)); - result[i * 2 + 1] = @as(u8, @truncate(@intFromEnum(elem))); + pub fn init(allocator: std.mem.Allocator) !Self { + return Self{ .buffer = try Buffer.init(allocator, Plaintext.max_length) }; } - return array(2, result); -} -pub inline fn int2(x: u16) [2]u8 { - return .{ - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; -} + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + self.buffer.deinit(allocator); + } -pub inline fn int3(x: u24) [3]u8 { - return .{ - @as(u8, @truncate(x >> 16)), - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; -} + pub fn readAll(self: *Self, buffer: []u8) ReadError!usize { + try self.buffer.readFirst(buffer, buffer.len); + return buffer.len; + } -/// An abstraction to ensure that protocol-parsing code does not perform an -/// out-of-bounds read. -pub const Decoder = struct { - buf: []u8, - /// Points to the next byte in buffer that will be decoded. - idx: usize = 0, - /// Up to this point in `buf` we have already checked that `cap` is greater than it. - our_end: usize = 0, - /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we - /// requested with `readAtLeast`. - their_end: usize = 0, - /// Points to the end within buffer that has been filled. Beyond this point - /// in buf is undefined bytes. - cap: usize = 0, - /// Debug helper to prevent illegal calls to read functions. - disable_reads: bool = false, - - pub fn fromTheirSlice(buf: []u8) Decoder { - return .{ - .buf = buf, - .their_end = buf.len, - .cap = buf.len, - .disable_reads = true, - }; + pub fn writev(self: *Self, iovecs: []const std.os.iovec_const) WriteError!usize { + var res: usize = 0; + for (iovecs) |i| { + const slice = i.iov_base[0..i.iov_len]; + try self.buffer.writeSlice(slice); + res += i.iov_len; + } + return res; } - /// Use this function to increase `their_end`. - pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { - assert(!d.disable_reads); - const existing_amt = d.cap - d.idx; - d.their_end = d.idx + their_amt; - if (their_amt <= existing_amt) return; - const request_amt = their_amt - existing_amt; - const dest = d.buf[d.cap..]; - if (request_amt > dest.len) return error.TlsRecordOverflow; - const actual_amt = try stream.readAtLeast(dest, request_amt); - if (actual_amt < request_amt) return error.TlsConnectionTruncated; - d.cap += actual_amt; + pub fn writevAll(self: *Self, iovecs: []std.os.iovec_const) WriteError!void { + _ = try self.writev(iovecs); } - /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. - /// Use when `our_amt` is calculated by us, not by them. - pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { - assert(!d.disable_reads); - try readAtLeast(d, stream, our_amt); - d.our_end = d.idx + our_amt; + pub fn peek(self: *Self, out: []u8) ReadError!void { + const read_index = self.buffer.read_index; + _ = try self.readAll(out); + self.buffer.read_index = read_index; } +}; - /// Use this function to increase `our_end`. - /// This should always be called with an amount provided by us, not them. - pub fn ensure(d: *Decoder, amt: usize) !void { - d.our_end = @max(d.idx + amt, d.our_end); - if (d.our_end > d.their_end) return error.TlsDecodeError; +/// Default suites used for client and server in descending order of preference. +/// The order is chosen based on what crypto algorithms Zig has available in +/// the standard library and their speed on x86_64-linux. +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast +/// aegis-128l: 15382 MiB/s +/// aegis-256: 9553 MiB/s +/// aes128-gcm: 3721 MiB/s +/// aes256-gcm: 3010 MiB/s +/// chacha20Poly1305: 597 MiB/s +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline +/// aegis-128l: 629 MiB/s +/// chacha20Poly1305: 529 MiB/s +/// aegis-256: 461 MiB/s +/// aes128-gcm: 138 MiB/s +/// aes256-gcm: 120 MiB/s +pub const default_cipher_suites = + if (crypto.core.aes.has_hardware_support) + [_]CipherSuite{ + .aegis_128l_sha256, + .aegis_256_sha512, + .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, } +else + [_]CipherSuite{ + .chacha20_poly1305_sha256, + .aegis_128l_sha256, + .aegis_256_sha512, + .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + }; - /// Use this function to increase `idx`. - pub fn decode(d: *Decoder, comptime T: type) T { - switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => { - skip(d, 1); - return d.buf[d.idx - 1]; - }, - 16 => { - skip(d, 2); - const b0: u16 = d.buf[d.idx - 2]; - const b1: u16 = d.buf[d.idx - 1]; - return (b0 << 8) | b1; - }, - 24 => { - skip(d, 3); - const b0: u24 = d.buf[d.idx - 3]; - const b1: u24 = d.buf[d.idx - 2]; - const b2: u24 = d.buf[d.idx - 1]; - return (b0 << 16) | (b1 << 8) | b2; - }, - else => @compileError("unsupported int type: " ++ @typeName(T)), +const TestHasher = struct { + fn update(self: *@This(), bytes: []const u8) void { + _ = .{ self, bytes }; + } + fn peek(self: @This()) []const u8 { + _ = .{self}; + return ""; + } +}; + +test "tls client and server handshake, data, and close_notify" { + const allocator = std.testing.allocator; + + var inner_stream = try TestStream.init(allocator); + defer inner_stream.deinit(allocator); + + const host = "example.ulfheim.net"; + var client = Client(@TypeOf(inner_stream)){ + .stream = Stream(Plaintext.max_length, TestStream, client_mod.MultiHash){ + .stream = &inner_stream, + .transcript_hash = .{}, + .is_client = true, + }, + .options = .{ .host = host, .ca_bundle = null }, + }; + + const server_cert = [_]u8{ +0x30, 0x82, 0x03, 0x21, 0x30, 0x82, 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, 0x04, 0x8f, 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xc4, 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, 0x91, 0x04, 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, 0x7b, 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, 0xf3, 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, 0xe1, 0xf3, 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, 0xb7, 0x3b, 0x31, 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, 0xe2, 0x0b, 0xf7, 0xf8, 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, 0x39, 0xf1, 0xea, 0x69, 0xec, 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, 0xf4, 0xf0, 0xd5, 0xce, 0x79, 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, 0x4b, 0xb7, 0x11, 0x01, 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, 0x4d, 0x02, 0x9d, 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, 0x94, 0x21, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, 0xe2, 0x52, 0xcf, 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x59, 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, 0x61, 0x8c, 0xe5, 0x98, 0xfa, 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, 0x63, 0xeb, 0xbf, 0xab, 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, 0x78, 0x19, 0x66, 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, 0x99, 0xc7, 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, 0x93, 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, 0xcb, 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, 0xf3, 0x4a, 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, 0x23, 0x7f, 0x11, 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, 0x95, 0x9e, 0xd6, 0x85, 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, 0x29, 0x26, 0xdc, 0xb4, 0x0e, 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, 0xb6, 0x08, 0x46, 0x36, 0xa0 + }; + var server = Server(@TypeOf(inner_stream)){ + .stream = Stream(Plaintext.max_length, TestStream, server_mod.TranscriptHash){ + .stream = &inner_stream, + .transcript_hash = server_mod.TranscriptHash.init(.{}), + .is_client = false, + }, + .options = .{ + // force this to use https://tls13.xargs.org/ as unit test for "server hello" onwards + .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, + .certificate = .{ + .entries = &[_]Certificate.Entry{ + .{ .data = &server_cert }, + } + } + }, + }; + + const session_id = [_]u8{ 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff }; + const client_random = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f }; + const server_random = [_]u8{ 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f }; + const client_x25519_seed = [_]u8{ 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f }; + const server_x25519_seed = [_]u8{ 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf }; + + const key_pairs = try client_mod.KeyPairs.initAdvanced( + client_random, + session_id, + client_x25519_seed ++ client_x25519_seed, + client_x25519_seed, + client_x25519_seed, + ); + { + // To get the same `hello_hash` as https://tls13.xargs.org/ just for + // this test we send a mostly falsified client hello. + // It doesn't matter because the server will be TLS 1.3 and only support .x25519 + const hello = ClientHello{ + .random = key_pairs.hello_rand, + .session_id = &key_pairs.session_id, + .cipher_suites = &[_]CipherSuite{ + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aes_128_gcm_sha256, + .empty_renegotiation_info_scsv, }, - .Enum => |info| { - const int = d.decode(info.tag_type); - if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); + .extensions = &.{ + .{ .server_name = &[_]ServerName{.{ .host_name = client.options.host }} }, + .{ .ec_point_formats = &[_]EcPointFormat{ + .uncompressed, + .ansiX962_compressed_prime, + .ansiX962_compressed_char2, + } }, + .{ .supported_groups = &[_]NamedGroup{ + .x25519, + .secp256r1, + .x448, + .secp521r1, + .secp384r1, + .ffdhe2048, + .ffdhe3072, + .ffdhe4096, + .ffdhe6144, + .ffdhe8192, + } }, + .{ .session_ticket = {} }, + .{ .encrypt_then_mac = {} }, + .{ .extended_master_secret = {} }, + .{ .signature_algorithms = &[_]SignatureScheme{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .ecdsa_secp521r1_sha512, + .ed25519, + .ed448, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + } }, + .{ .supported_versions = &[_]Version{.tls_1_3} }, + .{ .psk_key_exchange_modes = &[_]PskKeyExchangeMode{.ke} }, + .{ .key_share = &[_]KeyShare{ + .{ .x25519 = key_pairs.x25519.public_key }, + } }, }, - else => @compileError("unsupported type: " ++ @typeName(T)), - } - } + }; - /// Use this function to increase `idx`. - pub fn array(d: *Decoder, comptime len: usize) *[len]u8 { - skip(d, len); - return d.buf[d.idx - len ..][0..len]; + client.stream.version = .tls_1_0; + try client.stream.write(ClientHello, hello); + try client.stream.flush(); } - /// Use this function to increase `idx`. - pub fn slice(d: *Decoder, len: usize) []u8 { - skip(d, len); - return d.buf[d.idx - len ..][0..len]; - } + var tmp_buf: [Plaintext.max_length]u8 = undefined; + { + const buf = tmp_buf[0..inner_stream.buffer.len()]; + try inner_stream.peek(buf); - /// Use this function to increase `idx`. - pub fn skip(d: *Decoder, amt: usize) void { - d.idx += amt; - assert(d.idx <= d.our_end); // insufficient ensured bytes + const expected = [_]u8{ + 0x16, // handshake + 0x03, 0x01, // tls 1.0 (lie for compat) + 0x00, 0xf8, // handshake len + 0x01, // client hello + 0x00, 0x00, 0xf4, // client hello len + 0x03, 0x03, // tls 1.2 (lie for compat) + } ++ client_random ++ + [_]u8{session_id.len} ++ session_id ++ + [_]u8{ + 0x00, 0x08, // cipher suite len + 0x13, 0x02, // aes_256_gcm_sha384 + 0x13, 0x03, // chacha20_poly1305_sha256 + 0x13, 0x01, // aes_128_gcm_sha256 + 0x00, 0xff, // empty_renegotiation_info_scsv + 0x01, // compression methods len + 0x00, // none + 0x00, 0xa3, // extensions len + 0x00, 0x00, // server name ext + 0x00, 0x18, // server name len + 0x00, 0x16, // list entry len + 0x00, // dns hostname + } ++ + Encoder.encode(u16, @intCast(host.len)) ++ host ++ + [_]u8{ + 0x00, 0x0b, // ec point formats + 0x00, 0x04, // ext len + 0x03, // format type len + 0x00, // uncompresed + 0x01, // ansiX962_compressed_prime + 0x02, // ansiX962_compressed_char2 + 0x00, 0x0a, // supported groups + 0x00, 0x16, // ext len + 0x00, 0x14, // supported groups len + 0x00, 0x1d, // x25519 + 0x00, 0x17, // secp256r1 + 0x00, 0x1e, // x448 + 0x00, 0x19, // secp521r1 + 0x00, 0x18, // secp384r1 + 0x01, 0x00, // ffdhe2048 + 0x01, 0x01, // ffdhe3072 + 0x01, 0x02, // ffdhe4096 + 0x01, 0x03, // ffdhe6144 + 0x01, 0x04, // ffdhe8192 + 0x00, 0x23, // session ticket + 0x00, 0x00, // ext len + 0x00, 0x16, // encrypt then mac + 0x00, 0x00, // ext len + 0x00, 0x17, // extended master secrets + 0x00, 0x00, // ext len + 0x00, 0x0d, // signature algos + 0x00, 0x1e, // ext len + 0x00, 0x1c, // algos len + 0x04, 0x03, // ecdsa_secp256r1_sha256 + 0x05, 0x03, // ecdsa_secp384r1_sha384 + 0x06, 0x03, // ecdsa_secp521r1_sha512 + 0x08, 0x07, // ed25519 + 0x08, 0x08, // ed448 + 0x08, 0x09, // rsa_pss_pss_sha256 + 0x08, 0x0a, // rsa_pss_pss_sha384 + 0x08, 0x0b, // rsa_pss_pss_sha512 + 0x08, 0x04, // rsa_pss_rsae_sha256 + 0x08, 0x05, // rsa_pss_rsae_sha384 + 0x08, 0x06, // rsa_pss_rsae_sha512 + 0x04, 0x01, // rsa_pkcs1_sha256 + 0x05, 0x01, // rsa_pkcs1_sha384 + 0x06, 0x01, // rsa_pkcs1_sha512 + 0x00, 0x2b, // supported versions + 0x00, 0x03, // ext len + 0x02, // supported versions len + 0x03, 0x04, // tls 1.3 (not lying anymore!) + 0x00, 0x2d, // psk key exchange modes + 0x00, 0x02, // ext len + 0x01, // psk key exchange modes len + 0x01, // PSK with (EC)DHE key establishment + 0x00, 0x33, // key share + 0x00, 0x26, // ext len + 0x00, 0x24, // key shares len + 0x00, 0x1d, // curve 25519 + 0x00, 0x20, // key len + } ++ key_pairs.x25519.public_key; + try std.testing.expectEqualSlices(u8, expected, buf); } - pub fn eof(d: Decoder) bool { - assert(d.our_end <= d.their_end); - assert(d.idx <= d.our_end); - return d.idx == d.their_end; + const client_hello = try server.recv_hello(); + try std.testing.expectEqualSlices(u8, &client_random, &client_hello.random); + try std.testing.expectEqualSlices(u8, &session_id, &client_hello.session_id); + const server_key_pair = server_mod.KeyPair{ + .random = server_random, + .pair = .{ .x25519 = crypto.dh.X25519.KeyPair.create(server_x25519_seed) catch unreachable }, + }; + + try server.send_hello(client_hello, server_key_pair); + { + const buf = tmp_buf[0..inner_stream.buffer.len()]; + try inner_stream.peek(buf); + + const expected = [_]u8{ + 0x16, // handshake + 0x03, 0x03, // tls 1.2 + 0x00, 0x7a, // Handshake len + 0x02, // server hello + 0x00, 0x00, 0x76, // server hello len + 0x03, 0x03, // tls 1.2 + } ++ server_key_pair.random ++ [_]u8{session_id.len} ++ session_id ++ + [_]u8{ + 0x13, 0x02, // aes_256_gcm_sha384 + 0x00, // compression method + 0x00, 0x2e, // extensions len + 0x00, 0x2b, // supported versions + 0x00, 0x02, // ext len + 0x03, 0x04, // tls 1.3 + 0x00, 0x33, // key share + 0x00, 0x24, // ext len + 0x00, 0x1d, // x25519 + 0x00, 0x20, // key len + 0x9f, 0xd7, + 0xad, 0x6d, + 0xcf, 0xf4, + 0x29, 0x8d, + 0xd3, 0xf9, + 0x6d, 0x5b, + 0x1b, 0x2a, + 0xf9, 0x10, + 0xa0, 0x53, + 0x5b, 0x14, + 0x88, 0xd7, + 0xf8, 0xfa, + 0xbb, 0x34, + 0x9a, 0x98, + 0x28, 0x80, + 0xb6, 0x15, // key + } ++ + [_]u8{ + 0x14, // ChangeCipherSpec + 0x03, 0x03, // tls 1.2 + 0x00, 0x01, // Handshake len + 0x01, // .change_cipher_spec + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x00, 0x17, // application data len + 0x6b, 0xe0, 0x2f, 0x9d, 0xa7, 0xc2, // encrypted data (empty EncryptedExtensions message) + 0xdc, // encrypted data type (handshake) + 0x9d, 0xde, 0xf5, 0x6f, 0x24, 0x68, 0xb9, 0x0a, // auth tag + 0xdf, 0xa2, 0x51, 0x01, 0xab, 0x03, 0x44, 0xae, // auth tag + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x03, 0x43, // application data len +0xba, 0xf0, 0x0a, 0x9b, 0xe5, 0x0f, 0x3f, 0x23, 0x07, 0xe7, 0x26, 0xed, 0xcb, 0xda, 0xcb, 0xe4, 0xb1, 0x86, 0x16, 0x44, 0x9d, 0x46, 0xc6, 0x20, 0x7a, 0xf6, 0xe9, 0x95, 0x3e, 0xe5, 0xd2, 0x41, 0x1b, 0xa6, 0x5d, 0x31, 0xfe, 0xaf, 0x4f, 0x78, 0x76, 0x4f, 0x2d, 0x69, 0x39, 0x87, 0x18, 0x6c, 0xc0, 0x13, 0x29, 0xc1, 0x87, 0xa5, 0xe4, 0x60, 0x8e, 0x8d, 0x27, 0xb3, 0x18, 0xe9, 0x8d, 0xd9, 0x47, 0x69, 0xf7, 0x73, 0x9c, 0xe6, 0x76, 0x83, 0x92, 0xca, 0xca, 0x8d, 0xcc, 0x59, 0x7d, 0x77, 0xec, 0x0d, 0x12, 0x72, 0x23, 0x37, 0x85, 0xf6, 0xe6, 0x9d, 0x6f, 0x43, 0xef, 0xfa, 0x8e, 0x79, 0x05, 0xed, 0xfd, 0xc4, 0x03, 0x7e, 0xee, 0x59, 0x33, 0xe9, 0x90, 0xa7, 0x97, 0x2f, 0x20, 0x69, 0x13, 0xa3, 0x1e, 0x8d, 0x04, 0x93, 0x13, 0x66, 0xd3, 0xd8, 0xbc, 0xd6, 0xa4, 0xa4, 0xd6, 0x47, 0xdd, 0x4b, 0xd8, 0x0b, 0x0f, 0xf8, 0x63, 0xce, 0x35, 0x54, 0x83, 0x3d, 0x74, 0x4c, 0xf0, 0xe0, 0xb9, 0xc0, 0x7c, 0xae, 0x72, 0x6d, 0xd2, 0x3f, 0x99, 0x53, 0xdf, 0x1f, 0x1c, 0xe3, 0xac, 0xeb, 0x3b, 0x72, 0x30, 0x87, 0x1e, 0x92, 0x31, 0x0c, 0xfb, 0x2b, 0x09, 0x84, 0x86, 0xf4, 0x35, 0x38, 0xf8, 0xe8, 0x2d, 0x84, 0x04, 0xe5, 0xc6, 0xc2, 0x5f, 0x66, 0xa6, 0x2e, 0xbe, 0x3c, 0x5f, 0x26, 0x23, 0x26, 0x40, 0xe2, 0x0a, 0x76, 0x91, 0x75, 0xef, 0x83, 0x48, 0x3c, 0xd8, 0x1e, 0x6c, 0xb1, 0x6e, 0x78, 0xdf, 0xad, 0x4c, 0x1b, 0x71, 0x4b, 0x04, 0xb4, 0x5f, 0x6a, 0xc8, 0xd1, 0x06, 0x5a, 0xd1, 0x8c, 0x13, 0x45, 0x1c, 0x90, 0x55, 0xc4, 0x7d, 0xa3, 0x00, 0xf9, 0x35, 0x36, 0xea, 0x56, 0xf5, 0x31, 0x98, 0x6d, 0x64, 0x92, 0x77, 0x53, 0x93, 0xc4, 0xcc, 0xb0, 0x95, 0x46, 0x70, 0x92, 0xa0, 0xec, 0x0b, 0x43, 0xed, 0x7a, 0x06, 0x87, 0xcb, 0x47, 0x0c, 0xe3, 0x50, 0x91, 0x7b, 0x0a, 0xc3, 0x0c, 0x6e, 0x5c, 0x24, 0x72, 0x5a, 0x78, 0xc4, 0x5f, 0x9f, 0x5f, 0x29, 0xb6, 0x62, 0x68, 0x67, 0xf6, 0xf7, 0x9c, 0xe0, 0x54, 0x27, 0x35, 0x47, 0xb3, 0x6d, 0xf0, 0x30, 0xbd, 0x24, 0xaf, 0x10, 0xd6, 0x32, 0xdb, 0xa5, 0x4f, 0xc4, 0xe8, 0x90, 0xbd, 0x05, 0x86, 0x92, 0x8c, 0x02, 0x06, 0xca, 0x2e, 0x28, 0xe4, 0x4e, 0x22, 0x7a, 0x2d, 0x50, 0x63, 0x19, 0x59, 0x35, 0xdf, 0x38, 0xda, 0x89, 0x36, 0x09, 0x2e, 0xef, 0x01, 0xe8, 0x4c, 0xad, 0x2e, 0x49, 0xd6, 0x2e, 0x47, 0x0a, 0x6c, 0x77, 0x45, 0xf6, 0x25, 0xec, 0x39, 0xe4, 0xfc, 0x23, 0x32, 0x9c, 0x79, 0xd1, 0x17, 0x28, 0x76, 0x80, 0x7c, 0x36, 0xd7, 0x36, 0xba, 0x42, 0xbb, 0x69, 0xb0, 0x04, 0xff, 0x55, 0xf9, 0x38, 0x50, 0xdc, 0x33, 0xc1, 0xf9, 0x8a, 0xbb, 0x92, 0x85, 0x83, 0x24, 0xc7, 0x6f, 0xf1, 0xeb, 0x08, 0x5d, 0xb3, 0xc1, 0xfc, 0x50, 0xf7, 0x4e, 0xc0, 0x44, 0x42, 0xe6, 0x22, 0x97, 0x3e, 0xa7, 0x07, 0x43, 0x41, 0x87, 0x94, 0xc3, 0x88, 0x14, 0x0b, 0xb4, 0x92, 0xd6, 0x29, 0x4a, 0x05, 0x40, 0xe5, 0xa5, 0x9c, 0xfa, 0xe6, 0x0b, 0xa0, 0xf1, 0x48, 0x99, 0xfc, 0xa7, 0x13, 0x33, 0x31, 0x5e, 0xa0, 0x83, 0xa6, 0x8e, 0x1d, 0x7c, 0x1e, 0x4c, 0xdc, 0x2f, 0x56, 0xbc, 0xd6, 0x11, 0x96, 0x81, 0xa4, 0xad, 0xbc, 0x1b, 0xbf, 0x42, 0xaf, 0xd8, 0x06, 0xc3, 0xcb, 0xd4, 0x2a, 0x07, 0x6f, 0x54, 0x5d, 0xee, 0x4e, 0x11, 0x8d, 0x0b, 0x39, 0x67, 0x54, 0xbe, 0x2b, 0x04, 0x2a, 0x68, 0x5d, 0xd4, 0x72, 0x7e, 0x89, 0xc0, 0x38, 0x6a, 0x94, 0xd3, 0xcd, 0x6e, 0xcb, 0x98, 0x20, 0xe9, 0xd4, 0x9a, 0xfe, 0xed, 0x66, 0xc4, 0x7e, 0x6f, 0xc2, 0x43, 0xea, 0xbe, 0xbb, 0xcb, 0x0b, 0x02, 0x45, 0x38, 0x77, 0xf5, 0xac, 0x5d, 0xbf, 0xbd, 0xf8, 0xdb, 0x10, 0x52, 0xa3, 0xc9, 0x94, 0xb2, 0x24, 0xcd, 0x9a, 0xaa, 0xf5, 0x6b, 0x02, 0x6b, 0xb9, 0xef, 0xa2, 0xe0, 0x13, 0x02, 0xb3, 0x64, 0x01, 0xab, 0x64, 0x94, 0xe7, 0x01, 0x8d, 0x6e, 0x5b, 0x57, 0x3b, 0xd3, 0x8b, 0xce, 0xf0, 0x23, 0xb1, 0xfc, 0x92, 0x94, 0x6b, 0xbc, 0xa0, 0x20, 0x9c, 0xa5, 0xfa, 0x92, 0x6b, 0x49, 0x70, 0xb1, 0x00, 0x91, 0x03, 0x64, 0x5c, 0xb1, 0xfc, 0xfe, 0x55, 0x23, 0x11, 0xff, 0x73, 0x05, 0x58, 0x98, 0x43, 0x70, 0x03, 0x8f, 0xd2, 0xcc, 0xe2, 0xa9, 0x1f, 0xc7, 0x4d, 0x6f, 0x3e, 0x3e, 0xa9, 0xf8, 0x43, 0xee, 0xd3, 0x56, 0xf6, 0xf8, 0x2d, 0x35, 0xd0, 0x3b, 0xc2, 0x4b, 0x81, 0xb5, 0x8c, 0xeb, 0x1a, 0x43, 0xec, 0x94, 0x37, 0xe6, 0xf1, 0xe5, 0x0e, 0xb6, 0xf5, 0x55, 0xe3, 0x21, 0xfd, 0x67, 0xc8, 0x33, 0x2e, 0xb1, 0xb8, 0x32, 0xaa, 0x8d, 0x79, 0x5a, 0x27, 0xd4, 0x79, 0xc6, 0xe2, 0x7d, 0x5a, 0x61, 0x03, 0x46, 0x83, 0x89, 0x19, 0x03, 0xf6, 0x64, 0x21, 0xd0, 0x94, 0xe1, 0xb0, 0x0a, 0x9a, 0x13, 0x8d, 0x86, 0x1e, 0x6f, 0x78, 0xa2, 0x0a, 0xd3, 0xe1, 0x58, 0x00, 0x54, 0xd2, 0xe3, 0x05, 0x25, 0x3c, 0x71, 0x3a, 0x02, 0xfe, 0x1e, 0x28, 0xde, 0xee, 0x73, 0x36, 0x24, 0x6f, 0x6a, 0xe3, 0x43, 0x31, 0x80, 0x6b, 0x46, 0xb4, 0x7b, 0x83, 0x3c, 0x39, 0xb9, 0xd3, 0x1c, 0xd3, 0x00, 0xc2, 0xa6, 0xed, 0x83, 0x13, 0x99, 0x77, 0x6d, 0x07, 0xf5, 0x70, 0xea, 0xf0, 0x05, 0x9a, 0x2c, 0x68, 0xa5, 0xf3, 0xae, 0x16, 0xb6, 0x17, 0x40, 0x4a, 0xf7, 0xb7, 0x23, 0x1a, 0x4d, 0x94, 0x27, 0x58, 0xfc, 0x02, 0x0b, 0x3f, 0x23, 0xee, 0x8c, 0x15, 0xe3, 0x60, 0x44, 0xcf, 0xd6, 0x7c, 0xd6, 0x40, 0x99, 0x3b, 0x16, 0x20, 0x75, 0x97, 0xfb, 0xf3, 0x85, 0xea, 0x7a, 0x4d, 0x99, 0xe8, 0xd4, 0x56, 0xff, 0x83, 0xd4, 0x1f, 0x7b, 0x8b, 0x4f, 0x06, 0x9b, 0x02, 0x8a, 0x2a, 0x63, 0xa9, 0x19, 0xa7, 0x0e, 0x3a, 0x10, 0xe3, 0x08, // encrypted cert + 0x41, // encrypted data type (Certificate) + 0x58, 0xfa, 0xa5, 0xba, 0xfa, 0x30, 0x18, 0x6c, // auth tag + 0x6b, 0x2f, 0x23, 0x8e, 0xb5, 0x30, 0xc7, 0x3e, // auth tag + } + ; + try std.testing.expectEqualSlices(u8, &expected, buf); } - /// Provide the length they claim, and receive a sub-decoder specific to that slice. - /// The parent decoder is advanced to the end. - pub fn sub(d: *Decoder, their_len: usize) !Decoder { - const end = d.idx + their_len; - if (end > d.their_end) return error.TlsDecodeError; - const sub_buf = d.buf[d.idx..end]; - d.idx = end; - d.our_end = end; - return fromTheirSlice(sub_buf); + try client.recv_hello(key_pairs); + + // Test that all 8 shared keys are identical + // If one of these isn't true, check that the earlier transcript_hashes match. + { + const s = server.stream.handshake_cipher.?.aes_256_gcm_sha384; + const c = client.stream.handshake_cipher.?.aes_256_gcm_sha384; + + try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); + try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); + try std.testing.expectEqualSlices(u8, &s.server_handshake_key, &c.server_handshake_key); + try std.testing.expectEqualSlices(u8, &s.client_handshake_key, &c.client_handshake_key); + try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); + try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); + try std.testing.expectEqualSlices(u8, &s.server_handshake_iv, &c.server_handshake_iv); + try std.testing.expectEqualSlices(u8, &s.client_handshake_iv, &c.client_handshake_iv); } +} + +test { + _ = StreamInterface; +} - pub fn rest(d: Decoder) []u8 { - return d.buf[d.idx..d.cap]; +pub fn debugPrint(name: []const u8, slice: anytype) void { + std.debug.print("{s} ", .{name}); + if (@typeInfo(@TypeOf(slice)) == .Int) { + std.debug.print("{d} ", .{slice}); + } else { + for (slice) |c| std.debug.print("{x:0>2} ", .{c}); } -}; + std.debug.print("\n", .{}); +} diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index f07cfe781031..be3d83ee266a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,1354 +1,677 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; -const Client = @This(); const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; -const max_ciphertext_len = tls.max_ciphertext_len; -const hkdfExpandLabel = tls.hkdfExpandLabel; -const int2 = tls.int2; -const int3 = tls.int3; -const array = tls.array; -const enum_array = tls.enum_array; - -read_seq: u64, -write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. -partial_ciphertext_end: u15, -/// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. -received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. -allow_truncation_attacks: bool = false, -application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided, indicating end-of-stream. - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } -}; - -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, - }; -} - -/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which -/// must conform to `StreamInterface`. -/// -/// `host` is only borrowed during this function call. -pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { - const host_len: u16 = @intCast(host.len); - - var random_buffer: [128]u8 = undefined; - crypto.random.bytes(&random_buffer); - const hello_rand = random_buffer[0..32].*; - const legacy_session_id = random_buffer[32..64].*; - const x25519_kp_seed = random_buffer[64..96].*; - const secp256r1_kp_seed = random_buffer[96..128].*; - - const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const kyber768_kp = crypto.kem.kyber_d00.Kyber768.KeyPair.create(null) catch {}; - - const extensions_payload = - tls.extension(.supported_versions, [_]u8{ - 0x02, // byte length of supported versions - 0x03, 0x04, // TLS 1.3 - }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ - .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - .ed25519, - })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ - .x25519_kyber768d00, - .secp256r1, - .x25519, - })) ++ tls.extension( - .key_share, - array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, x25519_kp.public_key) ++ - int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++ - int2(@intFromEnum(tls.NamedGroup.x25519_kyber768d00)) ++ - array(1, x25519_kp.public_key ++ kyber768_kp.public_key.toBytes())), - ) ++ - int2(@intFromEnum(tls.ExtensionType.server_name)) ++ - int2(host_len + 5) ++ // byte length of this extension payload - int2(host_len + 3) ++ // server_name_list byte count - [1]u8{0x00} ++ // name_type - int2(host_len); - - const extensions_header = - int2(@intCast(extensions_payload.len + host_len)) ++ - extensions_payload; - - const legacy_compression_methods = 0x0100; - - const client_hello = - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - hello_rand ++ - [1]u8{32} ++ legacy_session_id ++ - cipher_suites ++ - int2(legacy_compression_methods) ++ - extensions_header; - - const out_handshake = - [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int3(@intCast(client_hello.len + host_len)) ++ - client_hello; - - const plaintext_header = [_]u8{ - @intFromEnum(tls.ContentType.handshake), - 0x03, 0x01, // legacy_record_version - } ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake; - - { - var iovecs = [_]std.os.iovec_const{ - .{ - .iov_base = &plaintext_header, - .iov_len = plaintext_header.len, - }, - .{ - .iov_base = host.ptr, - .iov_len = host.len, - }, - }; - try stream.writevAll(&iovecs); - } - - const client_hello_bytes1 = plaintext_header[5..]; - - var handshake_cipher: tls.HandshakeCipher = undefined; - var handshake_buffer: [8000]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; - { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_record_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - const server_hello_fragment = d.buf[d.idx..][0..record_len]; - var ptd = try d.sub(record_len); - switch (ct) { - .alert => { - try ptd.ensure(2); - const level = ptd.decode(tls.AlertLevel); - const desc = ptd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - try ptd.ensure(4); - const handshake_type = ptd.decode(tls.HandshakeType); - if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; - const length = ptd.decode(u24); - var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); - const legacy_version = hsd.decode(u16); - const random = hsd.array(32); - if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { - // This is a HelloRetryRequest message. This client implementation - // does not expect to get one. - return error.TlsUnexpectedMessage; - } - const legacy_session_id_echo_len = hsd.decode(u8); - if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = hsd.array(32); - if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) - return error.TlsIllegalParameter; - const cipher_suite_tag = hsd.decode(tls.CipherSuite); - hsd.skip(1); // legacy_compression_method - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - var supported_version: u16 = 0; - var shared_key: []const u8 = undefined; - var have_shared_key = false; - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version != 0) return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (have_shared_key) return error.TlsIllegalParameter; - have_shared_key = true; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - switch (named_group) { - .x25519_kyber768d00 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length; - if (key_size != hksl) - return error.TlsIllegalParameter; - const server_ks = extd.array(hksl); - - shared_key = &((crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_ks[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (kyber768_kp.secret_key.decaps( - server_ks[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (key_size != ksl) return error.TlsIllegalParameter; - const server_pub_key = extd.array(ksl); - - shared_key = &(crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_pub_key.*, - ) catch return error.TlsDecryptFailure); - }, - .secp256r1 => { - const server_pub_key = extd.slice(key_size); - - const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; - const pk = PublicKey.fromSec1(server_pub_key) catch { - return error.TlsDecryptFailure; - }; - const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { - return error.TlsDecryptFailure; - }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => {}, - } - } - if (!have_shared_key) return error.TlsIllegalParameter; - - const tls_version = if (supported_version == 0) legacy_version else supported_version; - if (tls_version != @intFromEnum(tls.ProtocolVersion.tls_1_3)) - return error.TlsIllegalParameter; - - switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - => |tag| { - const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ - .handshake_secret = undefined, - .master_secret = undefined, - .client_handshake_key = undefined, - .server_handshake_key = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, - .client_handshake_iv = undefined, - .server_handshake_iv = undefined, - .transcript_hash = P.Hash.init(.{}), - }); - const p = &@field(handshake_cipher, @tagName(tag)); - p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 - p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(server_hello_fragment); - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => return error.TlsUnexpectedMessage, +pub const TranscriptHash = MultiHash; + +/// `StreamType` must conform to `tls.StreamInterface`. +pub fn Client(comptime StreamType: type) type { + return struct { + stream: tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash), + options: Options, + + const Self = @This(); + + /// Initiates a TLS handshake and establishes a TLSv1.3 session + pub fn init(stream: *StreamType, options: Options) !Self { + var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash){ + .stream = stream, + .transcript_hash = .{}, + .is_client = true, + }; + var res = Self{ .stream = stream_, .options = options }; + { + const key_pairs = try KeyPairs.init(); + try res.send_hello(key_pairs); + try res.recv_hello(key_pairs); + } + _ = &stream_; + + return res; + + // var cert_index: usize = 0; + // var read_seq: u64 = 0; + // var prev_cert: Certificate.Parsed = undefined; + // // Set to true once a trust chain has been established from the first + // // certificate to a root CA. + // const HandshakeState = enum { + // /// In this state we expect an encrypted_extensions message. + // encrypted_extensions, + // /// In this state we expect certificate messages. + // certificate, + // /// In this state we expect certificate or certificate_verify messages. + // /// Certificate messages are ignored since the trust chain is already established. + // trust_chain_established, + // /// In this state, we expect only the finished message. + // finished, + // }; + // var handshake_state: HandshakeState = .encrypted_extensions; + // var cleartext_bufs: [2][tls.max_cipertext_inner_record_len]u8 = undefined; + // var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; + // var main_cert_pub_key_buf: [600]u8 = undefined; + // var main_cert_pub_key_len: u16 = undefined; + // const now_sec = std.time.timestamp(); + + // const cleartext_buf = &cleartext_bufs[cert_index % 2]; + // const cleartext = try handshake_cipher.cleartext(record, read_seq, cleartext_buf); + + // const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); + // if (inner_ct != .handshake) return error.TlsUnexpectedMessage; + + // var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); + // while (true) { + // try ctd.ensure(4); + // const handshake_type = ctd.decode(tls.HandshakeType); + // const handshake_len = ctd.decode(u24); + // var hsd = try ctd.sub(handshake_len); + // const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; + // const handshake_buf = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; + // switch (handshake_type) { + // .encrypted_extensions => { + // if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + // handshake_state = .certificate; + // switch (handshake_cipher) { + // inline else => |*p| p.transcript_hash.update(wrapped_handshake), + // } + // try hsd.ensure(2); + // const total_ext_size = hsd.decode(u16); + // var all_extd = try hsd.sub(total_ext_size); + // while (!all_extd.eof()) { + // try all_extd.ensure(4); + // const et = all_extd.decode(tls.ExtensionType); + // const ext_size = all_extd.decode(u16); + // const extd = try all_extd.sub(ext_size); + // _ = extd; + // switch (et) { + // .server_name => {}, + // else => {}, + // } + // } + // }, + // .certificate => cert: { + // switch (handshake_cipher) { + // inline else => |*p| p.transcript_hash.update(wrapped_handshake), + // } + // switch (handshake_state) { + // .certificate => {}, + // .trust_chain_established => break :cert, + // else => return error.TlsUnexpectedMessage, + // } + // try hsd.ensure(1 + 4); + // const cert_req_ctx_len = hsd.decode(u8); + // if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + // const certs_size = hsd.decode(u24); + // var certs_decoder = try hsd.sub(certs_size); + // while (!certs_decoder.eof()) { + // try certs_decoder.ensure(3); + // const cert_size = certs_decoder.decode(u24); + // const certd = try certs_decoder.sub(cert_size); + + // const subject_cert: Certificate = .{ + // .buffer = certd.buf, + // .index = @intCast(certd.idx), + // }; + // const subject = try subject_cert.parse(); + // if (cert_index == 0) { + // // Verify the host on the first certificate. + // try subject.verifyHostName(options.host); + + // // Keep track of the public key for the + // // certificate_verify message later. + // main_cert_pub_key_algo = subject.pub_key_algo; + // const pub_key = subject.pubKey(); + // if (pub_key.len > main_cert_pub_key_buf.len) + // return error.CertificatePublicKeyInvalid; + // @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); + // main_cert_pub_key_len = @intCast(pub_key.len); + // } else { + // try prev_cert.verify(subject, now_sec); + // } + + // if (options.ca_bundle.verify(subject, now_sec)) |_| { + // handshake_state = .trust_chain_established; + // break :cert; + // } else |err| switch (err) { + // error.CertificateIssuerNotFound => {}, + // else => |e| return e, + // } + + // prev_cert = subject; + // cert_index += 1; + + // try certs_decoder.ensure(2); + // const total_ext_size = certs_decoder.decode(u16); + // const all_extd = try certs_decoder.sub(total_ext_size); + // _ = all_extd; + // } + // }, + // .certificate_verify => { + // switch (handshake_state) { + // .trust_chain_established => handshake_state = .finished, + // .certificate => return error.TlsCertificateNotVerified, + // else => return error.TlsUnexpectedMessage, + // } + + // try hsd.ensure(4); + // const scheme = hsd.decode(tls.SignatureScheme); + // const sig_len = hsd.decode(u16); + // try hsd.ensure(sig_len); + // const encoded_sig = hsd.slice(sig_len); + // const max_digest_len = 64; + // var verify_buffer: [64 + 34 + max_digest_len]u8 = + // ([1]u8{0x20} ** 64) ++ + // "TLS 1.3, server CertificateVerify\x00".* ++ + // @as([max_digest_len]u8, undefined); + + // const verify_bytes = switch (handshake_cipher) { + // inline else => |*p| v: { + // const transcript_digest = p.transcript_hash.peek(); + // verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; + // p.transcript_hash.update(wrapped_handshake); + // break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; + // }, + // }; + // const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; + + // switch (scheme) { + // inline .ecdsa_secp256r1_sha256, + // .ecdsa_secp384r1_sha384, + // => |comptime_scheme| { + // if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) + // return error.TlsBadSignatureScheme; + // const Ecdsa = SchemeEcdsa(comptime_scheme); + // const sig = try Ecdsa.Signature.fromDer(encoded_sig); + // const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); + // try sig.verify(verify_bytes, key); + // }, + // inline .rsa_pss_rsae_sha256, + // .rsa_pss_rsae_sha384, + // .rsa_pss_rsae_sha512, + // => |comptime_scheme| { + // if (main_cert_pub_key_algo != .rsaEncryption) + // return error.TlsBadSignatureScheme; + + // const Hash = SchemeHash(comptime_scheme); + // const rsa = Certificate.rsa; + // const components = try rsa.PublicKey.parseDer(main_cert_pub_key); + // const exponent = components.exponent; + // const modulus = components.modulus; + // switch (modulus.len) { + // inline 128, 256, 512 => |modulus_len| { + // const key = try rsa.PublicKey.fromBytes(exponent, modulus); + // const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + // try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); + // }, + // else => { + // return error.TlsBadRsaSignatureBitCount; + // }, + // } + // }, + // inline .ed25519 => |comptime_scheme| { + // if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; + // const Eddsa = SchemeEddsa(comptime_scheme); + // if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; + // const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); + // if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; + // const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); + // try sig.verify(verify_bytes, key); + // }, + // else => { + // return error.TlsBadSignatureScheme; + // }, + // } + // }, + // .finished => { + // if (handshake_state != .finished) return error.TlsUnexpectedMessage; + // // This message is to trick buggy proxies into behaving correctly. + // const client_change_cipher_spec_msg = [_]u8{ + // @intFromEnum(tls.ContentType.change_cipher_spec), + // 0x03, 0x03, // legacy protocol version + // 0x00, 0x01, // length + // 0x01, + // }; + // const app_cipher = switch (handshake_cipher) { + // inline else => |*p, tag| c: { + // const P = @TypeOf(p.*); + // const finished_digest = p.transcript_hash.peek(); + // p.transcript_hash.update(wrapped_handshake); + // const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); + // if (!mem.eql(u8, &expected_server_verify_data, handshake_buf)) + // return error.TlsDecryptError; + // const handshake_hash = p.transcript_hash.finalResult(); + // const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); + // const out_cleartext = [_]u8{ + // @intFromEnum(tls.HandshakeType.finished), + // 0, 0, verify_data.len, // length + // } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; + + // const wrapped_len = out_cleartext.len + P.AEAD.tag_length; + + // var finished_msg = [_]u8{ + // @intFromEnum(tls.ContentType.application_data), + // 0x03, 0x03, // legacy protocol version + // 0, wrapped_len, // byte length of encrypted record + // } ++ @as([wrapped_len]u8, undefined); + + // const ad = finished_msg[0..5]; + // const ciphertext = finished_msg[5..][0..out_cleartext.len]; + // const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; + // const nonce = p.client_handshake_iv; + // P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + + // const both_msgs = client_change_cipher_spec_msg ++ finished_msg; + // var both_msgs_vec = [_]std.os.iovec_const{.{ + // .iov_base = &both_msgs, + // .iov_len = both_msgs.len, + // }}; + // try stream.writevAll(&both_msgs_vec); + + // const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + // const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + // break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ + // .client_secret = client_secret, + // .server_secret = server_secret, + // .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), + // .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), + // .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), + // .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), + // }); + // }, + // }; + // const leftover = decoder.rest(); + // var client: Client = .{ + // .read_seq = 0, + // .write_seq = 0, + // .partial_cleartext_idx = 0, + // .partial_ciphertext_idx = 0, + // .partial_ciphertext_end = @intCast(leftover.len), + // .received_close_notify = false, + // .application_cipher = app_cipher, + // .partially_read_buffer = undefined, + // }; + // @memcpy(client.partially_read_buffer[0..leftover.len], leftover); + // return client; + // }, + // else => { + // return error.TlsUnexpectedMessage; + // }, + // } + // if (ctd.eof()) break; + // } } - } - - // This is used for two purposes: - // * Detect whether a certificate is the first one presented, in which case - // we need to verify the host name. - // * Flip back and forth between the two cleartext buffers in order to keep - // the previous certificate in memory so that it can be verified by the - // next one. - var cert_index: usize = 0; - var read_seq: u64 = 0; - var prev_cert: Certificate.Parsed = undefined; - // Set to true once a trust chain has been established from the first - // certificate to a root CA. - const HandshakeState = enum { - /// In this state we expect only an encrypted_extensions message. - encrypted_extensions, - /// In this state we expect certificate messages. - certificate, - /// In this state we expect certificate or certificate_verify messages. - /// certificate messages are ignored since the trust chain is already - /// established. - trust_chain_established, - /// In this state, we expect only the finished message. - finished, - }; - var handshake_state: HandshakeState = .encrypted_extensions; - var cleartext_bufs: [2][8000]u8 = undefined; - var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [600]u8 = undefined; - var main_cert_pub_key_len: u16 = undefined; - const now_sec = std.time.timestamp(); - - while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..5]; - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - var record_decoder = try d.sub(record_len); - switch (ct) { - .change_cipher_spec => { - try record_decoder.ensure(1); - if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; - }, - .application_data => { - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - - const cleartext = switch (handshake_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ciphertext_len = record_len - P.AEAD.tag_length; - try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); - const ciphertext = record_decoder.slice(ciphertext_len); - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = record_decoder.array(P.AEAD.tag_length).*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_handshake_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(read_seq))); - break :nonce @as(V, p.server_handshake_iv) ^ operand; - }; - read_seq += 1; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch - return error.TlsBadRecordMac; - break :c cleartext; - }, - }; - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - if (inner_ct != .handshake) return error.TlsUnexpectedMessage; - - var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); - while (true) { - try ctd.ensure(4); - const handshake_type = ctd.decode(tls.HandshakeType); - const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); - const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; - const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - try hsd.ensure(2); - const total_ext_size = hsd.decode(u16); - var all_extd = try hsd.sub(total_ext_size); - while (!all_extd.eof()) { - try all_extd.ensure(4); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - const extd = try all_extd.sub(ext_size); - _ = extd; - switch (et) { - .server_name => {}, - else => {}, - } - } - }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - try hsd.ensure(1 + 4); - const cert_req_ctx_len = hsd.decode(u8); - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = hsd.decode(u24); - var certs_decoder = try hsd.sub(certs_size); - while (!certs_decoder.eof()) { - try certs_decoder.ensure(3); - const cert_size = certs_decoder.decode(u24); - const certd = try certs_decoder.sub(cert_size); - - const subject_cert: Certificate = .{ - .buffer = certd.buf, - .index = @intCast(certd.idx), - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - try subject.verifyHostName(host); - - // Keep track of the public key for the - // certificate_verify message later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); - main_cert_pub_key_len = @intCast(pub_key.len); - } else { - try prev_cert.verify(subject, now_sec); - } - - if (ca_bundle.verify(subject, now_sec)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; - - try certs_decoder.ensure(2); - const total_ext_size = certs_decoder.decode(u16); - const all_extd = try certs_decoder.sub(total_ext_size); - _ = all_extd; - } - }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, - } - - try hsd.ensure(4); - const scheme = hsd.decode(tls.SignatureScheme); - const sig_len = hsd.decode(u16); - try hsd.ensure(sig_len); - const encoded_sig = hsd.slice(sig_len); - const max_digest_len = 64; - var verify_buffer: [64 + 34 + max_digest_len]u8 = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); - - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = SchemeHash(comptime_scheme); - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = SchemeEddsa(comptime_scheme); - if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); - if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - else => { - return error.TlsBadSignatureScheme; - }, - } - }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @intFromEnum(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { - const P = @TypeOf(p.*); - const finished_digest = p.transcript_hash.peek(); - p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @intFromEnum(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; - - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - - var finished_msg = [_]u8{ - @intFromEnum(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); - - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - var both_msgs_vec = [_]std.os.iovec_const{.{ - .iov_base = &both_msgs, - .iov_len = both_msgs.len, - }}; - try stream.writevAll(&both_msgs_vec); - - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ - .client_secret = client_secret, - .server_secret = server_secret, - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); - }, - }; - const leftover = d.rest(); - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - if (ctd.eof()) break; - } - }, - else => { - return error.TlsUnexpectedMessage; - }, + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. + pub fn write(self: *Self, bytes: []const u8) !usize { + return self.writeEnd(bytes, false); } - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); -} -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.os.iovec_const = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } - - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].iov_len) { - const encrypted_amt = iovecs_buf[i].iov_len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + pub fn writeAll(self: *Self, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } } - iovecs_buf[i].iov_base += amt; - iovecs_buf[i].iov_len -= amt; - } -} - -fn prepareCiphertextRecord( - c: *Client, - iovecs: []std.os.iovec_const, - ciphertext_buf: []u8, - bytes: []const u8, - inner_content_type: tls.ContentType, -) struct { - iovec_end: usize, - ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, -} { - // Due to the trailing inner content type byte in the ciphertext, we need - // an additional buffer for storing the cleartext into before encrypting. - var cleartext_buf: [max_ciphertext_len]u8 = undefined; - var ciphertext_end: usize = 0; - var iovec_end: usize = 0; - var bytes_i: usize = 0; - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; - while (true) { - const encrypted_content_len: u16 = @intCast(@min( - @min(bytes.len - bytes_i, tls.max_cipertext_inner_record_len), - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), - )); - if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, - .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, - }; - @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); - bytes_i += encrypted_content_len; - const ciphertext_len = encrypted_content_len + 1; - const cleartext = cleartext_buf[0..ciphertext_len]; - - const record_start = ciphertext_end; - const ad = ciphertext_buf[ciphertext_end..][0..5]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); - ciphertext_end += ad.len; - const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; - ciphertext_end += ciphertext_len; - const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += auth_tag.len; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.client_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); - break :nonce @as(V, p.client_iv) ^ operand; - }; - c.write_seq += 1; // TODO send key_update on overflow - P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + /// If `end` is true, then this function additionally sends a `close_notify` alert, + /// which is necessary for the server to distinguish between a properly finished + /// TLS session, or a truncation attack. + pub fn writeAllEnd(self: *Self, bytes: []const u8, end: bool) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.writeEnd(bytes[index..], end); + } + } - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .iov_base = record.ptr, - .iov_len = record.len, + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. + /// If `end` is true, then this function additionally sends a `close_notify` alert, + /// which is necessary for the server to distinguish between a properly finished + /// TLS session, or a truncation attack. + pub fn writeEnd(self: *Self, bytes: []const u8, end: bool) !usize { + try self.stream.writeAll(bytes); + if (end) { + const alert = tls.Alert{ + .level = .fatal, + .description = .close_notify, }; - iovec_end += 1; + try self.stream.write(tls.Alert, alert); + try self.stream.flush(); } - }, - } -} - -pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].iov_len) { - amt -= iovecs[vec_i].iov_len; - vec_i += 1; + return bytes.len; } - iovecs[vec_i].iov_base += amt; - iovecs[vec_i].iov_len -= amt; - } -} -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read, calling the underlying read function the + /// minimal number of times until the buffer has at least `len` bytes filled. + /// If the number read is less than `len` it means the stream reached the end. + /// Reaching the end of the stream is not an error condition. + pub fn readAtLeast(self: *Self, buffer: []u8, len: usize) !usize { + var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; + return self.readvAtLeast(&iovecs, len); } - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + pub fn read(self: *Self, buffer: []u8) !usize { + return self.readAtLeast(buffer, 1); } - } - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer space`. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.os.iovec = .{ - .{ - .iov_base = first_iov.ptr, - .iov_len = first_iov.len, - }, - .{ - .iov_base = &in_stack_buffer, - .iov_len = in_stack_buffer.len, - }, - }; + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of the + /// stream is not an error condition. + pub fn readAll(self: *Self, buffer: []u8) !usize { + return self.readAtLeast(buffer, buffer.len); + } - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read. If the number read is less than the space + /// provided it means the stream reached the end. Reaching the end of the + /// stream is not an error condition. + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial reads from the underlying stream layer. + pub fn readv(self: *Self, iovecs: []std.os.iovec) !usize { + return self.readvAtLeast(iovecs, 1); } - } - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read, calling the underlying read function the + /// minimal number of times until the iovecs have at least `len` bytes filled. + /// If the number read is less than `len` it means the stream reached the end. + /// Reaching the end of the stream is not an error condition. + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial reads from the underlying stream layer. + pub fn readvAtLeast(self: *Self, iovecs: []std.os.iovec, len: usize) !usize { + if (self.eof()) return 0; + + var off_i: usize = 0; + var vec_i: usize = 0; + while (true) { + var amt = try self.readvAdvanced(iovecs[vec_i..]); + off_i += amt; + if (self.eof() or off_i >= len) return off_i; + while (amt >= iovecs[vec_i].iov_len) { + amt -= iovecs[vec_i].iov_len; + vec_i += 1; + } + iovecs[vec_i].iov_base += amt; + iovecs[vec_i].iov_len -= amt; } - frag = frag1; - in = 0; - continue; } - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns number of bytes that have been read, populated inside `iovecs`. A + /// return value of zero bytes does not mean end of stream. Instead, check the `eof()` + /// for the end of stream. The `eof()` may be true after any call to + /// `read`, including when greater than zero bytes are returned, and this + /// function asserts that `eof()` is `false`. + /// See `readv` for a higher level function that has the same, familiar API as + /// other read functions, such as `std.fs.File.read`. + pub fn readvAdvanced(self: *Self, iovecs: []const std.os.iovec) !usize { + _ = .{ self, iovecs }; + return 0; } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; + + pub fn eof(self: *Self) bool { + return self.stream.eof(); } - switch (ct) { - .alert => { - if (in + 2 > frag.len) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(frag[in]); - const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .application_data => { - const cleartext = switch (c.application_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); - break :nonce @as(V, p.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch - return error.TlsBadRecordMac; - break :c cleartext; - }, - }; - c.read_seq = try std.math.add(u64, c.read_seq, 1); - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - switch (inner_ct) { - .alert => { - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; + pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { + const hello = tls.ClientHello{ + .random = key_pairs.hello_rand, + .session_id = &key_pairs.session_id, + .cipher_suites = self.options.cipher_suites, + .extensions = &.{ + .{ .server_name = &[_]tls.ServerName{.{ .host_name = self.options.host }} }, + .{ .ec_point_formats = &[_]tls.EcPointFormat{.uncompressed} }, + .{ .supported_groups = &tls.supported_groups }, + .{ .signature_algorithms = &tls.supported_signature_schemes }, + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{ + .{ .x25519_kyber768d00 = .{ + .x25119 = key_pairs.x25519.public_key, + .kyber768d00 = key_pairs.kyber768d00.public_key, + } }, + .{ .secp256r1 = key_pairs.secp256r1.public_key }, + .{ .x25519 = key_pairs.x25519.public_key }, + } }, + }, + }; + + self.stream.version = .tls_1_0; + try self.stream.write(tls.ClientHello, hello); + try self.stream.flush(); + } - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } + pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { + try self.stream.readFragment(.server_hello); + + // TODO: check spec to see if we should verify this + _ = try self.stream.read(u24); + // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. + _ = try self.stream.read(tls.Version); + const random = try self.stream.readAll(32); + if (mem.eql(u8, random, &tls.ServerHello.hello_retry_request)) return error.TlsUnexpectedMessage; // `ClientHello` failed and we don't know how to rephrase it. + const legacy_session_id = try self.stream.readSmallArray(u8); + if (!mem.eql(u8, legacy_session_id, &key_pairs.session_id)) return error.TlsIllegalParameter; + const cipher_suite = try self.stream.read(tls.CipherSuite); + const compression_method = try self.stream.read(u8); + if (compression_method != 0) return error.TlsIllegalParameter; + + var supported_version: ?tls.Version = null; + var shared_key: ?[]const u8 = null; + + var iter = try self.stream.extensions(); + while (try iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (supported_version != null) return error.TlsIllegalParameter; + supported_version = try self.stream.read(tls.Version); }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - const msg = cleartext[0 .. cleartext.len - 1]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], - msg, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); - } else { - const amt = vp.put(msg); - if (amt < msg.len) { - const rest = msg[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); - } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len - 1); + .key_share => { + if (shared_key != null) return error.TlsIllegalParameter; + const named_group = try self.stream.read(tls.NamedGroup); + const key_size = try self.stream.read(u16); + switch (named_group) { + .x25519_kyber768d00 => { + const xksl = crypto.dh.X25519.public_length; + const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length; + if (key_size != hksl) + return error.TlsIllegalParameter; + const server_ks = try self.stream.readAll(hksl); + + shared_key = &((crypto.dh.X25519.scalarmult( + key_pairs.x25519.secret_key, + server_ks[0..xksl].*, + ) catch return error.TlsDecryptFailure) ++ (key_pairs.kyber768d00.secret_key.decaps( + server_ks[xksl..hksl], + ) catch return error.TlsDecryptFailure)); + }, + .x25519 => { + const ksl = crypto.dh.X25519.public_length; + if (key_size != ksl) return error.TlsIllegalParameter; + const server_pub_key = try self.stream.readAll(ksl); + + shared_key = &(crypto.dh.X25519.scalarmult( + key_pairs.x25519.secret_key, + server_pub_key[0..ksl].*, + ) catch return error.TlsDecryptFailure); + }, + .secp256r1 => { + const server_pub_key = try self.stream.readAll(key_size); + + const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch { + return error.TlsDecryptFailure; + }; + const mul = pk.p.mulPublic(key_pairs.secp256r1.secret_key.bytes, .big) catch { + return error.TlsDecryptFailure; + }; + shared_key = &mul.affineCoordinates().x.toBytes(.big); + }, + else => { + return error.TlsIllegalParameter; + }, } }, else => { - return error.TlsUnexpectedMessage; + _ = try self.stream.readAll(ext.len); }, } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - in = end; - } -} + } -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return out; -} + if (supported_version != tls.Version.tls_1_3) return error.TlsIllegalParameter; + if (shared_key == null) return error.TlsIllegalParameter; + + self.stream.transcript_hash.active = switch (cipher_suite) { + .aes_128_gcm_sha256, .chacha20_poly1305_sha256, .aegis_128l_sha256 => .sha256, + .aes_256_gcm_sha384 => .sha384, + .aegis_256_sha512 => .sha512, + else => return error.TlsIllegalParameter, + }; + const hello_hash = self.stream.transcript_hash.peek(); + self.stream.handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash); + self.stream.content_type = .application_data; + self.stream.handshake_cipher.?.print(); + + try self.stream.readFragment(.encrypted_extensions); + iter = try self.stream.extensions(); + while (try iter.next()) |ext| { + _ = try self.stream.readAll(ext.len); + } -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return out; + // CertificateRequest* + // Certificate* + // CertificateVerify* + // Finished + } + }; } -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); - } -} +pub const Options = struct { + /// Used to verify certificate chain. If null will **dangerously** skip certificate verification. + ca_bundle: ?Certificate.Bundle, + /// Used to verify cerficate chain and for Server Name Indication. + host: []const u8, + /// List of potential cipher suites in order of descending preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, +}; -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; +/// One of these potential hashes will be selected during the handshake as the transcript hash. +/// We init them before sending a single message to avoid having to store the `ClientHello` until +/// receiving `ServerHello`. +/// A nice benefit is decreased latency on hosts where one round trip takes longer than calling +/// `update` on each hashes. +pub const MultiHash = struct { + sha256: sha2.Sha256 = sha2.Sha256.init(.{}), + sha384: sha2.Sha384 = sha2.Sha384.init(.{}), + sha512: sha2.Sha512 = sha2.Sha512.init(.{}), + /// Chosen during handshake. + active: enum { all, sha256, sha384, sha512 } = .all, + + const sha2 = crypto.hash.sha2; + const Self = @This(); + + pub fn update(self: *Self, bytes: []const u8) void { + switch (self.active) { + .all => { + self.sha256.update(bytes); + self.sha384.update(bytes); + self.sha512.update(bytes); + }, + .sha256 => self.sha256.update(bytes), + .sha384 => self.sha384.update(bytes), + .sha512 => self.sha512.update(bytes), + } } -} - -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); -inline fn big(x: anytype) @TypeOf(x) { - return switch (native_endian) { - .big => x, - .little => @byteSwap(x), - }; -} + pub fn peek(self: Self) []const u8 { + return &switch (self.active) { + .all => [_]u8{}, + .sha256 => self.sha256.peek(), + .sha384 => self.sha384.peek(), + .sha512 => self.sha512.peek(), + }; + } +}; -fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, - .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, - else => @compileError("bad scheme"), - }; -} +/// One of these potential key pairs will be selected during the handshake. +pub const KeyPairs = struct { + hello_rand: [hello_rand_length]u8, + session_id: [session_id_length]u8, + kyber768d00: Kyber768, + secp256r1: Secp256r1, + x25519: X25519, + + const Self = @This(); + + const hello_rand_length = 32; + const session_id_length = 32; + const X25519 = tls.NamedGroupT(.x25519).KeyPair; + const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; + const Kyber768 = tls.NamedGroupT(.x25519_kyber768d00).Kyber768.KeyPair; + + pub fn init() Self { + var random_buffer: [ + hello_rand_length + + session_id_length + + Kyber768.seed_length + + Secp256r1.seed_length + + X25519.seed_length + ]u8 = undefined; -fn SchemeHash(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, - else => @compileError("bad scheme"), - }; -} + while (true) { + crypto.random.bytes(&random_buffer); + + const split1 = hello_rand_length; + const split2 = split1 + session_id_length; + const split3 = split2 + Kyber768.seed_length; + const split4 = split3 + Secp256r1.seed_length; + + return initAdvanced( + random_buffer[0..split1].*, + random_buffer[split1..split2].*, + random_buffer[split2..split3].*, + random_buffer[split3..split4].*, + random_buffer[split4..].*, + ) catch continue; + } + } -fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ed25519 => crypto.sign.Ed25519, - else => @compileError("bad scheme"), - }; -} + pub fn initAdvanced( + hello_rand: [hello_rand_length]u8, + session_id: [session_id_length]u8, + kyber_768_seed: [Kyber768.seed_length]u8, + secp256r1_seed: [Secp256r1.seed_length]u8, + x25519_seed: [X25519.seed_length]u8, + ) !Self { + return Self{ + .kyber768d00 = Kyber768.create(kyber_768_seed) catch {}, + .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + .x25519 = X25519.create(x25519_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + .hello_rand = hello_rand, + .session_id = session_id, + }; + } +}; /// Abstraction for sending multiple byte buffers to a slice of iovecs. const VecPut = struct { @@ -1424,45 +747,3 @@ fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { } return iovecs; } - -/// The priority order here is chosen based on what crypto algorithms Zig has -/// available in the standard library as well as what is faster. Following are -/// a few data points on the relative performance of these algorithms. -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -/// aegis-128l: 15382 MiB/s -/// aegis-256: 9553 MiB/s -/// aes128-gcm: 3721 MiB/s -/// aes256-gcm: 3010 MiB/s -/// chacha20Poly1305: 597 MiB/s -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline -/// aegis-128l: 629 MiB/s -/// chacha20Poly1305: 529 MiB/s -/// aegis-256: 461 MiB/s -/// aes128-gcm: 138 MiB/s -/// aes256-gcm: 120 MiB/s -const cipher_suites = if (crypto.core.aes.has_hardware_support) - enum_array(tls.CipherSuite, &.{ - .AEGIS_128L_SHA256, - .AEGIS_256_SHA512, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - }) -else - enum_array(tls.CipherSuite, &.{ - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - .AEGIS_256_SHA512, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - }); - -test { - _ = StreamInterface; -} diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig new file mode 100644 index 000000000000..ceb4b989fb3e --- /dev/null +++ b/lib/std/crypto/tls/Server.zig @@ -0,0 +1,325 @@ +const std = @import("../../std.zig"); +const tls = std.crypto.tls; +const net = std.net; +const mem = std.mem; +const crypto = std.crypto; +const assert = std.debug.assert; +const Certificate = std.crypto.Certificate; + +pub const TranscriptHash = std.crypto.hash.sha2.Sha384; + +/// `StreamType` must conform to `tls.StreamInterface`. +pub fn Server(comptime StreamType: type) type { + return struct { + stream: tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash), + options: Options, + + const Self = @This(); + + /// Initiates a TLS handshake and establishes a TLSv1.3 session + pub fn init(stream: *StreamType, options: Options) !Self { + var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash){ + .stream = stream, + .transcript_hash = TranscriptHash.init(.{}), + .is_client = false, + }; + var res = Self{ .stream = stream_, .options = options }; + const client_hello = try res.recv_hello(&stream_); + _ = client_hello; + // { + // var random_buffer: [32]u8 = undefined; + // crypto.random.bytes(&random_buffer); + // const key_pair = crypto.dh.X25519.KeyPair.create(random_buffer) catch |err| switch (err) { + // error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + // }; + // try res.send_hello(key_pair); + // } + + return res; + } + + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. + pub fn write(self: *Self, bytes: []const u8) !usize { + return self.writeEnd(bytes, false); + } + + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + pub fn writeAll(self: *Self, bytes: []const u8) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.write(bytes[index..]); + } + } + + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + /// If `end` is true, then this function additionally sends a `close_notify` alert, + /// which is necessary for the server to distinguish between a properly finished + /// TLS session, or a truncation attack. + pub fn writeAllEnd(self: *Self, bytes: []const u8, end: bool) !void { + var index: usize = 0; + while (index < bytes.len) { + index += try self.writeEnd(bytes[index..], end); + } + } + + /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. + /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. + /// If `end` is true, then this function additionally sends a `close_notify` alert, + /// which is necessary for the server to distinguish between a properly finished + /// TLS session, or a truncation attack. + pub fn writeEnd(self: *Self, bytes: []const u8, end: bool) !usize { + try self.stream.writeAll(bytes); + if (end) { + const alert = tls.Alert{ + .level = .fatal, + .description = .close_notify, + }; + try self.stream.write(tls.Alert, alert); + try self.stream.flush(); + } + return bytes.len; + } + + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read, calling the underlying read function the + /// minimal number of times until the buffer has at least `len` bytes filled. + /// If the number read is less than `len` it means the stream reached the end. + /// Reaching the end of the stream is not an error condition. + pub fn readAtLeast(self: *Self, buffer: []u8, len: usize) !usize { + var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; + return self.readvAtLeast(&iovecs, len); + } + + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + pub fn read(self: *Self, buffer: []u8) !usize { + return self.readAtLeast(buffer, 1); + } + + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. Reaching the end of the + /// stream is not an error condition. + pub fn readAll(self: *Self, buffer: []u8) !usize { + return self.readAtLeast(buffer, buffer.len); + } + + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read. If the number read is less than the space + /// provided it means the stream reached the end. Reaching the end of the + /// stream is not an error condition. + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial reads from the underlying stream layer. + pub fn readv(self: *Self, iovecs: []std.os.iovec) !usize { + return self.readvAtLeast(iovecs, 1); + } + + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns the number of bytes read, calling the underlying read function the + /// minimal number of times until the iovecs have at least `len` bytes filled. + /// If the number read is less than `len` it means the stream reached the end. + /// Reaching the end of the stream is not an error condition. + /// The `iovecs` parameter is mutable because this function needs to mutate the fields in + /// order to handle partial reads from the underlying stream layer. + pub fn readvAtLeast(self: *Self, iovecs: []std.os.iovec, len: usize) !usize { + if (self.eof()) return 0; + + var off_i: usize = 0; + var vec_i: usize = 0; + while (true) { + var amt = try self.readvAdvanced(iovecs[vec_i..]); + off_i += amt; + if (self.eof() or off_i >= len) return off_i; + while (amt >= iovecs[vec_i].iov_len) { + amt -= iovecs[vec_i].iov_len; + vec_i += 1; + } + iovecs[vec_i].iov_base += amt; + iovecs[vec_i].iov_len -= amt; + } + } + + /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. + /// Returns number of bytes that have been read, populated inside `iovecs`. A + /// return value of zero bytes does not mean end of stream. Instead, check the `eof()` + /// for the end of stream. The `eof()` may be true after any call to + /// `read`, including when greater than zero bytes are returned, and this + /// function asserts that `eof()` is `false`. + /// See `readv` for a higher level function that has the same, familiar API as + /// other read functions, such as `std.fs.File.read`. + pub fn readvAdvanced(self: *Self, iovecs: []const std.os.iovec) !usize { + _ = .{ self, iovecs }; + return 0; + } + + pub fn eof(self: *Self) bool { + return self.stream.eof(); + } + + const ClientHello = struct { + random: [32]u8, + session_id: [32]u8, + cipher_suite: tls.CipherSuite, + key_share: tls.KeyShare, + }; + + pub fn recv_hello(self: *Self) !ClientHello { + try self.stream.readFragment(.client_hello); + + // TODO: verify this + const msg_len = try self.stream.read(u24); + std.debug.print("msg_len {d}\n", .{msg_len}); + // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. + _ = try self.stream.read(tls.Version); + const client_random = try self.stream.readAll(32); + const session_id = try self.stream.readSmallArray(u8); + if (session_id.len != 32) return error.TlsUnexpectedMessage; + + var selected_suite: ?tls.CipherSuite = null; + + var cipher_suite_iter = try self.stream.iterator(tls.CipherSuite); + while (try cipher_suite_iter.next()) |suite| { + if (selected_suite == null) brk: { + for (self.options.cipher_suites) |s| { + if (s == suite) { + selected_suite = s; + break :brk; + } + } + } + } + + if (selected_suite == null) return error.TlsUnexpectedMessage; + + const compression_methods = try self.stream.readAll(2); + if (!std.mem.eql(u8, compression_methods, &[_]u8{ 1, 0 })) return error.TlsUnexpectedMessage; + + var tls_version: ?tls.Version = null; + var key_share: ?tls.KeyShare = null; + var ec_point_format: ?tls.EcPointFormat = null; + + var extension_iter = try self.stream.extensions(); + while (try extension_iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (tls_version != null) return error.TlsUnexpectedMessage; + const versions = try self.stream.readSmallArray(tls.Version); + for (versions) |v| { + std.debug.print("version {}\n", .{v}); + if (v == .tls_1_3) tls_version = v; + } + }, + // TODO: use supported_groups instead + .key_share => { + if (key_share != null) return error.TlsUnexpectedMessage; + + var key_share_iter = try self.stream.iterator(tls.KeyShare.Header); + while (try key_share_iter.next()) |ks| { + const key = try self.stream.readAll(ks.len); + if (ks.group == .x25519) { + key_share = .{ .x25519 = undefined }; + if (ks.len != key_share.?.keyLen(true)) return error.TlsUnexpectedMessage; + @memcpy(&key_share.?.x25519, key); + } + } + }, + .ec_point_formats => { + const formats = try self.stream.readSmallArray(tls.EcPointFormat); + for (formats) |f| { + if (f == .uncompressed) ec_point_format = .uncompressed; + } + }, + else => { + _ = try self.stream.readAll(ext.len); + }, + } + } + + if (tls_version == null) return error.TlsUnexpectedMessage; + if (key_share == null) return error.TlsUnexpectedMessage; + if (ec_point_format == null) return error.TlsUnexpectedMessage; + + return .{ + .random = client_random[0..32].*, + .session_id = session_id[0..32].*, + .cipher_suite = selected_suite.?, + .key_share = key_share.?, + }; + } + + /// `key_pair`'s active member MUST match `client_hello.key_share` + pub fn send_hello(self: *Self, client_hello: ClientHello, key_pair: KeyPair) !void { + const hello = tls.ServerHello{ + .random = key_pair.random, + .session_id = &client_hello.session_id, + .cipher_suite = client_hello.cipher_suite, + .extensions = &.{ + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{key_pair.pair.toKeyShare()} }, + }, + }; + self.stream.version = .tls_1_2; + try self.stream.write(tls.ServerHello, hello); + try self.stream.flush(); + + self.stream.content_type = .change_cipher_spec; + try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); + try self.stream.flush(); + + const shared_key = switch (client_hello.key_share) { + .x25519_kyber768d00 => |ks| brk: { + const T = tls.NamedGroupT(.x25519_kyber768d00); + const pair: tls.X25519Kyber768Draft.KeyPair = key_pair.pair.x25519_kyber768d00; + const shared_point = T.X25519.scalarmult( + ks.x25519, + pair.x25519.secret_key, + ) catch return error.TlsDecryptFailure; + // pair.kyber768d00.secret_key + // ks.kyber768d00 + const encaps = ks.kyber768d00.encaps(null).ciphertext; + + break :brk &(shared_point ++ encaps); + }, + .x25519 => |ks| brk: { + const shared_point = tls.NamedGroupT(.x25519).scalarmult( + key_pair.pair.x25519.secret_key, + ks, + ) catch return error.TlsDecryptFailure; + break :brk &shared_point; + }, + .secp256r1 => |ks| brk: { + const mul = ks.p.mulPublic( + key_pair.pair.secp256r1.secret_key.bytes, + .big, + ) catch + return error.TlsDecryptFailure; + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + else => return error.TlsIllegalParameter, + }; + + const hello_hash = self.stream.transcript_hash.peek(); + self.stream.handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, &hello_hash); + self.stream.handshake_cipher.?.print(); + + const extensions = tls.EncryptedExtensions{ .extensions = &.{} }; + self.stream.content_type = .handshake; + try self.stream.write(tls.EncryptedExtensions, extensions); + try self.stream.flush(); + + try self.stream.write(tls.Certificate, self.options.certificate); + try self.stream.flush(); + } + }; +} + +pub const Options = struct { + /// List of potential cipher suites in order of descending preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + certificate: tls.Certificate, +}; + +pub const KeyPair = struct { + random: [32]u8, + pair: tls.KeyPair, +}; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 339afdb96e91..31a15a0eb5b7 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -19,6 +19,7 @@ const Client = @This(); const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; +const TlsClient = if (disable_tls) void else std.crypto.tls.Client(net.Stream); /// Used for all client allocations. Must be thread-safe. allocator: Allocator, @@ -192,7 +193,7 @@ pub const ConnectionPool = struct { pub const Connection = struct { stream: net.Stream, /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, + tls_client: *TlsClient, /// The protocol that this connection is using. protocol: Protocol, @@ -215,13 +216,13 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; + pub const buffer_size = @sizeOf(std.crypto.tls.Plaintext) + std.crypto.tls.Plaintext.max_length; const BufferSize = std.math.IntFittingRange(0, buffer_size); pub const Protocol = enum { plain, tls }; pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { + return conn.tls_client.readv(buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; @@ -319,7 +320,7 @@ pub const Connection = struct { } pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { + return conn.tls_client.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -387,7 +388,7 @@ pub const Connection = struct { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + _ = conn.tls_client.writeEnd("", true) catch {}; allocator.destroy(conn.tls_client); } @@ -1373,13 +1374,16 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec if (protocol == .tls) { if (disable_tls) unreachable; - conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); + conn.data.tls_client = try client.allocator.create(TlsClient); errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; + conn.data.tls_client.* = TlsClient.init(&conn.data.stream, .{ + .ca_bundle = client.ca_bundle, + .host = host, + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + .allow_truncation_attacks = true, + }) catch return error.TlsInitializationFailed; } client.connection_pool.addUsed(conn); From 496fee054d1fa6fd50c26113d0b1dbb86969d534 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Thu, 7 Mar 2024 19:44:22 -0500 Subject: [PATCH 02/17] write with len, stream write pushback for large writes, partial certs and server finished --- TODO | 7 +- lib/std/crypto/Certificate.zig | 1 + lib/std/crypto/testdata/server.der | Bin 0 -> 805 bytes lib/std/crypto/tls.zig | 833 +++++++++++++++++------------ lib/std/crypto/tls/Client.zig | 11 +- lib/std/crypto/tls/Server.zig | 86 ++- 6 files changed, 565 insertions(+), 373 deletions(-) create mode 100644 lib/std/crypto/testdata/server.der diff --git a/TODO b/TODO index a5b21b7cfc26..7ebc5937c919 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,4 @@ -Refactoring TLS client to handle fragments + send messages with structs -Adding TLS server for testing -Using https://tls13.xargs.org/ for unit tests +store multiple fragments in buffer for less syscalls +remove @panic's +server dynamic transcript hash based on client cipher suites +send alert on error diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index c3ac3e22aa47..0d932fcafb98 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -385,6 +385,7 @@ test "Parsed.checkHostName" { pub const ParseError = der.Element.ParseElementError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError; +/// Parse a DER format certificate. pub fn parse(cert: Certificate) ParseError!Parsed { const cert_bytes = cert.buffer; const certificate = try der.Element.parse(cert_bytes, cert.index); diff --git a/lib/std/crypto/testdata/server.der b/lib/std/crypto/testdata/server.der new file mode 100644 index 0000000000000000000000000000000000000000..edce2b88201a81edfc084f4686e49c17cc4aaed2 GIT binary patch literal 805 zcmXqLVpcS0V&YuD%*4pV#333rY3(7F{s{)WY@Awc9&O)w85y}*84Q#RxeYkkm_u3E zgqcEv4TTK^K^!h&F4v00+=84`1!qSCIdNV?3qu0~Qv*X|3q$iLAlDL!YoLv4hKzv} z$P8v-;Z%qjdZjsO8L64MdU>fO22G4g$ZlX{WngY%G+&!CP z3d@`2+a|Kuf1k8R^4;&+lItsY8RV+Ht-PM`zc^6%r}#OO5AU8YdoRjbX74I+^AdjP z^~3h_L(a&M>XV_r9d#F%1qnX@m1y<_vw1g{98{=y*G8?GZv+*@Ad^%hAe!(zpL$=TQ2W> zv8hr%-xs(|ii{Bs`V=xR>bKwi>s97c8C{+nP;2)|wR5PgRc3c#{L=FOLVuO~f*^s5 zUp`ztS6QHG^yt*S0{Q!Ev%R+qGJf%1o0cUg%>U?J(>!0MxhG}rri7O{yr>8e!YM7jwcsa2WW-7VMuwY zUm=;sYp;G@VP0DOxu{1^*33M9(n8|Lra*S3`){Xz*0Y=JdF5Enyk(ZTB`!$|vYn1y z-IdX|#)`4}?w>xLKWo)bXI?&YXHkQrOX}ryCOr>6du{qXC$YKEP+L;KF-1nXUU2EP zm==E)mX?_oh1yf+U2EMF6MN0|t>lC$#SWTkcee1EWG;+!eDUba3h!+kZe|MrE_6KX literal 0 HcmV?d00001 diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 0d42fc3fa954..c312993cf2e4 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -43,42 +43,103 @@ pub const Plaintext = struct { pub const max_length = 1 << 14; }; -pub const Handshake = struct { - type: HandshakeType, - length: u24, - // `length` bytes follow -}; - -fn fieldsLen(comptime T: type) comptime_int { - var res: comptime_int = 0; - inline for (std.meta.fields(T)) |f| res += @sizeOf(f.type); - return res; -} - pub const HandshakeType = enum(u8) { - // hello_request = 0, + /// Deprecated. + hello_request = 0, client_hello = 1, server_hello = 2, - // hello_verify_request = 3, + /// Deprecated. + hello_verify_request = 3, new_session_ticket = 4, end_of_early_data = 5, - // hello_retry_request = 3, + /// Deprecated. + hello_retry_request = 6, encrypted_extensions = 8, certificate = 11, - // server_key_exchange = 12, + /// Deprecated. + server_key_exchange = 12, certificate_request = 13, - // server_hello_done = 14, + /// Deprecated. + server_hello_done = 14, certificate_verify = 15, - // client_key_exchange = 16, + /// Deprecated. + client_key_exchange = 16, finished = 20, - // certificate_url = 21, - // certificate_status = 22, - // supplemental_data = 23, + /// Deprecated. + certificate_url = 21, + /// Deprecated. + certificate_status = 22, + /// Deprecated. + supplemental_data = 23, key_update = 24, message_hash = 254, _, }; +pub const Handshake = union(HandshakeType) { + hello_request: void, + client_hello: ClientHello, + server_hello: ServerHello, + /// Deprecated. + hello_verify_request: void, + new_session_ticket: NewSessionTicket, + end_of_early_data: void, + /// Deprecated. + hello_retry_request: void, + encrypted_extensions: []const Extension, + certificate: Certificate, + /// Deprecated. + server_key_exchange: void, + certificate_request: CertificateRequest, + /// Deprecated. + server_hello_done: void, + certificate_verify: CertificateVerify, + /// Deprecated. + client_key_exchange: void, + finished: []const u8, + /// Deprecated. + certificate_url: void, + /// Deprecated. + certificate_status: void, + /// Deprecated. + supplemental_data: void, + key_update: KeyUpdate, + message_hash: void, + + // If `HandshakeCipherT.encode` accepts iovecs for the message this can be moved + // to `Stream.writeFragment` and this type can be deleted. + pub fn write(self: @This(), stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(HandshakeType, self); + switch (self) { + .finished => |verification| { + res += try stream.writeArray(3, u8, verification); + }, + inline else => |value| { + var len: usize = 0; + const T = @TypeOf(value); + switch (@typeInfo(T)) { + .Void => { + res += try stream.write(u24, @intCast(len)); + }, + .Pointer => |info| { + len += try stream.arrayLength(2, info.child, value); + res += try stream.write(u24, @intCast(len)); + res += try stream.writeArray(2, info.child, value); + }, + .Struct => { + len += try stream.length(T, value); + res += try stream.write(u24, @intCast(len)); + res += try stream.write(T, value); + }, + else => |t| @compileError("implement writing " ++ @tagName(t)), + } + } + } + return res; + } +}; + pub const NewSessionTicket = struct { ticket_lifetime: u32, ticket_age_add: u32, @@ -87,6 +148,11 @@ pub const NewSessionTicket = struct { /// Should have at least one ticket: []const u8, extensions: []const Extension, + + pub fn write(self: @This(), stream: anytype) !usize { + _ = .{ self, stream }; + @panic("TODO"); + } }; pub const CertificateRequest = struct { @@ -94,6 +160,11 @@ pub const CertificateRequest = struct { context: []const u8, /// At least 2 extensions: []const Extension, + + pub fn write(self: @This(), stream: anytype) !usize { + _ = .{ self, stream }; + @panic("TODO"); + } }; pub const Certificate = struct { @@ -106,47 +177,35 @@ pub const Certificate = struct { data: []const u8, extensions: []const Extension = &.{}, - pub fn len(self: @This(), is_client: bool) usize { - return 3 - + self.data.len - + @sizeOf(u16) + slice_len(Extension, self.extensions, is_client); - } - - pub fn write(self: @This(), stream: anytype) !void { - try stream.write(u24, @intCast(self.data.len)); - try stream.writeAll(self.data); - try stream.writeArray(2, Extension, self.extensions); + pub fn write(self: @This(), stream: anytype) !usize { + var res: usize = 0; + res += try stream.writeArray(3, u8, self.data); + res += try stream.writeArray(2, Extension, self.extensions); + return res; } }; const Self = @This(); - pub fn write(self: Self, stream: anytype) !void { - std.debug.assert(!stream.is_client); - try stream.write(HandshakeType, .certificate); - - const entries_len = slice_len(Entry, self.entries, stream.is_client); - std.debug.print("entries_len {d}\n", .{ entries_len }); - - const length: usize = @sizeOf(u8) + 3 + entries_len; - try stream.write(u24, @intCast(length)); - - // TODO: handle this being sent in response to certificate request - std.debug.assert(self.context.len == 0); - try stream.write(u8, 0); - - try stream.write(u24, @intCast(entries_len)); - for (self.entries) |e| try stream.write(Entry, e); + pub fn write(self: Self, stream: anytype) !usize { + var res: usize = 0; + res += try stream.writeArray(1, u8, self.context); + res += try stream.writeArray(3, Entry, self.entries); + return res; } }; pub const CertificateVerify = struct { algorithm: SignatureScheme, + /// Max len 2^16 - 1 signature: []const u8, -}; -pub const Finished = struct { - verify_data: []const u8, + pub fn write(self: @This(), stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(SignatureScheme, self.algorithm); + res += try stream.writeArray(2, u8, self.signature); + return res; + } }; pub const KeyUpdate = struct { @@ -157,6 +216,11 @@ pub const KeyUpdate = struct { update_requested = 1, _, }; + + pub fn write(self: @This(), stream: anytype) !usize { + _ = .{ self, stream }; + @panic("TODO"); + } }; // https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml @@ -218,7 +282,7 @@ pub const ExtensionType = enum(u16) { _, }; -pub const Alert = packed struct { +pub const Alert = struct { /// > In TLS 1.3, the severity is implicit in the type of alert being sent /// > and the "level" field can safely be ignored. level: Level, @@ -322,11 +386,19 @@ pub const Alert = packed struct { } }; - pub const len = @sizeOf(Level) + @sizeOf(Description); + const Self = @This(); + + pub fn read(stream: anytype) Self { + const level = try stream.read(Level); + const description = try stream.read(Description); + return .{ .level = level, .description = description }; + } - pub fn write(self: @This(), stream: anytype) !void { - try stream.write(Level, self.level); - try stream.write(Description, self.description); + pub fn write(self: Self, stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(Level, self.level); + res += try stream.write(Description, self.description); + return res; } }; @@ -400,6 +472,7 @@ pub const supported_signature_schemes = [_]SignatureScheme{ /// Key exchange formats pub const NamedGroup = enum(u16) { + invalid = 0x0000, // Elliptic Curve Groups (ECDHE) secp256r1 = 0x0017, secp384r1 = 0x0018, @@ -452,6 +525,7 @@ pub const X25519Kyber768Draft = struct { }; }; pub const KeyPair = union(NamedGroup) { + invalid: void, secp256r1: NamedGroupT(.secp256r1).KeyPair, secp384r1: NamedGroupT(.secp384r1).KeyPair, secp521r1: void, @@ -481,6 +555,7 @@ pub const KeyPair = union(NamedGroup) { }; /// The public key portion of a KeyPair. Saves bytes. pub const KeyShare = union(NamedGroup) { + invalid: void, secp256r1: NamedGroupT(.secp256r1).PublicKey, secp384r1: NamedGroupT(.secp384r1).PublicKey, secp521r1: void, @@ -495,12 +570,33 @@ pub const KeyShare = union(NamedGroup) { x25519_kyber768d00: NamedGroupT(.x25519_kyber768d00).PublicKey, - pub const max_len = NamedGroupT(.x25519_kyber768d00).PublicKey.bytes_length; - const Self = @This(); - pub fn write(self: Self, stream: anytype) !void { - try stream.write(NamedGroup, self); + pub fn read(stream: anytype) !Self { + const group = try stream.read(NamedGroup); + const len = try stream.read(u16); + const key = try stream.readAll(len); + switch (group) { + // .x25519_kyber768d00 => { + // const expected_len = if (stream.is_client) @TypeOf(k).bytes_length else X25519Kyber768Draft.Kyber768.ciphertext_length; + // }, + inline .secp256r1, .secp384r1 => |k| { + return @unionInit(Self, @tagName(k), try NamedGroupT(k).PublicKey.fromSec1(key)); + }, + .x25519 => { + var res = Self{ .x25519 = undefined }; + if (res.x25519.len != key.len) return error.TlsDecodeError; + @memcpy(&res.x25519, key); + return res; + }, + else => {}, + } + return .{ .invalid = {} }; + } + + pub fn write(self: Self, stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(NamedGroup, self); const public = switch (self) { .x25519_kyber768d00 => |k| if (stream.is_client) &k.toBytes() else &k.ciphertext(), .secp256r1 => |k| &k.toUncompressedSec1(), @@ -508,33 +604,9 @@ pub const KeyShare = union(NamedGroup) { .x25519 => |k| &k, else => "", }; - try stream.writeArray(2, u8, public); - } - - pub fn keyLen(self: Self, is_client: bool) usize { - return switch (self) { - .x25519_kyber768d00 => |k| if (is_client) @TypeOf(k).bytes_length else X25519Kyber768Draft.Kyber768.ciphertext_length, - .secp256r1 => |k| @TypeOf(k).uncompressed_sec1_encoded_length, - .secp384r1 => |k| @TypeOf(k).uncompressed_sec1_encoded_length, - .x25519 => |k| k.len, - else => 0, - }; - } - - pub fn len(self: Self, is_client: bool) usize { - return @sizeOf(NamedGroup) + @sizeOf(u16) + self.keyLen(is_client); + res += try stream.writeArray(2, u8, public); + return res; } - - pub const Header = struct { - group: NamedGroup, - len: u16, - - pub fn read(stream: anytype) !@This() { - const group = try stream.read(NamedGroup); - const length = try stream.read(u16); - return .{ .group = group, .len = length }; - } - }; }; /// In descending order of preference pub const supported_groups = [_]NamedGroup{ @@ -714,7 +786,7 @@ pub const ClientHello = struct { /// Legacy field for TLS 1.2 middleboxes version: Version = .tls_1_2, random: [32]u8, - /// Legacy session resumption + /// Legacy session resumption. Max len 32. session_id: []const u8, /// In descending order of preference cipher_suites: []const CipherSuite, @@ -723,26 +795,19 @@ pub const ClientHello = struct { // Certain extensions are mandatory for TLS 1.3 extensions: []const Extension, + pub const session_id_max_len = 32; + const Self = @This(); - pub fn write(self: Self, stream: anytype) !void { - try stream.write(HandshakeType, .client_hello); - - var length = @sizeOf(Version) + - self.random.len + - @sizeOf(u8) + self.session_id.len + - @sizeOf(u16) + self.cipher_suites.len * @sizeOf(CipherSuite) + - @sizeOf(u8) + self.compression_methods.len; - length += @sizeOf(u16); - for (self.extensions) |e| length += e.len(true); - try stream.write(u24, @intCast(length)); - - try stream.write(Version, self.version); - try stream.writeAll(&self.random); - try stream.writeArray(1, u8, self.session_id); - try stream.writeArray(2, CipherSuite, self.cipher_suites); - try stream.writeArray(1, u8, &self.compression_methods); - try stream.writeArray(2, Extension, self.extensions); + pub fn write(self: Self, stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(Version, self.version); + res += try stream.writeAll(&self.random); + res += try stream.writeArray(1, u8, self.session_id); + res += try stream.writeArray(2, CipherSuite, self.cipher_suites); + res += try stream.writeArray(1, u8, &self.compression_methods); + res += try stream.writeArray(2, Extension, self.extensions); + return res; } }; @@ -766,28 +831,16 @@ pub const ServerHello = struct { const Self = @This(); - pub fn len(self: Self, _: bool) usize { - var res = @sizeOf(Version) + - self.random.len + - @sizeOf(u8) + self.session_id.len + - @sizeOf(CipherSuite) + - @sizeOf(u8); - res += @sizeOf(u16); - for (self.extensions) |e| res += e.len(false); + pub fn write(self: Self, stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(Version, self.version); + res += try stream.writeAll(&self.random); + res += try stream.writeArray(1, u8, self.session_id); + res += try stream.write(CipherSuite, self.cipher_suite); + res += try stream.write(u8, self.compression_method); + res += try stream.writeArray(2, Extension, self.extensions); return res; } - - pub fn write(self: Self, stream: anytype) !void { - try stream.write(HandshakeType, .server_hello); - const length = self.len(false); - try stream.write(u24, @intCast(length)); - try stream.write(Version, self.version); - try stream.writeAll(&self.random); - try stream.writeArray(1, u8, self.session_id); - try stream.write(CipherSuite, self.cipher_suite); - try stream.write(u8, self.compression_method); - try stream.writeArray(2, Extension, self.extensions); - } }; pub const EncryptedExtensions = struct { @@ -795,17 +848,8 @@ pub const EncryptedExtensions = struct { const Self = @This(); - pub fn len(self: Self, _: bool) usize { - var res: usize = @sizeOf(HandshakeType) + @sizeOf(u24); - for (self.extensions) |e| res += e.len(false); - return res; - } - - pub fn write(self: Self, stream: anytype) !void { - try stream.write(HandshakeType, .encrypted_extensions); - const ext_len = @sizeOf(u16) + slice_len(Extension, self.extensions, false); - try stream.write(u24, @intCast(ext_len)); - try stream.writeArray(2, Extension, self.extensions); + pub fn write(self: Self, stream: anytype) !usize { + return try stream.writeArray(2, Extension, self.extensions); } }; @@ -842,64 +886,37 @@ pub const Extension = union(ExtensionType) { const Self = @This(); - pub fn len(self: Self, is_client: bool) usize { - var res: usize = @sizeOf(ExtensionType) + @sizeOf(u16); - switch (self) { - inline else => |items| { - const T = @TypeOf(items); - if (T == void) return res; - if (is_client) { - res += switch (self) { - .supported_versions, .ec_point_formats, .psk_key_exchange_modes => 1, - .server_name, .supported_groups, .signature_algorithms, .key_share => 2, - else => 0, - }; - } - res += slice_len(@typeInfo(T).Pointer.child, items, is_client); - }, - } - return res; - } - - pub fn write(self: Self, stream: anytype) !void { - const prefix_len = stream.is_client; - switch (self) { - inline else => |items, tag| { - const T = @TypeOf(items); - try stream.write(ExtensionType, tag); - const length = self.len(prefix_len) - @sizeOf(ExtensionType) - @sizeOf(u16); - try stream.write(u16, @intCast(length)); - if (T == void) return; - }, - } - switch (self) { - inline .supported_versions, + pub fn write(self: Self, stream: anytype) !usize { + const prefix_len: u8 = if (stream.is_client) switch (self) { + .supported_versions, .ec_point_formats, - .psk_key_exchange_modes, - => |items| { - const T = @typeInfo(@TypeOf(items)).Pointer.child; - if (prefix_len) { - try stream.writeArray(1, T, items); - } else { - try stream.writeArray(0, T, items); - } - }, - inline .server_name, + .psk_key_exchange_modes => 1, + .server_name, .supported_groups, .signature_algorithms, - .key_share, - => |items| { - const T = @typeInfo(@TypeOf(items)).Pointer.child; - if (prefix_len) { - try stream.writeArray(2, T, items); - } else { - try stream.writeArray(0, T, items); + .key_share => 2, + else => 0, + } else 0; + + var res: usize = 0; + res += try stream.write(ExtensionType, self); + + switch (self) { + inline else => |items| { + switch (@typeInfo(@TypeOf(items))) { + .Void => { + res += try stream.write(u16, 0); + }, + .Pointer => |info| { + const len = try stream.arrayLength(prefix_len, info.child, items); + res += try stream.write(u16, @intCast(len)); + res += try stream.writeArray(prefix_len, info.child, items); + }, + else => |t| @compileError("implement writing " ++ @tagName(t)), } }, - inline else => |_, tag| { - @panic("unsupported extension " ++ @tagName(tag)); - }, } + return res; } pub const Header = struct { @@ -926,12 +943,6 @@ pub const PskKeyExchangeMode = enum(u8) { _, }; -pub const UncompressedPointRepresentation = struct { - form: u8 = 4, // uncompressed - x: []const u8, - y: []const u8, -}; - /// RFC 8446 S4.1.3 pub const ServerName = struct { type: NameType = .host_name, @@ -939,13 +950,11 @@ pub const ServerName = struct { pub const NameType = enum(u8) { host_name = 0, _ }; - pub fn len(self: @This(), _: bool) usize { - return @sizeOf(NameType) + 2 + self.host_name.len; - } - - pub fn write(self: @This(), stream: anytype) !void { - try stream.write(NameType, self.type); - try stream.writeArray(2, u8, self.host_name); + pub fn write(self: @This(), stream: anytype) !usize { + var res: usize = 0; + res += try stream.write(NameType, self.type); + res += try stream.writeArray(2, u8, self.host_name); + return res; } }; @@ -996,13 +1005,13 @@ pub const StreamInterface = struct { } }; -/// Abstraction over TLS record layer that fragments all messages. +/// Abstraction over TLS record layer that handles fragmentation (RFC 8446 S5). /// It also encrypts and decrypts .application_data messages. /// This makes it suitable for both clients and servers. /// /// StreamType MUST satisfy `StreamInterface`. /// StreamType MUST satisfy: -/// * fn hash(self: @This(), bytes: []const u8) void +/// * fn update(self: @This(), bytes: []const u8) void /// * fn peek(self: @This()) [_]u8 /// Cannot read and write at the same time. pub fn Stream( @@ -1035,7 +1044,7 @@ pub fn Stream( content_type: ContentType = .handshake, /// When receiving fragments this is the next expected fragment type. - handshake_type: ?HandshakeType = null, + handshake_type: ?HandshakeType = .client_hello, /// Used to encrypt and decrypt .application_data messages until application_cipher is not null. handshake_cipher: ?HandshakeCipher = null, @@ -1052,6 +1061,12 @@ pub fn Stream( /// based on this. is_client: bool, + /// When > 0 won't actually do anything with writes. + /// This is to discover prefix lengths for sequential writing. + /// It would be nice to write in reverse sequence, + /// but the spec defines fragments as being sent in forward sequence. + nocommit: usize = 0, + const Self = @This(); pub const ReadError = StreamType.ReadError || error{ @@ -1101,8 +1116,6 @@ pub fn Stream( WeakPublicKey, }; pub const WriteError = StreamType.WriteError; - // pub const Reader = std.io.Reader(*Self, Error, read); - // pub const Writer = std.io.Writer(*Self, WriteError, write); fn ciphertextOverhead(self: Self) usize { if (self.application_cipher) |a| { @@ -1133,6 +1146,13 @@ pub fn Stream( }; const header = Encoder.encode(Plaintext, plaintext); + if (self.application_cipher == null) { + switch (plaintext.type) { + .change_cipher_spec, .alert => {}, + else => self.transcript_hash.update(self.view), + } + } + var aead: []const u8 = ""; if (self.application_cipher) |*a| { switch (a.*) { @@ -1154,11 +1174,6 @@ pub fn Stream( aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); }, } - } else { - switch (plaintext.type) { - .change_cipher_spec, .alert => {}, - else => self.transcript_hash.update(self.view), - } } var iovecs = [_]std.os.iovec_const{ @@ -1170,46 +1185,73 @@ pub fn Stream( self.view = self.buffer[0..0]; } - pub fn writeAll(self: *Self, bytes: []const u8) WriteError!void { - if (bytes.len > self.maxFragmentSize()) { - @panic("cannot flush bytes because have to append 1 byte of record type if encrypted"); - // try self.flush(); - // self.view = bytes; - // try self.flush(); - } else { - if (self.view.len + bytes.len >= self.maxFragmentSize()) { - // TODO: optimization copy bytes before flushing. - try self.flush(); + /// Write bytes with backpressure for fragment size. All other write functions end up here. + pub fn writeBytes(self: *Self, bytes: []const u8) WriteError!usize { + if (self.nocommit > 0) return bytes.len; + + if (self.view.len + bytes.len >= self.maxFragmentSize()) { + // TODO: copy before flush to consume as many bytes as possible + try self.flush(); + } + + const available = self.buffer.len - self.view.len; + const to_consume = bytes[0..@min(available, bytes.len)]; + + @memcpy(self.buffer[self.view.len..][0..bytes.len], to_consume); + self.view = self.buffer[0 .. self.view.len + to_consume.len]; + + return to_consume.len; + } + + pub fn writeAll(self: *Self, bytes: []const u8) WriteError!usize { + var index: usize = 0; + while (index != bytes.len) { + index += try self.writeBytes(bytes[index..]); + } + return index; + } + + pub fn writeArray(self: *Self, prefix_bytes: u8, comptime T: type, values: []const T) WriteError!usize { + var res: usize = 0; + for (values) |v| res += try self.length(T, v); + + if (prefix_bytes != 0) { + switch (prefix_bytes) { + 1 => res += try self.write(u8, @intCast(res)), + 2 => res += try self.write(u16, @intCast(res)), + 3 => res += try self.write(u24, @intCast(res)), + else => @panic("unsupported prefix len"), } - @memcpy(self.buffer[self.view.len..][0..bytes.len], bytes); - self.view = self.buffer[0 .. self.view.len + bytes.len]; } + for (values) |v| _ = try self.write(T, v); + + return res; } - pub fn write(self: *Self, comptime T: type, value: T) WriteError!void { + pub fn write(self: *Self, comptime T: type, value: T) WriteError!usize { switch (@typeInfo(T)) { - .Int, .Enum => try self.writeAll(&Encoder.encode(T, value)), - else => { - if (@hasDecl(T, "write")) { - try value.write(self); - } else { - @compileError("expected fn write(stream: anytype) on type " ++ @typeName(T)); - } + .Int, .Enum => { + const encoded = Encoder.encode(T, value); + return try self.writeAll(&encoded); }, + .Struct, .Union => { + return try T.write(value, self); + }, + .Void => {}, + else => @compileError("cannot write " ++ @typeName(T)), } } - pub fn writeArray(self: *Self, len_bytes: comptime_int, comptime T: type, values: []const T) !void { - if (len_bytes != 0) { - const len = slice_len(T, values, self.is_client); - switch (len_bytes) { - 1 => try self.write(u8, @intCast(len)), - 2 => try self.write(u16, @intCast(len)), - 3 => try self.write(u24, @intCast(len)), - else => @compileError("unsupported prefix len"), - } - } - for (values) |v| try self.write(T, v); + pub fn length(self: *Self, comptime T: type, value: T) WriteError!usize { + self.nocommit += 1; + defer self.nocommit -= 1; + return try self.write(T, value); + } + + pub fn arrayLength(self: *Self, prefix_len: u8, comptime T: type, values: []const T) WriteError!usize { + var res: usize = prefix_len; + for (values) |v| res += try self.length(T, v); + return res; } /// Returns slice that is valid until next `readAll` call. @@ -1226,21 +1268,21 @@ pub fn Stream( // Copy last (hopefully small) portion of buffer to start. It may alias. std.mem.copyForwards(u8, &self.buffer, self.view); self.view = self.buffer[0..self.view.len]; - try self.readFragment(self.handshake_type); + try self.readFragment(); return try self.readAll(len); } } } - /// Read plaintext fragment from `self.stream` into `self.buffer`. Checks message has correct - /// `content_type`. - pub fn readFragment(self: *Self, handshake_type: ?HandshakeType) ReadError!void { - self.handshake_type = handshake_type; + /// Read fragment from `self.stream` into `self.buffer`. + /// Checks message `content_type` matches `self.content_type`. + /// Checks message `handshake_type` matches `self.handshake_type`. + pub fn readFragment(self: *Self) ReadError!void { var plaintext_header: [fieldsLen(Plaintext)]u8 = undefined; var n_read: usize = 0; var ty: ContentType = .invalid; - var length: u16 = 0; + var len: u16 = 0; while (true) { n_read = try self.stream.readAll(&plaintext_header); @@ -1248,7 +1290,7 @@ pub fn Stream( self.view = &plaintext_header; ty = try self.read(ContentType); _ = try self.read(Version); - length = try self.read(u16); + len = try self.read(u16); switch (ty) { .alert => { @@ -1267,21 +1309,21 @@ pub fn Stream( if (self.application_cipher != null) return error.TlsUnexpectedMessage; var next_byte: [1]u8 = undefined; n_read = try self.stream.readAll(&next_byte); - if (length != 1 or n_read != 1 or next_byte[0] != 1) return error.TlsIllegalParameter; + if (len != 1 or n_read != 1 or next_byte[0] != 1) return error.TlsIllegalParameter; }, else => break, } } if (ty != self.content_type) return error.TlsDecodeError; - if (length > self.maxFragmentSize()) return error.TlsRecordOverflow; + if (len > self.maxFragmentSize()) return error.TlsRecordOverflow; if (self.view.len > self.maxFragmentSize()) return error.TlsDecodeError; // Should have read more before calling readFragment again. - const dest = self.buffer[self.view.len..][0..length]; + const dest = self.buffer[self.view.len..][0..len]; n_read = try self.stream.readAll(dest); - if (n_read != length) return error.TlsConnectionTruncated; + if (n_read != len) return error.TlsConnectionTruncated; - self.view = self.buffer[0 .. self.view.len + length]; + self.view = self.buffer[0 .. self.view.len + len]; if (ty == .application_data and self.handshake_cipher != null) { switch (self.handshake_cipher.?) { @@ -1307,17 +1349,16 @@ pub fn Stream( } self.transcript_hash.update(self.view); + } else { + self.transcript_hash.update(self.view); + } - const handshake_ty = try self.read(HandshakeType); - if (handshake_type == null or handshake_ty != handshake_type) return error.TlsDecodeError; + if (self.handshake_type) |expected| { + const actual = try self.read(HandshakeType); + if (actual != expected) return error.TlsDecodeError; + // TODO: verify this? const handshake_len = try self.read(u24); std.debug.print("handshake_len {d}\n", .{ handshake_len }); - } else { - self.transcript_hash.update(self.view); - if (handshake_type) |expected| { - const actual = try self.read(HandshakeType); - if (actual != expected) return error.TlsDecodeError; - } } } @@ -1349,11 +1390,7 @@ pub fn Stream( return @enumFromInt(int); }, else => { - if (@hasDecl(T, "read")) { - return try T.read(self); - } else { - @compileError("expected fn read(stream: anytype): @This() on type " ++ @typeName(T)); - } + return try T.read(self); }, } } @@ -1412,17 +1449,6 @@ pub fn Stream( }; } -fn slice_len(comptime T: type, values: []const T, is_client: bool) usize { - switch (@typeInfo(T)) { - .Int, .Enum => return @sizeOf(T) * values.len, - else => { - var len: usize = 0; - for (values) |v| len += v.len(is_client); - return len; - }, - } -} - const Encoder = struct { fn RetType(comptime T: type) type { switch (@typeInfo(T)) { @@ -1478,24 +1504,6 @@ const Encoder = struct { } }; -fn nonce_for_len(len: comptime_int, iv: [len]u8, seq: usize) [len]u8 { - if (builtin.zig_backend == .stage2_x86_64 and len > comptime std.simd.suggestVectorLength(u8) orelse 1) { - var res = iv; - const operand = std.mem.readInt(u64, res[res.len - 8 ..], .big); - std.mem.writeInt(u64, res[res.len - 8 ..], operand ^ seq, .big); - return res; - } else { - const V = @Vector(len, u8); - const pad = [1]u8{0} ** (len - 8); - const big = switch (native_endian) { - .big => seq, - .little => @byteSwap(seq), - }; - const operand: V = pad ++ @as([8]u8, @bitCast(big)); - return @as(V, iv) ^ operand; - } -} - fn HandshakeCipherT(comptime suite: CipherSuite) type { return struct { pub const AEAD = suite.Aead(); @@ -1594,6 +1602,24 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { }; } +fn nonce_for_len(len: comptime_int, iv: [len]u8, seq: usize) [len]u8 { + if (builtin.zig_backend == .stage2_x86_64 and len > comptime std.simd.suggestVectorLength(u8) orelse 1) { + var res = iv; + const operand = std.mem.readInt(u64, res[res.len - 8 ..], .big); + std.mem.writeInt(u64, res[res.len - 8 ..], operand ^ seq, .big); + return res; + } else { + const V = @Vector(len, u8); + const pad = [1]u8{0} ** (len - 8); + const big = switch (native_endian) { + .big => seq, + .little => @byteSwap(seq), + }; + const operand: V = pad ++ @as([8]u8, @bitCast(big)); + return @as(V, iv) ^ operand; + } +} + pub fn hkdfExpandLabel( comptime Hkdf: type, key: [Hkdf.prk_length]u8, @@ -1633,6 +1659,51 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } +fn fieldsLen(comptime T: type) comptime_int { + var res: comptime_int = 0; + inline for (std.meta.fields(T)) |f| res += @sizeOf(f.type); + return res; +} + +/// Default suites used for client and server in descending order of preference. +/// The order is chosen based on what crypto algorithms Zig has available in +/// the standard library and their speed on x86_64-linux. +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast +/// aegis-128l: 15382 MiB/s +/// aegis-256: 9553 MiB/s +/// aes128-gcm: 3721 MiB/s +/// aes256-gcm: 3010 MiB/s +/// chacha20Poly1305: 597 MiB/s +/// +/// Measurement taken with 0.11.0-dev.810+c2f5848fe +/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: +/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline +/// aegis-128l: 629 MiB/s +/// chacha20Poly1305: 529 MiB/s +/// aegis-256: 461 MiB/s +/// aes128-gcm: 138 MiB/s +/// aes256-gcm: 120 MiB/s +pub const default_cipher_suites = + if (crypto.core.aes.has_hardware_support) + [_]CipherSuite{ + .aegis_128l_sha256, + .aegis_256_sha512, + .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + } +else + [_]CipherSuite{ + .chacha20_poly1305_sha256, + .aegis_128l_sha256, + .aegis_256_sha512, + .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + }; + // Implements `StreamInterface` with a ring buffer const TestStream = struct { buffer: Buffer, @@ -1676,45 +1747,6 @@ const TestStream = struct { } }; -/// Default suites used for client and server in descending order of preference. -/// The order is chosen based on what crypto algorithms Zig has available in -/// the standard library and their speed on x86_64-linux. -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -/// aegis-128l: 15382 MiB/s -/// aegis-256: 9553 MiB/s -/// aes128-gcm: 3721 MiB/s -/// aes256-gcm: 3010 MiB/s -/// chacha20Poly1305: 597 MiB/s -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline -/// aegis-128l: 629 MiB/s -/// chacha20Poly1305: 529 MiB/s -/// aegis-256: 461 MiB/s -/// aes128-gcm: 138 MiB/s -/// aes256-gcm: 120 MiB/s -pub const default_cipher_suites = - if (crypto.core.aes.has_hardware_support) - [_]CipherSuite{ - .aegis_128l_sha256, - .aegis_256_sha512, - .aes_128_gcm_sha256, - .aes_256_gcm_sha384, - .chacha20_poly1305_sha256, - } -else - [_]CipherSuite{ - .chacha20_poly1305_sha256, - .aegis_128l_sha256, - .aegis_256_sha512, - .aes_128_gcm_sha256, - .aes_256_gcm_sha384, - }; - const TestHasher = struct { fn update(self: *@This(), bytes: []const u8) void { _ = .{ self, bytes }; @@ -1741,9 +1773,7 @@ test "tls client and server handshake, data, and close_notify" { .options = .{ .host = host, .ca_bundle = null }, }; - const server_cert = [_]u8{ -0x30, 0x82, 0x03, 0x21, 0x30, 0x82, 0x02, 0x09, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x08, 0x15, 0x5a, 0x92, 0xad, 0xc2, 0x04, 0x8f, 0x90, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x22, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x38, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x17, 0x0d, 0x31, 0x39, 0x31, 0x30, 0x30, 0x35, 0x30, 0x31, 0x33, 0x38, 0x31, 0x37, 0x5a, 0x30, 0x2b, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x1c, 0x30, 0x1a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xc4, 0x80, 0x36, 0x06, 0xba, 0xe7, 0x47, 0x6b, 0x08, 0x94, 0x04, 0xec, 0xa7, 0xb6, 0x91, 0x04, 0x3f, 0xf7, 0x92, 0xbc, 0x19, 0xee, 0xfb, 0x7d, 0x74, 0xd7, 0xa8, 0x0d, 0x00, 0x1e, 0x7b, 0x4b, 0x3a, 0x4a, 0xe6, 0x0f, 0xe8, 0xc0, 0x71, 0xfc, 0x73, 0xe7, 0x02, 0x4c, 0x0d, 0xbc, 0xf4, 0xbd, 0xd1, 0x1d, 0x39, 0x6b, 0xba, 0x70, 0x46, 0x4a, 0x13, 0xe9, 0x4a, 0xf8, 0x3d, 0xf3, 0xe1, 0x09, 0x59, 0x54, 0x7b, 0xc9, 0x55, 0xfb, 0x41, 0x2d, 0xa3, 0x76, 0x52, 0x11, 0xe1, 0xf3, 0xdc, 0x77, 0x6c, 0xaa, 0x53, 0x37, 0x6e, 0xca, 0x3a, 0xec, 0xbe, 0xc3, 0xaa, 0xb7, 0x3b, 0x31, 0xd5, 0x6c, 0xb6, 0x52, 0x9c, 0x80, 0x98, 0xbc, 0xc9, 0xe0, 0x28, 0x18, 0xe2, 0x0b, 0xf7, 0xf8, 0xa0, 0x3a, 0xfd, 0x17, 0x04, 0x50, 0x9e, 0xce, 0x79, 0xbd, 0x9f, 0x39, 0xf1, 0xea, 0x69, 0xec, 0x47, 0x97, 0x2e, 0x83, 0x0f, 0xb5, 0xca, 0x95, 0xde, 0x95, 0xa1, 0xe6, 0x04, 0x22, 0xd5, 0xee, 0xbe, 0x52, 0x79, 0x54, 0xa1, 0xe7, 0xbf, 0x8a, 0x86, 0xf6, 0x46, 0x6d, 0x0d, 0x9f, 0x16, 0x95, 0x1a, 0x4c, 0xf7, 0xa0, 0x46, 0x92, 0x59, 0x5c, 0x13, 0x52, 0xf2, 0x54, 0x9e, 0x5a, 0xfb, 0x4e, 0xbf, 0xd7, 0x7a, 0x37, 0x95, 0x01, 0x44, 0xe4, 0xc0, 0x26, 0x87, 0x4c, 0x65, 0x3e, 0x40, 0x7d, 0x7d, 0x23, 0x07, 0x44, 0x01, 0xf4, 0x84, 0xff, 0xd0, 0x8f, 0x7a, 0x1f, 0xa0, 0x52, 0x10, 0xd1, 0xf4, 0xf0, 0xd5, 0xce, 0x79, 0x70, 0x29, 0x32, 0xe2, 0xca, 0xbe, 0x70, 0x1f, 0xdf, 0xad, 0x6b, 0x4b, 0xb7, 0x11, 0x01, 0xf4, 0x4b, 0xad, 0x66, 0x6a, 0x11, 0x13, 0x0f, 0xe2, 0xee, 0x82, 0x9e, 0x4d, 0x02, 0x9d, 0xc9, 0x1c, 0xdd, 0x67, 0x16, 0xdb, 0xb9, 0x06, 0x18, 0x86, 0xed, 0xc1, 0xba, 0x94, 0x21, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x52, 0x30, 0x50, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x89, 0x4f, 0xde, 0x5b, 0xcc, 0x69, 0xe2, 0x52, 0xcf, 0x3e, 0xa3, 0x00, 0xdf, 0xb1, 0x97, 0xb8, 0x1d, 0xe1, 0xc1, 0x46, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x59, 0x16, 0x45, 0xa6, 0x9a, 0x2e, 0x37, 0x79, 0xe4, 0xf6, 0xdd, 0x27, 0x1a, 0xba, 0x1c, 0x0b, 0xfd, 0x6c, 0xd7, 0x55, 0x99, 0xb5, 0xe7, 0xc3, 0x6e, 0x53, 0x3e, 0xff, 0x36, 0x59, 0x08, 0x43, 0x24, 0xc9, 0xe7, 0xa5, 0x04, 0x07, 0x9d, 0x39, 0xe0, 0xd4, 0x29, 0x87, 0xff, 0xe3, 0xeb, 0xdd, 0x09, 0xc1, 0xcf, 0x1d, 0x91, 0x44, 0x55, 0x87, 0x0b, 0x57, 0x1d, 0xd1, 0x9b, 0xdf, 0x1d, 0x24, 0xf8, 0xbb, 0x9a, 0x11, 0xfe, 0x80, 0xfd, 0x59, 0x2b, 0xa0, 0x39, 0x8c, 0xde, 0x11, 0xe2, 0x65, 0x1e, 0x61, 0x8c, 0xe5, 0x98, 0xfa, 0x96, 0xe5, 0x37, 0x2e, 0xef, 0x3d, 0x24, 0x8a, 0xfd, 0xe1, 0x74, 0x63, 0xeb, 0xbf, 0xab, 0xb8, 0xe4, 0xd1, 0xab, 0x50, 0x2a, 0x54, 0xec, 0x00, 0x64, 0xe9, 0x2f, 0x78, 0x19, 0x66, 0x0d, 0x3f, 0x27, 0xcf, 0x20, 0x9e, 0x66, 0x7f, 0xce, 0x5a, 0xe2, 0xe4, 0xac, 0x99, 0xc7, 0xc9, 0x38, 0x18, 0xf8, 0xb2, 0x51, 0x07, 0x22, 0xdf, 0xed, 0x97, 0xf3, 0x2e, 0x3e, 0x93, 0x49, 0xd4, 0xc6, 0x6c, 0x9e, 0xa6, 0x39, 0x6d, 0x74, 0x44, 0x62, 0xa0, 0x6b, 0x42, 0xc6, 0xd5, 0xba, 0x68, 0x8e, 0xac, 0x3a, 0x01, 0x7b, 0xdd, 0xfc, 0x8e, 0x2c, 0xfc, 0xad, 0x27, 0xcb, 0x69, 0xd3, 0xcc, 0xdc, 0xa2, 0x80, 0x41, 0x44, 0x65, 0xd3, 0xae, 0x34, 0x8c, 0xe0, 0xf3, 0x4a, 0xb2, 0xfb, 0x9c, 0x61, 0x83, 0x71, 0x31, 0x2b, 0x19, 0x10, 0x41, 0x64, 0x1c, 0x23, 0x7f, 0x11, 0xa5, 0xd6, 0x5c, 0x84, 0x4f, 0x04, 0x04, 0x84, 0x99, 0x38, 0x71, 0x2b, 0x95, 0x9e, 0xd6, 0x85, 0xbc, 0x5c, 0x5d, 0xd6, 0x45, 0xed, 0x19, 0x90, 0x94, 0x73, 0x40, 0x29, 0x26, 0xdc, 0xb4, 0x0e, 0x34, 0x69, 0xa1, 0x59, 0x41, 0xe8, 0xe2, 0xcc, 0xa8, 0x4b, 0xb6, 0x08, 0x46, 0x36, 0xa0 - }; + const server_der = @embedFile("./testdata/server.der"); var server = Server(@TypeOf(inner_stream)){ .stream = Stream(Plaintext.max_length, TestStream, server_mod.TranscriptHash){ .stream = &inner_stream, @@ -1755,17 +1785,32 @@ test "tls client and server handshake, data, and close_notify" { .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, .certificate = .{ .entries = &[_]Certificate.Entry{ - .{ .data = &server_cert }, + .{ .data = server_der }, } } }, }; - const session_id = [_]u8{ 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff }; - const client_random = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f }; - const server_random = [_]u8{ 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f }; - const client_x25519_seed = [_]u8{ 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f }; - const server_x25519_seed = [_]u8{ 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf }; + const session_id = [_]u8{ + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, + }; + const client_random = [_]u8{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }; + const server_random = [_]u8{ + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + }; + const client_x25519_seed = [_]u8{ + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + }; + const server_x25519_seed = [_]u8{ + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, + }; const key_pairs = try client_mod.KeyPairs.initAdvanced( client_random, @@ -1833,8 +1878,7 @@ test "tls client and server handshake, data, and close_notify" { }, }; - client.stream.version = .tls_1_0; - try client.stream.write(ClientHello, hello); + _ = try client.stream.write(Handshake, .{ .client_hello = hello }); try client.stream.flush(); } @@ -1936,6 +1980,17 @@ test "tls client and server handshake, data, and close_notify" { }; try server.send_hello(client_hello, server_key_pair); + // hack to match xargs, need to fix server + const signature_verify = [_]u8{ +0x5c, 0xbb, 0x24, 0xc0, 0x40, 0x93, 0x32, 0xda, 0xa9, 0x20, 0xbb, 0xab, 0xbd, 0xb9, 0xbd, 0x50, 0x17, 0x0b, 0xe4, 0x9c, 0xfb, 0xe0, 0xa4, 0x10, 0x7f, 0xca, 0x6f, 0xfb, 0x10, 0x68, 0xe6, 0x5f, 0x96, 0x9e, 0x6d, 0xe7, 0xd4, 0xf9, 0xe5, 0x60, 0x38, 0xd6, 0x7c, 0x69, 0xc0, 0x31, 0x40, 0x3a, 0x7a, 0x7c, 0x0b, 0xcc, 0x86, 0x83, 0xe6, 0x57, 0x21, 0xa0, 0xc7, 0x2c, 0xc6, 0x63, 0x40, 0x19, 0xad, 0x1d, 0x3a, 0xd2, 0x65, 0xa8, 0x12, 0x61, 0x5b, 0xa3, 0x63, 0x80, 0x37, 0x20, 0x84, 0xf5, 0xda, 0xec, 0x7e, 0x63, 0xd3, 0xf4, 0x93, 0x3f, 0x27, 0x22, 0x74, 0x19, 0xa6, 0x11, 0x03, 0x46, 0x44, 0xdc, 0xdb, 0xc7, 0xbe, 0x3e, 0x74, 0xff, 0xac, 0x47, 0x3f, 0xaa, 0xad, 0xde, 0x8c, 0x2f, 0xc6, 0x5f, 0x32, 0x65, 0x77, 0x3e, 0x7e, 0x62, 0xde, 0x33, 0x86, 0x1f, 0xa7, 0x05, 0xd1, 0x9c, 0x50, 0x6e, 0x89, 0x6c, 0x8d, 0x82, 0xf5, 0xbc, 0xf3, 0x5f, 0xec, 0xe2, 0x59, 0xb7, 0x15, 0x38, 0x11, 0x5e, 0x9c, 0x8c, 0xfb, 0xa6, 0x2e, 0x49, 0xbb, 0x84, 0x74, 0xf5, 0x85, 0x87, 0xb1, 0x1b, 0x8a, 0xe3, 0x17, 0xc6, 0x33, 0xe9, 0xc7, 0x6c, 0x79, 0x1d, 0x46, 0x62, 0x84, 0xad, 0x9c, 0x4f, 0xf7, 0x35, 0xa6, 0xd2, 0xe9, 0x63, 0xb5, 0x9b, 0xbc, 0xa4, 0x40, 0xa3, 0x07, 0x09, 0x1a, 0x1b, 0x4e, 0x46, 0xbc, 0xc7, 0xa2, 0xf9, 0xfb, 0x2f, 0x1c, 0x89, 0x8e, 0xcb, 0x19, 0x91, 0x8b, 0xe4, 0x12, 0x1d, 0x7e, 0x8e, 0xd0, 0x4c, 0xd5, 0x0c, 0x9a, 0x59, 0xe9, 0x87, 0x98, 0x01, 0x07, 0xbb, 0xbf, 0x29, 0x9c, 0x23, 0x2e, 0x7f, 0xdb, 0xe1, 0x0a, 0x4c, 0xfd, 0xae, 0x5c, 0x89, 0x1c, 0x96, 0xaf, 0xdf, 0xf9, 0x4b, 0x54, 0xcc, 0xd2, 0xbc, 0x19, 0xd3, 0xcd, 0xaa, 0x66, 0x44, 0x85, 0x9c + }; + _ = try server.stream.write(Handshake, Handshake{ .certificate_verify = CertificateVerify{ + .algorithm = .rsa_pss_rsae_sha256, + .signature = &signature_verify, + } }); + try server.stream.flush(); + + try server.send_handshake_finish(); { const buf = tmp_buf[0..inner_stream.buffer.len()]; try inner_stream.peek(buf); @@ -1993,10 +2048,96 @@ test "tls client and server handshake, data, and close_notify" { 0x17, // application data (lie for tls 1.2 compat) 0x03, 0x03, // tls 1.2 0x03, 0x43, // application data len -0xba, 0xf0, 0x0a, 0x9b, 0xe5, 0x0f, 0x3f, 0x23, 0x07, 0xe7, 0x26, 0xed, 0xcb, 0xda, 0xcb, 0xe4, 0xb1, 0x86, 0x16, 0x44, 0x9d, 0x46, 0xc6, 0x20, 0x7a, 0xf6, 0xe9, 0x95, 0x3e, 0xe5, 0xd2, 0x41, 0x1b, 0xa6, 0x5d, 0x31, 0xfe, 0xaf, 0x4f, 0x78, 0x76, 0x4f, 0x2d, 0x69, 0x39, 0x87, 0x18, 0x6c, 0xc0, 0x13, 0x29, 0xc1, 0x87, 0xa5, 0xe4, 0x60, 0x8e, 0x8d, 0x27, 0xb3, 0x18, 0xe9, 0x8d, 0xd9, 0x47, 0x69, 0xf7, 0x73, 0x9c, 0xe6, 0x76, 0x83, 0x92, 0xca, 0xca, 0x8d, 0xcc, 0x59, 0x7d, 0x77, 0xec, 0x0d, 0x12, 0x72, 0x23, 0x37, 0x85, 0xf6, 0xe6, 0x9d, 0x6f, 0x43, 0xef, 0xfa, 0x8e, 0x79, 0x05, 0xed, 0xfd, 0xc4, 0x03, 0x7e, 0xee, 0x59, 0x33, 0xe9, 0x90, 0xa7, 0x97, 0x2f, 0x20, 0x69, 0x13, 0xa3, 0x1e, 0x8d, 0x04, 0x93, 0x13, 0x66, 0xd3, 0xd8, 0xbc, 0xd6, 0xa4, 0xa4, 0xd6, 0x47, 0xdd, 0x4b, 0xd8, 0x0b, 0x0f, 0xf8, 0x63, 0xce, 0x35, 0x54, 0x83, 0x3d, 0x74, 0x4c, 0xf0, 0xe0, 0xb9, 0xc0, 0x7c, 0xae, 0x72, 0x6d, 0xd2, 0x3f, 0x99, 0x53, 0xdf, 0x1f, 0x1c, 0xe3, 0xac, 0xeb, 0x3b, 0x72, 0x30, 0x87, 0x1e, 0x92, 0x31, 0x0c, 0xfb, 0x2b, 0x09, 0x84, 0x86, 0xf4, 0x35, 0x38, 0xf8, 0xe8, 0x2d, 0x84, 0x04, 0xe5, 0xc6, 0xc2, 0x5f, 0x66, 0xa6, 0x2e, 0xbe, 0x3c, 0x5f, 0x26, 0x23, 0x26, 0x40, 0xe2, 0x0a, 0x76, 0x91, 0x75, 0xef, 0x83, 0x48, 0x3c, 0xd8, 0x1e, 0x6c, 0xb1, 0x6e, 0x78, 0xdf, 0xad, 0x4c, 0x1b, 0x71, 0x4b, 0x04, 0xb4, 0x5f, 0x6a, 0xc8, 0xd1, 0x06, 0x5a, 0xd1, 0x8c, 0x13, 0x45, 0x1c, 0x90, 0x55, 0xc4, 0x7d, 0xa3, 0x00, 0xf9, 0x35, 0x36, 0xea, 0x56, 0xf5, 0x31, 0x98, 0x6d, 0x64, 0x92, 0x77, 0x53, 0x93, 0xc4, 0xcc, 0xb0, 0x95, 0x46, 0x70, 0x92, 0xa0, 0xec, 0x0b, 0x43, 0xed, 0x7a, 0x06, 0x87, 0xcb, 0x47, 0x0c, 0xe3, 0x50, 0x91, 0x7b, 0x0a, 0xc3, 0x0c, 0x6e, 0x5c, 0x24, 0x72, 0x5a, 0x78, 0xc4, 0x5f, 0x9f, 0x5f, 0x29, 0xb6, 0x62, 0x68, 0x67, 0xf6, 0xf7, 0x9c, 0xe0, 0x54, 0x27, 0x35, 0x47, 0xb3, 0x6d, 0xf0, 0x30, 0xbd, 0x24, 0xaf, 0x10, 0xd6, 0x32, 0xdb, 0xa5, 0x4f, 0xc4, 0xe8, 0x90, 0xbd, 0x05, 0x86, 0x92, 0x8c, 0x02, 0x06, 0xca, 0x2e, 0x28, 0xe4, 0x4e, 0x22, 0x7a, 0x2d, 0x50, 0x63, 0x19, 0x59, 0x35, 0xdf, 0x38, 0xda, 0x89, 0x36, 0x09, 0x2e, 0xef, 0x01, 0xe8, 0x4c, 0xad, 0x2e, 0x49, 0xd6, 0x2e, 0x47, 0x0a, 0x6c, 0x77, 0x45, 0xf6, 0x25, 0xec, 0x39, 0xe4, 0xfc, 0x23, 0x32, 0x9c, 0x79, 0xd1, 0x17, 0x28, 0x76, 0x80, 0x7c, 0x36, 0xd7, 0x36, 0xba, 0x42, 0xbb, 0x69, 0xb0, 0x04, 0xff, 0x55, 0xf9, 0x38, 0x50, 0xdc, 0x33, 0xc1, 0xf9, 0x8a, 0xbb, 0x92, 0x85, 0x83, 0x24, 0xc7, 0x6f, 0xf1, 0xeb, 0x08, 0x5d, 0xb3, 0xc1, 0xfc, 0x50, 0xf7, 0x4e, 0xc0, 0x44, 0x42, 0xe6, 0x22, 0x97, 0x3e, 0xa7, 0x07, 0x43, 0x41, 0x87, 0x94, 0xc3, 0x88, 0x14, 0x0b, 0xb4, 0x92, 0xd6, 0x29, 0x4a, 0x05, 0x40, 0xe5, 0xa5, 0x9c, 0xfa, 0xe6, 0x0b, 0xa0, 0xf1, 0x48, 0x99, 0xfc, 0xa7, 0x13, 0x33, 0x31, 0x5e, 0xa0, 0x83, 0xa6, 0x8e, 0x1d, 0x7c, 0x1e, 0x4c, 0xdc, 0x2f, 0x56, 0xbc, 0xd6, 0x11, 0x96, 0x81, 0xa4, 0xad, 0xbc, 0x1b, 0xbf, 0x42, 0xaf, 0xd8, 0x06, 0xc3, 0xcb, 0xd4, 0x2a, 0x07, 0x6f, 0x54, 0x5d, 0xee, 0x4e, 0x11, 0x8d, 0x0b, 0x39, 0x67, 0x54, 0xbe, 0x2b, 0x04, 0x2a, 0x68, 0x5d, 0xd4, 0x72, 0x7e, 0x89, 0xc0, 0x38, 0x6a, 0x94, 0xd3, 0xcd, 0x6e, 0xcb, 0x98, 0x20, 0xe9, 0xd4, 0x9a, 0xfe, 0xed, 0x66, 0xc4, 0x7e, 0x6f, 0xc2, 0x43, 0xea, 0xbe, 0xbb, 0xcb, 0x0b, 0x02, 0x45, 0x38, 0x77, 0xf5, 0xac, 0x5d, 0xbf, 0xbd, 0xf8, 0xdb, 0x10, 0x52, 0xa3, 0xc9, 0x94, 0xb2, 0x24, 0xcd, 0x9a, 0xaa, 0xf5, 0x6b, 0x02, 0x6b, 0xb9, 0xef, 0xa2, 0xe0, 0x13, 0x02, 0xb3, 0x64, 0x01, 0xab, 0x64, 0x94, 0xe7, 0x01, 0x8d, 0x6e, 0x5b, 0x57, 0x3b, 0xd3, 0x8b, 0xce, 0xf0, 0x23, 0xb1, 0xfc, 0x92, 0x94, 0x6b, 0xbc, 0xa0, 0x20, 0x9c, 0xa5, 0xfa, 0x92, 0x6b, 0x49, 0x70, 0xb1, 0x00, 0x91, 0x03, 0x64, 0x5c, 0xb1, 0xfc, 0xfe, 0x55, 0x23, 0x11, 0xff, 0x73, 0x05, 0x58, 0x98, 0x43, 0x70, 0x03, 0x8f, 0xd2, 0xcc, 0xe2, 0xa9, 0x1f, 0xc7, 0x4d, 0x6f, 0x3e, 0x3e, 0xa9, 0xf8, 0x43, 0xee, 0xd3, 0x56, 0xf6, 0xf8, 0x2d, 0x35, 0xd0, 0x3b, 0xc2, 0x4b, 0x81, 0xb5, 0x8c, 0xeb, 0x1a, 0x43, 0xec, 0x94, 0x37, 0xe6, 0xf1, 0xe5, 0x0e, 0xb6, 0xf5, 0x55, 0xe3, 0x21, 0xfd, 0x67, 0xc8, 0x33, 0x2e, 0xb1, 0xb8, 0x32, 0xaa, 0x8d, 0x79, 0x5a, 0x27, 0xd4, 0x79, 0xc6, 0xe2, 0x7d, 0x5a, 0x61, 0x03, 0x46, 0x83, 0x89, 0x19, 0x03, 0xf6, 0x64, 0x21, 0xd0, 0x94, 0xe1, 0xb0, 0x0a, 0x9a, 0x13, 0x8d, 0x86, 0x1e, 0x6f, 0x78, 0xa2, 0x0a, 0xd3, 0xe1, 0x58, 0x00, 0x54, 0xd2, 0xe3, 0x05, 0x25, 0x3c, 0x71, 0x3a, 0x02, 0xfe, 0x1e, 0x28, 0xde, 0xee, 0x73, 0x36, 0x24, 0x6f, 0x6a, 0xe3, 0x43, 0x31, 0x80, 0x6b, 0x46, 0xb4, 0x7b, 0x83, 0x3c, 0x39, 0xb9, 0xd3, 0x1c, 0xd3, 0x00, 0xc2, 0xa6, 0xed, 0x83, 0x13, 0x99, 0x77, 0x6d, 0x07, 0xf5, 0x70, 0xea, 0xf0, 0x05, 0x9a, 0x2c, 0x68, 0xa5, 0xf3, 0xae, 0x16, 0xb6, 0x17, 0x40, 0x4a, 0xf7, 0xb7, 0x23, 0x1a, 0x4d, 0x94, 0x27, 0x58, 0xfc, 0x02, 0x0b, 0x3f, 0x23, 0xee, 0x8c, 0x15, 0xe3, 0x60, 0x44, 0xcf, 0xd6, 0x7c, 0xd6, 0x40, 0x99, 0x3b, 0x16, 0x20, 0x75, 0x97, 0xfb, 0xf3, 0x85, 0xea, 0x7a, 0x4d, 0x99, 0xe8, 0xd4, 0x56, 0xff, 0x83, 0xd4, 0x1f, 0x7b, 0x8b, 0x4f, 0x06, 0x9b, 0x02, 0x8a, 0x2a, 0x63, 0xa9, 0x19, 0xa7, 0x0e, 0x3a, 0x10, 0xe3, 0x08, // encrypted cert + 0xba, 0xf0, 0x0a, 0x9b, 0xe5, 0x0f, 0x3f, 0x23, 0x07, 0xe7, 0x26, 0xed, 0xcb, 0xda, 0xcb, 0xe4, + 0xb1, 0x86, 0x16, 0x44, 0x9d, 0x46, 0xc6, 0x20, 0x7a, 0xf6, 0xe9, 0x95, 0x3e, 0xe5, 0xd2, 0x41, + 0x1b, 0xa6, 0x5d, 0x31, 0xfe, 0xaf, 0x4f, 0x78, 0x76, 0x4f, 0x2d, 0x69, 0x39, 0x87, 0x18, 0x6c, + 0xc0, 0x13, 0x29, 0xc1, 0x87, 0xa5, 0xe4, 0x60, 0x8e, 0x8d, 0x27, 0xb3, 0x18, 0xe9, 0x8d, 0xd9, + 0x47, 0x69, 0xf7, 0x73, 0x9c, 0xe6, 0x76, 0x83, 0x92, 0xca, 0xca, 0x8d, 0xcc, 0x59, 0x7d, 0x77, + 0xec, 0x0d, 0x12, 0x72, 0x23, 0x37, 0x85, 0xf6, 0xe6, 0x9d, 0x6f, 0x43, 0xef, 0xfa, 0x8e, 0x79, + 0x05, 0xed, 0xfd, 0xc4, 0x03, 0x7e, 0xee, 0x59, 0x33, 0xe9, 0x90, 0xa7, 0x97, 0x2f, 0x20, 0x69, + 0x13, 0xa3, 0x1e, 0x8d, 0x04, 0x93, 0x13, 0x66, 0xd3, 0xd8, 0xbc, 0xd6, 0xa4, 0xa4, 0xd6, 0x47, + 0xdd, 0x4b, 0xd8, 0x0b, 0x0f, 0xf8, 0x63, 0xce, 0x35, 0x54, 0x83, 0x3d, 0x74, 0x4c, 0xf0, 0xe0, + 0xb9, 0xc0, 0x7c, 0xae, 0x72, 0x6d, 0xd2, 0x3f, 0x99, 0x53, 0xdf, 0x1f, 0x1c, 0xe3, 0xac, 0xeb, + 0x3b, 0x72, 0x30, 0x87, 0x1e, 0x92, 0x31, 0x0c, 0xfb, 0x2b, 0x09, 0x84, 0x86, 0xf4, 0x35, 0x38, + 0xf8, 0xe8, 0x2d, 0x84, 0x04, 0xe5, 0xc6, 0xc2, 0x5f, 0x66, 0xa6, 0x2e, 0xbe, 0x3c, 0x5f, 0x26, + 0x23, 0x26, 0x40, 0xe2, 0x0a, 0x76, 0x91, 0x75, 0xef, 0x83, 0x48, 0x3c, 0xd8, 0x1e, 0x6c, 0xb1, + 0x6e, 0x78, 0xdf, 0xad, 0x4c, 0x1b, 0x71, 0x4b, 0x04, 0xb4, 0x5f, 0x6a, 0xc8, 0xd1, 0x06, 0x5a, + 0xd1, 0x8c, 0x13, 0x45, 0x1c, 0x90, 0x55, 0xc4, 0x7d, 0xa3, 0x00, 0xf9, 0x35, 0x36, 0xea, 0x56, + 0xf5, 0x31, 0x98, 0x6d, 0x64, 0x92, 0x77, 0x53, 0x93, 0xc4, 0xcc, 0xb0, 0x95, 0x46, 0x70, 0x92, + 0xa0, 0xec, 0x0b, 0x43, 0xed, 0x7a, 0x06, 0x87, 0xcb, 0x47, 0x0c, 0xe3, 0x50, 0x91, 0x7b, 0x0a, + 0xc3, 0x0c, 0x6e, 0x5c, 0x24, 0x72, 0x5a, 0x78, 0xc4, 0x5f, 0x9f, 0x5f, 0x29, 0xb6, 0x62, 0x68, + 0x67, 0xf6, 0xf7, 0x9c, 0xe0, 0x54, 0x27, 0x35, 0x47, 0xb3, 0x6d, 0xf0, 0x30, 0xbd, 0x24, 0xaf, + 0x10, 0xd6, 0x32, 0xdb, 0xa5, 0x4f, 0xc4, 0xe8, 0x90, 0xbd, 0x05, 0x86, 0x92, 0x8c, 0x02, 0x06, + 0xca, 0x2e, 0x28, 0xe4, 0x4e, 0x22, 0x7a, 0x2d, 0x50, 0x63, 0x19, 0x59, 0x35, 0xdf, 0x38, 0xda, + 0x89, 0x36, 0x09, 0x2e, 0xef, 0x01, 0xe8, 0x4c, 0xad, 0x2e, 0x49, 0xd6, 0x2e, 0x47, 0x0a, 0x6c, + 0x77, 0x45, 0xf6, 0x25, 0xec, 0x39, 0xe4, 0xfc, 0x23, 0x32, 0x9c, 0x79, 0xd1, 0x17, 0x28, 0x76, + 0x80, 0x7c, 0x36, 0xd7, 0x36, 0xba, 0x42, 0xbb, 0x69, 0xb0, 0x04, 0xff, 0x55, 0xf9, 0x38, 0x50, + 0xdc, 0x33, 0xc1, 0xf9, 0x8a, 0xbb, 0x92, 0x85, 0x83, 0x24, 0xc7, 0x6f, 0xf1, 0xeb, 0x08, 0x5d, + 0xb3, 0xc1, 0xfc, 0x50, 0xf7, 0x4e, 0xc0, 0x44, 0x42, 0xe6, 0x22, 0x97, 0x3e, 0xa7, 0x07, 0x43, + 0x41, 0x87, 0x94, 0xc3, 0x88, 0x14, 0x0b, 0xb4, 0x92, 0xd6, 0x29, 0x4a, 0x05, 0x40, 0xe5, 0xa5, + 0x9c, 0xfa, 0xe6, 0x0b, 0xa0, 0xf1, 0x48, 0x99, 0xfc, 0xa7, 0x13, 0x33, 0x31, 0x5e, 0xa0, 0x83, + 0xa6, 0x8e, 0x1d, 0x7c, 0x1e, 0x4c, 0xdc, 0x2f, 0x56, 0xbc, 0xd6, 0x11, 0x96, 0x81, 0xa4, 0xad, + 0xbc, 0x1b, 0xbf, 0x42, 0xaf, 0xd8, 0x06, 0xc3, 0xcb, 0xd4, 0x2a, 0x07, 0x6f, 0x54, 0x5d, 0xee, + 0x4e, 0x11, 0x8d, 0x0b, 0x39, 0x67, 0x54, 0xbe, 0x2b, 0x04, 0x2a, 0x68, 0x5d, 0xd4, 0x72, 0x7e, + 0x89, 0xc0, 0x38, 0x6a, 0x94, 0xd3, 0xcd, 0x6e, 0xcb, 0x98, 0x20, 0xe9, 0xd4, 0x9a, 0xfe, 0xed, + 0x66, 0xc4, 0x7e, 0x6f, 0xc2, 0x43, 0xea, 0xbe, 0xbb, 0xcb, 0x0b, 0x02, 0x45, 0x38, 0x77, 0xf5, + 0xac, 0x5d, 0xbf, 0xbd, 0xf8, 0xdb, 0x10, 0x52, 0xa3, 0xc9, 0x94, 0xb2, 0x24, 0xcd, 0x9a, 0xaa, + 0xf5, 0x6b, 0x02, 0x6b, 0xb9, 0xef, 0xa2, 0xe0, 0x13, 0x02, 0xb3, 0x64, 0x01, 0xab, 0x64, 0x94, + 0xe7, 0x01, 0x8d, 0x6e, 0x5b, 0x57, 0x3b, 0xd3, 0x8b, 0xce, 0xf0, 0x23, 0xb1, 0xfc, 0x92, 0x94, + 0x6b, 0xbc, 0xa0, 0x20, 0x9c, 0xa5, 0xfa, 0x92, 0x6b, 0x49, 0x70, 0xb1, 0x00, 0x91, 0x03, 0x64, + 0x5c, 0xb1, 0xfc, 0xfe, 0x55, 0x23, 0x11, 0xff, 0x73, 0x05, 0x58, 0x98, 0x43, 0x70, 0x03, 0x8f, + 0xd2, 0xcc, 0xe2, 0xa9, 0x1f, 0xc7, 0x4d, 0x6f, 0x3e, 0x3e, 0xa9, 0xf8, 0x43, 0xee, 0xd3, 0x56, + 0xf6, 0xf8, 0x2d, 0x35, 0xd0, 0x3b, 0xc2, 0x4b, 0x81, 0xb5, 0x8c, 0xeb, 0x1a, 0x43, 0xec, 0x94, + 0x37, 0xe6, 0xf1, 0xe5, 0x0e, 0xb6, 0xf5, 0x55, 0xe3, 0x21, 0xfd, 0x67, 0xc8, 0x33, 0x2e, 0xb1, + 0xb8, 0x32, 0xaa, 0x8d, 0x79, 0x5a, 0x27, 0xd4, 0x79, 0xc6, 0xe2, 0x7d, 0x5a, 0x61, 0x03, 0x46, + 0x83, 0x89, 0x19, 0x03, 0xf6, 0x64, 0x21, 0xd0, 0x94, 0xe1, 0xb0, 0x0a, 0x9a, 0x13, 0x8d, 0x86, + 0x1e, 0x6f, 0x78, 0xa2, 0x0a, 0xd3, 0xe1, 0x58, 0x00, 0x54, 0xd2, 0xe3, 0x05, 0x25, 0x3c, 0x71, + 0x3a, 0x02, 0xfe, 0x1e, 0x28, 0xde, 0xee, 0x73, 0x36, 0x24, 0x6f, 0x6a, 0xe3, 0x43, 0x31, 0x80, + 0x6b, 0x46, 0xb4, 0x7b, 0x83, 0x3c, 0x39, 0xb9, 0xd3, 0x1c, 0xd3, 0x00, 0xc2, 0xa6, 0xed, 0x83, + 0x13, 0x99, 0x77, 0x6d, 0x07, 0xf5, 0x70, 0xea, 0xf0, 0x05, 0x9a, 0x2c, 0x68, 0xa5, 0xf3, 0xae, + 0x16, 0xb6, 0x17, 0x40, 0x4a, 0xf7, 0xb7, 0x23, 0x1a, 0x4d, 0x94, 0x27, 0x58, 0xfc, 0x02, 0x0b, + 0x3f, 0x23, 0xee, 0x8c, 0x15, 0xe3, 0x60, 0x44, 0xcf, 0xd6, 0x7c, 0xd6, 0x40, 0x99, 0x3b, 0x16, + 0x20, 0x75, 0x97, 0xfb, 0xf3, 0x85, 0xea, 0x7a, 0x4d, 0x99, 0xe8, 0xd4, 0x56, 0xff, 0x83, 0xd4, + 0x1f, 0x7b, 0x8b, 0x4f, 0x06, 0x9b, 0x02, 0x8a, 0x2a, 0x63, 0xa9, 0x19, 0xa7, 0x0e, 0x3a, 0x10, + 0xe3, 0x08, // encrypted cert 0x41, // encrypted data type (Certificate) 0x58, 0xfa, 0xa5, 0xba, 0xfa, 0x30, 0x18, 0x6c, // auth tag 0x6b, 0x2f, 0x23, 0x8e, 0xb5, 0x30, 0xc7, 0x3e, // auth tag + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x01, 0x19, // application data len + 0x73, 0x71, 0x9f, 0xce, 0x07, 0xec, 0x2f, 0x6d, 0x3b, 0xba, 0x02, 0x92, 0xa0, 0xd4, 0x0b, 0x27, + 0x70, 0xc0, 0x6a, 0x27, 0x17, 0x99, 0xa5, 0x33, 0x14, 0xf6, 0xf7, 0x7f, 0xc9, 0x5c, 0x5f, 0xe7, + 0xb9, 0xa4, 0x32, 0x9f, 0xd9, 0x54, 0x8c, 0x67, 0x0e, 0xbe, 0xea, 0x2f, 0x2d, 0x5c, 0x35, 0x1d, + 0xd9, 0x35, 0x6e, 0xf2, 0xdc, 0xd5, 0x2e, 0xb1, 0x37, 0xbd, 0x3a, 0x67, 0x65, 0x22, 0xf8, 0xcd, + 0x0f, 0xb7, 0x56, 0x07, 0x89, 0xad, 0x7b, 0x0e, 0x3c, 0xab, 0xa2, 0xe3, 0x7e, 0x6b, 0x41, 0x99, + 0xc6, 0x79, 0x3b, 0x33, 0x46, 0xed, 0x46, 0xcf, 0x74, 0x0a, 0x9f, 0xa1, 0xfe, 0xc4, 0x14, 0xdc, + 0x71, 0x5c, 0x41, 0x5c, 0x60, 0xe5, 0x75, 0x70, 0x3c, 0xe6, 0xa3, 0x4b, 0x70, 0xb5, 0x19, 0x1a, + 0xa6, 0xa6, 0x1a, 0x18, 0xfa, 0xff, 0x21, 0x6c, 0x68, 0x7a, 0xd8, 0xd1, 0x7e, 0x12, 0xa7, 0xe9, + 0x99, 0x15, 0xa6, 0x11, 0xbf, 0xc1, 0xa2, 0xbe, 0xfc, 0x15, 0xe6, 0xe9, 0x4d, 0x78, 0x46, 0x42, + 0xe6, 0x82, 0xfd, 0x17, 0x38, 0x2a, 0x34, 0x8c, 0x30, 0x10, 0x56, 0xb9, 0x40, 0xc9, 0x84, 0x72, + 0x00, 0x40, 0x8b, 0xec, 0x56, 0xc8, 0x1e, 0xa3, 0xd7, 0x21, 0x7a, 0xb8, 0xe8, 0x5a, 0x88, 0x71, + 0x53, 0x95, 0x89, 0x9c, 0x90, 0x58, 0x7f, 0x72, 0xe8, 0xdd, 0xd7, 0x4b, 0x26, 0xd8, 0xed, 0xc1, + 0xc7, 0xc8, 0x37, 0xd9, 0xf2, 0xeb, 0xbc, 0x26, 0x09, 0x62, 0x21, 0x90, 0x38, 0xb0, 0x56, 0x54, + 0xa6, 0x3a, 0x0b, 0x12, 0x99, 0x9b, 0x4a, 0x83, 0x06, 0xa3, 0xdd, 0xcc, 0x0e, 0x17, 0xc5, 0x3b, + 0xa8, 0xf9, 0xc8, 0x03, 0x63, 0xf7, 0x84, 0x13, 0x54, 0xd2, 0x91, 0xb4, 0xac, 0xe0, 0xc0, 0xf3, + 0x30, 0xc0, 0xfc, 0xd5, 0xaa, 0x9d, 0xee, 0xf9, 0x69, 0xae, 0x8a, 0xb2, 0xd9, 0x8d, 0xa8, 0x8e, + 0xbb, 0x6e, 0xa8, 0x0a, 0x3a, 0x11, 0xf0, 0x0e, // encrypted signature_verify + 0xa2, // encrypted data type (SignatureVerify) + 0x96, 0xa3, 0x23, 0x23, 0x67, 0xff, 0x07, 0x5e, // auth tag + 0x1c, 0x66, 0xdd, 0x9c, 0xbe, 0xdc, 0x47, 0x13, // auth tag + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x00, 0x45, // application data len + 0x10, 0x61, 0xde, 0x27, 0xe5, 0x1c, 0x2c, 0x9f, 0x34, 0x29, 0x11, 0x80, 0x6f, 0x28, 0x2b, 0x71, + 0x0c, 0x10, 0x63, 0x2c, 0xa5, 0x00, 0x67, 0x55, 0x88, 0x0d, 0xbf, 0x70, 0x06, 0x00, 0x2d, 0x0e, + 0x84, 0xfe, 0xd9, 0xad, 0xf2, 0x7a, 0x43, 0xb5, 0x19, 0x23, 0x03, 0xe4, 0xdf, 0x5c, 0x28, 0x5d, + 0x58, 0xe3, 0xc7, 0x62, + 0x24, // encrypted data type (finished) + 0x07, 0x84, 0x40, 0xc0, 0x74, 0x23, 0x74, 0x74, // auth tag + 0x4a, 0xec, 0xf2, 0x8c, 0xf3, 0x18, 0x2f, 0xd0, // auth tag } ; try std.testing.expectEqualSlices(u8, &expected, buf); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index be3d83ee266a..3d43657cbf1e 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -444,16 +444,14 @@ pub fn Client(comptime StreamType: type) type { }, }; - self.stream.version = .tls_1_0; - try self.stream.write(tls.ClientHello, hello); + try self.stream.write(tls.Handshake, .{ .client_hello = hello }); try self.stream.flush(); } pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { - try self.stream.readFragment(.server_hello); + self.stream.handshake_type = .server_hello; + try self.stream.readFragment(); - // TODO: check spec to see if we should verify this - _ = try self.stream.read(u24); // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. _ = try self.stream.read(tls.Version); const random = try self.stream.readAll(32); @@ -540,7 +538,8 @@ pub fn Client(comptime StreamType: type) type { self.stream.content_type = .application_data; self.stream.handshake_cipher.?.print(); - try self.stream.readFragment(.encrypted_extensions); + self.stream.handshake_type = .encrypted_extensions; + try self.stream.readFragment(); iter = try self.stream.extensions(); while (try iter.next()) |ext| { _ = try self.stream.readAll(ext.len); diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index ceb4b989fb3e..90c850b68a80 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -161,19 +161,16 @@ pub fn Server(comptime StreamType: type) type { session_id: [32]u8, cipher_suite: tls.CipherSuite, key_share: tls.KeyShare, + sig_scheme: ?tls.SignatureScheme, }; pub fn recv_hello(self: *Self) !ClientHello { - try self.stream.readFragment(.client_hello); + try self.stream.readFragment(); - // TODO: verify this - const msg_len = try self.stream.read(u24); - std.debug.print("msg_len {d}\n", .{msg_len}); - // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. _ = try self.stream.read(tls.Version); const client_random = try self.stream.readAll(32); const session_id = try self.stream.readSmallArray(u8); - if (session_id.len != 32) return error.TlsUnexpectedMessage; + if (session_id.len > tls.ClientHello.session_id_max_len) return error.TlsUnexpectedMessage; var selected_suite: ?tls.CipherSuite = null; @@ -197,6 +194,7 @@ pub fn Server(comptime StreamType: type) type { var tls_version: ?tls.Version = null; var key_share: ?tls.KeyShare = null; var ec_point_format: ?tls.EcPointFormat = null; + var sig_scheme: ?tls.SignatureScheme = null; var extension_iter = try self.stream.extensions(); while (try extension_iter.next()) |ext| { @@ -213,13 +211,11 @@ pub fn Server(comptime StreamType: type) type { .key_share => { if (key_share != null) return error.TlsUnexpectedMessage; - var key_share_iter = try self.stream.iterator(tls.KeyShare.Header); + var key_share_iter = try self.stream.iterator(tls.KeyShare); while (try key_share_iter.next()) |ks| { - const key = try self.stream.readAll(ks.len); - if (ks.group == .x25519) { - key_share = .{ .x25519 = undefined }; - if (ks.len != key_share.?.keyLen(true)) return error.TlsUnexpectedMessage; - @memcpy(&key_share.?.x25519, key); + switch (ks) { + .x25519 => key_share = ks, + else => {}, } } }, @@ -229,6 +225,12 @@ pub fn Server(comptime StreamType: type) type { if (f == .uncompressed) ec_point_format = .uncompressed; } }, + .signature_algorithms => { + var algos_iter = try self.stream.iterator(tls.SignatureScheme); + while (try algos_iter.next()) |algo| { + if (algo == .rsa_pss_rsae_sha256) sig_scheme = algo; + } + }, else => { _ = try self.stream.readAll(ext.len); }, @@ -244,6 +246,7 @@ pub fn Server(comptime StreamType: type) type { .session_id = session_id[0..32].*, .cipher_suite = selected_suite.?, .key_share = key_share.?, + .sig_scheme = sig_scheme, }; } @@ -259,11 +262,11 @@ pub fn Server(comptime StreamType: type) type { }, }; self.stream.version = .tls_1_2; - try self.stream.write(tls.ServerHello, hello); + _ = try self.stream.write(tls.Handshake, .{ .server_hello = hello }); try self.stream.flush(); self.stream.content_type = .change_cipher_spec; - try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); + _ = try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); try self.stream.flush(); const shared_key = switch (client_hello.key_share) { @@ -302,19 +305,66 @@ pub fn Server(comptime StreamType: type) type { self.stream.handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, &hello_hash); self.stream.handshake_cipher.?.print(); - const extensions = tls.EncryptedExtensions{ .extensions = &.{} }; self.stream.content_type = .handshake; - try self.stream.write(tls.EncryptedExtensions, extensions); + _ = try self.stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); try self.stream.flush(); - try self.stream.write(tls.Certificate, self.options.certificate); + _ = try self.stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + try self.stream.flush(); + + // RFC 8446 S4.4.3 + // const signature_content = [_]u8{0x20} ** 64 + // ++ "TLS 1.3, server CertificateVerify\x00".* + // ++ self.stream.transcript_hash.peek() + // ; + + // const cert = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; + // const parsed = try cert.parse(); + // const pub_key = parsed.pubKey(); + + // switch (client_hello.sig_scheme) { + // .rsa_pss_rsae_sha256 => { + // const rsa = Certificate.rsa; + // const components = try rsa.PublicKey.parseDer(pub_key); + // const exponent = components.exponent; + // const modulus = components.modulus; + // switch (modulus.len) { + // inline 128, 256, 512 => |modulus_len| { + // const key = try rsa.PublicKey.fromBytes(exponent, modulus); + // const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); + // try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); + // }, + // else => { + // return error.TlsBadRsaSignatureBitCount; + // }, + // } + // }, + // else => {} + // } + } + + pub fn send_handshake_finish(self: *Self) !void { + const secret = self.stream.handshake_cipher.?.aes_256_gcm_sha384.server_finished_key; + const transcript_hash = self.stream.transcript_hash.peek(); + tls.debugPrint("peek", transcript_hash); + const verify = switch (self.stream.handshake_cipher.?) { + inline + .aes_256_gcm_sha384, + => |v| brk: { + const T = @TypeOf(v); + break :brk tls.hmac(T.Hmac, &transcript_hash, secret); + }, + else => return error.TlsDecryptFailure, + }; + tls.debugPrint("verify", verify); + _ = try self.stream.write(tls.Handshake, .{ .finished = &verify }); try self.stream.flush(); } }; } pub const Options = struct { - /// List of potential cipher suites in order of descending preference. + /// List of potential cipher suites in descending order of preference. cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, certificate: tls.Certificate, }; From ba9facf550516e74f99198fee7277c9da0ad4168 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Mon, 11 Mar 2024 20:40:34 -0400 Subject: [PATCH 03/17] todo 1-5 --- TODO | 18 +- lib/std/crypto/Certificate.zig | 4 +- lib/std/crypto/tls.zig | 1206 +++++++++++++++++++------------- lib/std/crypto/tls/Client.zig | 602 ++++------------ lib/std/crypto/tls/Server.zig | 239 ++----- 5 files changed, 925 insertions(+), 1144 deletions(-) diff --git a/TODO b/TODO index 7ebc5937c919..431f6659a5c3 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,14 @@ -store multiple fragments in buffer for less syscalls -remove @panic's -server dynamic transcript hash based on client cipher suites -send alert on error +[x] 1. single transcript hash type. move out of stream. +[x] 2. read backpressure, smaller stream buffer +[x] 3. Client recv_hello secp256r1 key share +[x] 4. remove @panic's +[x] 5. better errors than spammy TlsDecodeError. map new errors to TLS alerts. send alert on error. +6. verify certs and sigs +7. KeyShare kyber read +8. StreamInterface `readv` instead of `readAll` + +1. benchmark +2. store multiple fragments in buffer for less syscalls +3. streaming encode + decode +4. store handshake info (transcript_hash, handshake_type, handshake_cipher) somewhere temporary + diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 0d932fcafb98..decd59571696 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -705,7 +705,7 @@ fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) ParseEnu return E.map.get(oid_bytes) orelse return error.CertificateHasUnrecognizedObjectId; } -pub const ParseVersionError = error{ UnsupportedCertificateVersion, CertificateFieldHasInvalidLength }; +pub const ParseVersionError = error{ CertificateUnsupportedVersion, CertificateFieldHasInvalidLength }; pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionError!Version { if (@as(u8, @bitCast(version_elem.identifier)) != 0xa0) @@ -724,7 +724,7 @@ pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionEr return .v1; } - return error.UnsupportedCertificateVersion; + return error.CertificateUnsupportedVersion; } fn verifyRsa( diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index c312993cf2e4..67311da13ef7 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -82,7 +82,7 @@ pub const Handshake = union(HandshakeType) { server_hello: ServerHello, /// Deprecated. hello_verify_request: void, - new_session_ticket: NewSessionTicket, + new_session_ticket: void, end_of_early_data: void, /// Deprecated. hello_retry_request: void, @@ -90,7 +90,7 @@ pub const Handshake = union(HandshakeType) { certificate: Certificate, /// Deprecated. server_key_exchange: void, - certificate_request: CertificateRequest, + certificate_request: void, /// Deprecated. server_hello_done: void, certificate_verify: CertificateVerify, @@ -103,7 +103,7 @@ pub const Handshake = union(HandshakeType) { certificate_status: void, /// Deprecated. supplemental_data: void, - key_update: KeyUpdate, + key_update: void, message_hash: void, // If `HandshakeCipherT.encode` accepts iovecs for the message this can be moved @@ -113,7 +113,7 @@ pub const Handshake = union(HandshakeType) { res += try stream.write(HandshakeType, self); switch (self) { .finished => |verification| { - res += try stream.writeArray(3, u8, verification); + res += try stream.writeArray(u24, u8, verification); }, inline else => |value| { var len: usize = 0; @@ -123,12 +123,12 @@ pub const Handshake = union(HandshakeType) { res += try stream.write(u24, @intCast(len)); }, .Pointer => |info| { - len += try stream.arrayLength(2, info.child, value); + len += stream.arrayLength(u16, info.child, value); res += try stream.write(u24, @intCast(len)); - res += try stream.writeArray(2, info.child, value); + res += try stream.writeArray(u16, info.child, value); }, .Struct => { - len += try stream.length(T, value); + len += stream.length(T, value); res += try stream.write(u24, @intCast(len)); res += try stream.write(T, value); }, @@ -140,47 +140,22 @@ pub const Handshake = union(HandshakeType) { } }; -pub const NewSessionTicket = struct { - ticket_lifetime: u32, - ticket_age_add: u32, - /// max len 255 - ticket_nonce: []const u8, - /// Should have at least one - ticket: []const u8, - extensions: []const Extension, - - pub fn write(self: @This(), stream: anytype) !usize { - _ = .{ self, stream }; - @panic("TODO"); - } -}; - -pub const CertificateRequest = struct { - /// Max len 255 - context: []const u8, - /// At least 2 - extensions: []const Extension, - - pub fn write(self: @This(), stream: anytype) !usize { - _ = .{ self, stream }; - @panic("TODO"); - } -}; - pub const Certificate = struct { - /// Max len 255 context: []const u8 = "", entries: []const Entry, + pub const max_context_len = 255; + pub const Entry = struct { - /// Either ASN1_subjectPublicKeyInfo or cert_data based on CertificateType + /// Either ASN1_subjectPublicKeyInfo or cert_data based on CertificateType. + /// Max len 2^24-1 data: []const u8, extensions: []const Extension = &.{}, pub fn write(self: @This(), stream: anytype) !usize { var res: usize = 0; - res += try stream.writeArray(3, u8, self.data); - res += try stream.writeArray(2, Extension, self.extensions); + res += try stream.writeArray(u24, u8, self.data); + res += try stream.writeArray(u16, Extension, self.extensions); return res; } }; @@ -189,8 +164,8 @@ pub const Certificate = struct { pub fn write(self: Self, stream: anytype) !usize { var res: usize = 0; - res += try stream.writeArray(1, u8, self.context); - res += try stream.writeArray(3, Entry, self.entries); + res += try stream.writeArray(u8, u8, self.context); + res += try stream.writeArray(u24, Entry, self.entries); return res; } }; @@ -203,26 +178,11 @@ pub const CertificateVerify = struct { pub fn write(self: @This(), stream: anytype) !usize { var res: usize = 0; res += try stream.write(SignatureScheme, self.algorithm); - res += try stream.writeArray(2, u8, self.signature); + res += try stream.writeArray(u16, u8, self.signature); return res; } }; -pub const KeyUpdate = struct { - request: Request, - - pub const Request = enum(u8) { - update_not_requested = 0, - update_requested = 1, - _, - }; - - pub fn write(self: @This(), stream: anytype) !usize { - _ = .{ self, stream }; - @panic("TODO"); - } -}; - // https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml pub const ExtensionType = enum(u16) { /// RFC 6066 @@ -282,6 +242,114 @@ pub const ExtensionType = enum(u16) { _, }; +pub const Error = error{ + /// An inappropriate message (e.g., the wrong + /// handshake message, premature Application Data, etc.) was received. + /// This alert should never be observed in communication between + /// proper implementations. + TlsUnexpectedMessage, + /// This alert is returned if a record is received which + /// cannot be deprotected. Because AEAD algorithms combine decryption + /// and verification, and also to avoid side-channel attacks, this + /// alert is used for all deprotection failures. This alert should + /// never be observed in communication between proper implementations, + /// except when messages were corrupted in the network. + TlsBadRecordMac, + /// A TLSCiphertext record was received that had a + /// length more than 2^14 + 256 bytes, or a record decrypted to a + /// TLSPlaintext record with more than 2^14 bytes (or some other + /// negotiated limit). This alert should never be observed in + /// communication between proper implementations, except when messages + /// were corrupted in the network. + TlsRecordOverflow, + /// Receipt of a "handshake_failure" alert message + /// indicates that the sender was unable to negotiate an acceptable + /// set of security parameters given the options available. + TlsHandshakeFailure, + /// A certificate was corrupt, contained signatures + /// that did not verify correctly, etc. + TlsBadCertificate, + /// A certificate was of an unsupported type. + TlsUnsupportedCertificate, + /// A certificate was revoked by its signer. + TlsCertificateRevoked, + /// A certificate has expired or is not currently valid. + TlsCertificateExpired, + /// Some other (unspecified) issue arose in processing the certificate, rendering it unacceptable. + TlsCertificateUnknown, + /// A field in the handshake was incorrect or + /// inconsistent with other fields. This alert is used for errors + /// which conform to the formal protocol syntax but are otherwise + /// incorrect. + TlsIllegalParameter, + /// A valid certificate chain or partial chain was received, + /// but the certificate was not accepted because the CA certificate + /// could not be located or could not be matched with a known trust + /// anchor. + TlsUnknownCa, + /// A valid certificate or PSK was received, but when + /// access control was applied, the sender decided not to proceed with + /// negotiation. + TlsAccessDenied, + /// A message could not be decoded because some field was + /// out of the specified range or the length of the message was + /// incorrect. This alert is used for errors where the message does + /// not conform to the formal protocol syntax. This alert should + /// never be observed in communication between proper implementations, + /// except when messages were corrupted in the network. + TlsDecodeError, + /// A handshake (not record layer) cryptographic + /// operation failed, including being unable to correctly verify a + /// signature or validate a Finished message or a PSK binder. + TlsDecryptError, + /// The protocol version the peer has attempted to + /// negotiate is recognized but not supported (see Appendix D). + TlsProtocolVersion, + /// Returned instead of "handshake_failure" when + /// a negotiation has failed specifically because the server requires + /// parameters more secure than those supported by the client. + TlsInsufficientSecurity, + /// An internal error unrelated to the peer or the + /// correctness of the protocol (such as a memory allocation failure) + /// makes it impossible to continue. + TlsInternalError, + /// Sent by a server in response to an invalid + /// connection retry attempt from a client (see [RFC7507]). + TlsInappropriateFallback, + /// Sent by endpoints that receive a handshake + /// message not containing an extension that is mandatory to send for + /// the offered TLS version or other negotiated parameters. + TlsMissingExtension, + /// Sent by endpoints receiving any handshake + /// message containing an extension known to be prohibited for + /// inclusion in the given handshake message, or including any + /// extensions in a ServerHello or Certificate not first offered in + /// the corresponding ClientHello or CertificateRequest. + TlsUnsupportedExtension, + /// Sent by servers when no server exists identified + /// by the name provided by the client via the "server_name" extension + /// (see [RFC6066]). + TlsUnrecognizedName, + /// Sent by clients when an invalid or + /// unacceptable OCSP response is provided by the server via the + /// "status_request" extension (see [RFC6066]). + TlsBadCertificateStatusResponse, + /// Sent by servers when PSK key establishment is + /// desired but no acceptable PSK identity is provided by the client. + /// Sending this alert is OPTIONAL; servers MAY instead choose to send + /// a "decrypt_error" alert to merely indicate an invalid PSK + /// identity. + TlsUnknownPskIdentity, + /// Sent by servers when a client certificate is + /// desired but none was provided by the client. + TlsCertificateRequired, + /// Sent by servers when a client + /// "application_layer_protocol_negotiation" extension advertises only + /// protocols that the server does not support (see [RFC7301]). + TlsNoApplicationProtocol, + TlsUnknown, +}; + pub const Alert = struct { /// > In TLS 1.3, the severity is implicit in the type of alert being sent /// > and the "level" field can safely be ignored. @@ -294,35 +362,6 @@ pub const Alert = struct { _, }; pub const Description = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, - }; - close_notify = 0, unexpected_message = 10, bad_record_mac = 20, @@ -354,34 +393,33 @@ pub const Alert = struct { pub fn toError(alert: @This()) Error { return switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, + .close_notify, .user_canceled => unreachable, // not an error + .unexpected_message => Error.TlsUnexpectedMessage, + .bad_record_mac => Error.TlsBadRecordMac, + .record_overflow => Error.TlsRecordOverflow, + .handshake_failure => Error.TlsHandshakeFailure, + .bad_certificate => Error.TlsBadCertificate, + .unsupported_certificate => Error.TlsUnsupportedCertificate, + .certificate_revoked => Error.TlsCertificateRevoked, + .certificate_expired => Error.TlsCertificateExpired, + .certificate_unknown => Error.TlsCertificateUnknown, + .illegal_parameter => Error.TlsIllegalParameter, + .unknown_ca => Error.TlsUnknownCa, + .access_denied => Error.TlsAccessDenied, + .decode_error => Error.TlsDecodeError, + .decrypt_error => Error.TlsDecryptError, + .protocol_version => Error.TlsProtocolVersion, + .insufficient_security => Error.TlsInsufficientSecurity, + .internal_error => Error.TlsInternalError, + .inappropriate_fallback => Error.TlsInappropriateFallback, + .missing_extension => Error.TlsMissingExtension, + .unsupported_extension => Error.TlsUnsupportedExtension, + .unrecognized_name => Error.TlsUnrecognizedName, + .bad_certificate_status_response => Error.TlsBadCertificateStatusResponse, + .unknown_psk_identity => Error.TlsUnknownPskIdentity, + .certificate_required => Error.TlsCertificateRequired, + .no_application_protocol => Error.TlsNoApplicationProtocol, + _ => Error.TlsUnknown, }; } }; @@ -553,7 +591,7 @@ pub const KeyPair = union(NamedGroup) { }; } }; -/// The public key portion of a KeyPair. Saves bytes. +/// The public portion of a KeyPair. pub const KeyShare = union(NamedGroup) { invalid: void, secp256r1: NamedGroupT(.secp256r1).PublicKey, @@ -573,20 +611,24 @@ pub const KeyShare = union(NamedGroup) { const Self = @This(); pub fn read(stream: anytype) !Self { + var reader = stream.reader(); const group = try stream.read(NamedGroup); const len = try stream.read(u16); - const key = try stream.readAll(len); switch (group) { // .x25519_kyber768d00 => { // const expected_len = if (stream.is_client) @TypeOf(k).bytes_length else X25519Kyber768Draft.Kyber768.ciphertext_length; // }, inline .secp256r1, .secp384r1 => |k| { - return @unionInit(Self, @tagName(k), try NamedGroupT(k).PublicKey.fromSec1(key)); + const T = NamedGroupT(k).PublicKey; + var buf: [T.compressed_sec1_encoded_length]u8 = undefined; + try reader.readNoEof(&buf); + const val = T.fromSec1(&buf) catch return Error.TlsDecryptError; + return @unionInit(Self, @tagName(k), val); }, .x25519 => { var res = Self{ .x25519 = undefined }; - if (res.x25519.len != key.len) return error.TlsDecodeError; - @memcpy(&res.x25519, key); + if (res.x25519.len != len) return Error.TlsDecodeError; + try reader.readNoEof(&res.x25519); return res; }, else => {}, @@ -604,7 +646,7 @@ pub const KeyShare = union(NamedGroup) { .x25519 => |k| &k, else => "", }; - res += try stream.writeArray(2, u8, public); + res += try stream.writeArray(u16, u8, public); return res; } }; @@ -658,11 +700,8 @@ pub const HandshakeCipher = union(CipherSuite) { const Self = @This(); - pub fn init(suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8) Self { - debugPrint("hello_hash", hello_hash); - debugPrint("shared_key", shared_key); + pub fn init(suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8) !Self { switch (suite) { - else => unreachable, inline .aes_128_gcm_sha256, .aes_256_gcm_sha384, .chacha20_poly1305_sha256, @@ -672,12 +711,12 @@ pub const HandshakeCipher = union(CipherSuite) { var res = @unionInit(Self, @tagName(tag), .{ .handshake_secret = undefined, .master_secret = undefined, - .client_handshake_key = undefined, - .server_handshake_key = undefined, .client_finished_key = undefined, .server_finished_key = undefined, - .client_handshake_iv = undefined, - .server_handshake_iv = undefined, + .client_key = undefined, + .server_key = undefined, + .client_iv = undefined, + .server_iv = undefined, }); const P = std.meta.TagPayloadByName(Self, @tagName(tag)); const p = &@field(res, @tagName(tag)); @@ -686,81 +725,23 @@ pub const HandshakeCipher = union(CipherSuite) { const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); const empty_hash = emptyHash(P.Hash); - const derived_secret = hkdfExpandLabel( - P.Hkdf, - early_secret, - "derived", - &empty_hash, - P.Hash.digest_length, - ); + const derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); p.handshake_secret = P.Hkdf.extract(&derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel( - P.Hkdf, - p.handshake_secret, - "derived", - &empty_hash, - P.Hash.digest_length, - ); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel( - P.Hkdf, - p.handshake_secret, - "c hs traffic", - hello_hash, - P.Hash.digest_length, - ); - const server_secret = hkdfExpandLabel( - P.Hkdf, - p.handshake_secret, - "s hs traffic", - hello_hash, - P.Hash.digest_length, - ); - p.client_finished_key = hkdfExpandLabel( - P.Hkdf, - client_secret, - "finished", - "", - P.Hmac.key_length, - ); - p.server_finished_key = hkdfExpandLabel( - P.Hkdf, - server_secret, - "finished", - "", - P.Hmac.key_length, - ); - p.client_handshake_key = hkdfExpandLabel( - P.Hkdf, - client_secret, - "key", - "", - P.AEAD.key_length, - ); - p.server_handshake_key = hkdfExpandLabel( - P.Hkdf, - server_secret, - "key", - "", - P.AEAD.key_length, - ); - p.client_handshake_iv = hkdfExpandLabel( - P.Hkdf, - client_secret, - "iv", - "", - P.AEAD.nonce_length, - ); - p.server_handshake_iv = hkdfExpandLabel( - P.Hkdf, - server_secret, - "iv", - "", - P.AEAD.nonce_length, - ); + const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", hello_hash, P.Hash.digest_length); + p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); return res; }, + .empty_renegotiation_info_scsv => return .{ .empty_renegotiation_info_scsv = {} }, + _ => return error.TlsIllegalParameter, } } @@ -779,6 +760,52 @@ pub const ApplicationCipher = union(CipherSuite) { aegis_256_sha512: ApplicationCipherT(.aegis_256_sha512), aegis_128l_sha256: ApplicationCipherT(.aegis_128l_sha256), empty_renegotiation_info_scsv: void, + + const Self = @This(); + + pub fn init(handshake_cipher: HandshakeCipher, handshake_hash: []const u8) Self { + switch (handshake_cipher) { + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |c, tag| { + var res = @unionInit(Self, @tagName(tag), .{ + .client_secret = undefined, + .server_secret = undefined, + .client_key = undefined, + .server_key = undefined, + .client_iv = undefined, + .server_iv = undefined, + }); + const P = std.meta.TagPayloadByName(Self, @tagName(tag)); + const p = &@field(res, @tagName(tag)); + + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const empty_hash = emptyHash(P.Hash); + + const derived_secret = hkdfExpandLabel(P.Hkdf, c.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + const master_secret = P.Hkdf.extract(&derived_secret, &zeroes); + p.client_secret = hkdfExpandLabel(P.Hkdf, master_secret, "c ap traffic", handshake_hash, P.Hash.digest_length); + p.server_secret = hkdfExpandLabel(P.Hkdf, master_secret, "s ap traffic", handshake_hash, P.Hash.digest_length); + p.client_key = hkdfExpandLabel(P.Hkdf, p.client_secret, "key", "", P.AEAD.key_length); + p.server_key = hkdfExpandLabel(P.Hkdf, p.server_secret, "key", "", P.AEAD.key_length); + p.client_iv = hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); + p.server_iv = hkdfExpandLabel(P.Hkdf, p.server_secret, "iv", "", P.AEAD.nonce_length); + + return res; + }, + .empty_renegotiation_info_scsv => unreachable, + } + } + + pub fn print(self: Self) void { + switch (self) { + .empty_renegotiation_info_scsv => {}, + inline else => |v| v.print(), + } + } }; /// RFC 8446 S4.1.2 @@ -803,10 +830,10 @@ pub const ClientHello = struct { var res: usize = 0; res += try stream.write(Version, self.version); res += try stream.writeAll(&self.random); - res += try stream.writeArray(1, u8, self.session_id); - res += try stream.writeArray(2, CipherSuite, self.cipher_suites); - res += try stream.writeArray(1, u8, &self.compression_methods); - res += try stream.writeArray(2, Extension, self.extensions); + res += try stream.writeArray(u8, u8, self.session_id); + res += try stream.writeArray(u16, CipherSuite, self.cipher_suites); + res += try stream.writeArray(u8, u8, &self.compression_methods); + res += try stream.writeArray(u16, Extension, self.extensions); return res; } }; @@ -835,10 +862,10 @@ pub const ServerHello = struct { var res: usize = 0; res += try stream.write(Version, self.version); res += try stream.writeAll(&self.random); - res += try stream.writeArray(1, u8, self.session_id); + res += try stream.writeArray(u8, u8, self.session_id); res += try stream.write(CipherSuite, self.cipher_suite); res += try stream.write(u8, self.compression_method); - res += try stream.writeArray(2, Extension, self.extensions); + res += try stream.writeArray(u16, Extension, self.extensions); return res; } }; @@ -849,7 +876,7 @@ pub const EncryptedExtensions = struct { const Self = @This(); pub fn write(self: Self, stream: anytype) !usize { - return try stream.writeArray(2, Extension, self.extensions); + return try stream.writeArray(u16, Extension, self.extensions); } }; @@ -887,32 +914,43 @@ pub const Extension = union(ExtensionType) { const Self = @This(); pub fn write(self: Self, stream: anytype) !usize { - const prefix_len: u8 = if (stream.is_client) switch (self) { + const PrefixLen = enum { zero, one, two }; + const prefix_len: PrefixLen = if (stream.is_client) switch (self) { .supported_versions, .ec_point_formats, - .psk_key_exchange_modes => 1, + .psk_key_exchange_modes => .one, .server_name, .supported_groups, .signature_algorithms, - .key_share => 2, - else => 0, - } else 0; + .key_share => .two, + else => .zero, + } else .zero; var res: usize = 0; res += try stream.write(ExtensionType, self); switch (self) { inline else => |items| { - switch (@typeInfo(@TypeOf(items))) { + const T = @TypeOf(items); + switch (@typeInfo(T)) { .Void => { res += try stream.write(u16, 0); }, .Pointer => |info| { - const len = try stream.arrayLength(prefix_len, info.child, items); - res += try stream.write(u16, @intCast(len)); - res += try stream.writeArray(prefix_len, info.child, items); + switch (prefix_len) { + inline else => |t| { + const PrefixT = switch (t) { + .zero => void, + .one => u8, + .two => u16, + }; + const len = stream.arrayLength(PrefixT, info.child, items); + res += try stream.write(u16, @intCast(len)); + res += try stream.writeArray(PrefixT, info.child, items); + } + } }, - else => |t| @compileError("implement writing " ++ @tagName(t)), + else => |t| @compileError("unsupported type " ++ @typeName(T) ++ " for member " ++ @tagName(t)), } }, } @@ -923,7 +961,7 @@ pub const Extension = union(ExtensionType) { type: ExtensionType, len: u16, - pub fn read(stream: anytype) !@This() { + pub fn read(stream: anytype) @TypeOf(stream.*).ReadError!@This() { const ty = try stream.read(ExtensionType); const length = try stream.read(u16); return .{ .type = ty, .len = length }; @@ -953,7 +991,7 @@ pub const ServerName = struct { pub fn write(self: @This(), stream: anytype) !usize { var res: usize = 0; res += try stream.write(NameType, self.type); - res += try stream.writeArray(2, u8, self.host_name); + res += try stream.writeArray(u16, u8, self.host_name); return res; } }; @@ -1005,20 +1043,15 @@ pub const StreamInterface = struct { } }; -/// Abstraction over TLS record layer that handles fragmentation (RFC 8446 S5). -/// It also encrypts and decrypts .application_data messages. -/// This makes it suitable for both clients and servers. -/// -/// StreamType MUST satisfy `StreamInterface`. -/// StreamType MUST satisfy: -/// * fn update(self: @This(), bytes: []const u8) void -/// * fn peek(self: @This()) [_]u8 +/// Abstraction over TLS record layer (RFC 8446 S5). StreamType MUST satisfy `StreamInterface`. /// Cannot read and write at the same time. -pub fn Stream( - comptime fragment_size: usize, - comptime StreamType: type, - comptime TranscriptHash: type, -) type { +/// +/// Handles: +/// * Fragmentation +/// * Encryption and decryption of handshake and application data messages +/// * Reading and writing prefix length arrays +/// * TLS Alerts +pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { // TODO: Support RFC 6066 MaxFragmentLength and give fragment_size option to Client+Server. if (fragment_size > std.math.maxInt(u16)) @compileError("choose a smaller fragment_size"); @@ -1031,16 +1064,14 @@ pub fn Stream( /// > EncryptedExtensions, server CertificateRequest, server Certificate, /// > server CertificateVerify, server Finished, EndOfEarlyData, client /// > Certificate, client CertificateVerify, client Finished. - transcript_hash: TranscriptHash, - /// Used for both reading and writing. Cannot be doing both at the same time. Must be twice - /// fragment size to handle `readAll(fragment_size)`. In practice this is only approachable for - /// the SNI hostname which may be up to 8KB. - buffer: [fragment_size * 2]u8 = undefined, - /// Unflushed part of `buffer`. + transcript_hash: MultiHash = .{}, + /// Used for both reading and writing. Cannot be doing both at the same time. + buffer: [fragment_size]u8 = undefined, + /// Unread or unwritten view of `buffer`. view: []const u8 = "", /// When sending this is the record type that will be sent. - /// If a cipher is in use it will be encrypted in `inner_content_type`. + /// When receiving this is the next fragment's expected record type. content_type: ContentType = .handshake, /// When receiving fragments this is the next expected fragment type. @@ -1051,71 +1082,27 @@ pub fn Stream( /// Used to encrypt and decrypt .application_data messages. application_cipher: ?ApplicationCipher = null, - /// True when we send or receive `close_notify` + /// True when we send or receive a close_notify alert. closed: bool = false, /// Version to send out in record headers version: Version = .tls_1_0, - /// True if we're being used as a client. Certain shared struct types serialize differently - /// based on this. + /// True if we're being used as a client. This changes: + /// * Certain shared struct formats (like Extension) + /// * Which keys are used for encoding/decoding handshake and application messages. is_client: bool, /// When > 0 won't actually do anything with writes. - /// This is to discover prefix lengths for sequential writing. - /// It would be nice to write in reverse sequence, - /// but the spec defines fragments as being sent in forward sequence. + /// This is to discover prefix lengths for record level spec-adherant sequential writing. nocommit: usize = 0, const Self = @This(); - pub const ReadError = StreamType.ReadError || error{ - EndOfStream, - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, + pub const ReadError = StreamType.ReadError || Error || error{EndOfStream}; + pub const WriteError = StreamType.WriteError || error{ + TlsEncodeError, }; - pub const WriteError = StreamType.WriteError; fn ciphertextOverhead(self: Self) usize { if (self.application_cipher) |a| { @@ -1137,42 +1124,69 @@ pub fn Stream( return fragment_size - self.ciphertextOverhead(); } + const EncryptionMethod = enum { none, handshake, application }; + fn encryptionMethod(self: Self) EncryptionMethod { + switch (self.content_type) { + .change_cipher_spec => {}, + .handshake => { + if (self.handshake_cipher != null) return .handshake; + }, + else => { + if (self.application_cipher != null) return .application; + }, + } + + return .none; + } + pub fn flush(self: *Self) WriteError!void { - const aead_overhead = self.ciphertextOverhead(); - const plaintext = Plaintext{ - .type = if (aead_overhead > 0) .application_data else self.content_type, + if (self.view.len == 0) return; + var plaintext = Plaintext{ + .type = self.content_type, .version = self.version, - .length = @intCast(self.view.len + aead_overhead), + .length = @intCast(self.view.len), }; - const header = Encoder.encode(Plaintext, plaintext); if (self.application_cipher == null) { - switch (plaintext.type) { + switch (self.content_type) { .change_cipher_spec, .alert => {}, else => self.transcript_hash.update(self.view), } } + var header: [fieldsLen(Plaintext)]u8 = undefined; var aead: []const u8 = ""; - if (self.application_cipher) |*a| { - switch (a.*) { - .empty_renegotiation_info_scsv => {}, - inline else => |*c| { - std.debug.assert(self.view.ptr == &self.buffer); - self.buffer[self.view.len] = @intFromEnum(self.content_type); - self.view = self.buffer[0..self.view.len + 1]; - aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); - }, - } - } else if (self.handshake_cipher) |*a| { - switch (a.*) { - .empty_renegotiation_info_scsv => {}, - inline else => |*c| { - std.debug.assert(self.view.ptr == &self.buffer); - self.buffer[self.view.len] = @intFromEnum(self.content_type); - self.view = self.buffer[0..self.view.len + 1]; - aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); - }, + switch (self.encryptionMethod()) { + .none => { + header = Encoder.encode(Plaintext, plaintext); + }, + .handshake => { + plaintext.type = .application_data; + plaintext.length += @intCast(self.ciphertextOverhead()); + header = Encoder.encode(Plaintext, plaintext); + if (self.handshake_cipher) |*a| switch (a.*) { + .empty_renegotiation_info_scsv => {}, + inline else => |*c| { + std.debug.assert(self.view.ptr == &self.buffer); + self.buffer[self.view.len] = @intFromEnum(self.content_type); + self.view = self.buffer[0..self.view.len + 1]; + aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); + }, + }; + }, + .application => { + plaintext.type = .application_data; + plaintext.length += @intCast(self.ciphertextOverhead()); + header = Encoder.encode(Plaintext, plaintext); + if (self.application_cipher) |*a| switch (a.*) { + .empty_renegotiation_info_scsv => {}, + inline else => |*c| { + std.debug.assert(self.view.ptr == &self.buffer); + self.buffer[self.view.len] = @intFromEnum(self.content_type); + self.view = self.buffer[0..self.view.len + 1]; + aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); + }, + }; } } @@ -1185,21 +1199,39 @@ pub fn Stream( self.view = self.buffer[0..0]; } - /// Write bytes with backpressure for fragment size. All other write functions end up here. + /// Write an alert to stream and call `close_notify` after. Returns Zig error. + pub fn writeError(self: *Self, err: Alert.Description) Error { + const alert = Alert{ .level = .fatal, .description = err }; + + self.view = self.buffer[0..0]; + self.content_type = .alert; + _ = self.write(Alert, alert) catch {}; + self.flush() catch {}; + + self.close(); + return err.toError(); + } + + pub fn close(self: *Self) void { + const alert = Alert{ .level = .fatal, .description = .close_notify }; + _ = self.write(Alert, alert) catch {}; + self.content_type = .alert; + self.flush() catch {}; + self.closed = true; + } + + /// Write bytes to `stream`, potentially flushing once `self.buffer` is full. pub fn writeBytes(self: *Self, bytes: []const u8) WriteError!usize { if (self.nocommit > 0) return bytes.len; - if (self.view.len + bytes.len >= self.maxFragmentSize()) { - // TODO: copy before flush to consume as many bytes as possible - try self.flush(); - } - const available = self.buffer.len - self.view.len; const to_consume = bytes[0..@min(available, bytes.len)]; @memcpy(self.buffer[self.view.len..][0..bytes.len], to_consume); self.view = self.buffer[0 .. self.view.len + to_consume.len]; + if (self.view.len == self.buffer.len) try self.flush(); + return to_consume.len; } @@ -1211,18 +1243,18 @@ pub fn Stream( return index; } - pub fn writeArray(self: *Self, prefix_bytes: u8, comptime T: type, values: []const T) WriteError!usize { + pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T) WriteError!usize { var res: usize = 0; - for (values) |v| res += try self.length(T, v); - - if (prefix_bytes != 0) { - switch (prefix_bytes) { - 1 => res += try self.write(u8, @intCast(res)), - 2 => res += try self.write(u16, @intCast(res)), - 3 => res += try self.write(u24, @intCast(res)), - else => @panic("unsupported prefix len"), + for (values) |v| res += self.length(T, v); + + if (PrefixT != void) { + if (res > std.math.maxInt(PrefixT)) { + self.close(); + return error.TlsEncodeError; // Prefix length overflow } + res += try self.write(PrefixT, @intCast(res)); } + for (values) |v| _ = try self.write(T, v); return res; @@ -1237,152 +1269,148 @@ pub fn Stream( .Struct, .Union => { return try T.write(value, self); }, - .Void => {}, + .Void => return 0, else => @compileError("cannot write " ++ @typeName(T)), } } - pub fn length(self: *Self, comptime T: type, value: T) WriteError!usize { + pub fn length(self: *Self, comptime T: type, value: T) usize { + if (T == void) return 0; self.nocommit += 1; defer self.nocommit -= 1; - return try self.write(T, value); + return self.write(T, value) catch unreachable; } - pub fn arrayLength(self: *Self, prefix_len: u8, comptime T: type, values: []const T) WriteError!usize { - var res: usize = prefix_len; - for (values) |v| res += try self.length(T, v); + pub fn arrayLength(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T,) usize { + var res: usize = if (PrefixT == void) 0 else @divExact(@typeInfo(PrefixT).Int.bits, 8); + for (values) |v| res += self.length(T, v); return res; } - /// Returns slice that is valid until next `readAll` call. - pub fn readAll(self: *Self, len: usize) ReadError![]const u8 { - if (len >= self.maxFragmentSize()) { - // Only workaround is to use an allocator for self.buffer - return error.TlsRecordOverflow; - } else { - if (len <= self.view.len) { - defer self.view = self.view[len..]; - return self.view[0..len]; - } else { - // We need another fragment. - // Copy last (hopefully small) portion of buffer to start. It may alias. - std.mem.copyForwards(u8, &self.buffer, self.view); - self.view = self.buffer[0..self.view.len]; - try self.readFragment(); - return try self.readAll(len); - } + /// Reads bytes from `view`, potentially reading more fragments from `stream`. + /// A return value of 0 indicates EOF. + pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { + // > Any data received after a closure alert has been received MUST be ignored. + if (self.eof()) return 0; + + var bytes_read: usize = 0; + while (bytes_read != buf.len) { + if (self.view.len == 0) try self.expectFragment(self.content_type, self.handshake_type); + + const to_read = @min(buf.len, self.view.len); + @memcpy(buf[0..to_read], self.view[0..to_read]); + + self.view = self.view[to_read..]; + bytes_read += to_read; } + + return bytes_read; } - /// Read fragment from `self.stream` into `self.buffer`. - /// Checks message `content_type` matches `self.content_type`. - /// Checks message `handshake_type` matches `self.handshake_type`. - pub fn readFragment(self: *Self) ReadError!void { + /// Read fragment from `stream` into `buffer` and updates `self.view`. Returns message type. + pub fn readFragment(self: *Self) ReadError!ContentType { + std.debug.assert(self.view.len == 0); // last read should have completed var plaintext_header: [fieldsLen(Plaintext)]u8 = undefined; var n_read: usize = 0; - var ty: ContentType = .invalid; + var res: ContentType = .invalid; var len: u16 = 0; while (true) { n_read = try self.stream.readAll(&plaintext_header); - if (n_read != plaintext_header.len) return error.TlsConnectionTruncated; - self.view = &plaintext_header; - ty = try self.read(ContentType); - _ = try self.read(Version); - len = try self.read(u16); + if (n_read != plaintext_header.len) return self.writeError(.decode_error); + + // Take advantage of our `read` parsing code by setting view outside `self.buffer`. + { + self.view = &plaintext_header; + errdefer self.view = self.buffer[0..0]; + res = try self.read(ContentType); + _ = try self.read(Version); + len = try self.read(u16); + if (len > self.maxFragmentSize()) return self.writeError(.record_overflow); + } + + self.view = self.buffer[0..len]; + n_read = try self.stream.readAll(@constCast(self.view)); + if (n_read != len) return self.writeError(.decode_error); + + const encryption_method = if (res == .application_data) self.encryptionMethod() else .none; + switch (encryption_method) { + .none => {}, + inline .handshake, .application => |t| { + switch (if (comptime t == .handshake) self.handshake_cipher.? else self.application_cipher.?) { + .empty_renegotiation_info_scsv => {}, + inline else => |*p| { + const P = @TypeOf(p.*); + const tag_len = P.AEAD.tag_length; + + const ciphertext = self.view[0..self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len..][0..tag_len].*; + const out: []u8 = @constCast(self.view[0..ciphertext.len]); + try p.decrypt(ciphertext, &plaintext_header, tag, self.is_client, out); + const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); + if (padding_start) |s| { + res = @enumFromInt(self.view[s]); + self.view = self.view[0..s]; + } else { + return self.writeError(.decode_error); + } + }, + } + }, + } - switch (ty) { + switch (res) { .alert => { const level = try self.read(Alert.Level); const description = try self.read(Alert.Description); std.log.debug("TLS alert {} {}", .{ level, description }); - return error.TlsUnexpectedMessage; + if (description == .close_notify) { + self.closed = true; + return res; + } + if (level == .fatal) return self.writeError(.unexpected_message); + continue; }, - // An implementation may receive an unencrypted record of type - // change_cipher_spec consisting of the single byte value 0x01 at any - // time after the first ClientHello message has been sent or received - // and before the peer's Finished message has been received and MUST - // simply drop it without further processing. + // > An implementation may receive an unencrypted record of type + // > change_cipher_spec consisting of the single byte value 0x01 at any + // > time after the first ClientHello message has been sent or received + // > and before the peer's Finished message has been received and MUST + // > simply drop it without further processing. .change_cipher_spec => { - if (self.application_cipher != null) return error.TlsUnexpectedMessage; - var next_byte: [1]u8 = undefined; - n_read = try self.stream.readAll(&next_byte); - if (len != 1 or n_read != 1 or next_byte[0] != 1) return error.TlsIllegalParameter; - }, - else => break, - } - } - if (ty != self.content_type) return error.TlsDecodeError; - - if (len > self.maxFragmentSize()) return error.TlsRecordOverflow; - if (self.view.len > self.maxFragmentSize()) return error.TlsDecodeError; // Should have read more before calling readFragment again. - - const dest = self.buffer[self.view.len..][0..len]; - n_read = try self.stream.readAll(dest); - if (n_read != len) return error.TlsConnectionTruncated; - - self.view = self.buffer[0 .. self.view.len + len]; - - if (ty == .application_data and self.handshake_cipher != null) { - switch (self.handshake_cipher.?) { - .empty_renegotiation_info_scsv => {}, - inline else => |*p| { - const P = @TypeOf(p.*); - const tag_len = P.AEAD.tag_length; - - const ciphertext = self.view[0..self.view.len - tag_len]; - debugPrint("ciphertext", ciphertext); - const tag = self.view[self.view.len - tag_len..][0..tag_len].*; - debugPrint("tag", tag); - - try p.decrypt( - ciphertext, - &plaintext_header, - tag, - self.is_client, - self.buffer[0..ciphertext.len], - ); - self.view = self.buffer[0..ciphertext.len]; + if (!std.mem.eql(u8, self.view, &[_]u8{1})) return self.writeError(.unexpected_message); + continue; }, + else => {}, } self.transcript_hash.update(self.view); - } else { - self.transcript_hash.update(self.view); + + return res; } + } - if (self.handshake_type) |expected| { - const actual = try self.read(HandshakeType); - if (actual != expected) return error.TlsDecodeError; + pub fn expectFragment(self: *Self, expected_content: ContentType, expected_handshake: ?HandshakeType,) ReadError!void { + const actual_content = try self.readFragment(); + if (expected_content != actual_content) { + std.debug.print("expected {} got {}\n", .{ expected_content, actual_content }); + return self.writeError(.decode_error); + } + if (expected_handshake) |expected| { + const actual_handshake = try self.read(HandshakeType); + if (actual_handshake != expected) return self.writeError(.decode_error); // TODO: verify this? - const handshake_len = try self.read(u24); - std.debug.print("handshake_len {d}\n", .{ handshake_len }); + _ = try self.read(u24); } } pub fn read(self: *Self, comptime T: type) ReadError!T { + comptime std.debug.assert(@sizeOf(T) < fragment_size); switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => { - const byte = try self.readAll(1); - return byte[0]; - }, - 16 => { - const bytes = try self.readAll(2); - const b0: u16 = bytes[0]; - const b1: u16 = bytes[1]; - return (b0 << 8) | b1; - }, - 24 => { - const bytes = try self.readAll(3); - const b0: u24 = bytes[0]; - const b1: u24 = bytes[1]; - const b2: u24 = bytes[2]; - return (b0 << 16) | (b1 << 8) | b2; - }, - else => @compileError("unsupported int type: " ++ @typeName(T)), + .Int => return self.reader().readInt(T, .big) catch |err| switch (err) { + error.EndOfStream => return self.writeError(.decode_error), + else => |e| return e, }, .Enum => |info| { if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); @@ -1390,65 +1418,135 @@ pub fn Stream( return @enumFromInt(int); }, else => { - return try T.read(self); + return T.read(self) catch |err| switch (err) { + error.EndOfStream, error.Full, error.ReadLengthInvalid => return self.writeError(.decode_error), + error.TlsUnexpectedMessage => return self.writeError(.unexpected_message), + error.TlsBadRecordMac => return self.writeError(.bad_record_mac), + error.TlsRecordOverflow => return self.writeError(.record_overflow), + error.TlsHandshakeFailure => return self.writeError(.handshake_failure), + error.TlsBadCertificate => return self.writeError(.bad_certificate), + error.TlsUnsupportedCertificate => return self.writeError(.unsupported_certificate), + error.TlsCertificateRevoked => return self.writeError(.certificate_revoked), + error.TlsCertificateExpired => return self.writeError(.certificate_expired), + error.TlsCertificateUnknown => return self.writeError(.certificate_unknown), + error.TlsIllegalParameter => return self.writeError(.illegal_parameter), + error.TlsUnknownCa => return self.writeError(.unknown_ca), + error.TlsAccessDenied => return self.writeError(.access_denied), + error.TlsDecodeError => return self.writeError(.decode_error), + error.TlsDecryptError => return self.writeError(.decrypt_error), + error.TlsProtocolVersion => return self.writeError(.protocol_version), + error.TlsInsufficientSecurity => return self.writeError(.insufficient_security), + error.TlsInternalError => return self.writeError(.internal_error), + error.TlsInappropriateFallback => return self.writeError(.inappropriate_fallback), + error.TlsMissingExtension => return self.writeError(.missing_extension), + error.TlsUnsupportedExtension => return self.writeError(.unsupported_extension), + error.TlsUnrecognizedName => return self.writeError(.unrecognized_name), + error.TlsBadCertificateStatusResponse => return self.writeError(.bad_certificate_status_response), + error.TlsUnknownPskIdentity => return self.writeError(.unknown_psk_identity), + error.TlsCertificateRequired => return self.writeError(.certificate_required), + error.TlsNoApplicationProtocol => return self.writeError(.no_application_protocol), + error.TlsUnknown => |e| { + self.close(); + return e; + }, + }; }, } } - /// Read a u8 prefixed array. Valid until next `read`. - pub fn readSmallArray(self: *Self, comptime T: type) ReadError![]align(1) const T { - if (std.math.maxInt(u8) > self.maxFragmentSize()) @panic("increase fragment_size"); - const len = try self.read(u8); - const old_view = self.view; - var bytes = try self.readAll(len); - if (@sizeOf(T) > 1) { - self.view = old_view; - for (0..len / @sizeOf(T)) |i| { - const val_bytes = @constCast(bytes[i * @sizeOf(T) ..][0..@sizeOf(T)]); - var val = try self.read(T); - @memcpy(val_bytes, std.mem.asBytes(&val)); - } - } - return std.mem.bytesAsSlice(T, bytes); - } - fn Iterator(comptime T: type) type { return struct { stream: *Self, - expected_len: usize, - start: usize, - - pub fn next(self: *@This()) !?T { - const cur = @intFromPtr(self.stream.view.ptr) - @intFromPtr(&self.stream.buffer); - const len = cur - self.start; - if (len > self.expected_len) return error.TlsUnexpectedMessage; // overread - if (len == self.expected_len) return null; + end: usize, + pub fn next(self: *@This()) ReadError!?T { + const cur_offset = self.stream.buffer.len - self.stream.view.len; + if (cur_offset > self.end) return null; return try self.stream.read(T); } }; } - pub fn iterator(self: *Self, comptime Tag: type) !Iterator(Tag) { - const expected_len = try self.read(u16); - const start = @intFromPtr(self.view.ptr) - @intFromPtr(&self.buffer); + pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) ReadError!Iterator(Tag) { + const offset = self.buffer.len - self.view.len; + const len = try self.read(Len); return Iterator(Tag){ .stream = self, - .expected_len = expected_len, - .start = start, + .end = offset + len, }; } - pub fn extensions(self: *Self) !Iterator(Extension.Header) { - return self.iterator(Extension.Header); + pub fn extensions(self: *Self) ReadError!Iterator(Extension.Header) { + return self.iterator(u16, Extension.Header); } pub fn eof(self: Self) bool { return self.closed and self.view.len == 0; } + + pub const Reader = std.io.Reader(*Self, ReadError, readBytes); + pub const Writer = std.io.Writer(*Self, WriteError, writeBytes); + + pub fn reader(self: *Self) Reader { + return .{ .context = self }; + } + + pub fn writer(self: *Self) Writer { + return .{ .context = self }; + } }; } +/// One of these potential hashes will be selected after receiving the other party's hello. +/// +/// We init them before sending any messages to avoid having to store our first message until the +/// other party's handshake message returns. This message is usually larger than +/// `@sizeOf(MultiHash)` = 560 +/// +/// A nice benefit is decreased latency on hosts where one round trip takes longer than calling +/// `update` on each hashes. +pub const MultiHash = struct { + sha256: sha2.Sha256 = sha2.Sha256.init(.{}), + sha384: sha2.Sha384 = sha2.Sha384.init(.{}), + sha512: sha2.Sha512 = sha2.Sha512.init(.{}), + /// Chosen during handshake. + active: enum { all, sha256, sha384, sha512 } = .all, + + const sha2 = crypto.hash.sha2; + const Self = @This(); + + pub fn update(self: *Self, bytes: []const u8) void { + switch (self.active) { + .all => { + self.sha256.update(bytes); + self.sha384.update(bytes); + self.sha512.update(bytes); + }, + .sha256 => self.sha256.update(bytes), + .sha384 => self.sha384.update(bytes), + .sha512 => self.sha512.update(bytes), + } + } + + pub fn setActive(self: *Self, cipher_suite: CipherSuite) Error!void { + self.active = switch (cipher_suite) { + .aes_128_gcm_sha256, .chacha20_poly1305_sha256, .aegis_128l_sha256 => .sha256, + .aes_256_gcm_sha384 => .sha384, + .aegis_256_sha512 => .sha512, + else => return Error.TlsIllegalParameter, + }; + } + + pub inline fn peek(self: Self) []const u8 { + return &switch (self.active) { + .all => [_]u8{}, + .sha256 => self.sha256.peek(), + .sha384 => self.sha384.peek(), + .sha512 => self.sha512.peek(), + }; + } +}; + const Encoder = struct { fn RetType(comptime T: type) type { switch (@typeInfo(T)) { @@ -1513,20 +1611,17 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { handshake_secret: [Hkdf.prk_length]u8, master_secret: [Hkdf.prk_length]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, client_finished_key: [Hmac.key_length]u8, server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - seq: usize = 0, + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + read_seq: usize = 0, + write_seq: usize = 0, const Self = @This(); - pub fn nonce(self: Self) [AEAD.nonce_length]u8 { - return nonce_for_len(AEAD.nonce_length, self.server_handshake_iv, self.seq); - } - fn encrypt( self: *Self, data: []const u8, @@ -1535,9 +1630,11 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { out: []u8, ) [AEAD.tag_length]u8 { var res: [AEAD.tag_length]u8 = undefined; - const key = if (is_client) self.client_handshake_key else self.server_handshake_key; - AEAD.encrypt(out, &res, data, additional, self.nonce(), key); - self.seq += 1; + const key = if (is_client) self.client_key else self.server_key; + const iv = if (is_client) self.client_iv else self.server_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.write_seq); + AEAD.encrypt(out, &res, data, additional, nonce, key); + self.write_seq += 1; return res; } @@ -1548,10 +1645,12 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { tag: [AEAD.tag_length]u8, is_client: bool, out: []u8, - ) !void { - const key = if (is_client) self.server_handshake_key else self.client_handshake_key; - AEAD.decrypt(out, data, tag, additional, self.nonce(), key) catch return error.TlsBadRecordMac; - self.seq += 1; + ) Error!void { + const key = if (is_client) self.server_key else self.client_key; + const iv = if (is_client) self.server_iv else self.client_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.read_seq); + AEAD.decrypt(out, data, tag, additional, nonce, key) catch return Error.TlsBadRecordMac; + self.read_seq += 1; } pub fn print(self: Self) void { @@ -1573,18 +1672,11 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { server_key: [AEAD.key_length]u8, client_iv: [AEAD.nonce_length]u8, server_iv: [AEAD.nonce_length]u8, - seq: usize = 0, + read_seq: usize = 0, + write_seq: usize = 0, const Self = @This(); - pub fn client_nonce(self: Self) [AEAD.nonce_length]u8 { - return nonce_for_len(AEAD.nonce_length, self.client_iv, self.seq); - } - - pub fn server_nonce(self: Self) [AEAD.nonce_length]u8 { - return nonce_for_len(AEAD.nonce_length, self.server_iv, self.seq); - } - fn encrypt( self: *Self, data: []const u8, @@ -1593,12 +1685,32 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { out: []u8, ) [AEAD.tag_length]u8 { var res: [AEAD.tag_length]u8 = undefined; - const nonce = if (is_client) self.client_nonce() else self.server_nonce(); const key = if (is_client) self.client_key else self.server_key; + const iv = if (is_client) self.client_iv else self.server_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.write_seq); AEAD.encrypt(out, &res, data, additional, nonce, key); - self.seq += 1; + self.write_seq += 1; return res; } + + fn decrypt( + self: *Self, + data: []const u8, + additional: []const u8, + tag: [AEAD.tag_length]u8, + is_client: bool, + out: []u8, + ) Error!void { + const key = if (is_client) self.server_key else self.client_key; + const iv = if (is_client) self.server_iv else self.client_iv; + const nonce = nonce_for_len(AEAD.nonce_length, iv, self.read_seq); + AEAD.decrypt(out, data, tag, additional, nonce, key) catch return Error.TlsBadRecordMac; + self.read_seq += 1; + } + + pub fn print(self: Self) void { + inline for (std.meta.fields(Self)) |f| debugPrint(f.name, @field(self, f.name)); + } }; } @@ -1765,9 +1877,8 @@ test "tls client and server handshake, data, and close_notify" { const host = "example.ulfheim.net"; var client = Client(@TypeOf(inner_stream)){ - .stream = Stream(Plaintext.max_length, TestStream, client_mod.MultiHash){ + .stream = Stream(Plaintext.max_length, TestStream){ .stream = &inner_stream, - .transcript_hash = .{}, .is_client = true, }, .options = .{ .host = host, .ca_bundle = null }, @@ -1775,9 +1886,8 @@ test "tls client and server handshake, data, and close_notify" { const server_der = @embedFile("./testdata/server.der"); var server = Server(@TypeOf(inner_stream)){ - .stream = Stream(Plaintext.max_length, TestStream, server_mod.TranscriptHash){ + .stream = Stream(Plaintext.max_length, TestStream){ .stream = &inner_stream, - .transcript_hash = server_mod.TranscriptHash.init(.{}), .is_client = false, }, .options = .{ @@ -1817,6 +1927,7 @@ test "tls client and server handshake, data, and close_notify" { session_id, client_x25519_seed ++ client_x25519_seed, client_x25519_seed, + client_x25519_seed ++ [_]u8{0} ** (48-32), client_x25519_seed, ); { @@ -2034,7 +2145,7 @@ test "tls client and server handshake, data, and close_notify" { [_]u8{ 0x14, // ChangeCipherSpec 0x03, 0x03, // tls 1.2 - 0x00, 0x01, // Handshake len + 0x00, 0x01, // len 0x01, // .change_cipher_spec } ++ [_]u8{ 0x17, // application data (lie for tls 1.2 compat) @@ -2145,21 +2256,108 @@ test "tls client and server handshake, data, and close_notify" { try client.recv_hello(key_pairs); - // Test that all 8 shared keys are identical - // If one of these isn't true, check that the earlier transcript_hashes match. + // Test that ALL shared keys are identical { const s = server.stream.handshake_cipher.?.aes_256_gcm_sha384; const c = client.stream.handshake_cipher.?.aes_256_gcm_sha384; try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); - try std.testing.expectEqualSlices(u8, &s.server_handshake_key, &c.server_handshake_key); - try std.testing.expectEqualSlices(u8, &s.client_handshake_key, &c.client_handshake_key); try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); - try std.testing.expectEqualSlices(u8, &s.server_handshake_iv, &c.server_handshake_iv); - try std.testing.expectEqualSlices(u8, &s.client_handshake_iv, &c.client_handshake_iv); + try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + const client_iv = [_]u8{ 0x42,0x56,0xd2,0xe0,0xe8,0x8b,0xab,0xdd,0x05,0xeb,0x2f,0x27 }; + try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + } + { + const s = server.stream.application_cipher.?.aes_256_gcm_sha384; + const c = client.stream.application_cipher.?.aes_256_gcm_sha384; + + try std.testing.expectEqualSlices(u8, &s.client_secret, &c.client_secret); + try std.testing.expectEqualSlices(u8, &s.server_secret, &c.server_secret); + try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + const client_iv = [_]u8{ 0xbb,0x00,0x79,0x56,0xf4,0x74,0xb2,0x5d,0xe9,0x02,0x43,0x2f, }; + try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + } + + try client.send_finished(); + { + const buf = tmp_buf[0..inner_stream.buffer.len()]; + try inner_stream.peek(buf); + + const expected = [_]u8{ + 0x14, // ChangeCipherSpec + 0x03, 0x03, // tls 1.2 + 0x00, 0x01, // len + 0x01, // .change_cipher_spec + } + ++ [_]u8{ + 0x17, // app data (lie for TLS 1.2) + 0x03, 0x03, // tls 1.2 + 0x00, 0x45, // len + 0x9f, 0xf9, 0xb0, 0x63, 0x17, 0x51, 0x77, 0x32, 0x2a, 0x46, 0xdd, 0x98, 0x96, 0xf3, 0xc3, 0xbb, + 0x82, 0x0a, 0xb5, 0x17, 0x43, 0xeb, 0xc2, 0x5f, 0xda, 0xdd, 0x53, 0x45, 0x4b, 0x73, 0xde, 0xb5, + 0x4c, 0xc7, 0x24, 0x8d, 0x41, 0x1a, 0x18, 0xbc, 0xcf, 0x65, 0x7a, 0x96, 0x08, 0x24, 0xe9, 0xa1, + 0x93, 0x64, 0x83, 0x7c, // encrypted data + 0x35, // handshake + 0x0a, 0x69, 0xa8, 0x8d, 0x4b, 0xf6, 0x35, 0xc8, // auth tag + 0x5e, 0xb8, 0x74, 0xae, 0xbc, 0x9d, 0xfd, 0xe8, // auth tag + } + ; + try std.testing.expectEqualSlices(u8, &expected, buf); + } + try server.recv_finish(); + + _ = try client.stream.writer().writeAll("ping"); + try client.stream.flush(); + { + const buf = tmp_buf[0..inner_stream.buffer.len()]; + try inner_stream.peek(buf); + + const expected = [_]u8{ + 0x17, // app data (FOR REAL THIS TIME) + 0x03, 0x03, // tls 1.2 + 0x00, 0x15, // len + 0x82, 0x81, 0x39, 0xcb, // ping + 0x7b, // app data (exciting!) + 0x73, 0xaa, 0xab, 0xf5, 0xb8, 0x2f, 0xbf, 0x9a, // auth tag + 0x29, 0x61, 0xbc, 0xde, 0x10, 0x03, 0x8a, 0x32, // auth tag + } + ; + try std.testing.expectEqualSlices(u8, &expected, buf); + } + + var recv_ping: [4]u8 = undefined; + _ = try server.stream.reader().readAll(&recv_ping); + try std.testing.expectEqualStrings("ping", &recv_ping); + + server.stream.close(); + try std.testing.expect(server.stream.closed); + { + const buf = tmp_buf[0..inner_stream.buffer.len()]; + try inner_stream.peek(buf); + + const expected = [_]u8{ + 0x17, // app data (lie to encrypt) + 0x03, 0x03, // tls 1.2 + 0x00, 0x13, // len + 0x3e, 0x2d, // alert + 0x99, // encrypted message type + 0x26, 0xbb, 0xfe, 0x1f, 0x46, 0xfb, 0x4e, 0xe2, // auth tag + 0x75, 0x1e, 0x53, 0xbf, 0xfc, 0x7e, 0x65, 0x16, // auth tag + } + ; + try std.testing.expectEqualSlices(u8, &expected, buf); } + + _ = try client.stream.readFragment(); + try std.testing.expect(client.stream.closed); } test { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3d43657cbf1e..c215011a75b6 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -6,21 +6,18 @@ const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; -pub const TranscriptHash = MultiHash; - /// `StreamType` must conform to `tls.StreamInterface`. pub fn Client(comptime StreamType: type) type { return struct { - stream: tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash), + stream: tls.Stream(tls.Plaintext.max_length, StreamType), options: Options, const Self = @This(); /// Initiates a TLS handshake and establishes a TLSv1.3 session pub fn init(stream: *StreamType, options: Options) !Self { - var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash){ + var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ .stream = stream, - .transcript_hash = .{}, .is_client = true, }; var res = Self{ .stream = stream_, .options = options }; @@ -32,394 +29,6 @@ pub fn Client(comptime StreamType: type) type { _ = &stream_; return res; - - // var cert_index: usize = 0; - // var read_seq: u64 = 0; - // var prev_cert: Certificate.Parsed = undefined; - // // Set to true once a trust chain has been established from the first - // // certificate to a root CA. - // const HandshakeState = enum { - // /// In this state we expect an encrypted_extensions message. - // encrypted_extensions, - // /// In this state we expect certificate messages. - // certificate, - // /// In this state we expect certificate or certificate_verify messages. - // /// Certificate messages are ignored since the trust chain is already established. - // trust_chain_established, - // /// In this state, we expect only the finished message. - // finished, - // }; - // var handshake_state: HandshakeState = .encrypted_extensions; - // var cleartext_bufs: [2][tls.max_cipertext_inner_record_len]u8 = undefined; - // var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - // var main_cert_pub_key_buf: [600]u8 = undefined; - // var main_cert_pub_key_len: u16 = undefined; - // const now_sec = std.time.timestamp(); - - // const cleartext_buf = &cleartext_bufs[cert_index % 2]; - // const cleartext = try handshake_cipher.cleartext(record, read_seq, cleartext_buf); - - // const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - // if (inner_ct != .handshake) return error.TlsUnexpectedMessage; - - // var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); - // while (true) { - // try ctd.ensure(4); - // const handshake_type = ctd.decode(tls.HandshakeType); - // const handshake_len = ctd.decode(u24); - // var hsd = try ctd.sub(handshake_len); - // const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; - // const handshake_buf = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; - // switch (handshake_type) { - // .encrypted_extensions => { - // if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - // handshake_state = .certificate; - // switch (handshake_cipher) { - // inline else => |*p| p.transcript_hash.update(wrapped_handshake), - // } - // try hsd.ensure(2); - // const total_ext_size = hsd.decode(u16); - // var all_extd = try hsd.sub(total_ext_size); - // while (!all_extd.eof()) { - // try all_extd.ensure(4); - // const et = all_extd.decode(tls.ExtensionType); - // const ext_size = all_extd.decode(u16); - // const extd = try all_extd.sub(ext_size); - // _ = extd; - // switch (et) { - // .server_name => {}, - // else => {}, - // } - // } - // }, - // .certificate => cert: { - // switch (handshake_cipher) { - // inline else => |*p| p.transcript_hash.update(wrapped_handshake), - // } - // switch (handshake_state) { - // .certificate => {}, - // .trust_chain_established => break :cert, - // else => return error.TlsUnexpectedMessage, - // } - // try hsd.ensure(1 + 4); - // const cert_req_ctx_len = hsd.decode(u8); - // if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - // const certs_size = hsd.decode(u24); - // var certs_decoder = try hsd.sub(certs_size); - // while (!certs_decoder.eof()) { - // try certs_decoder.ensure(3); - // const cert_size = certs_decoder.decode(u24); - // const certd = try certs_decoder.sub(cert_size); - - // const subject_cert: Certificate = .{ - // .buffer = certd.buf, - // .index = @intCast(certd.idx), - // }; - // const subject = try subject_cert.parse(); - // if (cert_index == 0) { - // // Verify the host on the first certificate. - // try subject.verifyHostName(options.host); - - // // Keep track of the public key for the - // // certificate_verify message later. - // main_cert_pub_key_algo = subject.pub_key_algo; - // const pub_key = subject.pubKey(); - // if (pub_key.len > main_cert_pub_key_buf.len) - // return error.CertificatePublicKeyInvalid; - // @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); - // main_cert_pub_key_len = @intCast(pub_key.len); - // } else { - // try prev_cert.verify(subject, now_sec); - // } - - // if (options.ca_bundle.verify(subject, now_sec)) |_| { - // handshake_state = .trust_chain_established; - // break :cert; - // } else |err| switch (err) { - // error.CertificateIssuerNotFound => {}, - // else => |e| return e, - // } - - // prev_cert = subject; - // cert_index += 1; - - // try certs_decoder.ensure(2); - // const total_ext_size = certs_decoder.decode(u16); - // const all_extd = try certs_decoder.sub(total_ext_size); - // _ = all_extd; - // } - // }, - // .certificate_verify => { - // switch (handshake_state) { - // .trust_chain_established => handshake_state = .finished, - // .certificate => return error.TlsCertificateNotVerified, - // else => return error.TlsUnexpectedMessage, - // } - - // try hsd.ensure(4); - // const scheme = hsd.decode(tls.SignatureScheme); - // const sig_len = hsd.decode(u16); - // try hsd.ensure(sig_len); - // const encoded_sig = hsd.slice(sig_len); - // const max_digest_len = 64; - // var verify_buffer: [64 + 34 + max_digest_len]u8 = - // ([1]u8{0x20} ** 64) ++ - // "TLS 1.3, server CertificateVerify\x00".* ++ - // @as([max_digest_len]u8, undefined); - - // const verify_bytes = switch (handshake_cipher) { - // inline else => |*p| v: { - // const transcript_digest = p.transcript_hash.peek(); - // verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - // p.transcript_hash.update(wrapped_handshake); - // break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - // }, - // }; - // const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - // switch (scheme) { - // inline .ecdsa_secp256r1_sha256, - // .ecdsa_secp384r1_sha384, - // => |comptime_scheme| { - // if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - // return error.TlsBadSignatureScheme; - // const Ecdsa = SchemeEcdsa(comptime_scheme); - // const sig = try Ecdsa.Signature.fromDer(encoded_sig); - // const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - // try sig.verify(verify_bytes, key); - // }, - // inline .rsa_pss_rsae_sha256, - // .rsa_pss_rsae_sha384, - // .rsa_pss_rsae_sha512, - // => |comptime_scheme| { - // if (main_cert_pub_key_algo != .rsaEncryption) - // return error.TlsBadSignatureScheme; - - // const Hash = SchemeHash(comptime_scheme); - // const rsa = Certificate.rsa; - // const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - // const exponent = components.exponent; - // const modulus = components.modulus; - // switch (modulus.len) { - // inline 128, 256, 512 => |modulus_len| { - // const key = try rsa.PublicKey.fromBytes(exponent, modulus); - // const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - // try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - // }, - // else => { - // return error.TlsBadRsaSignatureBitCount; - // }, - // } - // }, - // inline .ed25519 => |comptime_scheme| { - // if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - // const Eddsa = SchemeEddsa(comptime_scheme); - // if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - // const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); - // if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - // const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); - // try sig.verify(verify_bytes, key); - // }, - // else => { - // return error.TlsBadSignatureScheme; - // }, - // } - // }, - // .finished => { - // if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // // This message is to trick buggy proxies into behaving correctly. - // const client_change_cipher_spec_msg = [_]u8{ - // @intFromEnum(tls.ContentType.change_cipher_spec), - // 0x03, 0x03, // legacy protocol version - // 0x00, 0x01, // length - // 0x01, - // }; - // const app_cipher = switch (handshake_cipher) { - // inline else => |*p, tag| c: { - // const P = @TypeOf(p.*); - // const finished_digest = p.transcript_hash.peek(); - // p.transcript_hash.update(wrapped_handshake); - // const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - // if (!mem.eql(u8, &expected_server_verify_data, handshake_buf)) - // return error.TlsDecryptError; - // const handshake_hash = p.transcript_hash.finalResult(); - // const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - // const out_cleartext = [_]u8{ - // @intFromEnum(tls.HandshakeType.finished), - // 0, 0, verify_data.len, // length - // } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; - - // const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - - // var finished_msg = [_]u8{ - // @intFromEnum(tls.ContentType.application_data), - // 0x03, 0x03, // legacy protocol version - // 0, wrapped_len, // byte length of encrypted record - // } ++ @as([wrapped_len]u8, undefined); - - // const ad = finished_msg[0..5]; - // const ciphertext = finished_msg[5..][0..out_cleartext.len]; - // const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - // const nonce = p.client_handshake_iv; - // P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - - // const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - // var both_msgs_vec = [_]std.os.iovec_const{.{ - // .iov_base = &both_msgs, - // .iov_len = both_msgs.len, - // }}; - // try stream.writevAll(&both_msgs_vec); - - // const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - // const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - // break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ - // .client_secret = client_secret, - // .server_secret = server_secret, - // .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - // .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - // .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - // .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - // }); - // }, - // }; - // const leftover = decoder.rest(); - // var client: Client = .{ - // .read_seq = 0, - // .write_seq = 0, - // .partial_cleartext_idx = 0, - // .partial_ciphertext_idx = 0, - // .partial_ciphertext_end = @intCast(leftover.len), - // .received_close_notify = false, - // .application_cipher = app_cipher, - // .partially_read_buffer = undefined, - // }; - // @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - // return client; - // }, - // else => { - // return error.TlsUnexpectedMessage; - // }, - // } - // if (ctd.eof()) break; - // } - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. - pub fn write(self: *Self, bytes: []const u8) !usize { - return self.writeEnd(bytes, false); - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - pub fn writeAll(self: *Self, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - /// If `end` is true, then this function additionally sends a `close_notify` alert, - /// which is necessary for the server to distinguish between a properly finished - /// TLS session, or a truncation attack. - pub fn writeAllEnd(self: *Self, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.writeEnd(bytes[index..], end); - } - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. - /// If `end` is true, then this function additionally sends a `close_notify` alert, - /// which is necessary for the server to distinguish between a properly finished - /// TLS session, or a truncation attack. - pub fn writeEnd(self: *Self, bytes: []const u8, end: bool) !usize { - try self.stream.writeAll(bytes); - if (end) { - const alert = tls.Alert{ - .level = .fatal, - .description = .close_notify, - }; - try self.stream.write(tls.Alert, alert); - try self.stream.flush(); - } - return bytes.len; - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read, calling the underlying read function the - /// minimal number of times until the buffer has at least `len` bytes filled. - /// If the number read is less than `len` it means the stream reached the end. - /// Reaching the end of the stream is not an error condition. - pub fn readAtLeast(self: *Self, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; - return self.readvAtLeast(&iovecs, len); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - pub fn read(self: *Self, buffer: []u8) !usize { - return self.readAtLeast(buffer, 1); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of the - /// stream is not an error condition. - pub fn readAll(self: *Self, buffer: []u8) !usize { - return self.readAtLeast(buffer, buffer.len); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read. If the number read is less than the space - /// provided it means the stream reached the end. Reaching the end of the - /// stream is not an error condition. - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial reads from the underlying stream layer. - pub fn readv(self: *Self, iovecs: []std.os.iovec) !usize { - return self.readvAtLeast(iovecs, 1); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read, calling the underlying read function the - /// minimal number of times until the iovecs have at least `len` bytes filled. - /// If the number read is less than `len` it means the stream reached the end. - /// Reaching the end of the stream is not an error condition. - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial reads from the underlying stream layer. - pub fn readvAtLeast(self: *Self, iovecs: []std.os.iovec, len: usize) !usize { - if (self.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try self.readvAdvanced(iovecs[vec_i..]); - off_i += amt; - if (self.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].iov_len) { - amt -= iovecs[vec_i].iov_len; - vec_i += 1; - } - iovecs[vec_i].iov_base += amt; - iovecs[vec_i].iov_len -= amt; - } - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns number of bytes that have been read, populated inside `iovecs`. A - /// return value of zero bytes does not mean end of stream. Instead, check the `eof()` - /// for the end of stream. The `eof()` may be true after any call to - /// `read`, including when greater than zero bytes are returned, and this - /// function asserts that `eof()` is `false`. - /// See `readv` for a higher level function that has the same, familiar API as - /// other read functions, such as `std.fs.File.read`. - pub fn readvAdvanced(self: *Self, iovecs: []const std.os.iovec) !usize { - _ = .{ self, iovecs }; - return 0; - } - - pub fn eof(self: *Self) bool { - return self.stream.eof(); } pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { @@ -449,15 +58,25 @@ pub fn Client(comptime StreamType: type) type { } pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { - self.stream.handshake_type = .server_hello; - try self.stream.readFragment(); + try self.stream.expectFragment(.handshake, .server_hello); + var reader = self.stream.reader(); // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. _ = try self.stream.read(tls.Version); - const random = try self.stream.readAll(32); - if (mem.eql(u8, random, &tls.ServerHello.hello_retry_request)) return error.TlsUnexpectedMessage; // `ClientHello` failed and we don't know how to rephrase it. - const legacy_session_id = try self.stream.readSmallArray(u8); - if (!mem.eql(u8, legacy_session_id, &key_pairs.session_id)) return error.TlsIllegalParameter; + var random: [32]u8 = undefined; + try reader.readNoEof(&random); + if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { + // We already offered all our supported options and we aren't changing them. + return error.TlsUnexpectedMessage; + } + + var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try self.stream.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) return error.TlsUnexpectedMessage; + const session_id: []u8 = session_id_buf[0..session_id_len]; + try reader.readNoEof(session_id); + if (!mem.eql(u8, session_id, &key_pairs.session_id)) return error.TlsIllegalParameter; + const cipher_suite = try self.stream.read(tls.CipherSuite); const compression_method = try self.stream.read(u8); if (compression_method != 0) return error.TlsIllegalParameter; @@ -478,37 +97,46 @@ pub fn Client(comptime StreamType: type) type { const key_size = try self.stream.read(u16); switch (named_group) { .x25519_kyber768d00 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length; - if (key_size != hksl) + const T = tls.NamedGroupT(.x25519_kyber768d00); + const x25519_len = T.X25519.public_length; + const expected_len = x25519_len + T.Kyber768.ciphertext_length; + if (key_size != expected_len) return error.TlsIllegalParameter; - const server_ks = try self.stream.readAll(hksl); + var server_ks: [expected_len]u8 = undefined; + try reader.readNoEof(&server_ks); - shared_key = &((crypto.dh.X25519.scalarmult( + shared_key = &((T.X25519.scalarmult( key_pairs.x25519.secret_key, - server_ks[0..xksl].*, + server_ks[0..x25519_len].*, ) catch return error.TlsDecryptFailure) ++ (key_pairs.kyber768d00.secret_key.decaps( - server_ks[xksl..hksl], + server_ks[x25519_len..expected_len], ) catch return error.TlsDecryptFailure)); }, .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (key_size != ksl) return error.TlsIllegalParameter; - const server_pub_key = try self.stream.readAll(ksl); + const T = tls.NamedGroupT(.x25519); + const expected_len = T.public_length; + if (key_size != expected_len) return error.TlsIllegalParameter; + var server_ks: [expected_len]u8 = undefined; + try reader.readNoEof(&server_ks); shared_key = &(crypto.dh.X25519.scalarmult( key_pairs.x25519.secret_key, - server_pub_key[0..ksl].*, + server_ks[0..expected_len].*, ) catch return error.TlsDecryptFailure); }, - .secp256r1 => { - const server_pub_key = try self.stream.readAll(key_size); + inline .secp256r1, .secp384r1 => |t| { + const T = tls.NamedGroupT(t); + const expected_len = T.PublicKey.compressed_sec1_encoded_length; + if (key_size != expected_len) return error.TlsIllegalParameter; + + var server_ks: [expected_len]u8 = undefined; + try reader.readNoEof(&server_ks); - const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; - const pk = PublicKey.fromSec1(server_pub_key) catch { + const pk = T.PublicKey.fromSec1(&server_ks) catch { return error.TlsDecryptFailure; }; - const mul = pk.p.mulPublic(key_pairs.secp256r1.secret_key.bytes, .big) catch { + const key_pair = @field(key_pairs, @tagName(t)); + const mul = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch { return error.TlsDecryptFailure; }; shared_key = &mul.affineCoordinates().x.toBytes(.big); @@ -519,7 +147,7 @@ pub fn Client(comptime StreamType: type) type { } }, else => { - _ = try self.stream.readAll(ext.len); + try reader.skipBytes(ext.len, .{}); }, } } @@ -527,28 +155,92 @@ pub fn Client(comptime StreamType: type) type { if (supported_version != tls.Version.tls_1_3) return error.TlsIllegalParameter; if (shared_key == null) return error.TlsIllegalParameter; - self.stream.transcript_hash.active = switch (cipher_suite) { - .aes_128_gcm_sha256, .chacha20_poly1305_sha256, .aegis_128l_sha256 => .sha256, - .aes_256_gcm_sha384 => .sha384, - .aegis_256_sha512 => .sha512, - else => return error.TlsIllegalParameter, - }; + try self.stream.transcript_hash.setActive(cipher_suite); const hello_hash = self.stream.transcript_hash.peek(); - self.stream.handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash); - self.stream.content_type = .application_data; - self.stream.handshake_cipher.?.print(); + self.stream.handshake_cipher = try tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash); - self.stream.handshake_type = .encrypted_extensions; - try self.stream.readFragment(); - iter = try self.stream.extensions(); - while (try iter.next()) |ext| { - _ = try self.stream.readAll(ext.len); + { + try self.stream.expectFragment(.handshake, .encrypted_extensions); + iter = try self.stream.extensions(); + while (try iter.next()) |ext| { + try reader.skipBytes(ext.len, .{}); + } } // CertificateRequest* // Certificate* // CertificateVerify* - // Finished + { + try self.stream.expectFragment(.handshake, .certificate); + + var context: [tls.Certificate.max_context_len]u8 = undefined; + const context_len = try self.stream.read(u8); + try reader.readNoEof(context[0..context_len]); + + var certs_iter = try self.stream.iterator(u24, u24); + while (try certs_iter.next()) |cert_len| { + try reader.skipBytes(cert_len, .{}); + var ext_iter = try self.stream.extensions(); + while (try ext_iter.next()) |ext| { + switch (ext.type) { + else => { + try reader.skipBytes(ext.len, .{}); + }, + } + } + } + } + + { + try self.stream.expectFragment(.handshake, .certificate_verify); + + const scheme = try self.stream.read(tls.SignatureScheme); + const len = try self.stream.read(u16); + try reader.skipBytes(len, .{}); + + // TODO: verify + _ = .{ scheme }; + } + + { + try self.stream.expectFragment(.handshake, .finished); + + var verify_data: [48]u8 = undefined; + try reader.readNoEof(&verify_data); + + // TODO: verify + _ = .{ verify_data }; + } + + self.stream.application_cipher = tls.ApplicationCipher.init( + self.stream.handshake_cipher.?, + self.stream.transcript_hash.peek(), + ); + } + + pub fn send_finished(self: *Self) !void { + self.stream.version = .tls_1_2; + self.stream.content_type = .change_cipher_spec; + _ = try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); + try self.stream.flush(); + + const verify_data = switch (self.stream.handshake_cipher.?) { + inline + .aes_256_gcm_sha384, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.client_finished_key; + const transcript_hash = self.stream.transcript_hash.peek(); + + break :brk tls.hmac(T.Hmac, transcript_hash, secret); + }, + else => return error.TlsDecryptFailure, + }; + self.stream.content_type = .handshake; + _ = try self.stream.write(tls.Handshake, .{ .finished = &verify_data }); + try self.stream.flush(); + + self.stream.content_type = .application_data; } }; } @@ -570,43 +262,6 @@ pub const Options = struct { allow_truncation_attacks: bool = false, }; -/// One of these potential hashes will be selected during the handshake as the transcript hash. -/// We init them before sending a single message to avoid having to store the `ClientHello` until -/// receiving `ServerHello`. -/// A nice benefit is decreased latency on hosts where one round trip takes longer than calling -/// `update` on each hashes. -pub const MultiHash = struct { - sha256: sha2.Sha256 = sha2.Sha256.init(.{}), - sha384: sha2.Sha384 = sha2.Sha384.init(.{}), - sha512: sha2.Sha512 = sha2.Sha512.init(.{}), - /// Chosen during handshake. - active: enum { all, sha256, sha384, sha512 } = .all, - - const sha2 = crypto.hash.sha2; - const Self = @This(); - - pub fn update(self: *Self, bytes: []const u8) void { - switch (self.active) { - .all => { - self.sha256.update(bytes); - self.sha384.update(bytes); - self.sha512.update(bytes); - }, - .sha256 => self.sha256.update(bytes), - .sha384 => self.sha384.update(bytes), - .sha512 => self.sha512.update(bytes), - } - } - - pub fn peek(self: Self) []const u8 { - return &switch (self.active) { - .all => [_]u8{}, - .sha256 => self.sha256.peek(), - .sha384 => self.sha384.peek(), - .sha512 => self.sha512.peek(), - }; - } -}; /// One of these potential key pairs will be selected during the handshake. pub const KeyPairs = struct { @@ -614,6 +269,7 @@ pub const KeyPairs = struct { session_id: [session_id_length]u8, kyber768d00: Kyber768, secp256r1: Secp256r1, + secp384r1: Secp384r1, x25519: X25519, const Self = @This(); @@ -622,6 +278,7 @@ pub const KeyPairs = struct { const session_id_length = 32; const X25519 = tls.NamedGroupT(.x25519).KeyPair; const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; + const Secp384r1 = tls.NamedGroupT(.secp384r1).KeyPair; const Kyber768 = tls.NamedGroupT(.x25519_kyber768d00).Kyber768.KeyPair; pub fn init() Self { @@ -630,6 +287,7 @@ pub const KeyPairs = struct { session_id_length + Kyber768.seed_length + Secp256r1.seed_length + + Secp384r1.seed_length + X25519.seed_length ]u8 = undefined; @@ -640,13 +298,15 @@ pub const KeyPairs = struct { const split2 = split1 + session_id_length; const split3 = split2 + Kyber768.seed_length; const split4 = split3 + Secp256r1.seed_length; + const split5 = split3 + Secp384r1.seed_length; return initAdvanced( random_buffer[0..split1].*, random_buffer[split1..split2].*, random_buffer[split2..split3].*, random_buffer[split3..split4].*, - random_buffer[split4..].*, + random_buffer[split4..split5].*, + random_buffer[split5..].*, ) catch continue; } } @@ -656,6 +316,7 @@ pub const KeyPairs = struct { session_id: [session_id_length]u8, kyber_768_seed: [Kyber768.seed_length]u8, secp256r1_seed: [Secp256r1.seed_length]u8, + secp384r1_seed: [Secp384r1.seed_length]u8, x25519_seed: [X25519.seed_length]u8, ) !Self { return Self{ @@ -663,6 +324,9 @@ pub const KeyPairs = struct { .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. }, + .secp384r1 = Secp384r1.create(secp384r1_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, .x25519 = X25519.create(x25519_seed) catch |err| switch (err) { error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. }, diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index 90c850b68a80..dd3546e0879d 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -3,24 +3,26 @@ const tls = std.crypto.tls; const net = std.net; const mem = std.mem; const crypto = std.crypto; +const io = std.io; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; - -pub const TranscriptHash = std.crypto.hash.sha2.Sha384; +const Allocator = std.mem.Allocator; /// `StreamType` must conform to `tls.StreamInterface`. pub fn Server(comptime StreamType: type) type { return struct { - stream: tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash), + stream: Stream, options: Options, + /// Only used during handshake for messages larger than tls.Plaintext.max_length. + // allocator: Allocator, + const Stream = tls.Stream(tls.Plaintext.max_length, StreamType); const Self = @This(); /// Initiates a TLS handshake and establishes a TLSv1.3 session pub fn init(stream: *StreamType, options: Options) !Self { - var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType, TranscriptHash){ + var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ .stream = stream, - .transcript_hash = TranscriptHash.init(.{}), .is_client = false, }; var res = Self{ .stream = stream_, .options = options }; @@ -38,126 +40,9 @@ pub fn Server(comptime StreamType: type) type { return res; } - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. - pub fn write(self: *Self, bytes: []const u8) !usize { - return self.writeEnd(bytes, false); - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - pub fn writeAll(self: *Self, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - /// If `end` is true, then this function additionally sends a `close_notify` alert, - /// which is necessary for the server to distinguish between a properly finished - /// TLS session, or a truncation attack. - pub fn writeAllEnd(self: *Self, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.writeEnd(bytes[index..], end); - } - } - - /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. - /// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. - /// If `end` is true, then this function additionally sends a `close_notify` alert, - /// which is necessary for the server to distinguish between a properly finished - /// TLS session, or a truncation attack. - pub fn writeEnd(self: *Self, bytes: []const u8, end: bool) !usize { - try self.stream.writeAll(bytes); - if (end) { - const alert = tls.Alert{ - .level = .fatal, - .description = .close_notify, - }; - try self.stream.write(tls.Alert, alert); - try self.stream.flush(); - } - return bytes.len; - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read, calling the underlying read function the - /// minimal number of times until the buffer has at least `len` bytes filled. - /// If the number read is less than `len` it means the stream reached the end. - /// Reaching the end of the stream is not an error condition. - pub fn readAtLeast(self: *Self, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; - return self.readvAtLeast(&iovecs, len); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - pub fn read(self: *Self, buffer: []u8) !usize { - return self.readAtLeast(buffer, 1); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of the - /// stream is not an error condition. - pub fn readAll(self: *Self, buffer: []u8) !usize { - return self.readAtLeast(buffer, buffer.len); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read. If the number read is less than the space - /// provided it means the stream reached the end. Reaching the end of the - /// stream is not an error condition. - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial reads from the underlying stream layer. - pub fn readv(self: *Self, iovecs: []std.os.iovec) !usize { - return self.readvAtLeast(iovecs, 1); - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns the number of bytes read, calling the underlying read function the - /// minimal number of times until the iovecs have at least `len` bytes filled. - /// If the number read is less than `len` it means the stream reached the end. - /// Reaching the end of the stream is not an error condition. - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial reads from the underlying stream layer. - pub fn readvAtLeast(self: *Self, iovecs: []std.os.iovec, len: usize) !usize { - if (self.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try self.readvAdvanced(iovecs[vec_i..]); - off_i += amt; - if (self.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].iov_len) { - amt -= iovecs[vec_i].iov_len; - vec_i += 1; - } - iovecs[vec_i].iov_base += amt; - iovecs[vec_i].iov_len -= amt; - } - } - - /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. - /// Returns number of bytes that have been read, populated inside `iovecs`. A - /// return value of zero bytes does not mean end of stream. Instead, check the `eof()` - /// for the end of stream. The `eof()` may be true after any call to - /// `read`, including when greater than zero bytes are returned, and this - /// function asserts that `eof()` is `false`. - /// See `readv` for a higher level function that has the same, familiar API as - /// other read functions, such as `std.fs.File.read`. - pub fn readvAdvanced(self: *Self, iovecs: []const std.os.iovec) !usize { - _ = .{ self, iovecs }; - return 0; - } - - pub fn eof(self: *Self) bool { - return self.stream.eof(); - } - const ClientHello = struct { random: [32]u8, + session_id_len: u8, session_id: [32]u8, cipher_suite: tls.CipherSuite, key_share: tls.KeyShare, @@ -165,31 +50,36 @@ pub fn Server(comptime StreamType: type) type { }; pub fn recv_hello(self: *Self) !ClientHello { - try self.stream.readFragment(); + try self.stream.expectFragment(.handshake, .client_hello); + var reader = self.stream.reader(); _ = try self.stream.read(tls.Version); - const client_random = try self.stream.readAll(32); - const session_id = try self.stream.readSmallArray(u8); - if (session_id.len > tls.ClientHello.session_id_max_len) return error.TlsUnexpectedMessage; - - var selected_suite: ?tls.CipherSuite = null; - - var cipher_suite_iter = try self.stream.iterator(tls.CipherSuite); - while (try cipher_suite_iter.next()) |suite| { - if (selected_suite == null) brk: { + var client_random: [32]u8 = undefined; + try reader.readNoEof(&client_random); + + var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try self.stream.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) return error.TlsUnexpectedMessage; + try reader.readNoEof(session_id[0..session_id_len]); + + const cipher_suite: tls.CipherSuite = brk: { + var cipher_suite_iter = try self.stream.iterator(u16, tls.CipherSuite); + var res: ?tls.CipherSuite = null; + while (try cipher_suite_iter.next()) |suite| { for (self.options.cipher_suites) |s| { - if (s == suite) { - selected_suite = s; - break :brk; - } + if (s == suite and res == null) res = s; } } - } - - if (selected_suite == null) return error.TlsUnexpectedMessage; + if (res == null) return error.TlsUnexpectedMessage; + break :brk res.?; + }; + try self.stream.transcript_hash.setActive(cipher_suite); - const compression_methods = try self.stream.readAll(2); - if (!std.mem.eql(u8, compression_methods, &[_]u8{ 1, 0 })) return error.TlsUnexpectedMessage; + { + var compression_methods: [2]u8 = undefined; + try reader.readNoEof(&compression_methods); + if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) return error.TlsUnexpectedMessage; + } var tls_version: ?tls.Version = null; var key_share: ?tls.KeyShare = null; @@ -201,9 +91,8 @@ pub fn Server(comptime StreamType: type) type { switch (ext.type) { .supported_versions => { if (tls_version != null) return error.TlsUnexpectedMessage; - const versions = try self.stream.readSmallArray(tls.Version); - for (versions) |v| { - std.debug.print("version {}\n", .{v}); + var versions_iter = try self.stream.iterator(u8, tls.Version); + while (try versions_iter.next()) |v| { if (v == .tls_1_3) tls_version = v; } }, @@ -211,7 +100,7 @@ pub fn Server(comptime StreamType: type) type { .key_share => { if (key_share != null) return error.TlsUnexpectedMessage; - var key_share_iter = try self.stream.iterator(tls.KeyShare); + var key_share_iter = try self.stream.iterator(u16, tls.KeyShare); while (try key_share_iter.next()) |ks| { switch (ks) { .x25519 => key_share = ks, @@ -220,19 +109,19 @@ pub fn Server(comptime StreamType: type) type { } }, .ec_point_formats => { - const formats = try self.stream.readSmallArray(tls.EcPointFormat); - for (formats) |f| { + var format_iter = try self.stream.iterator(u8, tls.EcPointFormat); + while (try format_iter.next()) |f| { if (f == .uncompressed) ec_point_format = .uncompressed; } }, .signature_algorithms => { - var algos_iter = try self.stream.iterator(tls.SignatureScheme); + var algos_iter = try self.stream.iterator(u16, tls.SignatureScheme); while (try algos_iter.next()) |algo| { if (algo == .rsa_pss_rsae_sha256) sig_scheme = algo; } }, else => { - _ = try self.stream.readAll(ext.len); + try reader.skipBytes(ext.len, .{}); }, } } @@ -242,9 +131,10 @@ pub fn Server(comptime StreamType: type) type { if (ec_point_format == null) return error.TlsUnexpectedMessage; return .{ - .random = client_random[0..32].*, - .session_id = session_id[0..32].*, - .cipher_suite = selected_suite.?, + .random = client_random, + .session_id_len = session_id_len, + .session_id = session_id, + .cipher_suite = cipher_suite, .key_share = key_share.?, .sig_scheme = sig_scheme, }; @@ -265,9 +155,12 @@ pub fn Server(comptime StreamType: type) type { _ = try self.stream.write(tls.Handshake, .{ .server_hello = hello }); try self.stream.flush(); - self.stream.content_type = .change_cipher_spec; - _ = try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); - try self.stream.flush(); + // > if the client sends a non-empty session ID, the server MUST send the change_cipher_spec + if (hello.session_id.len > 0) { + self.stream.content_type = .change_cipher_spec; + _ = try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); + try self.stream.flush(); + } const shared_key = switch (client_hello.key_share) { .x25519_kyber768d00 => |ks| brk: { @@ -302,8 +195,7 @@ pub fn Server(comptime StreamType: type) type { }; const hello_hash = self.stream.transcript_hash.peek(); - self.stream.handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, &hello_hash); - self.stream.handshake_cipher.?.print(); + self.stream.handshake_cipher = try tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash); self.stream.content_type = .handshake; _ = try self.stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); @@ -344,21 +236,38 @@ pub fn Server(comptime StreamType: type) type { } pub fn send_handshake_finish(self: *Self) !void { - const secret = self.stream.handshake_cipher.?.aes_256_gcm_sha384.server_finished_key; - const transcript_hash = self.stream.transcript_hash.peek(); - tls.debugPrint("peek", transcript_hash); - const verify = switch (self.stream.handshake_cipher.?) { + const verify_data = switch (self.stream.handshake_cipher.?) { inline .aes_256_gcm_sha384, => |v| brk: { const T = @TypeOf(v); - break :brk tls.hmac(T.Hmac, &transcript_hash, secret); + const secret = v.server_finished_key; + const transcript_hash = self.stream.transcript_hash.peek(); + + break :brk tls.hmac(T.Hmac, transcript_hash, secret); }, else => return error.TlsDecryptFailure, }; - tls.debugPrint("verify", verify); - _ = try self.stream.write(tls.Handshake, .{ .finished = &verify }); + _ = try self.stream.write(tls.Handshake, .{ .finished = &verify_data }); try self.stream.flush(); + + self.stream.application_cipher = tls.ApplicationCipher.init( + self.stream.handshake_cipher.?, + self.stream.transcript_hash.peek(), + ); + } + + pub fn recv_finish(self: *Self) !void { + try self.stream.expectFragment(.handshake, .finished); + var reader = self.stream.reader(); + + var verify_data: [48]u8 = undefined; + try reader.readNoEof(&verify_data); + + // TODO: verify + _ = .{ verify_data }; + self.stream.content_type = .application_data; + self.stream.handshake_type = null; } }; } From 5bd77310b83cda1b3fe4d640254e1d223a40d589 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Tue, 12 Mar 2024 17:56:06 -0400 Subject: [PATCH 04/17] verify sigs, allow skipping certs --- TODO | 2 +- lib/std/crypto/Certificate.zig | 2 +- lib/std/crypto/tls.zig | 248 +++++++++++------- lib/std/crypto/tls/Client.zig | 457 ++++++++++++++++++++------------- lib/std/crypto/tls/Server.zig | 161 ++++++------ 5 files changed, 503 insertions(+), 367 deletions(-) diff --git a/TODO b/TODO index 431f6659a5c3..68ee2dc306b1 100644 --- a/TODO +++ b/TODO @@ -3,7 +3,7 @@ [x] 3. Client recv_hello secp256r1 key share [x] 4. remove @panic's [x] 5. better errors than spammy TlsDecodeError. map new errors to TLS alerts. send alert on error. -6. verify certs and sigs +6. verify certs and [x] sigs 7. KeyShare kyber read 8. StreamInterface `readv` instead of `readAll` diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index decd59571696..a94f3cc3e4f6 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -967,7 +967,7 @@ pub const rsa = struct { const mod_bits = public_key.n.bits(); const em_dec = try encrypt(modulus_len, sig, public_key); - EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash) catch unreachable; + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void { diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 67311da13ef7..1692a50dee3c 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -138,6 +138,11 @@ pub const Handshake = union(HandshakeType) { } return res; } + + pub const Header = struct { + type: HandshakeType, + len: u24, + }; }; pub const Certificate = struct { @@ -148,10 +153,11 @@ pub const Certificate = struct { pub const Entry = struct { /// Either ASN1_subjectPublicKeyInfo or cert_data based on CertificateType. - /// Max len 2^24-1 data: []const u8, extensions: []const Extension = &.{}, + pub const max_data_len = 1 << 24 - 1; + pub fn write(self: @This(), stream: anytype) !usize { var res: usize = 0; res += try stream.writeArray(u24, u8, self.data); @@ -172,9 +178,10 @@ pub const Certificate = struct { pub const CertificateVerify = struct { algorithm: SignatureScheme, - /// Max len 2^16 - 1 signature: []const u8, + pub const max_signature_length = 1 << 16 - 1; + pub fn write(self: @This(), stream: anytype) !usize { var res: usize = 0; res += try stream.write(SignatureScheme, self.algorithm); @@ -242,110 +249,32 @@ pub const ExtensionType = enum(u16) { _, }; +/// Matching error set for Alert.Description. pub const Error = error{ - /// An inappropriate message (e.g., the wrong - /// handshake message, premature Application Data, etc.) was received. - /// This alert should never be observed in communication between - /// proper implementations. TlsUnexpectedMessage, - /// This alert is returned if a record is received which - /// cannot be deprotected. Because AEAD algorithms combine decryption - /// and verification, and also to avoid side-channel attacks, this - /// alert is used for all deprotection failures. This alert should - /// never be observed in communication between proper implementations, - /// except when messages were corrupted in the network. TlsBadRecordMac, - /// A TLSCiphertext record was received that had a - /// length more than 2^14 + 256 bytes, or a record decrypted to a - /// TLSPlaintext record with more than 2^14 bytes (or some other - /// negotiated limit). This alert should never be observed in - /// communication between proper implementations, except when messages - /// were corrupted in the network. TlsRecordOverflow, - /// Receipt of a "handshake_failure" alert message - /// indicates that the sender was unable to negotiate an acceptable - /// set of security parameters given the options available. TlsHandshakeFailure, - /// A certificate was corrupt, contained signatures - /// that did not verify correctly, etc. TlsBadCertificate, - /// A certificate was of an unsupported type. TlsUnsupportedCertificate, - /// A certificate was revoked by its signer. TlsCertificateRevoked, - /// A certificate has expired or is not currently valid. TlsCertificateExpired, - /// Some other (unspecified) issue arose in processing the certificate, rendering it unacceptable. TlsCertificateUnknown, - /// A field in the handshake was incorrect or - /// inconsistent with other fields. This alert is used for errors - /// which conform to the formal protocol syntax but are otherwise - /// incorrect. TlsIllegalParameter, - /// A valid certificate chain or partial chain was received, - /// but the certificate was not accepted because the CA certificate - /// could not be located or could not be matched with a known trust - /// anchor. TlsUnknownCa, - /// A valid certificate or PSK was received, but when - /// access control was applied, the sender decided not to proceed with - /// negotiation. TlsAccessDenied, - /// A message could not be decoded because some field was - /// out of the specified range or the length of the message was - /// incorrect. This alert is used for errors where the message does - /// not conform to the formal protocol syntax. This alert should - /// never be observed in communication between proper implementations, - /// except when messages were corrupted in the network. TlsDecodeError, - /// A handshake (not record layer) cryptographic - /// operation failed, including being unable to correctly verify a - /// signature or validate a Finished message or a PSK binder. TlsDecryptError, - /// The protocol version the peer has attempted to - /// negotiate is recognized but not supported (see Appendix D). TlsProtocolVersion, - /// Returned instead of "handshake_failure" when - /// a negotiation has failed specifically because the server requires - /// parameters more secure than those supported by the client. TlsInsufficientSecurity, - /// An internal error unrelated to the peer or the - /// correctness of the protocol (such as a memory allocation failure) - /// makes it impossible to continue. TlsInternalError, - /// Sent by a server in response to an invalid - /// connection retry attempt from a client (see [RFC7507]). TlsInappropriateFallback, - /// Sent by endpoints that receive a handshake - /// message not containing an extension that is mandatory to send for - /// the offered TLS version or other negotiated parameters. TlsMissingExtension, - /// Sent by endpoints receiving any handshake - /// message containing an extension known to be prohibited for - /// inclusion in the given handshake message, or including any - /// extensions in a ServerHello or Certificate not first offered in - /// the corresponding ClientHello or CertificateRequest. TlsUnsupportedExtension, - /// Sent by servers when no server exists identified - /// by the name provided by the client via the "server_name" extension - /// (see [RFC6066]). TlsUnrecognizedName, - /// Sent by clients when an invalid or - /// unacceptable OCSP response is provided by the server via the - /// "status_request" extension (see [RFC6066]). TlsBadCertificateStatusResponse, - /// Sent by servers when PSK key establishment is - /// desired but no acceptable PSK identity is provided by the client. - /// Sending this alert is OPTIONAL; servers MAY instead choose to send - /// a "decrypt_error" alert to merely indicate an invalid PSK - /// identity. TlsUnknownPskIdentity, - /// Sent by servers when a client certificate is - /// desired but none was provided by the client. TlsCertificateRequired, - /// Sent by servers when a client - /// "application_layer_protocol_negotiation" extension advertises only - /// protocols that the server does not support (see [RFC7301]). TlsNoApplicationProtocol, TlsUnknown, }; @@ -362,32 +291,114 @@ pub const Alert = struct { _, }; pub const Description = enum(u8) { + /// Stream is closing. close_notify = 0, + /// An inappropriate message (e.g., the wrong + /// handshake message, premature Application Data, etc.) was received. + /// This alert should never be observed in communication between + /// proper implementations. unexpected_message = 10, + /// This alert is returned if a record is received which + /// cannot be deprotected. Because AEAD algorithms combine decryption + /// and verification, and also to avoid side-channel attacks, this + /// alert is used for all deprotection failures. This alert should + /// never be observed in communication between proper implementations, + /// except when messages were corrupted in the network. bad_record_mac = 20, + /// A TLSCiphertext record was received that had a + /// length more than 2^14 + 256 bytes, or a record decrypted to a + /// TLSPlaintext record with more than 2^14 bytes (or some other + /// negotiated limit). This alert should never be observed in + /// communication between proper implementations, except when messages + /// were corrupted in the network. record_overflow = 22, + /// Receipt of a "handshake_failure" alert message + /// indicates that the sender was unable to negotiate an acceptable + /// set of security parameters given the options available. handshake_failure = 40, + /// A certificate was corrupt, contained signatures + /// that did not verify correctly, etc. bad_certificate = 42, + /// A certificate was of an unsupported type. unsupported_certificate = 43, + /// A certificate was revoked by its signer. certificate_revoked = 44, + /// A certificate has expired or is not currently valid. certificate_expired = 45, + /// Some other (unspecified) issue arose in processing the certificate, + /// rendering it unacceptable. certificate_unknown = 46, + /// A field in the handshake was incorrect or + /// inconsistent with other fields. This alert is used for errors + /// which conform to the formal protocol syntax but are otherwise + /// incorrect. illegal_parameter = 47, + /// A valid certificate chain or partial chain was received, + /// but the certificate was not accepted because the CA certificate + /// could not be located or could not be matched with a known trust + /// anchor. unknown_ca = 48, + /// A valid certificate or PSK was received, but when + /// access control was applied, the sender decided not to proceed with + /// negotiation. access_denied = 49, + /// A message could not be decoded because some field was + /// out of the specified range or the length of the message was + /// incorrect. This alert is used for errors where the message does + /// not conform to the formal protocol syntax. This alert should + /// never be observed in communication between proper implementations, + /// except when messages were corrupted in the network. decode_error = 50, + /// A handshake (not record layer) cryptographic + /// operation failed, including being unable to correctly verify a + /// signature or validate a Finished message or a PSK binder. decrypt_error = 51, + /// The protocol version the peer has attempted to + /// negotiate is recognized but not supported (see Appendix D). protocol_version = 70, + /// Returned instead of "handshake_failure" when + /// a negotiation has failed specifically because the server requires + /// parameters more secure than those supported by the client. insufficient_security = 71, + /// An internal error unrelated to the peer or the + /// correctness of the protocol (such as a memory allocation failure) + /// makes it impossible to continue. internal_error = 80, + /// Sent by a server in response to an invalid + /// connection retry attempt from a client (see [RFC7507]). inappropriate_fallback = 86, + /// User cancelled handshake. user_canceled = 90, + /// Sent by endpoints that receive a handshake + /// message not containing an extension that is mandatory to send for + /// the offered TLS version or other negotiated parameters. missing_extension = 109, + /// Sent by endpoints receiving any handshake + /// message containing an extension known to be prohibited for + /// inclusion in the given handshake message, or including any + /// extensions in a ServerHello or Certificate not first offered in + /// the corresponding ClientHello or CertificateRequest. unsupported_extension = 110, + /// Sent by servers when no server exists identified + /// by the name provided by the client via the "server_name" extension + /// (see [RFC6066]). unrecognized_name = 112, + /// Sent by clients when an invalid or + /// unacceptable OCSP response is provided by the server via the + /// "status_request" extension (see [RFC6066]). bad_certificate_status_response = 113, + /// Sent by servers when PSK key establishment is + /// desired but no acceptable PSK identity is provided by the client. + /// Sending this alert is OPTIONAL; servers MAY instead choose to send + /// a "decrypt_error" alert to merely indicate an invalid PSK + /// identity. unknown_psk_identity = 115, + /// Sent by servers when a client certificate is + /// desired but none was provided by the client. certificate_required = 116, + /// Sent by servers when a client + /// "application_layer_protocol_negotiation" extension advertises only + /// protocols that the server does not support (see [RFC7301]). no_application_protocol = 120, _, @@ -700,7 +711,7 @@ pub const HandshakeCipher = union(CipherSuite) { const Self = @This(); - pub fn init(suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8) !Self { + pub fn init(suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8) Error!Self { switch (suite) { inline .aes_128_gcm_sha256, .aes_256_gcm_sha384, @@ -741,7 +752,7 @@ pub const HandshakeCipher = union(CipherSuite) { return res; }, .empty_renegotiation_info_scsv => return .{ .empty_renegotiation_info_scsv = {} }, - _ => return error.TlsIllegalParameter, + _ => return Error.TlsIllegalParameter, } } @@ -1050,7 +1061,7 @@ pub const StreamInterface = struct { /// * Fragmentation /// * Encryption and decryption of handshake and application data messages /// * Reading and writing prefix length arrays -/// * TLS Alerts +/// * Alerts pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { // TODO: Support RFC 6066 MaxFragmentLength and give fragment_size option to Client+Server. if (fragment_size > std.math.maxInt(u16)) @compileError("choose a smaller fragment_size"); @@ -1150,7 +1161,9 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { if (self.application_cipher == null) { switch (self.content_type) { .change_cipher_spec, .alert => {}, - else => self.transcript_hash.update(self.view), + else => { + self.transcript_hash.update(self.view); + } } } @@ -1209,7 +1222,8 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { self.flush() catch {}; self.close(); - return err.toError(); + @panic("writeError"); + // return err.toError(); } pub fn close(self: *Self) void { @@ -1347,7 +1361,8 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { const ciphertext = self.view[0..self.view.len - tag_len]; const tag = self.view[self.view.len - tag_len..][0..tag_len].*; const out: []u8 = @constCast(self.view[0..ciphertext.len]); - try p.decrypt(ciphertext, &plaintext_header, tag, self.is_client, out); + p.decrypt(ciphertext, &plaintext_header, tag, self.is_client, out) catch + return self.writeError(.bad_record_mac); const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); if (padding_start) |s| { res = @enumFromInt(self.view[s]); @@ -1395,16 +1410,23 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { const actual_content = try self.readFragment(); if (expected_content != actual_content) { std.debug.print("expected {} got {}\n", .{ expected_content, actual_content }); - return self.writeError(.decode_error); } if (expected_handshake) |expected| { const actual_handshake = try self.read(HandshakeType); if (actual_handshake != expected) return self.writeError(.decode_error); - // TODO: verify this? - _ = try self.read(u24); + const stated_len = try self.read(u24); + if (stated_len != self.view.len) return self.writeError(.decode_error); } } + pub fn expectHandshake(self: *Self) ReadError!Handshake.Header { + try self.expectFragment(.handshake, null); + const ty = try self.read(HandshakeType); + const len = try self.read(u24); + if (self.view.len != len) return self.writeError(.decode_error); + return .{ .type = ty, .len = len }; + } + pub fn read(self: *Self, comptime T: type) ReadError!T { comptime std.debug.assert(@sizeOf(T) < fragment_size); switch (@typeInfo(T)) { @@ -1513,6 +1535,7 @@ pub const MultiHash = struct { active: enum { all, sha256, sha384, sha512 } = .all, const sha2 = crypto.hash.sha2; + pub const max_digest_len = sha2.Sha512.digest_length; const Self = @This(); pub fn update(self: *Self, bytes: []const u8) void { @@ -1528,12 +1551,13 @@ pub const MultiHash = struct { } } - pub fn setActive(self: *Self, cipher_suite: CipherSuite) Error!void { + pub fn setActive(self: *Self, cipher_suite: CipherSuite) void { self.active = switch (cipher_suite) { .aes_128_gcm_sha256, .chacha20_poly1305_sha256, .aegis_128l_sha256 => .sha256, .aes_256_gcm_sha384 => .sha384, .aegis_256_sha512 => .sha512, - else => return Error.TlsIllegalParameter, + .empty_renegotiation_info_scsv => .all, + _ => .all, }; } @@ -1700,11 +1724,11 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { tag: [AEAD.tag_length]u8, is_client: bool, out: []u8, - ) Error!void { + ) !void { const key = if (is_client) self.server_key else self.client_key; const iv = if (is_client) self.server_iv else self.client_iv; const nonce = nonce_for_len(AEAD.nonce_length, iv, self.read_seq); - AEAD.decrypt(out, data, tag, additional, nonce, key) catch return Error.TlsBadRecordMac; + try AEAD.decrypt(out, data, tag, additional, nonce, key); self.read_seq += 1; } @@ -1771,6 +1795,19 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } +/// Slice of stack allocated signature content from RFC 8446 S4.4.3 +pub inline fn sigContent(digest: []const u8) []const u8 { + const max_digest_len = MultiHash.max_digest_len; + var buf = [_]u8{0x20} ** 64 + ++ "TLS 1.3, server CertificateVerify\x00".* + ++ @as([max_digest_len]u8, undefined) + ; + @memcpy(buf[buf.len - max_digest_len..][0..digest.len], digest); + + return buf[0..buf.len - (max_digest_len - digest.len)]; +} + + fn fieldsLen(comptime T: type) comptime_int { var res: comptime_int = 0; inline for (std.meta.fields(T)) |f| res += @sizeOf(f.type); @@ -1881,7 +1918,7 @@ test "tls client and server handshake, data, and close_notify" { .stream = &inner_stream, .is_client = true, }, - .options = .{ .host = host, .ca_bundle = null }, + .options = .{ .host = host, .ca_bundle = null, .allocator = allocator }, }; const server_der = @embedFile("./testdata/server.der"); @@ -2101,7 +2138,7 @@ test "tls client and server handshake, data, and close_notify" { } }); try server.stream.flush(); - try server.send_handshake_finish(); + try server.send_finished(); { const buf = tmp_buf[0..inner_stream.buffer.len()]; try inner_stream.peek(buf); @@ -2254,9 +2291,8 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &expected, buf); } + try client.stream.expectFragment(.handshake, .server_hello); try client.recv_hello(key_pairs); - - // Test that ALL shared keys are identical { const s = server.stream.handshake_cipher.?.aes_256_gcm_sha384; const c = client.stream.handshake_cipher.?.aes_256_gcm_sha384; @@ -2272,6 +2308,20 @@ test "tls client and server handshake, data, and close_notify" { const client_iv = [_]u8{ 0x42,0x56,0xd2,0xe0,0xe8,0x8b,0xab,0xdd,0x05,0xeb,0x2f,0x27 }; try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } + try client.stream.expectFragment(.handshake, .encrypted_extensions); + try client.recv_encrypted_extensions(); + try client.stream.expectFragment(.handshake, .certificate); + const cert = try client.recv_certificate(); + defer allocator.free(cert.certificate.buffer); + + var digest = client.stream.transcript_hash.peek(); + try client.stream.expectFragment(.handshake, .certificate_verify); + try client.recv_certificate_verify(digest, cert); + + digest = client.stream.transcript_hash.peek(); + try client.stream.expectFragment(.handshake, .finished); + try client.recv_finished(digest); + { const s = server.stream.application_cipher.?.aes_256_gcm_sha384; const c = client.stream.application_cipher.?.aes_256_gcm_sha384; @@ -2312,7 +2362,7 @@ test "tls client and server handshake, data, and close_notify" { ; try std.testing.expectEqualSlices(u8, &expected, buf); } - try server.recv_finish(); + try server.recv_finished(); _ = try client.stream.writer().writeAll("ping"); try client.stream.flush(); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c215011a75b6..526c0fe952b2 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -12,25 +12,74 @@ pub fn Client(comptime StreamType: type) type { stream: tls.Stream(tls.Plaintext.max_length, StreamType), options: Options, + state: State = .start, + + const State = enum { + start, + recv_encrypted_extensions, + recv_finished, + sent_finished, + }; const Self = @This(); /// Initiates a TLS handshake and establishes a TLSv1.3 session pub fn init(stream: *StreamType, options: Options) !Self { - var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ + const stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ .stream = stream, .is_client = true, }; var res = Self{ .stream = stream_, .options = options }; - { - const key_pairs = try KeyPairs.init(); - try res.send_hello(key_pairs); - try res.recv_hello(key_pairs); - } - _ = &stream_; + + while (res.state != .sent_finished) try res.advance(); return res; } + /// Advance to next handshake state. + pub fn advance(self: *Self) !void { + var stream = &self.stream; + switch (self.state) { + .start => { + const key_pairs = KeyPairs.init(); + try self.send_hello(key_pairs); + + try stream.expectFragment(.handshake, .server_hello); + try self.recv_hello(key_pairs); + + try stream.expectFragment(.handshake, .encrypted_extensions); + try self.recv_encrypted_extensions(); + + self.state = .recv_encrypted_extensions; + }, + .recv_encrypted_extensions => { + var digest = stream.transcript_hash.owned(); + const header = try stream.expectHandshake(); + switch (header.type) { + .certificate => { + const parsed = try self.recv_certificate(); + defer self.options.allocator.free(parsed.certificate.buffer); + try self.recv_certificate_verify(parsed); + + digest = stream.transcript_hash.owned(); + try stream.expectFragment(.handshake, .finished); + try self.recv_finished(digest); + self.state = .recv_finished; + }, + .finished => { + try self.recv_finished(digest); + self.state = .recv_finished; + }, + else => return self.stream.writeError(.unexpected_message), + } + }, + .recv_finished => { + try self.send_finished(); + self.state = .sent_finished; + }, + .sent_finished => {}, + } + } + pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { const hello = tls.ClientHello{ .random = key_pairs.hello_rand, @@ -58,91 +107,95 @@ pub fn Client(comptime StreamType: type) type { } pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { - try self.stream.expectFragment(.handshake, .server_hello); - var reader = self.stream.reader(); + var stream = &self.stream; + var reader = stream.reader(); // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. - _ = try self.stream.read(tls.Version); + _ = try stream.read(tls.Version); var random: [32]u8 = undefined; try reader.readNoEof(&random); if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { // We already offered all our supported options and we aren't changing them. - return error.TlsUnexpectedMessage; + return stream.writeError(.unexpected_message); } var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; - const session_id_len = try self.stream.read(u8); - if (session_id_len > tls.ClientHello.session_id_max_len) return error.TlsUnexpectedMessage; + const session_id_len = try stream.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return stream.writeError(.illegal_parameter); const session_id: []u8 = session_id_buf[0..session_id_len]; try reader.readNoEof(session_id); - if (!mem.eql(u8, session_id, &key_pairs.session_id)) return error.TlsIllegalParameter; + if (!mem.eql(u8, session_id, &key_pairs.session_id)) + return stream.writeError(.illegal_parameter); - const cipher_suite = try self.stream.read(tls.CipherSuite); - const compression_method = try self.stream.read(u8); - if (compression_method != 0) return error.TlsIllegalParameter; + const cipher_suite = try stream.read(tls.CipherSuite); + const compression_method = try stream.read(u8); + if (compression_method != 0) return stream.writeError(.illegal_parameter); var supported_version: ?tls.Version = null; var shared_key: ?[]const u8 = null; - var iter = try self.stream.extensions(); + var iter = try stream.extensions(); while (try iter.next()) |ext| { switch (ext.type) { .supported_versions => { - if (supported_version != null) return error.TlsIllegalParameter; - supported_version = try self.stream.read(tls.Version); + if (supported_version != null) return stream.writeError(.illegal_parameter); + supported_version = try stream.read(tls.Version); }, .key_share => { - if (shared_key != null) return error.TlsIllegalParameter; - const named_group = try self.stream.read(tls.NamedGroup); - const key_size = try self.stream.read(u16); + if (shared_key != null) return stream.writeError(.illegal_parameter); + const named_group = try stream.read(tls.NamedGroup); + const key_size = try stream.read(u16); switch (named_group) { .x25519_kyber768d00 => { const T = tls.NamedGroupT(.x25519_kyber768d00); const x25519_len = T.X25519.public_length; const expected_len = x25519_len + T.Kyber768.ciphertext_length; - if (key_size != expected_len) - return error.TlsIllegalParameter; + if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; try reader.readNoEof(&server_ks); - shared_key = &((T.X25519.scalarmult( + const mult = T.X25519.scalarmult( key_pairs.x25519.secret_key, server_ks[0..x25519_len].*, - ) catch return error.TlsDecryptFailure) ++ (key_pairs.kyber768d00.secret_key.decaps( + ) catch return stream.writeError(.decrypt_error); + const decaps = key_pairs.kyber768d00.secret_key.decaps( server_ks[x25519_len..expected_len], - ) catch return error.TlsDecryptFailure)); + ) catch return stream.writeError(.decrypt_error); + shared_key = &(mult ++ decaps); }, .x25519 => { const T = tls.NamedGroupT(.x25519); const expected_len = T.public_length; - if (key_size != expected_len) return error.TlsIllegalParameter; + if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; try reader.readNoEof(&server_ks); - shared_key = &(crypto.dh.X25519.scalarmult( + const mult = crypto.dh.X25519.scalarmult( key_pairs.x25519.secret_key, server_ks[0..expected_len].*, - ) catch return error.TlsDecryptFailure); + ) catch return stream.writeError(.illegal_parameter); + shared_key = &mult; }, inline .secp256r1, .secp384r1 => |t| { const T = tls.NamedGroupT(t); const expected_len = T.PublicKey.compressed_sec1_encoded_length; - if (key_size != expected_len) return error.TlsIllegalParameter; + if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; try reader.readNoEof(&server_ks); - const pk = T.PublicKey.fromSec1(&server_ks) catch { - return error.TlsDecryptFailure; - }; + const pk = T.PublicKey.fromSec1(&server_ks) catch + return stream.writeError(.illegal_parameter); const key_pair = @field(key_pairs, @tagName(t)); - const mul = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch { - return error.TlsDecryptFailure; - }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); + const mult = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch + return stream.writeError(.illegal_parameter); + shared_key = &mult.affineCoordinates().x.toBytes(.big); }, + // Server sent us back unknown key. That's weird because we only request known ones, + // but we can try for another. else => { - return error.TlsIllegalParameter; + try reader.skipBytes(key_size, .{}); }, } }, @@ -152,101 +205,195 @@ pub fn Client(comptime StreamType: type) type { } } - if (supported_version != tls.Version.tls_1_3) return error.TlsIllegalParameter; - if (shared_key == null) return error.TlsIllegalParameter; + if (supported_version != tls.Version.tls_1_3) return stream.writeError(.protocol_version); + if (shared_key == null) return stream.writeError(.missing_extension); - try self.stream.transcript_hash.setActive(cipher_suite); - const hello_hash = self.stream.transcript_hash.peek(); - self.stream.handshake_cipher = try tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash); + stream.transcript_hash.setActive(cipher_suite); + const hello_hash = stream.transcript_hash.peek(); + stream.handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash) catch return stream.writeError(.illegal_parameter); + } - { - try self.stream.expectFragment(.handshake, .encrypted_extensions); - iter = try self.stream.extensions(); - while (try iter.next()) |ext| { - try reader.skipBytes(ext.len, .{}); - } + /// Currently skipped. + pub fn recv_encrypted_extensions(self: *Self) !void { + var stream = &self.stream; + var reader = stream.reader(); + + var iter = try stream.extensions(); + while (try iter.next()) |ext| { + try reader.skipBytes(ext.len, .{}); } + } - // CertificateRequest* - // Certificate* - // CertificateVerify* - { - try self.stream.expectFragment(.handshake, .certificate); - - var context: [tls.Certificate.max_context_len]u8 = undefined; - const context_len = try self.stream.read(u8); - try reader.readNoEof(context[0..context_len]); - - var certs_iter = try self.stream.iterator(u24, u24); - while (try certs_iter.next()) |cert_len| { - try reader.skipBytes(cert_len, .{}); - var ext_iter = try self.stream.extensions(); - while (try ext_iter.next()) |ext| { - switch (ext.type) { - else => { - try reader.skipBytes(ext.len, .{}); - }, - } + /// Allocates `server_cert`. + pub fn recv_certificate(self: *Self) !crypto.Certificate.Parsed { + var stream = &self.stream; + var reader = stream.reader(); + const allocator = self.options.allocator; + + var context: [tls.Certificate.max_context_len]u8 = undefined; + const context_len = try stream.read(u8); + if (context_len > tls.Certificate.max_context_len) return stream.writeError(.decode_error); + try reader.readNoEof(context[0..context_len]); + + var res: ?crypto.Certificate.Parsed = null; + + var certs_iter = try stream.iterator(u24, u24); + while (try certs_iter.next()) |cert_len| { + if (cert_len > tls.Certificate.Entry.max_data_len) + return stream.writeError(.decode_error); + const buf = allocator.alloc(u8, cert_len) catch + return stream.writeError(.internal_error); + errdefer allocator.free(buf); + try reader.readNoEof(buf); + + const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; + res = cert.parse() catch + return stream.writeError(.bad_certificate); + + var ext_iter = try stream.extensions(); + while (try ext_iter.next()) |ext| { + switch (ext.type) { + else => { + try reader.skipBytes(ext.len, .{}); + }, } } } - { - try self.stream.expectFragment(.handshake, .certificate_verify); - - const scheme = try self.stream.read(tls.SignatureScheme); - const len = try self.stream.read(u16); - try reader.skipBytes(len, .{}); + return if (res) |r| r else stream.writeError(.bad_certificate); + } - // TODO: verify - _ = .{ scheme }; + /// Deallocates `server_cert` + pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: crypto.Certificate.Parsed,) !void { + var stream = &self.stream; + var reader = stream.reader(); + const allocator = self.options.allocator; + + const sig_content = tls.sigContent(digest); + + const scheme = try stream.read(tls.SignatureScheme); + const len = try stream.read(u16); + if (len > tls.CertificateVerify.max_signature_length) + return stream.writeError(.decode_error); + const sig_bytes = allocator.alloc(u8, len) catch + return stream.writeError(.internal_error); + defer allocator.free(sig_bytes); + try reader.readNoEof(sig_bytes); + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (cert.pub_key_algo != .X9_62_id_ecPublicKey) + return stream.writeError(.bad_certificate); + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(sig_bytes); + const key = try Ecdsa.PublicKey.fromSec1(cert.pubKey()); + try sig.verify(sig_content, key); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (cert.pub_key_algo != .rsaEncryption) + return stream.writeError(.bad_certificate); + + const Hash = SchemeHash(comptime_scheme); + const rsa = Certificate.rsa; + const components = try rsa.PublicKey.parseDer(cert.pubKey()); + const exponent = components.exponent; + const modulus = components.modulus; + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key = try rsa.PublicKey.fromBytes(exponent, modulus); + const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); + try rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash); + }, + else => { + return error.TlsBadRsaSignatureBitCount; + }, + } + }, + inline .ed25519 => |comptime_scheme| { + if (cert.pub_key_algo != .curveEd25519) + return stream.writeError(.bad_certificate); + const Eddsa = SchemeEddsa(comptime_scheme); + if (sig_content.len != Eddsa.Signature.encoded_length) + return stream.writeError(.decode_error); + const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); + if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) + return stream.writeError(.decode_error); + const key = try Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*); + try sig.verify(sig_content, key); + }, + else => { + return error.TlsBadSignatureScheme; + }, } + } - { - try self.stream.expectFragment(.handshake, .finished); + pub fn recv_finished(self: *Self, digest: []const u8) !void { + var stream = &self.stream; + var reader = stream.reader(); + const cipher = stream.handshake_cipher.?; - var verify_data: [48]u8 = undefined; - try reader.readNoEof(&verify_data); + const expected = switch (cipher) { + .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), + inline else => |p| brk: { + const P = @TypeOf(p); + break :brk &tls.hmac(P.Hmac, digest, p.server_finished_key); + } + }; - // TODO: verify - _ = .{ verify_data }; - } + // This message's stated length is in the handshake header, which `expectFragment` skips + // over. Cheat and rip it out of the view. + const actual = stream.view; + try reader.skipBytes(stream.view.len, .{}); + + if (!mem.eql(u8, expected, actual)) return stream.writeError(.decode_error); - self.stream.application_cipher = tls.ApplicationCipher.init( - self.stream.handshake_cipher.?, - self.stream.transcript_hash.peek(), + stream.application_cipher = tls.ApplicationCipher.init( + stream.handshake_cipher.?, + stream.transcript_hash.peek(), ); } pub fn send_finished(self: *Self) !void { - self.stream.version = .tls_1_2; - self.stream.content_type = .change_cipher_spec; - _ = try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); - try self.stream.flush(); + var stream = &self.stream; + + stream.version = .tls_1_2; + stream.content_type = .change_cipher_spec; + _ = try stream.write(tls.ChangeCipherSpec, .change_cipher_spec); + try stream.flush(); - const verify_data = switch (self.stream.handshake_cipher.?) { + const verify_data = switch (stream.handshake_cipher.?) { inline + .aes_128_gcm_sha256, .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, => |v| brk: { const T = @TypeOf(v); const secret = v.client_finished_key; - const transcript_hash = self.stream.transcript_hash.peek(); + const transcript_hash = stream.transcript_hash.peek(); - break :brk tls.hmac(T.Hmac, transcript_hash, secret); + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); }, - else => return error.TlsDecryptFailure, + else => return stream.writeError(.decrypt_error), }; - self.stream.content_type = .handshake; - _ = try self.stream.write(tls.Handshake, .{ .finished = &verify_data }); - try self.stream.flush(); + stream.content_type = .handshake; + _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); + try stream.flush(); - self.stream.content_type = .application_data; + stream.content_type = .application_data; } }; } pub const Options = struct { - /// Used to verify certificate chain. If null will **dangerously** skip certificate verification. + /// Trusted certificate authority bundle used to authenticate server certificates. + /// When null, server certificate and certificate_verify messages will be skipped. ca_bundle: ?Certificate.Bundle, /// Used to verify cerficate chain and for Server Name Indication. host: []const u8, @@ -260,9 +407,12 @@ pub const Options = struct { /// application layer itself verifies that the amount of data received equals /// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool = false, + /// Certificate messages may be up to 2^24-1 bytes long. + /// Certificate verify messages may be up to 2^16-1 bytes long. + /// This is the allocator used for them. + allocator: std.mem.Allocator, }; - /// One of these potential key pairs will be selected during the handshake. pub const KeyPairs = struct { hello_rand: [hello_rand_length]u8, @@ -284,11 +434,11 @@ pub const KeyPairs = struct { pub fn init() Self { var random_buffer: [ hello_rand_length + - session_id_length + - Kyber768.seed_length + - Secp256r1.seed_length + - Secp384r1.seed_length + - X25519.seed_length + session_id_length + + Kyber768.seed_length + + Secp256r1.seed_length + + Secp384r1.seed_length + + X25519.seed_length ]u8 = undefined; while (true) { @@ -336,77 +486,26 @@ pub const KeyPairs = struct { } }; -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -const VecPut = struct { - iovecs: []const std.os.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.iov_base[vp.off..v.iov_len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.iov_len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } - - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.iov_base[vp.off..v.iov_len]; - } - - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].iov_len) { - vp.off = 0; - vp.idx += 1; - } - } +fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, + else => @compileError("bad scheme"), + }; +} - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].iov_len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len; - return total; - } -}; +fn SchemeHash(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, + .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, + .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, + else => @compileError("bad scheme"), + }; +} -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.iov_len) { - iovec.iov_len = bytes_left; - return iovecs[0 .. vec_i + 1]; - } - bytes_left -= iovec.iov_len; - } - return iovecs; +fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .ed25519 => crypto.sign.Ed25519, + else => @compileError("bad scheme"), + }; } diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index dd3546e0879d..c0561363b32c 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -50,35 +50,39 @@ pub fn Server(comptime StreamType: type) type { }; pub fn recv_hello(self: *Self) !ClientHello { - try self.stream.expectFragment(.handshake, .client_hello); - var reader = self.stream.reader(); + var stream = &self.stream; + var reader = stream.reader(); - _ = try self.stream.read(tls.Version); + try stream.expectFragment(.handshake, .client_hello); + + _ = try stream.read(tls.Version); var client_random: [32]u8 = undefined; try reader.readNoEof(&client_random); var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; - const session_id_len = try self.stream.read(u8); - if (session_id_len > tls.ClientHello.session_id_max_len) return error.TlsUnexpectedMessage; + const session_id_len = try stream.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return stream.writeError(.illegal_parameter); try reader.readNoEof(session_id[0..session_id_len]); const cipher_suite: tls.CipherSuite = brk: { - var cipher_suite_iter = try self.stream.iterator(u16, tls.CipherSuite); + var cipher_suite_iter = try stream.iterator(u16, tls.CipherSuite); var res: ?tls.CipherSuite = null; while (try cipher_suite_iter.next()) |suite| { for (self.options.cipher_suites) |s| { if (s == suite and res == null) res = s; } } - if (res == null) return error.TlsUnexpectedMessage; + if (res == null) return stream.writeError(.illegal_parameter); break :brk res.?; }; - try self.stream.transcript_hash.setActive(cipher_suite); + stream.transcript_hash.setActive(cipher_suite); { var compression_methods: [2]u8 = undefined; try reader.readNoEof(&compression_methods); - if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) return error.TlsUnexpectedMessage; + if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) + return stream.writeError(.illegal_parameter); } var tls_version: ?tls.Version = null; @@ -86,21 +90,21 @@ pub fn Server(comptime StreamType: type) type { var ec_point_format: ?tls.EcPointFormat = null; var sig_scheme: ?tls.SignatureScheme = null; - var extension_iter = try self.stream.extensions(); + var extension_iter = try stream.extensions(); while (try extension_iter.next()) |ext| { switch (ext.type) { .supported_versions => { - if (tls_version != null) return error.TlsUnexpectedMessage; - var versions_iter = try self.stream.iterator(u8, tls.Version); + if (tls_version != null) return stream.writeError(.illegal_parameter); + var versions_iter = try stream.iterator(u8, tls.Version); while (try versions_iter.next()) |v| { if (v == .tls_1_3) tls_version = v; } }, // TODO: use supported_groups instead .key_share => { - if (key_share != null) return error.TlsUnexpectedMessage; + if (key_share != null) return stream.writeError(.illegal_parameter); - var key_share_iter = try self.stream.iterator(u16, tls.KeyShare); + var key_share_iter = try stream.iterator(u16, tls.KeyShare); while (try key_share_iter.next()) |ks| { switch (ks) { .x25519 => key_share = ks, @@ -109,13 +113,13 @@ pub fn Server(comptime StreamType: type) type { } }, .ec_point_formats => { - var format_iter = try self.stream.iterator(u8, tls.EcPointFormat); + var format_iter = try stream.iterator(u8, tls.EcPointFormat); while (try format_iter.next()) |f| { if (f == .uncompressed) ec_point_format = .uncompressed; } }, .signature_algorithms => { - var algos_iter = try self.stream.iterator(u16, tls.SignatureScheme); + var algos_iter = try stream.iterator(u16, tls.SignatureScheme); while (try algos_iter.next()) |algo| { if (algo == .rsa_pss_rsae_sha256) sig_scheme = algo; } @@ -126,9 +130,9 @@ pub fn Server(comptime StreamType: type) type { } } - if (tls_version == null) return error.TlsUnexpectedMessage; - if (key_share == null) return error.TlsUnexpectedMessage; - if (ec_point_format == null) return error.TlsUnexpectedMessage; + if (tls_version == null) return stream.writeError(.protocol_version); + if (key_share == null) return stream.writeError(.missing_extension); + if (ec_point_format == null) return stream.writeError(.missing_extension); return .{ .random = client_random, @@ -142,6 +146,8 @@ pub fn Server(comptime StreamType: type) type { /// `key_pair`'s active member MUST match `client_hello.key_share` pub fn send_hello(self: *Self, client_hello: ClientHello, key_pair: KeyPair) !void { + var stream = &self.stream; + const hello = tls.ServerHello{ .random = key_pair.random, .session_id = &client_hello.session_id, @@ -151,15 +157,15 @@ pub fn Server(comptime StreamType: type) type { .{ .key_share = &[_]tls.KeyShare{key_pair.pair.toKeyShare()} }, }, }; - self.stream.version = .tls_1_2; - _ = try self.stream.write(tls.Handshake, .{ .server_hello = hello }); - try self.stream.flush(); + stream.version = .tls_1_2; + _ = try stream.write(tls.Handshake, .{ .server_hello = hello }); + try stream.flush(); // > if the client sends a non-empty session ID, the server MUST send the change_cipher_spec if (hello.session_id.len > 0) { - self.stream.content_type = .change_cipher_spec; - _ = try self.stream.write(tls.ChangeCipherSpec, .change_cipher_spec); - try self.stream.flush(); + stream.content_type = .change_cipher_spec; + _ = try stream.write(tls.ChangeCipherSpec, .change_cipher_spec); + try stream.flush(); } const shared_key = switch (client_hello.key_share) { @@ -169,7 +175,7 @@ pub fn Server(comptime StreamType: type) type { const shared_point = T.X25519.scalarmult( ks.x25519, pair.x25519.secret_key, - ) catch return error.TlsDecryptFailure; + ) catch return stream.writeError(.decrypt_error); // pair.kyber768d00.secret_key // ks.kyber768d00 const encaps = ks.kyber768d00.encaps(null).ciphertext; @@ -180,94 +186,75 @@ pub fn Server(comptime StreamType: type) type { const shared_point = tls.NamedGroupT(.x25519).scalarmult( key_pair.pair.x25519.secret_key, ks, - ) catch return error.TlsDecryptFailure; + ) catch return stream.writeError(.decrypt_error); break :brk &shared_point; }, .secp256r1 => |ks| brk: { const mul = ks.p.mulPublic( key_pair.pair.secp256r1.secret_key.bytes, .big, - ) catch - return error.TlsDecryptFailure; + ) catch return stream.writeError(.decrypt_error); break :brk &mul.affineCoordinates().x.toBytes(.big); }, - else => return error.TlsIllegalParameter, + else => return stream.writeError(.illegal_parameter), }; - const hello_hash = self.stream.transcript_hash.peek(); - self.stream.handshake_cipher = try tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash); - - self.stream.content_type = .handshake; - _ = try self.stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); - try self.stream.flush(); - - _ = try self.stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); - try self.stream.flush(); - - // RFC 8446 S4.4.3 - // const signature_content = [_]u8{0x20} ** 64 - // ++ "TLS 1.3, server CertificateVerify\x00".* - // ++ self.stream.transcript_hash.peek() - // ; - - // const cert = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; - // const parsed = try cert.parse(); - // const pub_key = parsed.pubKey(); - - // switch (client_hello.sig_scheme) { - // .rsa_pss_rsae_sha256 => { - // const rsa = Certificate.rsa; - // const components = try rsa.PublicKey.parseDer(pub_key); - // const exponent = components.exponent; - // const modulus = components.modulus; - // switch (modulus.len) { - // inline 128, 256, 512 => |modulus_len| { - // const key = try rsa.PublicKey.fromBytes(exponent, modulus); - // const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - // try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - // }, - // else => { - // return error.TlsBadRsaSignatureBitCount; - // }, - // } - // }, - // else => {} - // } + const hello_hash = stream.transcript_hash.peek(); + stream.handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash) catch return stream.writeError(.illegal_parameter); + + stream.content_type = .handshake; + _ = try stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); + try stream.flush(); + + _ = try stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + try stream.flush(); } - pub fn send_handshake_finish(self: *Self) !void { - const verify_data = switch (self.stream.handshake_cipher.?) { + pub fn send_finished(self: *Self) !void { + var stream = &self.stream; + const verify_data = switch (stream.handshake_cipher.?) { inline .aes_256_gcm_sha384, => |v| brk: { const T = @TypeOf(v); const secret = v.server_finished_key; - const transcript_hash = self.stream.transcript_hash.peek(); + const transcript_hash = stream.transcript_hash.peek(); break :brk tls.hmac(T.Hmac, transcript_hash, secret); }, - else => return error.TlsDecryptFailure, + else => return stream.writeError(.illegal_parameter), }; - _ = try self.stream.write(tls.Handshake, .{ .finished = &verify_data }); - try self.stream.flush(); + _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); + try stream.flush(); - self.stream.application_cipher = tls.ApplicationCipher.init( - self.stream.handshake_cipher.?, - self.stream.transcript_hash.peek(), + stream.application_cipher = tls.ApplicationCipher.init( + stream.handshake_cipher.?, + stream.transcript_hash.peek(), ); } - pub fn recv_finish(self: *Self) !void { - try self.stream.expectFragment(.handshake, .finished); - var reader = self.stream.reader(); + pub fn recv_finished(self: *Self) !void { + var stream = &self.stream; + var reader = stream.reader(); + const cipher = stream.handshake_cipher.?; + + const expected = switch (cipher) { + .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), + inline else => |p| brk: { + const P = @TypeOf(p); + const digest = stream.transcript_hash.peek(); + break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); + } + }; + + try stream.expectFragment(.handshake, .finished); + const actual = stream.view; + try reader.skipBytes(stream.view.len, .{}); - var verify_data: [48]u8 = undefined; - try reader.readNoEof(&verify_data); + if (!mem.eql(u8, expected, actual)) return stream.writeError(.decode_error); - // TODO: verify - _ = .{ verify_data }; - self.stream.content_type = .application_data; - self.stream.handshake_type = null; + stream.content_type = .application_data; + stream.handshake_type = null; } }; } From d7e3a3f3d63b9576f77699afcdbe988f84b95acd Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Tue, 12 Mar 2024 18:16:40 -0400 Subject: [PATCH 05/17] pretty up test --- lib/std/crypto/tls.zig | 406 +++++++++++++++++++---------------------- 1 file changed, 190 insertions(+), 216 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 1692a50dee3c..9fdcb4ce98eb 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -1894,6 +1894,14 @@ const TestStream = struct { _ = try self.readAll(out); self.buffer.read_index = read_index; } + + pub fn expect(self: *Self, expected: []const u8) !void { + var tmp_buf: [Plaintext.max_length]u8 = undefined; + const buf = tmp_buf[0..self.buffer.len()]; + try self.peek(buf); + + try std.testing.expectEqualSlices(u8, expected, buf); + } }; const TestHasher = struct { @@ -2030,12 +2038,7 @@ test "tls client and server handshake, data, and close_notify" { try client.stream.flush(); } - var tmp_buf: [Plaintext.max_length]u8 = undefined; - { - const buf = tmp_buf[0..inner_stream.buffer.len()]; - try inner_stream.peek(buf); - - const expected = [_]u8{ + try inner_stream.expect([_]u8{ 0x16, // handshake 0x03, 0x01, // tls 1.0 (lie for compat) 0x00, 0xf8, // handshake len @@ -2115,9 +2118,8 @@ test "tls client and server handshake, data, and close_notify" { 0x00, 0x24, // key shares len 0x00, 0x1d, // curve 25519 0x00, 0x20, // key len - } ++ key_pairs.x25519.public_key; - try std.testing.expectEqualSlices(u8, expected, buf); - } + } ++ key_pairs.x25519.public_key + ); const client_hello = try server.recv_hello(); try std.testing.expectEqualSlices(u8, &client_random, &client_hello.random); @@ -2139,157 +2141,150 @@ test "tls client and server handshake, data, and close_notify" { try server.stream.flush(); try server.send_finished(); - { - const buf = tmp_buf[0..inner_stream.buffer.len()]; - try inner_stream.peek(buf); - - const expected = [_]u8{ - 0x16, // handshake - 0x03, 0x03, // tls 1.2 - 0x00, 0x7a, // Handshake len - 0x02, // server hello - 0x00, 0x00, 0x76, // server hello len - 0x03, 0x03, // tls 1.2 - } ++ server_key_pair.random ++ [_]u8{session_id.len} ++ session_id ++ - [_]u8{ - 0x13, 0x02, // aes_256_gcm_sha384 - 0x00, // compression method - 0x00, 0x2e, // extensions len - 0x00, 0x2b, // supported versions - 0x00, 0x02, // ext len - 0x03, 0x04, // tls 1.3 - 0x00, 0x33, // key share - 0x00, 0x24, // ext len - 0x00, 0x1d, // x25519 - 0x00, 0x20, // key len - 0x9f, 0xd7, - 0xad, 0x6d, - 0xcf, 0xf4, - 0x29, 0x8d, - 0xd3, 0xf9, - 0x6d, 0x5b, - 0x1b, 0x2a, - 0xf9, 0x10, - 0xa0, 0x53, - 0x5b, 0x14, - 0x88, 0xd7, - 0xf8, 0xfa, - 0xbb, 0x34, - 0x9a, 0x98, - 0x28, 0x80, - 0xb6, 0x15, // key - } ++ - [_]u8{ - 0x14, // ChangeCipherSpec - 0x03, 0x03, // tls 1.2 - 0x00, 0x01, // len - 0x01, // .change_cipher_spec - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x00, 0x17, // application data len - 0x6b, 0xe0, 0x2f, 0x9d, 0xa7, 0xc2, // encrypted data (empty EncryptedExtensions message) - 0xdc, // encrypted data type (handshake) - 0x9d, 0xde, 0xf5, 0x6f, 0x24, 0x68, 0xb9, 0x0a, // auth tag - 0xdf, 0xa2, 0x51, 0x01, 0xab, 0x03, 0x44, 0xae, // auth tag - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x03, 0x43, // application data len - 0xba, 0xf0, 0x0a, 0x9b, 0xe5, 0x0f, 0x3f, 0x23, 0x07, 0xe7, 0x26, 0xed, 0xcb, 0xda, 0xcb, 0xe4, - 0xb1, 0x86, 0x16, 0x44, 0x9d, 0x46, 0xc6, 0x20, 0x7a, 0xf6, 0xe9, 0x95, 0x3e, 0xe5, 0xd2, 0x41, - 0x1b, 0xa6, 0x5d, 0x31, 0xfe, 0xaf, 0x4f, 0x78, 0x76, 0x4f, 0x2d, 0x69, 0x39, 0x87, 0x18, 0x6c, - 0xc0, 0x13, 0x29, 0xc1, 0x87, 0xa5, 0xe4, 0x60, 0x8e, 0x8d, 0x27, 0xb3, 0x18, 0xe9, 0x8d, 0xd9, - 0x47, 0x69, 0xf7, 0x73, 0x9c, 0xe6, 0x76, 0x83, 0x92, 0xca, 0xca, 0x8d, 0xcc, 0x59, 0x7d, 0x77, - 0xec, 0x0d, 0x12, 0x72, 0x23, 0x37, 0x85, 0xf6, 0xe6, 0x9d, 0x6f, 0x43, 0xef, 0xfa, 0x8e, 0x79, - 0x05, 0xed, 0xfd, 0xc4, 0x03, 0x7e, 0xee, 0x59, 0x33, 0xe9, 0x90, 0xa7, 0x97, 0x2f, 0x20, 0x69, - 0x13, 0xa3, 0x1e, 0x8d, 0x04, 0x93, 0x13, 0x66, 0xd3, 0xd8, 0xbc, 0xd6, 0xa4, 0xa4, 0xd6, 0x47, - 0xdd, 0x4b, 0xd8, 0x0b, 0x0f, 0xf8, 0x63, 0xce, 0x35, 0x54, 0x83, 0x3d, 0x74, 0x4c, 0xf0, 0xe0, - 0xb9, 0xc0, 0x7c, 0xae, 0x72, 0x6d, 0xd2, 0x3f, 0x99, 0x53, 0xdf, 0x1f, 0x1c, 0xe3, 0xac, 0xeb, - 0x3b, 0x72, 0x30, 0x87, 0x1e, 0x92, 0x31, 0x0c, 0xfb, 0x2b, 0x09, 0x84, 0x86, 0xf4, 0x35, 0x38, - 0xf8, 0xe8, 0x2d, 0x84, 0x04, 0xe5, 0xc6, 0xc2, 0x5f, 0x66, 0xa6, 0x2e, 0xbe, 0x3c, 0x5f, 0x26, - 0x23, 0x26, 0x40, 0xe2, 0x0a, 0x76, 0x91, 0x75, 0xef, 0x83, 0x48, 0x3c, 0xd8, 0x1e, 0x6c, 0xb1, - 0x6e, 0x78, 0xdf, 0xad, 0x4c, 0x1b, 0x71, 0x4b, 0x04, 0xb4, 0x5f, 0x6a, 0xc8, 0xd1, 0x06, 0x5a, - 0xd1, 0x8c, 0x13, 0x45, 0x1c, 0x90, 0x55, 0xc4, 0x7d, 0xa3, 0x00, 0xf9, 0x35, 0x36, 0xea, 0x56, - 0xf5, 0x31, 0x98, 0x6d, 0x64, 0x92, 0x77, 0x53, 0x93, 0xc4, 0xcc, 0xb0, 0x95, 0x46, 0x70, 0x92, - 0xa0, 0xec, 0x0b, 0x43, 0xed, 0x7a, 0x06, 0x87, 0xcb, 0x47, 0x0c, 0xe3, 0x50, 0x91, 0x7b, 0x0a, - 0xc3, 0x0c, 0x6e, 0x5c, 0x24, 0x72, 0x5a, 0x78, 0xc4, 0x5f, 0x9f, 0x5f, 0x29, 0xb6, 0x62, 0x68, - 0x67, 0xf6, 0xf7, 0x9c, 0xe0, 0x54, 0x27, 0x35, 0x47, 0xb3, 0x6d, 0xf0, 0x30, 0xbd, 0x24, 0xaf, - 0x10, 0xd6, 0x32, 0xdb, 0xa5, 0x4f, 0xc4, 0xe8, 0x90, 0xbd, 0x05, 0x86, 0x92, 0x8c, 0x02, 0x06, - 0xca, 0x2e, 0x28, 0xe4, 0x4e, 0x22, 0x7a, 0x2d, 0x50, 0x63, 0x19, 0x59, 0x35, 0xdf, 0x38, 0xda, - 0x89, 0x36, 0x09, 0x2e, 0xef, 0x01, 0xe8, 0x4c, 0xad, 0x2e, 0x49, 0xd6, 0x2e, 0x47, 0x0a, 0x6c, - 0x77, 0x45, 0xf6, 0x25, 0xec, 0x39, 0xe4, 0xfc, 0x23, 0x32, 0x9c, 0x79, 0xd1, 0x17, 0x28, 0x76, - 0x80, 0x7c, 0x36, 0xd7, 0x36, 0xba, 0x42, 0xbb, 0x69, 0xb0, 0x04, 0xff, 0x55, 0xf9, 0x38, 0x50, - 0xdc, 0x33, 0xc1, 0xf9, 0x8a, 0xbb, 0x92, 0x85, 0x83, 0x24, 0xc7, 0x6f, 0xf1, 0xeb, 0x08, 0x5d, - 0xb3, 0xc1, 0xfc, 0x50, 0xf7, 0x4e, 0xc0, 0x44, 0x42, 0xe6, 0x22, 0x97, 0x3e, 0xa7, 0x07, 0x43, - 0x41, 0x87, 0x94, 0xc3, 0x88, 0x14, 0x0b, 0xb4, 0x92, 0xd6, 0x29, 0x4a, 0x05, 0x40, 0xe5, 0xa5, - 0x9c, 0xfa, 0xe6, 0x0b, 0xa0, 0xf1, 0x48, 0x99, 0xfc, 0xa7, 0x13, 0x33, 0x31, 0x5e, 0xa0, 0x83, - 0xa6, 0x8e, 0x1d, 0x7c, 0x1e, 0x4c, 0xdc, 0x2f, 0x56, 0xbc, 0xd6, 0x11, 0x96, 0x81, 0xa4, 0xad, - 0xbc, 0x1b, 0xbf, 0x42, 0xaf, 0xd8, 0x06, 0xc3, 0xcb, 0xd4, 0x2a, 0x07, 0x6f, 0x54, 0x5d, 0xee, - 0x4e, 0x11, 0x8d, 0x0b, 0x39, 0x67, 0x54, 0xbe, 0x2b, 0x04, 0x2a, 0x68, 0x5d, 0xd4, 0x72, 0x7e, - 0x89, 0xc0, 0x38, 0x6a, 0x94, 0xd3, 0xcd, 0x6e, 0xcb, 0x98, 0x20, 0xe9, 0xd4, 0x9a, 0xfe, 0xed, - 0x66, 0xc4, 0x7e, 0x6f, 0xc2, 0x43, 0xea, 0xbe, 0xbb, 0xcb, 0x0b, 0x02, 0x45, 0x38, 0x77, 0xf5, - 0xac, 0x5d, 0xbf, 0xbd, 0xf8, 0xdb, 0x10, 0x52, 0xa3, 0xc9, 0x94, 0xb2, 0x24, 0xcd, 0x9a, 0xaa, - 0xf5, 0x6b, 0x02, 0x6b, 0xb9, 0xef, 0xa2, 0xe0, 0x13, 0x02, 0xb3, 0x64, 0x01, 0xab, 0x64, 0x94, - 0xe7, 0x01, 0x8d, 0x6e, 0x5b, 0x57, 0x3b, 0xd3, 0x8b, 0xce, 0xf0, 0x23, 0xb1, 0xfc, 0x92, 0x94, - 0x6b, 0xbc, 0xa0, 0x20, 0x9c, 0xa5, 0xfa, 0x92, 0x6b, 0x49, 0x70, 0xb1, 0x00, 0x91, 0x03, 0x64, - 0x5c, 0xb1, 0xfc, 0xfe, 0x55, 0x23, 0x11, 0xff, 0x73, 0x05, 0x58, 0x98, 0x43, 0x70, 0x03, 0x8f, - 0xd2, 0xcc, 0xe2, 0xa9, 0x1f, 0xc7, 0x4d, 0x6f, 0x3e, 0x3e, 0xa9, 0xf8, 0x43, 0xee, 0xd3, 0x56, - 0xf6, 0xf8, 0x2d, 0x35, 0xd0, 0x3b, 0xc2, 0x4b, 0x81, 0xb5, 0x8c, 0xeb, 0x1a, 0x43, 0xec, 0x94, - 0x37, 0xe6, 0xf1, 0xe5, 0x0e, 0xb6, 0xf5, 0x55, 0xe3, 0x21, 0xfd, 0x67, 0xc8, 0x33, 0x2e, 0xb1, - 0xb8, 0x32, 0xaa, 0x8d, 0x79, 0x5a, 0x27, 0xd4, 0x79, 0xc6, 0xe2, 0x7d, 0x5a, 0x61, 0x03, 0x46, - 0x83, 0x89, 0x19, 0x03, 0xf6, 0x64, 0x21, 0xd0, 0x94, 0xe1, 0xb0, 0x0a, 0x9a, 0x13, 0x8d, 0x86, - 0x1e, 0x6f, 0x78, 0xa2, 0x0a, 0xd3, 0xe1, 0x58, 0x00, 0x54, 0xd2, 0xe3, 0x05, 0x25, 0x3c, 0x71, - 0x3a, 0x02, 0xfe, 0x1e, 0x28, 0xde, 0xee, 0x73, 0x36, 0x24, 0x6f, 0x6a, 0xe3, 0x43, 0x31, 0x80, - 0x6b, 0x46, 0xb4, 0x7b, 0x83, 0x3c, 0x39, 0xb9, 0xd3, 0x1c, 0xd3, 0x00, 0xc2, 0xa6, 0xed, 0x83, - 0x13, 0x99, 0x77, 0x6d, 0x07, 0xf5, 0x70, 0xea, 0xf0, 0x05, 0x9a, 0x2c, 0x68, 0xa5, 0xf3, 0xae, - 0x16, 0xb6, 0x17, 0x40, 0x4a, 0xf7, 0xb7, 0x23, 0x1a, 0x4d, 0x94, 0x27, 0x58, 0xfc, 0x02, 0x0b, - 0x3f, 0x23, 0xee, 0x8c, 0x15, 0xe3, 0x60, 0x44, 0xcf, 0xd6, 0x7c, 0xd6, 0x40, 0x99, 0x3b, 0x16, - 0x20, 0x75, 0x97, 0xfb, 0xf3, 0x85, 0xea, 0x7a, 0x4d, 0x99, 0xe8, 0xd4, 0x56, 0xff, 0x83, 0xd4, - 0x1f, 0x7b, 0x8b, 0x4f, 0x06, 0x9b, 0x02, 0x8a, 0x2a, 0x63, 0xa9, 0x19, 0xa7, 0x0e, 0x3a, 0x10, - 0xe3, 0x08, // encrypted cert - 0x41, // encrypted data type (Certificate) - 0x58, 0xfa, 0xa5, 0xba, 0xfa, 0x30, 0x18, 0x6c, // auth tag - 0x6b, 0x2f, 0x23, 0x8e, 0xb5, 0x30, 0xc7, 0x3e, // auth tag - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x01, 0x19, // application data len - 0x73, 0x71, 0x9f, 0xce, 0x07, 0xec, 0x2f, 0x6d, 0x3b, 0xba, 0x02, 0x92, 0xa0, 0xd4, 0x0b, 0x27, - 0x70, 0xc0, 0x6a, 0x27, 0x17, 0x99, 0xa5, 0x33, 0x14, 0xf6, 0xf7, 0x7f, 0xc9, 0x5c, 0x5f, 0xe7, - 0xb9, 0xa4, 0x32, 0x9f, 0xd9, 0x54, 0x8c, 0x67, 0x0e, 0xbe, 0xea, 0x2f, 0x2d, 0x5c, 0x35, 0x1d, - 0xd9, 0x35, 0x6e, 0xf2, 0xdc, 0xd5, 0x2e, 0xb1, 0x37, 0xbd, 0x3a, 0x67, 0x65, 0x22, 0xf8, 0xcd, - 0x0f, 0xb7, 0x56, 0x07, 0x89, 0xad, 0x7b, 0x0e, 0x3c, 0xab, 0xa2, 0xe3, 0x7e, 0x6b, 0x41, 0x99, - 0xc6, 0x79, 0x3b, 0x33, 0x46, 0xed, 0x46, 0xcf, 0x74, 0x0a, 0x9f, 0xa1, 0xfe, 0xc4, 0x14, 0xdc, - 0x71, 0x5c, 0x41, 0x5c, 0x60, 0xe5, 0x75, 0x70, 0x3c, 0xe6, 0xa3, 0x4b, 0x70, 0xb5, 0x19, 0x1a, - 0xa6, 0xa6, 0x1a, 0x18, 0xfa, 0xff, 0x21, 0x6c, 0x68, 0x7a, 0xd8, 0xd1, 0x7e, 0x12, 0xa7, 0xe9, - 0x99, 0x15, 0xa6, 0x11, 0xbf, 0xc1, 0xa2, 0xbe, 0xfc, 0x15, 0xe6, 0xe9, 0x4d, 0x78, 0x46, 0x42, - 0xe6, 0x82, 0xfd, 0x17, 0x38, 0x2a, 0x34, 0x8c, 0x30, 0x10, 0x56, 0xb9, 0x40, 0xc9, 0x84, 0x72, - 0x00, 0x40, 0x8b, 0xec, 0x56, 0xc8, 0x1e, 0xa3, 0xd7, 0x21, 0x7a, 0xb8, 0xe8, 0x5a, 0x88, 0x71, - 0x53, 0x95, 0x89, 0x9c, 0x90, 0x58, 0x7f, 0x72, 0xe8, 0xdd, 0xd7, 0x4b, 0x26, 0xd8, 0xed, 0xc1, - 0xc7, 0xc8, 0x37, 0xd9, 0xf2, 0xeb, 0xbc, 0x26, 0x09, 0x62, 0x21, 0x90, 0x38, 0xb0, 0x56, 0x54, - 0xa6, 0x3a, 0x0b, 0x12, 0x99, 0x9b, 0x4a, 0x83, 0x06, 0xa3, 0xdd, 0xcc, 0x0e, 0x17, 0xc5, 0x3b, - 0xa8, 0xf9, 0xc8, 0x03, 0x63, 0xf7, 0x84, 0x13, 0x54, 0xd2, 0x91, 0xb4, 0xac, 0xe0, 0xc0, 0xf3, - 0x30, 0xc0, 0xfc, 0xd5, 0xaa, 0x9d, 0xee, 0xf9, 0x69, 0xae, 0x8a, 0xb2, 0xd9, 0x8d, 0xa8, 0x8e, - 0xbb, 0x6e, 0xa8, 0x0a, 0x3a, 0x11, 0xf0, 0x0e, // encrypted signature_verify - 0xa2, // encrypted data type (SignatureVerify) - 0x96, 0xa3, 0x23, 0x23, 0x67, 0xff, 0x07, 0x5e, // auth tag - 0x1c, 0x66, 0xdd, 0x9c, 0xbe, 0xdc, 0x47, 0x13, // auth tag - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x00, 0x45, // application data len - 0x10, 0x61, 0xde, 0x27, 0xe5, 0x1c, 0x2c, 0x9f, 0x34, 0x29, 0x11, 0x80, 0x6f, 0x28, 0x2b, 0x71, - 0x0c, 0x10, 0x63, 0x2c, 0xa5, 0x00, 0x67, 0x55, 0x88, 0x0d, 0xbf, 0x70, 0x06, 0x00, 0x2d, 0x0e, - 0x84, 0xfe, 0xd9, 0xad, 0xf2, 0x7a, 0x43, 0xb5, 0x19, 0x23, 0x03, 0xe4, 0xdf, 0x5c, 0x28, 0x5d, - 0x58, 0xe3, 0xc7, 0x62, - 0x24, // encrypted data type (finished) - 0x07, 0x84, 0x40, 0xc0, 0x74, 0x23, 0x74, 0x74, // auth tag - 0x4a, 0xec, 0xf2, 0x8c, 0xf3, 0x18, 0x2f, 0xd0, // auth tag - } - ; - try std.testing.expectEqualSlices(u8, &expected, buf); - } + try inner_stream.expect(&([_]u8{ + 0x16, // handshake + 0x03, 0x03, // tls 1.2 + 0x00, 0x7a, // Handshake len + 0x02, // server hello + 0x00, 0x00, 0x76, // server hello len + 0x03, 0x03, // tls 1.2 + } ++ server_key_pair.random ++ [_]u8{session_id.len} ++ session_id ++ + [_]u8{ + 0x13, 0x02, // aes_256_gcm_sha384 + 0x00, // compression method + 0x00, 0x2e, // extensions len + 0x00, 0x2b, // supported versions + 0x00, 0x02, // ext len + 0x03, 0x04, // tls 1.3 + 0x00, 0x33, // key share + 0x00, 0x24, // ext len + 0x00, 0x1d, // x25519 + 0x00, 0x20, // key len + 0x9f, 0xd7, + 0xad, 0x6d, + 0xcf, 0xf4, + 0x29, 0x8d, + 0xd3, 0xf9, + 0x6d, 0x5b, + 0x1b, 0x2a, + 0xf9, 0x10, + 0xa0, 0x53, + 0x5b, 0x14, + 0x88, 0xd7, + 0xf8, 0xfa, + 0xbb, 0x34, + 0x9a, 0x98, + 0x28, 0x80, + 0xb6, 0x15, // key + } ++ + [_]u8{ + 0x14, // ChangeCipherSpec + 0x03, 0x03, // tls 1.2 + 0x00, 0x01, // len + 0x01, // .change_cipher_spec + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x00, 0x17, // application data len + 0x6b, 0xe0, 0x2f, 0x9d, 0xa7, 0xc2, // encrypted data (empty EncryptedExtensions message) + 0xdc, // encrypted data type (handshake) + 0x9d, 0xde, 0xf5, 0x6f, 0x24, 0x68, 0xb9, 0x0a, // auth tag + 0xdf, 0xa2, 0x51, 0x01, 0xab, 0x03, 0x44, 0xae, // auth tag + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x03, 0x43, // application data len + 0xba, 0xf0, 0x0a, 0x9b, 0xe5, 0x0f, 0x3f, 0x23, 0x07, 0xe7, 0x26, 0xed, 0xcb, 0xda, 0xcb, 0xe4, + 0xb1, 0x86, 0x16, 0x44, 0x9d, 0x46, 0xc6, 0x20, 0x7a, 0xf6, 0xe9, 0x95, 0x3e, 0xe5, 0xd2, 0x41, + 0x1b, 0xa6, 0x5d, 0x31, 0xfe, 0xaf, 0x4f, 0x78, 0x76, 0x4f, 0x2d, 0x69, 0x39, 0x87, 0x18, 0x6c, + 0xc0, 0x13, 0x29, 0xc1, 0x87, 0xa5, 0xe4, 0x60, 0x8e, 0x8d, 0x27, 0xb3, 0x18, 0xe9, 0x8d, 0xd9, + 0x47, 0x69, 0xf7, 0x73, 0x9c, 0xe6, 0x76, 0x83, 0x92, 0xca, 0xca, 0x8d, 0xcc, 0x59, 0x7d, 0x77, + 0xec, 0x0d, 0x12, 0x72, 0x23, 0x37, 0x85, 0xf6, 0xe6, 0x9d, 0x6f, 0x43, 0xef, 0xfa, 0x8e, 0x79, + 0x05, 0xed, 0xfd, 0xc4, 0x03, 0x7e, 0xee, 0x59, 0x33, 0xe9, 0x90, 0xa7, 0x97, 0x2f, 0x20, 0x69, + 0x13, 0xa3, 0x1e, 0x8d, 0x04, 0x93, 0x13, 0x66, 0xd3, 0xd8, 0xbc, 0xd6, 0xa4, 0xa4, 0xd6, 0x47, + 0xdd, 0x4b, 0xd8, 0x0b, 0x0f, 0xf8, 0x63, 0xce, 0x35, 0x54, 0x83, 0x3d, 0x74, 0x4c, 0xf0, 0xe0, + 0xb9, 0xc0, 0x7c, 0xae, 0x72, 0x6d, 0xd2, 0x3f, 0x99, 0x53, 0xdf, 0x1f, 0x1c, 0xe3, 0xac, 0xeb, + 0x3b, 0x72, 0x30, 0x87, 0x1e, 0x92, 0x31, 0x0c, 0xfb, 0x2b, 0x09, 0x84, 0x86, 0xf4, 0x35, 0x38, + 0xf8, 0xe8, 0x2d, 0x84, 0x04, 0xe5, 0xc6, 0xc2, 0x5f, 0x66, 0xa6, 0x2e, 0xbe, 0x3c, 0x5f, 0x26, + 0x23, 0x26, 0x40, 0xe2, 0x0a, 0x76, 0x91, 0x75, 0xef, 0x83, 0x48, 0x3c, 0xd8, 0x1e, 0x6c, 0xb1, + 0x6e, 0x78, 0xdf, 0xad, 0x4c, 0x1b, 0x71, 0x4b, 0x04, 0xb4, 0x5f, 0x6a, 0xc8, 0xd1, 0x06, 0x5a, + 0xd1, 0x8c, 0x13, 0x45, 0x1c, 0x90, 0x55, 0xc4, 0x7d, 0xa3, 0x00, 0xf9, 0x35, 0x36, 0xea, 0x56, + 0xf5, 0x31, 0x98, 0x6d, 0x64, 0x92, 0x77, 0x53, 0x93, 0xc4, 0xcc, 0xb0, 0x95, 0x46, 0x70, 0x92, + 0xa0, 0xec, 0x0b, 0x43, 0xed, 0x7a, 0x06, 0x87, 0xcb, 0x47, 0x0c, 0xe3, 0x50, 0x91, 0x7b, 0x0a, + 0xc3, 0x0c, 0x6e, 0x5c, 0x24, 0x72, 0x5a, 0x78, 0xc4, 0x5f, 0x9f, 0x5f, 0x29, 0xb6, 0x62, 0x68, + 0x67, 0xf6, 0xf7, 0x9c, 0xe0, 0x54, 0x27, 0x35, 0x47, 0xb3, 0x6d, 0xf0, 0x30, 0xbd, 0x24, 0xaf, + 0x10, 0xd6, 0x32, 0xdb, 0xa5, 0x4f, 0xc4, 0xe8, 0x90, 0xbd, 0x05, 0x86, 0x92, 0x8c, 0x02, 0x06, + 0xca, 0x2e, 0x28, 0xe4, 0x4e, 0x22, 0x7a, 0x2d, 0x50, 0x63, 0x19, 0x59, 0x35, 0xdf, 0x38, 0xda, + 0x89, 0x36, 0x09, 0x2e, 0xef, 0x01, 0xe8, 0x4c, 0xad, 0x2e, 0x49, 0xd6, 0x2e, 0x47, 0x0a, 0x6c, + 0x77, 0x45, 0xf6, 0x25, 0xec, 0x39, 0xe4, 0xfc, 0x23, 0x32, 0x9c, 0x79, 0xd1, 0x17, 0x28, 0x76, + 0x80, 0x7c, 0x36, 0xd7, 0x36, 0xba, 0x42, 0xbb, 0x69, 0xb0, 0x04, 0xff, 0x55, 0xf9, 0x38, 0x50, + 0xdc, 0x33, 0xc1, 0xf9, 0x8a, 0xbb, 0x92, 0x85, 0x83, 0x24, 0xc7, 0x6f, 0xf1, 0xeb, 0x08, 0x5d, + 0xb3, 0xc1, 0xfc, 0x50, 0xf7, 0x4e, 0xc0, 0x44, 0x42, 0xe6, 0x22, 0x97, 0x3e, 0xa7, 0x07, 0x43, + 0x41, 0x87, 0x94, 0xc3, 0x88, 0x14, 0x0b, 0xb4, 0x92, 0xd6, 0x29, 0x4a, 0x05, 0x40, 0xe5, 0xa5, + 0x9c, 0xfa, 0xe6, 0x0b, 0xa0, 0xf1, 0x48, 0x99, 0xfc, 0xa7, 0x13, 0x33, 0x31, 0x5e, 0xa0, 0x83, + 0xa6, 0x8e, 0x1d, 0x7c, 0x1e, 0x4c, 0xdc, 0x2f, 0x56, 0xbc, 0xd6, 0x11, 0x96, 0x81, 0xa4, 0xad, + 0xbc, 0x1b, 0xbf, 0x42, 0xaf, 0xd8, 0x06, 0xc3, 0xcb, 0xd4, 0x2a, 0x07, 0x6f, 0x54, 0x5d, 0xee, + 0x4e, 0x11, 0x8d, 0x0b, 0x39, 0x67, 0x54, 0xbe, 0x2b, 0x04, 0x2a, 0x68, 0x5d, 0xd4, 0x72, 0x7e, + 0x89, 0xc0, 0x38, 0x6a, 0x94, 0xd3, 0xcd, 0x6e, 0xcb, 0x98, 0x20, 0xe9, 0xd4, 0x9a, 0xfe, 0xed, + 0x66, 0xc4, 0x7e, 0x6f, 0xc2, 0x43, 0xea, 0xbe, 0xbb, 0xcb, 0x0b, 0x02, 0x45, 0x38, 0x77, 0xf5, + 0xac, 0x5d, 0xbf, 0xbd, 0xf8, 0xdb, 0x10, 0x52, 0xa3, 0xc9, 0x94, 0xb2, 0x24, 0xcd, 0x9a, 0xaa, + 0xf5, 0x6b, 0x02, 0x6b, 0xb9, 0xef, 0xa2, 0xe0, 0x13, 0x02, 0xb3, 0x64, 0x01, 0xab, 0x64, 0x94, + 0xe7, 0x01, 0x8d, 0x6e, 0x5b, 0x57, 0x3b, 0xd3, 0x8b, 0xce, 0xf0, 0x23, 0xb1, 0xfc, 0x92, 0x94, + 0x6b, 0xbc, 0xa0, 0x20, 0x9c, 0xa5, 0xfa, 0x92, 0x6b, 0x49, 0x70, 0xb1, 0x00, 0x91, 0x03, 0x64, + 0x5c, 0xb1, 0xfc, 0xfe, 0x55, 0x23, 0x11, 0xff, 0x73, 0x05, 0x58, 0x98, 0x43, 0x70, 0x03, 0x8f, + 0xd2, 0xcc, 0xe2, 0xa9, 0x1f, 0xc7, 0x4d, 0x6f, 0x3e, 0x3e, 0xa9, 0xf8, 0x43, 0xee, 0xd3, 0x56, + 0xf6, 0xf8, 0x2d, 0x35, 0xd0, 0x3b, 0xc2, 0x4b, 0x81, 0xb5, 0x8c, 0xeb, 0x1a, 0x43, 0xec, 0x94, + 0x37, 0xe6, 0xf1, 0xe5, 0x0e, 0xb6, 0xf5, 0x55, 0xe3, 0x21, 0xfd, 0x67, 0xc8, 0x33, 0x2e, 0xb1, + 0xb8, 0x32, 0xaa, 0x8d, 0x79, 0x5a, 0x27, 0xd4, 0x79, 0xc6, 0xe2, 0x7d, 0x5a, 0x61, 0x03, 0x46, + 0x83, 0x89, 0x19, 0x03, 0xf6, 0x64, 0x21, 0xd0, 0x94, 0xe1, 0xb0, 0x0a, 0x9a, 0x13, 0x8d, 0x86, + 0x1e, 0x6f, 0x78, 0xa2, 0x0a, 0xd3, 0xe1, 0x58, 0x00, 0x54, 0xd2, 0xe3, 0x05, 0x25, 0x3c, 0x71, + 0x3a, 0x02, 0xfe, 0x1e, 0x28, 0xde, 0xee, 0x73, 0x36, 0x24, 0x6f, 0x6a, 0xe3, 0x43, 0x31, 0x80, + 0x6b, 0x46, 0xb4, 0x7b, 0x83, 0x3c, 0x39, 0xb9, 0xd3, 0x1c, 0xd3, 0x00, 0xc2, 0xa6, 0xed, 0x83, + 0x13, 0x99, 0x77, 0x6d, 0x07, 0xf5, 0x70, 0xea, 0xf0, 0x05, 0x9a, 0x2c, 0x68, 0xa5, 0xf3, 0xae, + 0x16, 0xb6, 0x17, 0x40, 0x4a, 0xf7, 0xb7, 0x23, 0x1a, 0x4d, 0x94, 0x27, 0x58, 0xfc, 0x02, 0x0b, + 0x3f, 0x23, 0xee, 0x8c, 0x15, 0xe3, 0x60, 0x44, 0xcf, 0xd6, 0x7c, 0xd6, 0x40, 0x99, 0x3b, 0x16, + 0x20, 0x75, 0x97, 0xfb, 0xf3, 0x85, 0xea, 0x7a, 0x4d, 0x99, 0xe8, 0xd4, 0x56, 0xff, 0x83, 0xd4, + 0x1f, 0x7b, 0x8b, 0x4f, 0x06, 0x9b, 0x02, 0x8a, 0x2a, 0x63, 0xa9, 0x19, 0xa7, 0x0e, 0x3a, 0x10, + 0xe3, 0x08, // encrypted cert + 0x41, // encrypted data type (Certificate) + 0x58, 0xfa, 0xa5, 0xba, 0xfa, 0x30, 0x18, 0x6c, // auth tag + 0x6b, 0x2f, 0x23, 0x8e, 0xb5, 0x30, 0xc7, 0x3e, // auth tag + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x01, 0x19, // application data len + 0x73, 0x71, 0x9f, 0xce, 0x07, 0xec, 0x2f, 0x6d, 0x3b, 0xba, 0x02, 0x92, 0xa0, 0xd4, 0x0b, 0x27, + 0x70, 0xc0, 0x6a, 0x27, 0x17, 0x99, 0xa5, 0x33, 0x14, 0xf6, 0xf7, 0x7f, 0xc9, 0x5c, 0x5f, 0xe7, + 0xb9, 0xa4, 0x32, 0x9f, 0xd9, 0x54, 0x8c, 0x67, 0x0e, 0xbe, 0xea, 0x2f, 0x2d, 0x5c, 0x35, 0x1d, + 0xd9, 0x35, 0x6e, 0xf2, 0xdc, 0xd5, 0x2e, 0xb1, 0x37, 0xbd, 0x3a, 0x67, 0x65, 0x22, 0xf8, 0xcd, + 0x0f, 0xb7, 0x56, 0x07, 0x89, 0xad, 0x7b, 0x0e, 0x3c, 0xab, 0xa2, 0xe3, 0x7e, 0x6b, 0x41, 0x99, + 0xc6, 0x79, 0x3b, 0x33, 0x46, 0xed, 0x46, 0xcf, 0x74, 0x0a, 0x9f, 0xa1, 0xfe, 0xc4, 0x14, 0xdc, + 0x71, 0x5c, 0x41, 0x5c, 0x60, 0xe5, 0x75, 0x70, 0x3c, 0xe6, 0xa3, 0x4b, 0x70, 0xb5, 0x19, 0x1a, + 0xa6, 0xa6, 0x1a, 0x18, 0xfa, 0xff, 0x21, 0x6c, 0x68, 0x7a, 0xd8, 0xd1, 0x7e, 0x12, 0xa7, 0xe9, + 0x99, 0x15, 0xa6, 0x11, 0xbf, 0xc1, 0xa2, 0xbe, 0xfc, 0x15, 0xe6, 0xe9, 0x4d, 0x78, 0x46, 0x42, + 0xe6, 0x82, 0xfd, 0x17, 0x38, 0x2a, 0x34, 0x8c, 0x30, 0x10, 0x56, 0xb9, 0x40, 0xc9, 0x84, 0x72, + 0x00, 0x40, 0x8b, 0xec, 0x56, 0xc8, 0x1e, 0xa3, 0xd7, 0x21, 0x7a, 0xb8, 0xe8, 0x5a, 0x88, 0x71, + 0x53, 0x95, 0x89, 0x9c, 0x90, 0x58, 0x7f, 0x72, 0xe8, 0xdd, 0xd7, 0x4b, 0x26, 0xd8, 0xed, 0xc1, + 0xc7, 0xc8, 0x37, 0xd9, 0xf2, 0xeb, 0xbc, 0x26, 0x09, 0x62, 0x21, 0x90, 0x38, 0xb0, 0x56, 0x54, + 0xa6, 0x3a, 0x0b, 0x12, 0x99, 0x9b, 0x4a, 0x83, 0x06, 0xa3, 0xdd, 0xcc, 0x0e, 0x17, 0xc5, 0x3b, + 0xa8, 0xf9, 0xc8, 0x03, 0x63, 0xf7, 0x84, 0x13, 0x54, 0xd2, 0x91, 0xb4, 0xac, 0xe0, 0xc0, 0xf3, + 0x30, 0xc0, 0xfc, 0xd5, 0xaa, 0x9d, 0xee, 0xf9, 0x69, 0xae, 0x8a, 0xb2, 0xd9, 0x8d, 0xa8, 0x8e, + 0xbb, 0x6e, 0xa8, 0x0a, 0x3a, 0x11, 0xf0, 0x0e, // encrypted signature_verify + 0xa2, // encrypted data type (SignatureVerify) + 0x96, 0xa3, 0x23, 0x23, 0x67, 0xff, 0x07, 0x5e, // auth tag + 0x1c, 0x66, 0xdd, 0x9c, 0xbe, 0xdc, 0x47, 0x13, // auth tag + } ++ [_]u8{ + 0x17, // application data (lie for tls 1.2 compat) + 0x03, 0x03, // tls 1.2 + 0x00, 0x45, // application data len + 0x10, 0x61, 0xde, 0x27, 0xe5, 0x1c, 0x2c, 0x9f, 0x34, 0x29, 0x11, 0x80, 0x6f, 0x28, 0x2b, 0x71, + 0x0c, 0x10, 0x63, 0x2c, 0xa5, 0x00, 0x67, 0x55, 0x88, 0x0d, 0xbf, 0x70, 0x06, 0x00, 0x2d, 0x0e, + 0x84, 0xfe, 0xd9, 0xad, 0xf2, 0x7a, 0x43, 0xb5, 0x19, 0x23, 0x03, 0xe4, 0xdf, 0x5c, 0x28, 0x5d, + 0x58, 0xe3, 0xc7, 0x62, + 0x24, // encrypted data type (finished) + 0x07, 0x84, 0x40, 0xc0, 0x74, 0x23, 0x74, 0x74, // auth tag + 0x4a, 0xec, 0xf2, 0x8c, 0xf3, 0x18, 0x2f, 0xd0, // auth tag + })); try client.stream.expectFragment(.handshake, .server_hello); try client.recv_hello(key_pairs); @@ -2337,51 +2332,37 @@ test "tls client and server handshake, data, and close_notify" { } try client.send_finished(); - { - const buf = tmp_buf[0..inner_stream.buffer.len()]; - try inner_stream.peek(buf); - - const expected = [_]u8{ - 0x14, // ChangeCipherSpec - 0x03, 0x03, // tls 1.2 - 0x00, 0x01, // len - 0x01, // .change_cipher_spec - } - ++ [_]u8{ - 0x17, // app data (lie for TLS 1.2) - 0x03, 0x03, // tls 1.2 - 0x00, 0x45, // len - 0x9f, 0xf9, 0xb0, 0x63, 0x17, 0x51, 0x77, 0x32, 0x2a, 0x46, 0xdd, 0x98, 0x96, 0xf3, 0xc3, 0xbb, - 0x82, 0x0a, 0xb5, 0x17, 0x43, 0xeb, 0xc2, 0x5f, 0xda, 0xdd, 0x53, 0x45, 0x4b, 0x73, 0xde, 0xb5, - 0x4c, 0xc7, 0x24, 0x8d, 0x41, 0x1a, 0x18, 0xbc, 0xcf, 0x65, 0x7a, 0x96, 0x08, 0x24, 0xe9, 0xa1, - 0x93, 0x64, 0x83, 0x7c, // encrypted data - 0x35, // handshake - 0x0a, 0x69, 0xa8, 0x8d, 0x4b, 0xf6, 0x35, 0xc8, // auth tag - 0x5e, 0xb8, 0x74, 0xae, 0xbc, 0x9d, 0xfd, 0xe8, // auth tag - } - ; - try std.testing.expectEqualSlices(u8, &expected, buf); + try inner_stream.expect(&([_]u8{ + 0x14, // ChangeCipherSpec + 0x03, 0x03, // tls 1.2 + 0x00, 0x01, // len + 0x01, // .change_cipher_spec } + ++ [_]u8{ + 0x17, // app data (lie for TLS 1.2) + 0x03, 0x03, // tls 1.2 + 0x00, 0x45, // len + 0x9f, 0xf9, 0xb0, 0x63, 0x17, 0x51, 0x77, 0x32, 0x2a, 0x46, 0xdd, 0x98, 0x96, 0xf3, 0xc3, 0xbb, + 0x82, 0x0a, 0xb5, 0x17, 0x43, 0xeb, 0xc2, 0x5f, 0xda, 0xdd, 0x53, 0x45, 0x4b, 0x73, 0xde, 0xb5, + 0x4c, 0xc7, 0x24, 0x8d, 0x41, 0x1a, 0x18, 0xbc, 0xcf, 0x65, 0x7a, 0x96, 0x08, 0x24, 0xe9, 0xa1, + 0x93, 0x64, 0x83, 0x7c, // encrypted data + 0x35, // handshake + 0x0a, 0x69, 0xa8, 0x8d, 0x4b, 0xf6, 0x35, 0xc8, // auth tag + 0x5e, 0xb8, 0x74, 0xae, 0xbc, 0x9d, 0xfd, 0xe8, // auth tag + })); try server.recv_finished(); _ = try client.stream.writer().writeAll("ping"); try client.stream.flush(); - { - const buf = tmp_buf[0..inner_stream.buffer.len()]; - try inner_stream.peek(buf); - - const expected = [_]u8{ - 0x17, // app data (FOR REAL THIS TIME) - 0x03, 0x03, // tls 1.2 - 0x00, 0x15, // len - 0x82, 0x81, 0x39, 0xcb, // ping - 0x7b, // app data (exciting!) - 0x73, 0xaa, 0xab, 0xf5, 0xb8, 0x2f, 0xbf, 0x9a, // auth tag - 0x29, 0x61, 0xbc, 0xde, 0x10, 0x03, 0x8a, 0x32, // auth tag - } - ; - try std.testing.expectEqualSlices(u8, &expected, buf); - } + try inner_stream.expect(&([_]u8{ + 0x17, // app data (FOR REAL THIS TIME) + 0x03, 0x03, // tls 1.2 + 0x00, 0x15, // len + 0x82, 0x81, 0x39, 0xcb, // ping + 0x7b, // app data (exciting!) + 0x73, 0xaa, 0xab, 0xf5, 0xb8, 0x2f, 0xbf, 0x9a, // auth tag + 0x29, 0x61, 0xbc, 0xde, 0x10, 0x03, 0x8a, 0x32, // auth tag + })); var recv_ping: [4]u8 = undefined; _ = try server.stream.reader().readAll(&recv_ping); @@ -2389,22 +2370,15 @@ test "tls client and server handshake, data, and close_notify" { server.stream.close(); try std.testing.expect(server.stream.closed); - { - const buf = tmp_buf[0..inner_stream.buffer.len()]; - try inner_stream.peek(buf); - - const expected = [_]u8{ - 0x17, // app data (lie to encrypt) - 0x03, 0x03, // tls 1.2 - 0x00, 0x13, // len - 0x3e, 0x2d, // alert - 0x99, // encrypted message type - 0x26, 0xbb, 0xfe, 0x1f, 0x46, 0xfb, 0x4e, 0xe2, // auth tag - 0x75, 0x1e, 0x53, 0xbf, 0xfc, 0x7e, 0x65, 0x16, // auth tag - } - ; - try std.testing.expectEqualSlices(u8, &expected, buf); - } + try inner_stream.expect(&([_]u8{ + 0x17, // app data (lie to encrypt) + 0x03, 0x03, // tls 1.2 + 0x00, 0x13, // len + 0x3e, 0x2d, // alert + 0x99, // encrypted message type + 0x26, 0xbb, 0xfe, 0x1f, 0x46, 0xfb, 0x4e, 0xe2, // auth tag + 0x75, 0x1e, 0x53, 0xbf, 0xfc, 0x7e, 0x65, 0x16, // auth tag + })); _ = try client.stream.readFragment(); try std.testing.expect(client.stream.closed); From 17c8f1b9d0b4394f18d7e2898a910381b800f04f Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Wed, 13 Mar 2024 17:11:24 -0400 Subject: [PATCH 06/17] working client with demo site --- TODO | 10 +- lib/std/crypto/tls.zig | 1178 ++++++++++++++++++++++++--------- lib/std/crypto/tls/Client.zig | 260 +++++--- lib/std/crypto/tls/Server.zig | 33 +- lib/std/http/Client.zig | 15 +- 5 files changed, 1056 insertions(+), 440 deletions(-) diff --git a/TODO b/TODO index 68ee2dc306b1..4474c5eb46b7 100644 --- a/TODO +++ b/TODO @@ -3,12 +3,14 @@ [x] 3. Client recv_hello secp256r1 key share [x] 4. remove @panic's [x] 5. better errors than spammy TlsDecodeError. map new errors to TLS alerts. send alert on error. -6. verify certs and [x] sigs -7. KeyShare kyber read -8. StreamInterface `readv` instead of `readAll` +[x] 6. client state machine union + test +[x] 7. cert formats +[x] 8. verify certs and [x] sigs +9. KeyShare kyber read +10. StreamInterface `readv` instead of `readAll` 1. benchmark 2. store multiple fragments in buffer for less syscalls 3. streaming encode + decode -4. store handshake info (transcript_hash, handshake_type, handshake_cipher) somewhere temporary +4. store handshake_cipher somewhere temporary diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 9fdcb4ce98eb..34deb578fa6d 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -30,17 +30,30 @@ pub const ContentType = enum(u8) { _, }; -/// Also called a Record. pub const Plaintext = struct { type: ContentType, version: Version = .tls_1_0, - /// > The length MUST NOT exceed 2^14 bytes. - /// > An endpoint that receives a record that exceeds this length MUST terminate the connection - /// > with a "record_overflow" alert. - length: u16, - // `length` bytes follow which may contain a Message or a partial Message. + len: u16, + pub const size = @sizeOf(ContentType) + @sizeOf(Version) + @sizeOf(u16); pub const max_length = 1 << 14; + + const Self = @This(); + + pub fn init(bytes: [size]u8) Self { + var stream = std.io.fixedBufferStream(&bytes); + var reader = stream.reader(); + const ty = reader.readInt(u8, .big) catch unreachable; + const version = reader.readInt(u16, .big) catch unreachable; + const len = reader.readInt(u16, .big) catch unreachable; + return .{ .type = @enumFromInt(ty), .version = @enumFromInt(version), .len = len }; + } +}; + +const InnerPlaintext = struct { + type: ContentType, + handshake_type: HandshakeType, + len: u24, }; pub const HandshakeType = enum(u8) { @@ -134,7 +147,7 @@ pub const Handshake = union(HandshakeType) { }, else => |t| @compileError("implement writing " ++ @tagName(t)), } - } + }, } return res; } @@ -629,14 +642,14 @@ pub const KeyShare = union(NamedGroup) { // .x25519_kyber768d00 => { // const expected_len = if (stream.is_client) @TypeOf(k).bytes_length else X25519Kyber768Draft.Kyber768.ciphertext_length; // }, - inline .secp256r1, .secp384r1 => |k| { + inline .secp256r1, .secp384r1 => |k| { const T = NamedGroupT(k).PublicKey; var buf: [T.compressed_sec1_encoded_length]u8 = undefined; try reader.readNoEof(&buf); const val = T.fromSec1(&buf) catch return Error.TlsDecryptError; return @unionInit(Self, @tagName(k), val); }, - .x25519 => { + .x25519 => { var res = Self{ .x25519 = undefined }; if (res.x25519.len != len) return Error.TlsDecodeError; try reader.readNoEof(&res.x25519); @@ -898,6 +911,7 @@ pub const Extension = union(ExtensionType) { status_request: void, supported_groups: []const NamedGroup, ec_point_formats: []const EcPointFormat, + /// For signature_verify messages signature_algorithms: []const SignatureScheme, use_srtp: void, /// https://en.wikipedia.org/wiki/Heartbleed @@ -918,6 +932,9 @@ pub const Extension = union(ExtensionType) { certificate_authorities: void, oid_filters: void, post_handshake_auth: void, + /// For certificate signatures. + /// > Implementations which have the same policy in both cases MAY omit the + /// > "signature_algorithms_cert" extension. signature_algorithms_cert: void, key_share: []const KeyShare, none: void, @@ -927,17 +944,12 @@ pub const Extension = union(ExtensionType) { pub fn write(self: Self, stream: anytype) !usize { const PrefixLen = enum { zero, one, two }; const prefix_len: PrefixLen = if (stream.is_client) switch (self) { - .supported_versions, - .ec_point_formats, - .psk_key_exchange_modes => .one, - .server_name, - .supported_groups, - .signature_algorithms, - .key_share => .two, + .supported_versions, .ec_point_formats, .psk_key_exchange_modes => .one, + .server_name, .supported_groups, .signature_algorithms, .key_share => .two, else => .zero, } else .zero; - var res: usize = 0; + var res: usize = 0; res += try stream.write(ExtensionType, self); switch (self) { @@ -958,7 +970,7 @@ pub const Extension = union(ExtensionType) { const len = stream.arrayLength(PrefixT, info.child, items); res += try stream.write(u16, @intCast(len)); res += try stream.writeArray(PrefixT, info.child, items); - } + }, } }, else => |t| @compileError("unsupported type " ++ @typeName(T) ++ " for member " ++ @tagName(t)), @@ -1077,15 +1089,17 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// > Certificate, client CertificateVerify, client Finished. transcript_hash: MultiHash = .{}, /// Used for both reading and writing. Cannot be doing both at the same time. + /// Stores plaintext or ciphertext, but not Plaintext headers. buffer: [fragment_size]u8 = undefined, - /// Unread or unwritten view of `buffer`. + /// Unread or unwritten view of `buffer`. May contain multiple handshakes. view: []const u8 = "", - /// When sending this is the record type that will be sent. + /// When sending this is the record type that will be flushed. /// When receiving this is the next fragment's expected record type. content_type: ContentType = .handshake, - - /// When receiving fragments this is the next expected fragment type. + /// When sending this is the flushed version. + version: Version = .tls_1_0, + /// When receiving a handshake message will be expected with this type. handshake_type: ?HandshakeType = .client_hello, /// Used to encrypt and decrypt .application_data messages until application_cipher is not null. @@ -1096,19 +1110,15 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// True when we send or receive a close_notify alert. closed: bool = false, - /// Version to send out in record headers - version: Version = .tls_1_0, - /// True if we're being used as a client. This changes: /// * Certain shared struct formats (like Extension) - /// * Which keys are used for encoding/decoding handshake and application messages. + /// * Which ciphers are used for encoding/decoding handshake and application messages. is_client: bool, - /// When > 0 won't actually do anything with writes. - /// This is to discover prefix lengths for record level spec-adherant sequential writing. + /// When > 0 won't actually do anything with writes. Used to discover prefix lengths. nocommit: usize = 0, - const Self = @This(); + const Self = @This(); pub const ReadError = StreamType.ReadError || Error || error{EndOfStream}; pub const WriteError = StreamType.WriteError || error{ @@ -1155,7 +1165,7 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { var plaintext = Plaintext{ .type = self.content_type, .version = self.version, - .length = @intCast(self.view.len), + .len = @intCast(self.view.len), }; if (self.application_cipher == null) { @@ -1163,11 +1173,11 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { .change_cipher_spec, .alert => {}, else => { self.transcript_hash.update(self.view); - } + }, } } - var header: [fieldsLen(Plaintext)]u8 = undefined; + var header: [Plaintext.size]u8 = undefined; var aead: []const u8 = ""; switch (self.encryptionMethod()) { .none => { @@ -1175,32 +1185,32 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { }, .handshake => { plaintext.type = .application_data; - plaintext.length += @intCast(self.ciphertextOverhead()); + plaintext.len += @intCast(self.ciphertextOverhead()); header = Encoder.encode(Plaintext, plaintext); if (self.handshake_cipher) |*a| switch (a.*) { .empty_renegotiation_info_scsv => {}, inline else => |*c| { std.debug.assert(self.view.ptr == &self.buffer); self.buffer[self.view.len] = @intFromEnum(self.content_type); - self.view = self.buffer[0..self.view.len + 1]; + self.view = self.buffer[0 .. self.view.len + 1]; aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); }, }; }, .application => { plaintext.type = .application_data; - plaintext.length += @intCast(self.ciphertextOverhead()); + plaintext.len += @intCast(self.ciphertextOverhead()); header = Encoder.encode(Plaintext, plaintext); if (self.application_cipher) |*a| switch (a.*) { .empty_renegotiation_info_scsv => {}, inline else => |*c| { std.debug.assert(self.view.ptr == &self.buffer); self.buffer[self.view.len] = @intFromEnum(self.content_type); - self.view = self.buffer[0..self.view.len + 1]; + self.view = self.buffer[0 .. self.view.len + 1]; aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); }, }; - } + }, } var iovecs = [_]std.os.iovec_const{ @@ -1295,60 +1305,72 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { return self.write(T, value) catch unreachable; } - pub fn arrayLength(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T,) usize { + pub fn arrayLength( + self: *Self, + comptime PrefixT: type, + comptime T: type, + values: []const T, + ) usize { var res: usize = if (PrefixT == void) 0 else @divExact(@typeInfo(PrefixT).Int.bits, 8); for (values) |v| res += self.length(T, v); return res; } /// Reads bytes from `view`, potentially reading more fragments from `stream`. + /// /// A return value of 0 indicates EOF. - pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { + pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { // > Any data received after a closure alert has been received MUST be ignored. if (self.eof()) return 0; + if (self.view.len == 0) try self.expectInnerPlaintext(self.content_type, self.handshake_type); + var bytes_read: usize = 0; - while (bytes_read != buf.len) { - if (self.view.len == 0) try self.expectFragment(self.content_type, self.handshake_type); - const to_read = @min(buf.len, self.view.len); - @memcpy(buf[0..to_read], self.view[0..to_read]); + for (buffers) |b| { + var bytes_read_buffer: usize = 0; + while (bytes_read_buffer != b.iov_len) { + const to_read = @min(b.iov_len, self.view.len); + if (to_read == 0) return bytes_read; + + @memcpy(b.iov_base[0..to_read], self.view[0..to_read]); - self.view = self.view[to_read..]; - bytes_read += to_read; + self.view = self.view[to_read..]; + bytes_read_buffer += to_read; + bytes_read += bytes_read_buffer; + } } return bytes_read; } - /// Read fragment from `stream` into `buffer` and updates `self.view`. Returns message type. - pub fn readFragment(self: *Self) ReadError!ContentType { + /// Reads bytes from `view`, potentially reading more fragments from `stream`. + /// A return value of 0 indicates EOF. + pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { + const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; + return try self.readv(&buffers); + } + + /// Reads plaintext from `stream` into `buffer` and updates `view`. + /// Skips non-fatal alert and change_cipher_spec messages. + /// Will decrypt according to `encryptionMethod` if receiving application_data message. + pub fn readPlaintext(self: *Self) ReadError!Plaintext { std.debug.assert(self.view.len == 0); // last read should have completed - var plaintext_header: [fieldsLen(Plaintext)]u8 = undefined; + var plaintext_header_bytes: [Plaintext.size]u8 = undefined; var n_read: usize = 0; - var res: ContentType = .invalid; - var len: u16 = 0; - while (true) { - n_read = try self.stream.readAll(&plaintext_header); - if (n_read != plaintext_header.len) return self.writeError(.decode_error); - - // Take advantage of our `read` parsing code by setting view outside `self.buffer`. - { - self.view = &plaintext_header; - errdefer self.view = self.buffer[0..0]; - res = try self.read(ContentType); - _ = try self.read(Version); - len = try self.read(u16); - if (len > self.maxFragmentSize()) return self.writeError(.record_overflow); - } + n_read = try self.stream.readAll(&plaintext_header_bytes); + if (n_read != plaintext_header_bytes.len) return self.writeError(.decode_error); + + var res = Plaintext.init(plaintext_header_bytes); + if (res.len > Plaintext.max_length) return self.writeError(.record_overflow); - self.view = self.buffer[0..len]; + self.view = self.buffer[0..res.len]; n_read = try self.stream.readAll(@constCast(self.view)); - if (n_read != len) return self.writeError(.decode_error); + if (n_read != res.len) return self.writeError(.decode_error); - const encryption_method = if (res == .application_data) self.encryptionMethod() else .none; + const encryption_method = if (res.type == .application_data) self.encryptionMethod() else .none; switch (encryption_method) { .none => {}, inline .handshake, .application => |t| { @@ -1358,14 +1380,14 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { const P = @TypeOf(p.*); const tag_len = P.AEAD.tag_length; - const ciphertext = self.view[0..self.view.len - tag_len]; - const tag = self.view[self.view.len - tag_len..][0..tag_len].*; + const ciphertext = self.view[0 .. self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; const out: []u8 = @constCast(self.view[0..ciphertext.len]); - p.decrypt(ciphertext, &plaintext_header, tag, self.is_client, out) catch + p.decrypt(ciphertext, &plaintext_header_bytes, tag, self.is_client, out) catch return self.writeError(.bad_record_mac); const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); if (padding_start) |s| { - res = @enumFromInt(self.view[s]); + res.type = @enumFromInt(self.view[s]); self.view = self.view[0..s]; } else { return self.writeError(.decode_error); @@ -1375,7 +1397,7 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { }, } - switch (res) { + switch (res.type) { .alert => { const level = try self.read(Alert.Level); const description = try self.read(Alert.Description); @@ -1386,7 +1408,6 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { return res; } if (level == .fatal) return self.writeError(.unexpected_message); - continue; }, // > An implementation may receive an unencrypted record of type // > change_cipher_spec consisting of the single byte value 0x01 at any @@ -1394,39 +1415,58 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { // > and before the peer's Finished message has been received and MUST // > simply drop it without further processing. .change_cipher_spec => { - if (!std.mem.eql(u8, self.view, &[_]u8{1})) return self.writeError(.unexpected_message); - continue; + if (!std.mem.eql(u8, self.view, &[_]u8{1})) { + return self.writeError(.unexpected_message); + } + }, + else => { + return res; }, - else => {}, } + } + } - self.transcript_hash.update(self.view); + pub fn readInnerPlaintext(self: *Self) ReadError!InnerPlaintext { + var res: InnerPlaintext = .{ + .type = self.content_type, + .handshake_type = if (self.handshake_type) |h| h else undefined, + .len = undefined, + }; + if (self.view.len == 0) { + const plaintext = try self.readPlaintext(); + res.type = plaintext.type; + res.len = plaintext.len; - return res; + self.content_type = res.type; } + + if (res.type == .handshake) { + self.transcript_hash.update(self.view[0..@sizeOf(HandshakeType) + @bitSizeOf(u24) / 8]); + res.handshake_type = try self.read(HandshakeType); + res.len = try self.read(u24); + self.transcript_hash.update(self.view[0..res.len]); + + self.handshake_type = res.handshake_type; + } + + return res; } - pub fn expectFragment(self: *Self, expected_content: ContentType, expected_handshake: ?HandshakeType,) ReadError!void { - const actual_content = try self.readFragment(); - if (expected_content != actual_content) { - std.debug.print("expected {} got {}\n", .{ expected_content, actual_content }); + pub fn expectInnerPlaintext( + self: *Self, + expected_content: ContentType, + expected_handshake: ?HandshakeType, + ) ReadError!void { + const inner_plaintext = try self.readInnerPlaintext(); + if (expected_content != inner_plaintext.type) { + std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext.type }); + return self.writeError(.unexpected_message); } if (expected_handshake) |expected| { - const actual_handshake = try self.read(HandshakeType); - if (actual_handshake != expected) return self.writeError(.decode_error); - const stated_len = try self.read(u24); - if (stated_len != self.view.len) return self.writeError(.decode_error); + if (expected != inner_plaintext.handshake_type) return self.writeError(.decode_error); } } - pub fn expectHandshake(self: *Self) ReadError!Handshake.Header { - try self.expectFragment(.handshake, null); - const ty = try self.read(HandshakeType); - const len = try self.read(u24); - if (self.view.len != len) return self.writeError(.decode_error); - return .{ .type = ty, .len = len }; - } - pub fn read(self: *Self, comptime T: type) ReadError!T { comptime std.debug.assert(@sizeOf(T) < fragment_size); switch (@typeInfo(T)) { @@ -1441,7 +1481,6 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { }, else => { return T.read(self) catch |err| switch (err) { - error.EndOfStream, error.Full, error.ReadLengthInvalid => return self.writeError(.decode_error), error.TlsUnexpectedMessage => return self.writeError(.unexpected_message), error.TlsBadRecordMac => return self.writeError(.bad_record_mac), error.TlsRecordOverflow => return self.writeError(.record_overflow), @@ -1471,6 +1510,7 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { self.close(); return e; }, + else => return self.writeError(.decode_error), }; }, } @@ -1798,20 +1838,10 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) /// Slice of stack allocated signature content from RFC 8446 S4.4.3 pub inline fn sigContent(digest: []const u8) []const u8 { const max_digest_len = MultiHash.max_digest_len; - var buf = [_]u8{0x20} ** 64 - ++ "TLS 1.3, server CertificateVerify\x00".* - ++ @as([max_digest_len]u8, undefined) - ; - @memcpy(buf[buf.len - max_digest_len..][0..digest.len], digest); + var buf = [_]u8{0x20} ** 64 ++ "TLS 1.3, server CertificateVerify\x00".* ++ @as([max_digest_len]u8, undefined); + @memcpy(buf[buf.len - max_digest_len ..][0..digest.len], digest); - return buf[0..buf.len - (max_digest_len - digest.len)]; -} - - -fn fieldsLen(comptime T: type) comptime_int { - var res: comptime_int = 0; - inline for (std.meta.fields(T)) |f| res += @sizeOf(f.type); - return res; + return buf[0 .. buf.len - (max_digest_len - digest.len)]; } /// Default suites used for client and server in descending order of preference. @@ -1858,9 +1888,10 @@ const TestStream = struct { buffer: Buffer, const Buffer = std.RingBuffer; + const Self = @This(); + pub const ReadError = Buffer.Error; pub const WriteError = Buffer.Error; - const Self = @This(); pub fn init(allocator: std.mem.Allocator) !Self { return Self{ .buffer = try Buffer.init(allocator, Plaintext.max_length) }; @@ -1938,11 +1969,9 @@ test "tls client and server handshake, data, and close_notify" { .options = .{ // force this to use https://tls13.xargs.org/ as unit test for "server hello" onwards .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, - .certificate = .{ - .entries = &[_]Certificate.Entry{ - .{ .data = server_der }, - } - } + .certificate = .{ .entries = &[_]Certificate.Entry{ + .{ .data = server_der }, + } }, }, }; @@ -1972,10 +2001,10 @@ test "tls client and server handshake, data, and close_notify" { session_id, client_x25519_seed ++ client_x25519_seed, client_x25519_seed, - client_x25519_seed ++ [_]u8{0} ** (48-32), + client_x25519_seed ++ [_]u8{0} ** (48 - 32), client_x25519_seed, ); - { + var client_command = brk: { // To get the same `hello_hash` as https://tls13.xargs.org/ just for // this test we send a mostly falsified client hello. // It doesn't matter because the server will be TLS 1.3 and only support .x25519 @@ -2036,90 +2065,91 @@ test "tls client and server handshake, data, and close_notify" { _ = try client.stream.write(Handshake, .{ .client_hello = hello }); try client.stream.flush(); - } + + break :brk client_mod.Command{ .recv_hello = key_pairs }; + }; try inner_stream.expect([_]u8{ - 0x16, // handshake - 0x03, 0x01, // tls 1.0 (lie for compat) - 0x00, 0xf8, // handshake len - 0x01, // client hello - 0x00, 0x00, 0xf4, // client hello len - 0x03, 0x03, // tls 1.2 (lie for compat) - } ++ client_random ++ - [_]u8{session_id.len} ++ session_id ++ - [_]u8{ - 0x00, 0x08, // cipher suite len - 0x13, 0x02, // aes_256_gcm_sha384 - 0x13, 0x03, // chacha20_poly1305_sha256 - 0x13, 0x01, // aes_128_gcm_sha256 - 0x00, 0xff, // empty_renegotiation_info_scsv - 0x01, // compression methods len - 0x00, // none - 0x00, 0xa3, // extensions len - 0x00, 0x00, // server name ext - 0x00, 0x18, // server name len - 0x00, 0x16, // list entry len - 0x00, // dns hostname - } ++ - Encoder.encode(u16, @intCast(host.len)) ++ host ++ - [_]u8{ - 0x00, 0x0b, // ec point formats - 0x00, 0x04, // ext len - 0x03, // format type len - 0x00, // uncompresed - 0x01, // ansiX962_compressed_prime - 0x02, // ansiX962_compressed_char2 - 0x00, 0x0a, // supported groups - 0x00, 0x16, // ext len - 0x00, 0x14, // supported groups len - 0x00, 0x1d, // x25519 - 0x00, 0x17, // secp256r1 - 0x00, 0x1e, // x448 - 0x00, 0x19, // secp521r1 - 0x00, 0x18, // secp384r1 - 0x01, 0x00, // ffdhe2048 - 0x01, 0x01, // ffdhe3072 - 0x01, 0x02, // ffdhe4096 - 0x01, 0x03, // ffdhe6144 - 0x01, 0x04, // ffdhe8192 - 0x00, 0x23, // session ticket - 0x00, 0x00, // ext len - 0x00, 0x16, // encrypt then mac - 0x00, 0x00, // ext len - 0x00, 0x17, // extended master secrets - 0x00, 0x00, // ext len - 0x00, 0x0d, // signature algos - 0x00, 0x1e, // ext len - 0x00, 0x1c, // algos len - 0x04, 0x03, // ecdsa_secp256r1_sha256 - 0x05, 0x03, // ecdsa_secp384r1_sha384 - 0x06, 0x03, // ecdsa_secp521r1_sha512 - 0x08, 0x07, // ed25519 - 0x08, 0x08, // ed448 - 0x08, 0x09, // rsa_pss_pss_sha256 - 0x08, 0x0a, // rsa_pss_pss_sha384 - 0x08, 0x0b, // rsa_pss_pss_sha512 - 0x08, 0x04, // rsa_pss_rsae_sha256 - 0x08, 0x05, // rsa_pss_rsae_sha384 - 0x08, 0x06, // rsa_pss_rsae_sha512 - 0x04, 0x01, // rsa_pkcs1_sha256 - 0x05, 0x01, // rsa_pkcs1_sha384 - 0x06, 0x01, // rsa_pkcs1_sha512 - 0x00, 0x2b, // supported versions - 0x00, 0x03, // ext len - 0x02, // supported versions len - 0x03, 0x04, // tls 1.3 (not lying anymore!) - 0x00, 0x2d, // psk key exchange modes - 0x00, 0x02, // ext len - 0x01, // psk key exchange modes len - 0x01, // PSK with (EC)DHE key establishment - 0x00, 0x33, // key share - 0x00, 0x26, // ext len - 0x00, 0x24, // key shares len - 0x00, 0x1d, // curve 25519 - 0x00, 0x20, // key len - } ++ key_pairs.x25519.public_key - ); + 0x16, // handshake + 0x03, 0x01, // tls 1.0 (lie for compat) + 0x00, 0xf8, // handshake len + 0x01, // client hello + 0x00, 0x00, 0xf4, // client hello len + 0x03, 0x03, // tls 1.2 (lie for compat) + } ++ client_random ++ + [_]u8{session_id.len} ++ session_id ++ + [_]u8{ + 0x00, 0x08, // cipher suite len + 0x13, 0x02, // aes_256_gcm_sha384 + 0x13, 0x03, // chacha20_poly1305_sha256 + 0x13, 0x01, // aes_128_gcm_sha256 + 0x00, 0xff, // empty_renegotiation_info_scsv + 0x01, // compression methods len + 0x00, // none + 0x00, 0xa3, // extensions len + 0x00, 0x00, // server name ext + 0x00, 0x18, // server name len + 0x00, 0x16, // list entry len + 0x00, // dns hostname + } ++ + Encoder.encode(u16, @intCast(host.len)) ++ host ++ + [_]u8{ + 0x00, 0x0b, // ec point formats + 0x00, 0x04, // ext len + 0x03, // format type len + 0x00, // uncompresed + 0x01, // ansiX962_compressed_prime + 0x02, // ansiX962_compressed_char2 + 0x00, 0x0a, // supported groups + 0x00, 0x16, // ext len + 0x00, 0x14, // supported groups len + 0x00, 0x1d, // x25519 + 0x00, 0x17, // secp256r1 + 0x00, 0x1e, // x448 + 0x00, 0x19, // secp521r1 + 0x00, 0x18, // secp384r1 + 0x01, 0x00, // ffdhe2048 + 0x01, 0x01, // ffdhe3072 + 0x01, 0x02, // ffdhe4096 + 0x01, 0x03, // ffdhe6144 + 0x01, 0x04, // ffdhe8192 + 0x00, 0x23, // session ticket + 0x00, 0x00, // ext len + 0x00, 0x16, // encrypt then mac + 0x00, 0x00, // ext len + 0x00, 0x17, // extended master secrets + 0x00, 0x00, // ext len + 0x00, 0x0d, // signature algos + 0x00, 0x1e, // ext len + 0x00, 0x1c, // algos len + 0x04, 0x03, // ecdsa_secp256r1_sha256 + 0x05, 0x03, // ecdsa_secp384r1_sha384 + 0x06, 0x03, // ecdsa_secp521r1_sha512 + 0x08, 0x07, // ed25519 + 0x08, 0x08, // ed448 + 0x08, 0x09, // rsa_pss_pss_sha256 + 0x08, 0x0a, // rsa_pss_pss_sha384 + 0x08, 0x0b, // rsa_pss_pss_sha512 + 0x08, 0x04, // rsa_pss_rsae_sha256 + 0x08, 0x05, // rsa_pss_rsae_sha384 + 0x08, 0x06, // rsa_pss_rsae_sha512 + 0x04, 0x01, // rsa_pkcs1_sha256 + 0x05, 0x01, // rsa_pkcs1_sha384 + 0x06, 0x01, // rsa_pkcs1_sha512 + 0x00, 0x2b, // supported versions + 0x00, 0x03, // ext len + 0x02, // supported versions len + 0x03, 0x04, // tls 1.3 (not lying anymore!) + 0x00, 0x2d, // psk key exchange modes + 0x00, 0x02, // ext len + 0x01, // psk key exchange modes len + 0x01, // PSK with (EC)DHE key establishment + 0x00, 0x33, // key share + 0x00, 0x26, // ext len + 0x00, 0x24, // key shares len + 0x00, 0x1d, // curve 25519 + 0x00, 0x20, // key len + } ++ key_pairs.x25519.public_key); const client_hello = try server.recv_hello(); try std.testing.expectEqualSlices(u8, &client_random, &client_hello.random); @@ -2131,9 +2161,7 @@ test "tls client and server handshake, data, and close_notify" { try server.send_hello(client_hello, server_key_pair); // hack to match xargs, need to fix server - const signature_verify = [_]u8{ -0x5c, 0xbb, 0x24, 0xc0, 0x40, 0x93, 0x32, 0xda, 0xa9, 0x20, 0xbb, 0xab, 0xbd, 0xb9, 0xbd, 0x50, 0x17, 0x0b, 0xe4, 0x9c, 0xfb, 0xe0, 0xa4, 0x10, 0x7f, 0xca, 0x6f, 0xfb, 0x10, 0x68, 0xe6, 0x5f, 0x96, 0x9e, 0x6d, 0xe7, 0xd4, 0xf9, 0xe5, 0x60, 0x38, 0xd6, 0x7c, 0x69, 0xc0, 0x31, 0x40, 0x3a, 0x7a, 0x7c, 0x0b, 0xcc, 0x86, 0x83, 0xe6, 0x57, 0x21, 0xa0, 0xc7, 0x2c, 0xc6, 0x63, 0x40, 0x19, 0xad, 0x1d, 0x3a, 0xd2, 0x65, 0xa8, 0x12, 0x61, 0x5b, 0xa3, 0x63, 0x80, 0x37, 0x20, 0x84, 0xf5, 0xda, 0xec, 0x7e, 0x63, 0xd3, 0xf4, 0x93, 0x3f, 0x27, 0x22, 0x74, 0x19, 0xa6, 0x11, 0x03, 0x46, 0x44, 0xdc, 0xdb, 0xc7, 0xbe, 0x3e, 0x74, 0xff, 0xac, 0x47, 0x3f, 0xaa, 0xad, 0xde, 0x8c, 0x2f, 0xc6, 0x5f, 0x32, 0x65, 0x77, 0x3e, 0x7e, 0x62, 0xde, 0x33, 0x86, 0x1f, 0xa7, 0x05, 0xd1, 0x9c, 0x50, 0x6e, 0x89, 0x6c, 0x8d, 0x82, 0xf5, 0xbc, 0xf3, 0x5f, 0xec, 0xe2, 0x59, 0xb7, 0x15, 0x38, 0x11, 0x5e, 0x9c, 0x8c, 0xfb, 0xa6, 0x2e, 0x49, 0xbb, 0x84, 0x74, 0xf5, 0x85, 0x87, 0xb1, 0x1b, 0x8a, 0xe3, 0x17, 0xc6, 0x33, 0xe9, 0xc7, 0x6c, 0x79, 0x1d, 0x46, 0x62, 0x84, 0xad, 0x9c, 0x4f, 0xf7, 0x35, 0xa6, 0xd2, 0xe9, 0x63, 0xb5, 0x9b, 0xbc, 0xa4, 0x40, 0xa3, 0x07, 0x09, 0x1a, 0x1b, 0x4e, 0x46, 0xbc, 0xc7, 0xa2, 0xf9, 0xfb, 0x2f, 0x1c, 0x89, 0x8e, 0xcb, 0x19, 0x91, 0x8b, 0xe4, 0x12, 0x1d, 0x7e, 0x8e, 0xd0, 0x4c, 0xd5, 0x0c, 0x9a, 0x59, 0xe9, 0x87, 0x98, 0x01, 0x07, 0xbb, 0xbf, 0x29, 0x9c, 0x23, 0x2e, 0x7f, 0xdb, 0xe1, 0x0a, 0x4c, 0xfd, 0xae, 0x5c, 0x89, 0x1c, 0x96, 0xaf, 0xdf, 0xf9, 0x4b, 0x54, 0xcc, 0xd2, 0xbc, 0x19, 0xd3, 0xcd, 0xaa, 0x66, 0x44, 0x85, 0x9c - }; + const signature_verify = [_]u8{ 0x5c, 0xbb, 0x24, 0xc0, 0x40, 0x93, 0x32, 0xda, 0xa9, 0x20, 0xbb, 0xab, 0xbd, 0xb9, 0xbd, 0x50, 0x17, 0x0b, 0xe4, 0x9c, 0xfb, 0xe0, 0xa4, 0x10, 0x7f, 0xca, 0x6f, 0xfb, 0x10, 0x68, 0xe6, 0x5f, 0x96, 0x9e, 0x6d, 0xe7, 0xd4, 0xf9, 0xe5, 0x60, 0x38, 0xd6, 0x7c, 0x69, 0xc0, 0x31, 0x40, 0x3a, 0x7a, 0x7c, 0x0b, 0xcc, 0x86, 0x83, 0xe6, 0x57, 0x21, 0xa0, 0xc7, 0x2c, 0xc6, 0x63, 0x40, 0x19, 0xad, 0x1d, 0x3a, 0xd2, 0x65, 0xa8, 0x12, 0x61, 0x5b, 0xa3, 0x63, 0x80, 0x37, 0x20, 0x84, 0xf5, 0xda, 0xec, 0x7e, 0x63, 0xd3, 0xf4, 0x93, 0x3f, 0x27, 0x22, 0x74, 0x19, 0xa6, 0x11, 0x03, 0x46, 0x44, 0xdc, 0xdb, 0xc7, 0xbe, 0x3e, 0x74, 0xff, 0xac, 0x47, 0x3f, 0xaa, 0xad, 0xde, 0x8c, 0x2f, 0xc6, 0x5f, 0x32, 0x65, 0x77, 0x3e, 0x7e, 0x62, 0xde, 0x33, 0x86, 0x1f, 0xa7, 0x05, 0xd1, 0x9c, 0x50, 0x6e, 0x89, 0x6c, 0x8d, 0x82, 0xf5, 0xbc, 0xf3, 0x5f, 0xec, 0xe2, 0x59, 0xb7, 0x15, 0x38, 0x11, 0x5e, 0x9c, 0x8c, 0xfb, 0xa6, 0x2e, 0x49, 0xbb, 0x84, 0x74, 0xf5, 0x85, 0x87, 0xb1, 0x1b, 0x8a, 0xe3, 0x17, 0xc6, 0x33, 0xe9, 0xc7, 0x6c, 0x79, 0x1d, 0x46, 0x62, 0x84, 0xad, 0x9c, 0x4f, 0xf7, 0x35, 0xa6, 0xd2, 0xe9, 0x63, 0xb5, 0x9b, 0xbc, 0xa4, 0x40, 0xa3, 0x07, 0x09, 0x1a, 0x1b, 0x4e, 0x46, 0xbc, 0xc7, 0xa2, 0xf9, 0xfb, 0x2f, 0x1c, 0x89, 0x8e, 0xcb, 0x19, 0x91, 0x8b, 0xe4, 0x12, 0x1d, 0x7e, 0x8e, 0xd0, 0x4c, 0xd5, 0x0c, 0x9a, 0x59, 0xe9, 0x87, 0x98, 0x01, 0x07, 0xbb, 0xbf, 0x29, 0x9c, 0x23, 0x2e, 0x7f, 0xdb, 0xe1, 0x0a, 0x4c, 0xfd, 0xae, 0x5c, 0x89, 0x1c, 0x96, 0xaf, 0xdf, 0xf9, 0x4b, 0x54, 0xcc, 0xd2, 0xbc, 0x19, 0xd3, 0xcd, 0xaa, 0x66, 0x44, 0x85, 0x9c }; _ = try server.stream.write(Handshake, Handshake{ .certificate_verify = CertificateVerify{ .algorithm = .rsa_pss_rsae_sha256, .signature = &signature_verify, @@ -2186,65 +2214,422 @@ test "tls client and server handshake, data, and close_notify" { 0x17, // application data (lie for tls 1.2 compat) 0x03, 0x03, // tls 1.2 0x00, 0x17, // application data len - 0x6b, 0xe0, 0x2f, 0x9d, 0xa7, 0xc2, // encrypted data (empty EncryptedExtensions message) + 0x6b, 0xe0, 0x2f, 0x9d, 0xa7, 0xc2, // encrypted data (empty EncryptedExtensions message) 0xdc, // encrypted data type (handshake) - 0x9d, 0xde, 0xf5, 0x6f, 0x24, 0x68, 0xb9, 0x0a, // auth tag - 0xdf, 0xa2, 0x51, 0x01, 0xab, 0x03, 0x44, 0xae, // auth tag + 0x9d, 0xde, 0xf5, 0x6f, 0x24, 0x68, 0xb9, 0x0a, // auth tag + 0xdf, 0xa2, 0x51, 0x01, 0xab, 0x03, 0x44, 0xae, // auth tag } ++ [_]u8{ 0x17, // application data (lie for tls 1.2 compat) 0x03, 0x03, // tls 1.2 0x03, 0x43, // application data len - 0xba, 0xf0, 0x0a, 0x9b, 0xe5, 0x0f, 0x3f, 0x23, 0x07, 0xe7, 0x26, 0xed, 0xcb, 0xda, 0xcb, 0xe4, - 0xb1, 0x86, 0x16, 0x44, 0x9d, 0x46, 0xc6, 0x20, 0x7a, 0xf6, 0xe9, 0x95, 0x3e, 0xe5, 0xd2, 0x41, - 0x1b, 0xa6, 0x5d, 0x31, 0xfe, 0xaf, 0x4f, 0x78, 0x76, 0x4f, 0x2d, 0x69, 0x39, 0x87, 0x18, 0x6c, - 0xc0, 0x13, 0x29, 0xc1, 0x87, 0xa5, 0xe4, 0x60, 0x8e, 0x8d, 0x27, 0xb3, 0x18, 0xe9, 0x8d, 0xd9, - 0x47, 0x69, 0xf7, 0x73, 0x9c, 0xe6, 0x76, 0x83, 0x92, 0xca, 0xca, 0x8d, 0xcc, 0x59, 0x7d, 0x77, - 0xec, 0x0d, 0x12, 0x72, 0x23, 0x37, 0x85, 0xf6, 0xe6, 0x9d, 0x6f, 0x43, 0xef, 0xfa, 0x8e, 0x79, - 0x05, 0xed, 0xfd, 0xc4, 0x03, 0x7e, 0xee, 0x59, 0x33, 0xe9, 0x90, 0xa7, 0x97, 0x2f, 0x20, 0x69, - 0x13, 0xa3, 0x1e, 0x8d, 0x04, 0x93, 0x13, 0x66, 0xd3, 0xd8, 0xbc, 0xd6, 0xa4, 0xa4, 0xd6, 0x47, - 0xdd, 0x4b, 0xd8, 0x0b, 0x0f, 0xf8, 0x63, 0xce, 0x35, 0x54, 0x83, 0x3d, 0x74, 0x4c, 0xf0, 0xe0, - 0xb9, 0xc0, 0x7c, 0xae, 0x72, 0x6d, 0xd2, 0x3f, 0x99, 0x53, 0xdf, 0x1f, 0x1c, 0xe3, 0xac, 0xeb, - 0x3b, 0x72, 0x30, 0x87, 0x1e, 0x92, 0x31, 0x0c, 0xfb, 0x2b, 0x09, 0x84, 0x86, 0xf4, 0x35, 0x38, - 0xf8, 0xe8, 0x2d, 0x84, 0x04, 0xe5, 0xc6, 0xc2, 0x5f, 0x66, 0xa6, 0x2e, 0xbe, 0x3c, 0x5f, 0x26, - 0x23, 0x26, 0x40, 0xe2, 0x0a, 0x76, 0x91, 0x75, 0xef, 0x83, 0x48, 0x3c, 0xd8, 0x1e, 0x6c, 0xb1, - 0x6e, 0x78, 0xdf, 0xad, 0x4c, 0x1b, 0x71, 0x4b, 0x04, 0xb4, 0x5f, 0x6a, 0xc8, 0xd1, 0x06, 0x5a, - 0xd1, 0x8c, 0x13, 0x45, 0x1c, 0x90, 0x55, 0xc4, 0x7d, 0xa3, 0x00, 0xf9, 0x35, 0x36, 0xea, 0x56, - 0xf5, 0x31, 0x98, 0x6d, 0x64, 0x92, 0x77, 0x53, 0x93, 0xc4, 0xcc, 0xb0, 0x95, 0x46, 0x70, 0x92, - 0xa0, 0xec, 0x0b, 0x43, 0xed, 0x7a, 0x06, 0x87, 0xcb, 0x47, 0x0c, 0xe3, 0x50, 0x91, 0x7b, 0x0a, - 0xc3, 0x0c, 0x6e, 0x5c, 0x24, 0x72, 0x5a, 0x78, 0xc4, 0x5f, 0x9f, 0x5f, 0x29, 0xb6, 0x62, 0x68, - 0x67, 0xf6, 0xf7, 0x9c, 0xe0, 0x54, 0x27, 0x35, 0x47, 0xb3, 0x6d, 0xf0, 0x30, 0xbd, 0x24, 0xaf, - 0x10, 0xd6, 0x32, 0xdb, 0xa5, 0x4f, 0xc4, 0xe8, 0x90, 0xbd, 0x05, 0x86, 0x92, 0x8c, 0x02, 0x06, - 0xca, 0x2e, 0x28, 0xe4, 0x4e, 0x22, 0x7a, 0x2d, 0x50, 0x63, 0x19, 0x59, 0x35, 0xdf, 0x38, 0xda, - 0x89, 0x36, 0x09, 0x2e, 0xef, 0x01, 0xe8, 0x4c, 0xad, 0x2e, 0x49, 0xd6, 0x2e, 0x47, 0x0a, 0x6c, - 0x77, 0x45, 0xf6, 0x25, 0xec, 0x39, 0xe4, 0xfc, 0x23, 0x32, 0x9c, 0x79, 0xd1, 0x17, 0x28, 0x76, - 0x80, 0x7c, 0x36, 0xd7, 0x36, 0xba, 0x42, 0xbb, 0x69, 0xb0, 0x04, 0xff, 0x55, 0xf9, 0x38, 0x50, - 0xdc, 0x33, 0xc1, 0xf9, 0x8a, 0xbb, 0x92, 0x85, 0x83, 0x24, 0xc7, 0x6f, 0xf1, 0xeb, 0x08, 0x5d, - 0xb3, 0xc1, 0xfc, 0x50, 0xf7, 0x4e, 0xc0, 0x44, 0x42, 0xe6, 0x22, 0x97, 0x3e, 0xa7, 0x07, 0x43, - 0x41, 0x87, 0x94, 0xc3, 0x88, 0x14, 0x0b, 0xb4, 0x92, 0xd6, 0x29, 0x4a, 0x05, 0x40, 0xe5, 0xa5, - 0x9c, 0xfa, 0xe6, 0x0b, 0xa0, 0xf1, 0x48, 0x99, 0xfc, 0xa7, 0x13, 0x33, 0x31, 0x5e, 0xa0, 0x83, - 0xa6, 0x8e, 0x1d, 0x7c, 0x1e, 0x4c, 0xdc, 0x2f, 0x56, 0xbc, 0xd6, 0x11, 0x96, 0x81, 0xa4, 0xad, - 0xbc, 0x1b, 0xbf, 0x42, 0xaf, 0xd8, 0x06, 0xc3, 0xcb, 0xd4, 0x2a, 0x07, 0x6f, 0x54, 0x5d, 0xee, - 0x4e, 0x11, 0x8d, 0x0b, 0x39, 0x67, 0x54, 0xbe, 0x2b, 0x04, 0x2a, 0x68, 0x5d, 0xd4, 0x72, 0x7e, - 0x89, 0xc0, 0x38, 0x6a, 0x94, 0xd3, 0xcd, 0x6e, 0xcb, 0x98, 0x20, 0xe9, 0xd4, 0x9a, 0xfe, 0xed, - 0x66, 0xc4, 0x7e, 0x6f, 0xc2, 0x43, 0xea, 0xbe, 0xbb, 0xcb, 0x0b, 0x02, 0x45, 0x38, 0x77, 0xf5, - 0xac, 0x5d, 0xbf, 0xbd, 0xf8, 0xdb, 0x10, 0x52, 0xa3, 0xc9, 0x94, 0xb2, 0x24, 0xcd, 0x9a, 0xaa, - 0xf5, 0x6b, 0x02, 0x6b, 0xb9, 0xef, 0xa2, 0xe0, 0x13, 0x02, 0xb3, 0x64, 0x01, 0xab, 0x64, 0x94, - 0xe7, 0x01, 0x8d, 0x6e, 0x5b, 0x57, 0x3b, 0xd3, 0x8b, 0xce, 0xf0, 0x23, 0xb1, 0xfc, 0x92, 0x94, - 0x6b, 0xbc, 0xa0, 0x20, 0x9c, 0xa5, 0xfa, 0x92, 0x6b, 0x49, 0x70, 0xb1, 0x00, 0x91, 0x03, 0x64, - 0x5c, 0xb1, 0xfc, 0xfe, 0x55, 0x23, 0x11, 0xff, 0x73, 0x05, 0x58, 0x98, 0x43, 0x70, 0x03, 0x8f, - 0xd2, 0xcc, 0xe2, 0xa9, 0x1f, 0xc7, 0x4d, 0x6f, 0x3e, 0x3e, 0xa9, 0xf8, 0x43, 0xee, 0xd3, 0x56, - 0xf6, 0xf8, 0x2d, 0x35, 0xd0, 0x3b, 0xc2, 0x4b, 0x81, 0xb5, 0x8c, 0xeb, 0x1a, 0x43, 0xec, 0x94, - 0x37, 0xe6, 0xf1, 0xe5, 0x0e, 0xb6, 0xf5, 0x55, 0xe3, 0x21, 0xfd, 0x67, 0xc8, 0x33, 0x2e, 0xb1, - 0xb8, 0x32, 0xaa, 0x8d, 0x79, 0x5a, 0x27, 0xd4, 0x79, 0xc6, 0xe2, 0x7d, 0x5a, 0x61, 0x03, 0x46, - 0x83, 0x89, 0x19, 0x03, 0xf6, 0x64, 0x21, 0xd0, 0x94, 0xe1, 0xb0, 0x0a, 0x9a, 0x13, 0x8d, 0x86, - 0x1e, 0x6f, 0x78, 0xa2, 0x0a, 0xd3, 0xe1, 0x58, 0x00, 0x54, 0xd2, 0xe3, 0x05, 0x25, 0x3c, 0x71, - 0x3a, 0x02, 0xfe, 0x1e, 0x28, 0xde, 0xee, 0x73, 0x36, 0x24, 0x6f, 0x6a, 0xe3, 0x43, 0x31, 0x80, - 0x6b, 0x46, 0xb4, 0x7b, 0x83, 0x3c, 0x39, 0xb9, 0xd3, 0x1c, 0xd3, 0x00, 0xc2, 0xa6, 0xed, 0x83, - 0x13, 0x99, 0x77, 0x6d, 0x07, 0xf5, 0x70, 0xea, 0xf0, 0x05, 0x9a, 0x2c, 0x68, 0xa5, 0xf3, 0xae, - 0x16, 0xb6, 0x17, 0x40, 0x4a, 0xf7, 0xb7, 0x23, 0x1a, 0x4d, 0x94, 0x27, 0x58, 0xfc, 0x02, 0x0b, - 0x3f, 0x23, 0xee, 0x8c, 0x15, 0xe3, 0x60, 0x44, 0xcf, 0xd6, 0x7c, 0xd6, 0x40, 0x99, 0x3b, 0x16, - 0x20, 0x75, 0x97, 0xfb, 0xf3, 0x85, 0xea, 0x7a, 0x4d, 0x99, 0xe8, 0xd4, 0x56, 0xff, 0x83, 0xd4, - 0x1f, 0x7b, 0x8b, 0x4f, 0x06, 0x9b, 0x02, 0x8a, 0x2a, 0x63, 0xa9, 0x19, 0xa7, 0x0e, 0x3a, 0x10, + 0xba, 0xf0, + 0x0a, 0x9b, + 0xe5, 0x0f, + 0x3f, 0x23, + 0x07, 0xe7, + 0x26, 0xed, + 0xcb, 0xda, + 0xcb, 0xe4, + 0xb1, 0x86, + 0x16, 0x44, + 0x9d, 0x46, + 0xc6, 0x20, + 0x7a, 0xf6, + 0xe9, 0x95, + 0x3e, 0xe5, + 0xd2, 0x41, + 0x1b, 0xa6, + 0x5d, 0x31, + 0xfe, 0xaf, + 0x4f, 0x78, + 0x76, 0x4f, + 0x2d, 0x69, + 0x39, 0x87, + 0x18, 0x6c, + 0xc0, 0x13, + 0x29, 0xc1, + 0x87, 0xa5, + 0xe4, 0x60, + 0x8e, 0x8d, + 0x27, 0xb3, + 0x18, 0xe9, + 0x8d, 0xd9, + 0x47, 0x69, + 0xf7, 0x73, + 0x9c, 0xe6, + 0x76, 0x83, + 0x92, 0xca, + 0xca, 0x8d, + 0xcc, 0x59, + 0x7d, 0x77, + 0xec, 0x0d, + 0x12, 0x72, + 0x23, 0x37, + 0x85, 0xf6, + 0xe6, 0x9d, + 0x6f, 0x43, + 0xef, 0xfa, + 0x8e, 0x79, + 0x05, 0xed, + 0xfd, 0xc4, + 0x03, 0x7e, + 0xee, 0x59, + 0x33, 0xe9, + 0x90, 0xa7, + 0x97, 0x2f, + 0x20, 0x69, + 0x13, 0xa3, + 0x1e, 0x8d, + 0x04, 0x93, + 0x13, 0x66, + 0xd3, 0xd8, + 0xbc, 0xd6, + 0xa4, 0xa4, + 0xd6, 0x47, + 0xdd, 0x4b, + 0xd8, 0x0b, + 0x0f, 0xf8, + 0x63, 0xce, + 0x35, 0x54, + 0x83, 0x3d, + 0x74, 0x4c, + 0xf0, 0xe0, + 0xb9, 0xc0, + 0x7c, 0xae, + 0x72, 0x6d, + 0xd2, 0x3f, + 0x99, 0x53, + 0xdf, 0x1f, + 0x1c, 0xe3, + 0xac, 0xeb, + 0x3b, 0x72, + 0x30, 0x87, + 0x1e, 0x92, + 0x31, 0x0c, + 0xfb, 0x2b, + 0x09, 0x84, + 0x86, 0xf4, + 0x35, 0x38, + 0xf8, 0xe8, + 0x2d, 0x84, + 0x04, 0xe5, + 0xc6, 0xc2, + 0x5f, 0x66, + 0xa6, 0x2e, + 0xbe, 0x3c, + 0x5f, 0x26, + 0x23, 0x26, + 0x40, 0xe2, + 0x0a, 0x76, + 0x91, 0x75, + 0xef, 0x83, + 0x48, 0x3c, + 0xd8, 0x1e, + 0x6c, 0xb1, + 0x6e, 0x78, + 0xdf, 0xad, + 0x4c, 0x1b, + 0x71, 0x4b, + 0x04, 0xb4, + 0x5f, 0x6a, + 0xc8, 0xd1, + 0x06, 0x5a, + 0xd1, 0x8c, + 0x13, 0x45, + 0x1c, 0x90, + 0x55, 0xc4, + 0x7d, 0xa3, + 0x00, 0xf9, + 0x35, 0x36, + 0xea, 0x56, + 0xf5, 0x31, + 0x98, 0x6d, + 0x64, 0x92, + 0x77, 0x53, + 0x93, 0xc4, + 0xcc, 0xb0, + 0x95, 0x46, + 0x70, 0x92, + 0xa0, 0xec, + 0x0b, 0x43, + 0xed, 0x7a, + 0x06, 0x87, + 0xcb, 0x47, + 0x0c, 0xe3, + 0x50, 0x91, + 0x7b, 0x0a, + 0xc3, 0x0c, + 0x6e, 0x5c, + 0x24, 0x72, + 0x5a, 0x78, + 0xc4, 0x5f, + 0x9f, 0x5f, + 0x29, 0xb6, + 0x62, 0x68, + 0x67, 0xf6, + 0xf7, 0x9c, + 0xe0, 0x54, + 0x27, 0x35, + 0x47, 0xb3, + 0x6d, 0xf0, + 0x30, 0xbd, + 0x24, 0xaf, + 0x10, 0xd6, + 0x32, 0xdb, + 0xa5, 0x4f, + 0xc4, 0xe8, + 0x90, 0xbd, + 0x05, 0x86, + 0x92, 0x8c, + 0x02, 0x06, + 0xca, 0x2e, + 0x28, 0xe4, + 0x4e, 0x22, + 0x7a, 0x2d, + 0x50, 0x63, + 0x19, 0x59, + 0x35, 0xdf, + 0x38, 0xda, + 0x89, 0x36, + 0x09, 0x2e, + 0xef, 0x01, + 0xe8, 0x4c, + 0xad, 0x2e, + 0x49, 0xd6, + 0x2e, 0x47, + 0x0a, 0x6c, + 0x77, 0x45, + 0xf6, 0x25, + 0xec, 0x39, + 0xe4, 0xfc, + 0x23, 0x32, + 0x9c, 0x79, + 0xd1, 0x17, + 0x28, 0x76, + 0x80, 0x7c, + 0x36, 0xd7, + 0x36, 0xba, + 0x42, 0xbb, + 0x69, 0xb0, + 0x04, 0xff, + 0x55, 0xf9, + 0x38, 0x50, + 0xdc, 0x33, + 0xc1, 0xf9, + 0x8a, 0xbb, + 0x92, 0x85, + 0x83, 0x24, + 0xc7, 0x6f, + 0xf1, 0xeb, + 0x08, 0x5d, + 0xb3, 0xc1, + 0xfc, 0x50, + 0xf7, 0x4e, + 0xc0, 0x44, + 0x42, 0xe6, + 0x22, 0x97, + 0x3e, 0xa7, + 0x07, 0x43, + 0x41, 0x87, + 0x94, 0xc3, + 0x88, 0x14, + 0x0b, 0xb4, + 0x92, 0xd6, + 0x29, 0x4a, + 0x05, 0x40, + 0xe5, 0xa5, + 0x9c, 0xfa, + 0xe6, 0x0b, + 0xa0, 0xf1, + 0x48, 0x99, + 0xfc, 0xa7, + 0x13, 0x33, + 0x31, 0x5e, + 0xa0, 0x83, + 0xa6, 0x8e, + 0x1d, 0x7c, + 0x1e, 0x4c, + 0xdc, 0x2f, + 0x56, 0xbc, + 0xd6, 0x11, + 0x96, 0x81, + 0xa4, 0xad, + 0xbc, 0x1b, + 0xbf, 0x42, + 0xaf, 0xd8, + 0x06, 0xc3, + 0xcb, 0xd4, + 0x2a, 0x07, + 0x6f, 0x54, + 0x5d, 0xee, + 0x4e, 0x11, + 0x8d, 0x0b, + 0x39, 0x67, + 0x54, 0xbe, + 0x2b, 0x04, + 0x2a, 0x68, + 0x5d, 0xd4, + 0x72, 0x7e, + 0x89, 0xc0, + 0x38, 0x6a, + 0x94, 0xd3, + 0xcd, 0x6e, + 0xcb, 0x98, + 0x20, 0xe9, + 0xd4, 0x9a, + 0xfe, 0xed, + 0x66, 0xc4, + 0x7e, 0x6f, + 0xc2, 0x43, + 0xea, 0xbe, + 0xbb, 0xcb, + 0x0b, 0x02, + 0x45, 0x38, + 0x77, 0xf5, + 0xac, 0x5d, + 0xbf, 0xbd, + 0xf8, 0xdb, + 0x10, 0x52, + 0xa3, 0xc9, + 0x94, 0xb2, + 0x24, 0xcd, + 0x9a, 0xaa, + 0xf5, 0x6b, + 0x02, 0x6b, + 0xb9, 0xef, + 0xa2, 0xe0, + 0x13, 0x02, + 0xb3, 0x64, + 0x01, 0xab, + 0x64, 0x94, + 0xe7, 0x01, + 0x8d, 0x6e, + 0x5b, 0x57, + 0x3b, 0xd3, + 0x8b, 0xce, + 0xf0, 0x23, + 0xb1, 0xfc, + 0x92, 0x94, + 0x6b, 0xbc, + 0xa0, 0x20, + 0x9c, 0xa5, + 0xfa, 0x92, + 0x6b, 0x49, + 0x70, 0xb1, + 0x00, 0x91, + 0x03, 0x64, + 0x5c, 0xb1, + 0xfc, 0xfe, + 0x55, 0x23, + 0x11, 0xff, + 0x73, 0x05, + 0x58, 0x98, + 0x43, 0x70, + 0x03, 0x8f, + 0xd2, 0xcc, + 0xe2, 0xa9, + 0x1f, 0xc7, + 0x4d, 0x6f, + 0x3e, 0x3e, + 0xa9, 0xf8, + 0x43, 0xee, + 0xd3, 0x56, + 0xf6, 0xf8, + 0x2d, 0x35, + 0xd0, 0x3b, + 0xc2, 0x4b, + 0x81, 0xb5, + 0x8c, 0xeb, + 0x1a, 0x43, + 0xec, 0x94, + 0x37, 0xe6, + 0xf1, 0xe5, + 0x0e, 0xb6, + 0xf5, 0x55, + 0xe3, 0x21, + 0xfd, 0x67, + 0xc8, 0x33, + 0x2e, 0xb1, + 0xb8, 0x32, + 0xaa, 0x8d, + 0x79, 0x5a, + 0x27, 0xd4, + 0x79, 0xc6, + 0xe2, 0x7d, + 0x5a, 0x61, + 0x03, 0x46, + 0x83, 0x89, + 0x19, 0x03, + 0xf6, 0x64, + 0x21, 0xd0, + 0x94, 0xe1, + 0xb0, 0x0a, + 0x9a, 0x13, + 0x8d, 0x86, + 0x1e, 0x6f, + 0x78, 0xa2, + 0x0a, 0xd3, + 0xe1, 0x58, + 0x00, 0x54, + 0xd2, 0xe3, + 0x05, 0x25, + 0x3c, 0x71, + 0x3a, 0x02, + 0xfe, 0x1e, + 0x28, 0xde, + 0xee, 0x73, + 0x36, 0x24, + 0x6f, 0x6a, + 0xe3, 0x43, + 0x31, 0x80, + 0x6b, 0x46, + 0xb4, 0x7b, + 0x83, 0x3c, + 0x39, 0xb9, + 0xd3, 0x1c, + 0xd3, 0x00, + 0xc2, 0xa6, + 0xed, 0x83, + 0x13, 0x99, + 0x77, 0x6d, + 0x07, 0xf5, + 0x70, 0xea, + 0xf0, 0x05, + 0x9a, 0x2c, + 0x68, 0xa5, + 0xf3, 0xae, + 0x16, 0xb6, + 0x17, 0x40, + 0x4a, 0xf7, + 0xb7, 0x23, + 0x1a, 0x4d, + 0x94, 0x27, + 0x58, 0xfc, + 0x02, 0x0b, + 0x3f, 0x23, + 0xee, 0x8c, + 0x15, 0xe3, + 0x60, 0x44, + 0xcf, 0xd6, + 0x7c, 0xd6, + 0x40, 0x99, + 0x3b, 0x16, + 0x20, 0x75, + 0x97, 0xfb, + 0xf3, 0x85, + 0xea, 0x7a, + 0x4d, 0x99, + 0xe8, 0xd4, + 0x56, 0xff, + 0x83, 0xd4, + 0x1f, 0x7b, + 0x8b, 0x4f, + 0x06, 0x9b, + 0x02, 0x8a, + 0x2a, 0x63, + 0xa9, 0x19, + 0xa7, 0x0e, + 0x3a, 0x10, 0xe3, 0x08, // encrypted cert 0x41, // encrypted data type (Certificate) 0x58, 0xfa, 0xa5, 0xba, 0xfa, 0x30, 0x18, 0x6c, // auth tag @@ -2253,22 +2638,134 @@ test "tls client and server handshake, data, and close_notify" { 0x17, // application data (lie for tls 1.2 compat) 0x03, 0x03, // tls 1.2 0x01, 0x19, // application data len - 0x73, 0x71, 0x9f, 0xce, 0x07, 0xec, 0x2f, 0x6d, 0x3b, 0xba, 0x02, 0x92, 0xa0, 0xd4, 0x0b, 0x27, - 0x70, 0xc0, 0x6a, 0x27, 0x17, 0x99, 0xa5, 0x33, 0x14, 0xf6, 0xf7, 0x7f, 0xc9, 0x5c, 0x5f, 0xe7, - 0xb9, 0xa4, 0x32, 0x9f, 0xd9, 0x54, 0x8c, 0x67, 0x0e, 0xbe, 0xea, 0x2f, 0x2d, 0x5c, 0x35, 0x1d, - 0xd9, 0x35, 0x6e, 0xf2, 0xdc, 0xd5, 0x2e, 0xb1, 0x37, 0xbd, 0x3a, 0x67, 0x65, 0x22, 0xf8, 0xcd, - 0x0f, 0xb7, 0x56, 0x07, 0x89, 0xad, 0x7b, 0x0e, 0x3c, 0xab, 0xa2, 0xe3, 0x7e, 0x6b, 0x41, 0x99, - 0xc6, 0x79, 0x3b, 0x33, 0x46, 0xed, 0x46, 0xcf, 0x74, 0x0a, 0x9f, 0xa1, 0xfe, 0xc4, 0x14, 0xdc, - 0x71, 0x5c, 0x41, 0x5c, 0x60, 0xe5, 0x75, 0x70, 0x3c, 0xe6, 0xa3, 0x4b, 0x70, 0xb5, 0x19, 0x1a, - 0xa6, 0xa6, 0x1a, 0x18, 0xfa, 0xff, 0x21, 0x6c, 0x68, 0x7a, 0xd8, 0xd1, 0x7e, 0x12, 0xa7, 0xe9, - 0x99, 0x15, 0xa6, 0x11, 0xbf, 0xc1, 0xa2, 0xbe, 0xfc, 0x15, 0xe6, 0xe9, 0x4d, 0x78, 0x46, 0x42, - 0xe6, 0x82, 0xfd, 0x17, 0x38, 0x2a, 0x34, 0x8c, 0x30, 0x10, 0x56, 0xb9, 0x40, 0xc9, 0x84, 0x72, - 0x00, 0x40, 0x8b, 0xec, 0x56, 0xc8, 0x1e, 0xa3, 0xd7, 0x21, 0x7a, 0xb8, 0xe8, 0x5a, 0x88, 0x71, - 0x53, 0x95, 0x89, 0x9c, 0x90, 0x58, 0x7f, 0x72, 0xe8, 0xdd, 0xd7, 0x4b, 0x26, 0xd8, 0xed, 0xc1, - 0xc7, 0xc8, 0x37, 0xd9, 0xf2, 0xeb, 0xbc, 0x26, 0x09, 0x62, 0x21, 0x90, 0x38, 0xb0, 0x56, 0x54, - 0xa6, 0x3a, 0x0b, 0x12, 0x99, 0x9b, 0x4a, 0x83, 0x06, 0xa3, 0xdd, 0xcc, 0x0e, 0x17, 0xc5, 0x3b, - 0xa8, 0xf9, 0xc8, 0x03, 0x63, 0xf7, 0x84, 0x13, 0x54, 0xd2, 0x91, 0xb4, 0xac, 0xe0, 0xc0, 0xf3, - 0x30, 0xc0, 0xfc, 0xd5, 0xaa, 0x9d, 0xee, 0xf9, 0x69, 0xae, 0x8a, 0xb2, 0xd9, 0x8d, 0xa8, 0x8e, + 0x73, 0x71, + 0x9f, 0xce, + 0x07, 0xec, + 0x2f, 0x6d, + 0x3b, 0xba, + 0x02, 0x92, + 0xa0, 0xd4, + 0x0b, 0x27, + 0x70, 0xc0, + 0x6a, 0x27, + 0x17, 0x99, + 0xa5, 0x33, + 0x14, 0xf6, + 0xf7, 0x7f, + 0xc9, 0x5c, + 0x5f, 0xe7, + 0xb9, 0xa4, + 0x32, 0x9f, + 0xd9, 0x54, + 0x8c, 0x67, + 0x0e, 0xbe, + 0xea, 0x2f, + 0x2d, 0x5c, + 0x35, 0x1d, + 0xd9, 0x35, + 0x6e, 0xf2, + 0xdc, 0xd5, + 0x2e, 0xb1, + 0x37, 0xbd, + 0x3a, 0x67, + 0x65, 0x22, + 0xf8, 0xcd, + 0x0f, 0xb7, + 0x56, 0x07, + 0x89, 0xad, + 0x7b, 0x0e, + 0x3c, 0xab, + 0xa2, 0xe3, + 0x7e, 0x6b, + 0x41, 0x99, + 0xc6, 0x79, + 0x3b, 0x33, + 0x46, 0xed, + 0x46, 0xcf, + 0x74, 0x0a, + 0x9f, 0xa1, + 0xfe, 0xc4, + 0x14, 0xdc, + 0x71, 0x5c, + 0x41, 0x5c, + 0x60, 0xe5, + 0x75, 0x70, + 0x3c, 0xe6, + 0xa3, 0x4b, + 0x70, 0xb5, + 0x19, 0x1a, + 0xa6, 0xa6, + 0x1a, 0x18, + 0xfa, 0xff, + 0x21, 0x6c, + 0x68, 0x7a, + 0xd8, 0xd1, + 0x7e, 0x12, + 0xa7, 0xe9, + 0x99, 0x15, + 0xa6, 0x11, + 0xbf, 0xc1, + 0xa2, 0xbe, + 0xfc, 0x15, + 0xe6, 0xe9, + 0x4d, 0x78, + 0x46, 0x42, + 0xe6, 0x82, + 0xfd, 0x17, + 0x38, 0x2a, + 0x34, 0x8c, + 0x30, 0x10, + 0x56, 0xb9, + 0x40, 0xc9, + 0x84, 0x72, + 0x00, 0x40, + 0x8b, 0xec, + 0x56, 0xc8, + 0x1e, 0xa3, + 0xd7, 0x21, + 0x7a, 0xb8, + 0xe8, 0x5a, + 0x88, 0x71, + 0x53, 0x95, + 0x89, 0x9c, + 0x90, 0x58, + 0x7f, 0x72, + 0xe8, 0xdd, + 0xd7, 0x4b, + 0x26, 0xd8, + 0xed, 0xc1, + 0xc7, 0xc8, + 0x37, 0xd9, + 0xf2, 0xeb, + 0xbc, 0x26, + 0x09, 0x62, + 0x21, 0x90, + 0x38, 0xb0, + 0x56, 0x54, + 0xa6, 0x3a, + 0x0b, 0x12, + 0x99, 0x9b, + 0x4a, 0x83, + 0x06, 0xa3, + 0xdd, 0xcc, + 0x0e, 0x17, + 0xc5, 0x3b, + 0xa8, 0xf9, + 0xc8, 0x03, + 0x63, 0xf7, + 0x84, 0x13, + 0x54, 0xd2, + 0x91, 0xb4, + 0xac, 0xe0, + 0xc0, 0xf3, + 0x30, 0xc0, + 0xfc, 0xd5, + 0xaa, 0x9d, + 0xee, 0xf9, + 0x69, 0xae, + 0x8a, 0xb2, + 0xd9, 0x8d, + 0xa8, 0x8e, 0xbb, 0x6e, 0xa8, 0x0a, 0x3a, 0x11, 0xf0, 0x0e, // encrypted signature_verify 0xa2, // encrypted data type (SignatureVerify) 0x96, 0xa3, 0x23, 0x23, 0x67, 0xff, 0x07, 0x5e, // auth tag @@ -2277,17 +2774,39 @@ test "tls client and server handshake, data, and close_notify" { 0x17, // application data (lie for tls 1.2 compat) 0x03, 0x03, // tls 1.2 0x00, 0x45, // application data len - 0x10, 0x61, 0xde, 0x27, 0xe5, 0x1c, 0x2c, 0x9f, 0x34, 0x29, 0x11, 0x80, 0x6f, 0x28, 0x2b, 0x71, - 0x0c, 0x10, 0x63, 0x2c, 0xa5, 0x00, 0x67, 0x55, 0x88, 0x0d, 0xbf, 0x70, 0x06, 0x00, 0x2d, 0x0e, - 0x84, 0xfe, 0xd9, 0xad, 0xf2, 0x7a, 0x43, 0xb5, 0x19, 0x23, 0x03, 0xe4, 0xdf, 0x5c, 0x28, 0x5d, - 0x58, 0xe3, 0xc7, 0x62, + 0x10, 0x61, + 0xde, 0x27, + 0xe5, 0x1c, + 0x2c, 0x9f, + 0x34, 0x29, + 0x11, 0x80, + 0x6f, 0x28, + 0x2b, 0x71, + 0x0c, 0x10, + 0x63, 0x2c, + 0xa5, 0x00, + 0x67, 0x55, + 0x88, 0x0d, + 0xbf, 0x70, + 0x06, 0x00, + 0x2d, 0x0e, + 0x84, 0xfe, + 0xd9, 0xad, + 0xf2, 0x7a, + 0x43, 0xb5, + 0x19, 0x23, + 0x03, 0xe4, + 0xdf, 0x5c, + 0x28, 0x5d, + 0x58, 0xe3, + 0xc7, 0x62, 0x24, // encrypted data type (finished) 0x07, 0x84, 0x40, 0xc0, 0x74, 0x23, 0x74, 0x74, // auth tag 0x4a, 0xec, 0xf2, 0x8c, 0xf3, 0x18, 0x2f, 0xd0, // auth tag })); - try client.stream.expectFragment(.handshake, .server_hello); - try client.recv_hello(key_pairs); + client_command = try client.advance(client_command); + try std.testing.expect(client_command == .recv_encrypted_extensions); { const s = server.stream.handshake_cipher.?.aes_256_gcm_sha384; const c = client.stream.handshake_cipher.?.aes_256_gcm_sha384; @@ -2300,22 +2819,20 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - const client_iv = [_]u8{ 0x42,0x56,0xd2,0xe0,0xe8,0x8b,0xab,0xdd,0x05,0xeb,0x2f,0x27 }; + const client_iv = [_]u8{ 0x42, 0x56, 0xd2, 0xe0, 0xe8, 0x8b, 0xab, 0xdd, 0x05, 0xeb, 0x2f, 0x27 }; try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } - try client.stream.expectFragment(.handshake, .encrypted_extensions); - try client.recv_encrypted_extensions(); - try client.stream.expectFragment(.handshake, .certificate); - const cert = try client.recv_certificate(); - defer allocator.free(cert.certificate.buffer); + client_command = try client.advance(client_command); + try std.testing.expect(client_command == .recv_certificate_or_finished); + + client_command = try client.advance(client_command); + try std.testing.expect(client_command == .recv_certificate_verify); - var digest = client.stream.transcript_hash.peek(); - try client.stream.expectFragment(.handshake, .certificate_verify); - try client.recv_certificate_verify(digest, cert); + client_command = try client.advance(client_command); + try std.testing.expect(client_command == .recv_finished); - digest = client.stream.transcript_hash.peek(); - try client.stream.expectFragment(.handshake, .finished); - try client.recv_finished(digest); + client_command = try client.advance(client_command); + try std.testing.expect(client_command == .send_finished); { const s = server.stream.application_cipher.?.aes_256_gcm_sha384; @@ -2327,42 +2844,77 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - const client_iv = [_]u8{ 0xbb,0x00,0x79,0x56,0xf4,0x74,0xb2,0x5d,0xe9,0x02,0x43,0x2f, }; + const client_iv = [_]u8{ + 0xbb, + 0x00, + 0x79, + 0x56, + 0xf4, + 0x74, + 0xb2, + 0x5d, + 0xe9, + 0x02, + 0x43, + 0x2f, + }; try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } - try client.send_finished(); + client_command = try client.advance(client_command); + try std.testing.expect(client_command == .sent_finished); + try inner_stream.expect(&([_]u8{ 0x14, // ChangeCipherSpec 0x03, 0x03, // tls 1.2 0x00, 0x01, // len 0x01, // .change_cipher_spec - } - ++ [_]u8{ - 0x17, // app data (lie for TLS 1.2) - 0x03, 0x03, // tls 1.2 - 0x00, 0x45, // len - 0x9f, 0xf9, 0xb0, 0x63, 0x17, 0x51, 0x77, 0x32, 0x2a, 0x46, 0xdd, 0x98, 0x96, 0xf3, 0xc3, 0xbb, - 0x82, 0x0a, 0xb5, 0x17, 0x43, 0xeb, 0xc2, 0x5f, 0xda, 0xdd, 0x53, 0x45, 0x4b, 0x73, 0xde, 0xb5, - 0x4c, 0xc7, 0x24, 0x8d, 0x41, 0x1a, 0x18, 0xbc, 0xcf, 0x65, 0x7a, 0x96, 0x08, 0x24, 0xe9, 0xa1, - 0x93, 0x64, 0x83, 0x7c, // encrypted data - 0x35, // handshake - 0x0a, 0x69, 0xa8, 0x8d, 0x4b, 0xf6, 0x35, 0xc8, // auth tag - 0x5e, 0xb8, 0x74, 0xae, 0xbc, 0x9d, 0xfd, 0xe8, // auth tag - })); + } ++ [_]u8{ + 0x17, // app data (lie for TLS 1.2) + 0x03, 0x03, // tls 1.2 + 0x00, 0x45, // len + 0x9f, 0xf9, + 0xb0, 0x63, + 0x17, 0x51, + 0x77, 0x32, + 0x2a, 0x46, + 0xdd, 0x98, + 0x96, 0xf3, + 0xc3, 0xbb, + 0x82, 0x0a, + 0xb5, 0x17, + 0x43, 0xeb, + 0xc2, 0x5f, + 0xda, 0xdd, + 0x53, 0x45, + 0x4b, 0x73, + 0xde, 0xb5, + 0x4c, 0xc7, + 0x24, 0x8d, + 0x41, 0x1a, + 0x18, 0xbc, + 0xcf, 0x65, + 0x7a, 0x96, + 0x08, 0x24, + 0xe9, 0xa1, + 0x93, 0x64, 0x83, 0x7c, // encrypted data + 0x35, // handshake + 0x0a, 0x69, 0xa8, 0x8d, 0x4b, 0xf6, 0x35, 0xc8, // auth tag + 0x5e, 0xb8, 0x74, 0xae, 0xbc, 0x9d, 0xfd, 0xe8, // auth tag + })); try server.recv_finished(); _ = try client.stream.writer().writeAll("ping"); try client.stream.flush(); try inner_stream.expect(&([_]u8{ - 0x17, // app data (FOR REAL THIS TIME) - 0x03, 0x03, // tls 1.2 - 0x00, 0x15, // len + 0x17, // app data (FOR REAL THIS TIME) + 0x03, 0x03, // tls 1.2 + 0x00, 0x15, // len 0x82, 0x81, 0x39, 0xcb, // ping - 0x7b, // app data (exciting!) + 0x7b, // app data (exciting!) 0x73, 0xaa, 0xab, 0xf5, 0xb8, 0x2f, 0xbf, 0x9a, // auth tag 0x29, 0x61, 0xbc, 0xde, 0x10, 0x03, 0x8a, 0x32, // auth tag - })); + })); var recv_ping: [4]u8 = undefined; _ = try server.stream.reader().readAll(&recv_ping); @@ -2371,16 +2923,16 @@ test "tls client and server handshake, data, and close_notify" { server.stream.close(); try std.testing.expect(server.stream.closed); try inner_stream.expect(&([_]u8{ - 0x17, // app data (lie to encrypt) - 0x03, 0x03, // tls 1.2 - 0x00, 0x13, // len + 0x17, // app data (lie to encrypt) + 0x03, 0x03, // tls 1.2 + 0x00, 0x13, // len 0x3e, 0x2d, // alert 0x99, // encrypted message type 0x26, 0xbb, 0xfe, 0x1f, 0x46, 0xfb, 0x4e, 0xe2, // auth tag 0x75, 0x1e, 0x53, 0xbf, 0xfc, 0x7e, 0x65, 0x16, // auth tag - })); + })); - _ = try client.stream.readFragment(); + _ = try client.stream.readPlaintext(); try std.testing.expect(client.stream.closed); } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 526c0fe952b2..05d5cf2f7f9f 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -9,17 +9,10 @@ const Certificate = std.crypto.Certificate; /// `StreamType` must conform to `tls.StreamInterface`. pub fn Client(comptime StreamType: type) type { return struct { - stream: tls.Stream(tls.Plaintext.max_length, StreamType), + stream: Stream, options: Options, - state: State = .start, - - const State = enum { - start, - recv_encrypted_extensions, - recv_finished, - sent_finished, - }; + const Stream = tls.Stream(tls.Plaintext.max_length, StreamType); const Self = @This(); /// Initiates a TLS handshake and establishes a TLSv1.3 session @@ -30,53 +23,76 @@ pub fn Client(comptime StreamType: type) type { }; var res = Self{ .stream = stream_, .options = options }; - while (res.state != .sent_finished) try res.advance(); + var state = Command{ .send_hello = KeyPairs.init() }; + while (state != .sent_finished) state = try res.advance(state); return res; } - /// Advance to next handshake state. - pub fn advance(self: *Self) !void { + /// Execute command and return next one. + pub fn advance(self: *Self, command: Command) !Command { var stream = &self.stream; - switch (self.state) { - .start => { - const key_pairs = KeyPairs.init(); + switch (command) { + .send_hello => |key_pairs| { try self.send_hello(key_pairs); - try stream.expectFragment(.handshake, .server_hello); + return .{ .recv_hello = key_pairs }; + }, + .recv_hello => |key_pairs| { + try stream.expectInnerPlaintext(.handshake, .server_hello); try self.recv_hello(key_pairs); - try stream.expectFragment(.handshake, .encrypted_extensions); + return .{ .recv_encrypted_extensions = {} }; + }, + .recv_encrypted_extensions => { + try stream.expectInnerPlaintext(.handshake, .encrypted_extensions); try self.recv_encrypted_extensions(); - self.state = .recv_encrypted_extensions; + return .{ .recv_certificate_or_finished = {} }; }, - .recv_encrypted_extensions => { - var digest = stream.transcript_hash.owned(); - const header = try stream.expectHandshake(); - switch (header.type) { + .recv_certificate_or_finished => { + const digest = stream.transcript_hash.peek(); + const inner_plaintext = try stream.readInnerPlaintext(); + if (inner_plaintext.type != .handshake) return stream.writeError(.unexpected_message); + switch (inner_plaintext.handshake_type) { .certificate => { - const parsed = try self.recv_certificate(); - defer self.options.allocator.free(parsed.certificate.buffer); - try self.recv_certificate_verify(parsed); + const parsed = try self.recv_certificate(); - digest = stream.transcript_hash.owned(); - try stream.expectFragment(.handshake, .finished); - try self.recv_finished(digest); - self.state = .recv_finished; + return .{ .recv_certificate_verify = parsed }; }, .finished => { + if (self.options.ca_bundle != null) + return self.stream.writeError(.certificate_required); + try self.recv_finished(digest); - self.state = .recv_finished; + + return .{ .send_finished = {} }; }, else => return self.stream.writeError(.unexpected_message), } }, + .recv_certificate_verify => |parsed| { + defer self.options.allocator.free(parsed.certificate.buffer); + + const digest = stream.transcript_hash.peek(); + try stream.expectInnerPlaintext(.handshake, .certificate_verify); + try self.recv_certificate_verify(digest, parsed); + + return .{ .recv_finished = {} }; + }, .recv_finished => { + const digest = stream.transcript_hash.peek(); + try stream.expectInnerPlaintext(.handshake, .finished); + try self.recv_finished(digest); + + return .{ .send_finished = {} }; + }, + .send_finished => { try self.send_finished(); - self.state = .sent_finished; + + return .{ .sent_finished = {} }; }, - .sent_finished => {}, + .sent_finished => return .{ .sent_finished = {} }, } } @@ -93,7 +109,7 @@ pub fn Client(comptime StreamType: type) type { .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, .{ .key_share = &[_]tls.KeyShare{ .{ .x25519_kyber768d00 = .{ - .x25119 = key_pairs.x25519.public_key, + .x25519 = key_pairs.x25519.public_key, .kyber768d00 = key_pairs.kyber768d00.public_key, } }, .{ .secp256r1 = key_pairs.secp256r1.public_key }, @@ -102,7 +118,7 @@ pub fn Client(comptime StreamType: type) type { }, }; - try self.stream.write(tls.Handshake, .{ .client_hello = hello }); + _ = try self.stream.write(tls.Handshake, .{ .client_hello = hello }); try self.stream.flush(); } @@ -117,7 +133,7 @@ pub fn Client(comptime StreamType: type) type { if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { // We already offered all our supported options and we aren't changing them. return stream.writeError(.unexpected_message); - } + } var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; const session_id_len = try stream.read(u8); @@ -130,7 +146,7 @@ pub fn Client(comptime StreamType: type) type { const cipher_suite = try stream.read(tls.CipherSuite); const compression_method = try stream.read(u8); - if (compression_method != 0) return stream.writeError(.illegal_parameter); + if (compression_method != 0) return stream.writeError(.illegal_parameter); var supported_version: ?tls.Version = null; var shared_key: ?[]const u8 = null; @@ -179,7 +195,7 @@ pub fn Client(comptime StreamType: type) type { }, inline .secp256r1, .secp384r1 => |t| { const T = tls.NamedGroupT(t); - const expected_len = T.PublicKey.compressed_sec1_encoded_length; + const expected_len = T.PublicKey.uncompressed_sec1_encoded_length; if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; @@ -213,58 +229,88 @@ pub fn Client(comptime StreamType: type) type { stream.handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash) catch return stream.writeError(.illegal_parameter); } - /// Currently skipped. pub fn recv_encrypted_extensions(self: *Self) !void { var stream = &self.stream; var reader = stream.reader(); var iter = try stream.extensions(); - while (try iter.next()) |ext| { - try reader.skipBytes(ext.len, .{}); + while (try iter.next()) |ext| { + switch (ext.type) { + .server_name => { + try reader.skipBytes(ext.len, .{}); + }, + else => |t| { + std.debug.print("unsupported extension {}\n", .{t}); + return stream.writeError(.unsupported_extension); + }, + } } } - /// Allocates `server_cert`. - pub fn recv_certificate(self: *Self) !crypto.Certificate.Parsed { + /// Verifies trust chain if `options.ca_bundle` is specified. + /// + /// Caller owns allocated Certificate.Parsed.certificate. + pub fn recv_certificate(self: *Self) !Certificate.Parsed { var stream = &self.stream; var reader = stream.reader(); const allocator = self.options.allocator; + const ca_bundle = self.options.ca_bundle; + const verify = ca_bundle != null; var context: [tls.Certificate.max_context_len]u8 = undefined; const context_len = try stream.read(u8); if (context_len > tls.Certificate.max_context_len) return stream.writeError(.decode_error); try reader.readNoEof(context[0..context_len]); - var res: ?crypto.Certificate.Parsed = null; + var first: ?crypto.Certificate.Parsed = null; + var prev: Certificate.Parsed = undefined; + var verified = false; + const now_sec = std.time.timestamp(); var certs_iter = try stream.iterator(u24, u24); while (try certs_iter.next()) |cert_len| { - if (cert_len > tls.Certificate.Entry.max_data_len) - return stream.writeError(.decode_error); - const buf = allocator.alloc(u8, cert_len) catch - return stream.writeError(.internal_error); - errdefer allocator.free(buf); - try reader.readNoEof(buf); + const is_first = first == null; - const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; - res = cert.parse() catch - return stream.writeError(.bad_certificate); + if (!verified) { + if (cert_len > tls.Certificate.Entry.max_data_len) + return stream.writeError(.decode_error); + const buf = allocator.alloc(u8, cert_len) catch + return stream.writeError(.internal_error); + defer if (!is_first) allocator.free(buf); + errdefer allocator.free(buf); + try reader.readNoEof(buf); + + const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; + const cur = cert.parse() catch return stream.writeError(.bad_certificate); + if (first == null) { + if (verify) try cur.verifyHostName(self.options.host); + first = cur; + } else { + if (verify) try prev.verify(cur, now_sec); + } - var ext_iter = try stream.extensions(); - while (try ext_iter.next()) |ext| { - switch (ext.type) { - else => { - try reader.skipBytes(ext.len, .{}); - }, + if (ca_bundle) |b| { + if (b.verify(cur, now_sec)) |_| { + verified = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + error.CertificateExpired => return stream.writeError(.certificate_expired), + else => return stream.writeError(.bad_certificate), + } } + + prev = cur; } + + var ext_iter = try stream.extensions(); + while (try ext_iter.next()) |ext| try reader.skipBytes(ext.len, .{}); } + if (verify and !verified) return stream.writeError(.bad_certificate); - return if (res) |r| r else stream.writeError(.bad_certificate); + return if (first) |r| r else stream.writeError(.bad_certificate); } - /// Deallocates `server_cert` - pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: crypto.Certificate.Parsed,) !void { + pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificate.Parsed) !void { var stream = &self.stream; var reader = stream.reader(); const allocator = self.options.allocator; @@ -287,9 +333,11 @@ pub fn Client(comptime StreamType: type) type { if (cert.pub_key_algo != .X9_62_id_ecPublicKey) return stream.writeError(.bad_certificate); const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(sig_bytes); - const key = try Ecdsa.PublicKey.fromSec1(cert.pubKey()); - try sig.verify(sig_content, key); + const sig = Ecdsa.Signature.fromDer(sig_bytes) catch + return stream.writeError(.decode_error); + const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch + return stream.writeError(.decode_error); + sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); }, inline .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, @@ -300,14 +348,17 @@ pub fn Client(comptime StreamType: type) type { const Hash = SchemeHash(comptime_scheme); const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(cert.pubKey()); + const components = rsa.PublicKey.parseDer(cert.pubKey()) catch + return stream.writeError(.decode_error); const exponent = components.exponent; const modulus = components.modulus; switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus); + const key = rsa.PublicKey.fromBytes(exponent, modulus) catch + return stream.writeError(.bad_certificate); const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); - try rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash); + rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch + return stream.writeError(.decode_error); }, else => { return error.TlsBadRsaSignatureBitCount; @@ -323,8 +374,9 @@ pub fn Client(comptime StreamType: type) type { const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) return stream.writeError(.decode_error); - const key = try Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(sig_content, key); + const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch + return stream.writeError(.bad_certificate); + sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); }, else => { return error.TlsBadSignatureScheme; @@ -337,20 +389,17 @@ pub fn Client(comptime StreamType: type) type { var reader = stream.reader(); const cipher = stream.handshake_cipher.?; - const expected = switch (cipher) { + switch (cipher) { .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), - inline else => |p| brk: { + inline else => |p| { const P = @TypeOf(p); - break :brk &tls.hmac(P.Hmac, digest, p.server_finished_key); - } - }; - - // This message's stated length is in the handshake header, which `expectFragment` skips - // over. Cheat and rip it out of the view. - const actual = stream.view; - try reader.skipBytes(stream.view.len, .{}); + const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); - if (!mem.eql(u8, expected, actual)) return stream.writeError(.decode_error); + var actual: [expected.len]u8 = undefined; + try reader.readNoEof(&actual); + if (!mem.eql(u8, expected, &actual)) return stream.writeError(.decode_error); + }, + } stream.application_cipher = tls.ApplicationCipher.init( stream.handshake_cipher.?, @@ -367,19 +416,18 @@ pub fn Client(comptime StreamType: type) type { try stream.flush(); const verify_data = switch (stream.handshake_cipher.?) { - inline - .aes_128_gcm_sha256, - .aes_256_gcm_sha384, - .chacha20_poly1305_sha256, - .aegis_256_sha512, - .aegis_128l_sha256, - => |v| brk: { - const T = @TypeOf(v); - const secret = v.client_finished_key; - const transcript_hash = stream.transcript_hash.peek(); - - break :brk &tls.hmac(T.Hmac, transcript_hash, secret); - }, + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.client_finished_key; + const transcript_hash = stream.transcript_hash.peek(); + + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); + }, else => return stream.writeError(.decrypt_error), }; stream.content_type = .handshake; @@ -397,7 +445,7 @@ pub const Options = struct { ca_bundle: ?Certificate.Bundle, /// Used to verify cerficate chain and for Server Name Indication. host: []const u8, - /// List of potential cipher suites in order of descending preference. + /// List of cipher suites to advertise in order of descending preference. cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, /// By default, reaching the end-of-stream when reading from the server will /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify @@ -434,11 +482,11 @@ pub const KeyPairs = struct { pub fn init() Self { var random_buffer: [ hello_rand_length + - session_id_length + - Kyber768.seed_length + - Secp256r1.seed_length + - Secp384r1.seed_length + - X25519.seed_length + session_id_length + + Kyber768.seed_length + + Secp256r1.seed_length + + Secp384r1.seed_length + + X25519.seed_length ]u8 = undefined; while (true) { @@ -448,7 +496,7 @@ pub const KeyPairs = struct { const split2 = split1 + session_id_length; const split3 = split2 + Kyber768.seed_length; const split4 = split3 + Secp256r1.seed_length; - const split5 = split3 + Secp384r1.seed_length; + const split5 = split4 + Secp384r1.seed_length; return initAdvanced( random_buffer[0..split1].*, @@ -509,3 +557,15 @@ fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { else => @compileError("bad scheme"), }; } + +/// A single `send` or `recv`. Allows for testing `advance`. +pub const Command = union(enum) { + send_hello: KeyPairs, + recv_hello: KeyPairs, + recv_encrypted_extensions: void, + recv_certificate_or_finished: void, + recv_certificate_verify: Certificate.Parsed, + recv_finished: void, + send_finished: void, + sent_finished: void, +}; diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index c0561363b32c..b4aa32a5e8bc 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -53,7 +53,7 @@ pub fn Server(comptime StreamType: type) type { var stream = &self.stream; var reader = stream.reader(); - try stream.expectFragment(.handshake, .client_hello); + try stream.expectInnerPlaintext(.handshake, .client_hello); _ = try stream.read(tls.Version); var client_random: [32]u8 = undefined; @@ -130,7 +130,7 @@ pub fn Server(comptime StreamType: type) type { } } - if (tls_version == null) return stream.writeError(.protocol_version); + if (tls_version == null) return stream.writeError(.protocol_version); if (key_share == null) return stream.writeError(.missing_extension); if (ec_point_format == null) return stream.writeError(.missing_extension); @@ -186,7 +186,7 @@ pub fn Server(comptime StreamType: type) type { const shared_point = tls.NamedGroupT(.x25519).scalarmult( key_pair.pair.x25519.secret_key, ks, - ) catch return stream.writeError(.decrypt_error); + ) catch return stream.writeError(.decrypt_error); break :brk &shared_point; }, .secp256r1 => |ks| brk: { @@ -196,7 +196,7 @@ pub fn Server(comptime StreamType: type) type { ) catch return stream.writeError(.decrypt_error); break :brk &mul.affineCoordinates().x.toBytes(.big); }, - else => return stream.writeError(.illegal_parameter), + else => return stream.writeError(.illegal_parameter), }; const hello_hash = stream.transcript_hash.peek(); @@ -213,18 +213,17 @@ pub fn Server(comptime StreamType: type) type { pub fn send_finished(self: *Self) !void { var stream = &self.stream; const verify_data = switch (stream.handshake_cipher.?) { - inline - .aes_256_gcm_sha384, - => |v| brk: { - const T = @TypeOf(v); - const secret = v.server_finished_key; - const transcript_hash = stream.transcript_hash.peek(); - - break :brk tls.hmac(T.Hmac, transcript_hash, secret); - }, - else => return stream.writeError(.illegal_parameter), + inline .aes_256_gcm_sha384, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.server_finished_key; + const transcript_hash = stream.transcript_hash.peek(); + + break :brk tls.hmac(T.Hmac, transcript_hash, secret); + }, + else => return stream.writeError(.illegal_parameter), }; - _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); + _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); try stream.flush(); stream.application_cipher = tls.ApplicationCipher.init( @@ -244,10 +243,10 @@ pub fn Server(comptime StreamType: type) type { const P = @TypeOf(p); const digest = stream.transcript_hash.peek(); break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); - } + }, }; - try stream.expectFragment(.handshake, .finished); + try stream.expectInnerPlaintext(.handshake, .finished); const actual = stream.view; try reader.skipBytes(stream.view.len, .{}); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 31a15a0eb5b7..e1d80cfa3f38 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -222,12 +222,11 @@ pub const Connection = struct { pub const Protocol = enum { plain, tls }; pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return conn.tls_client.readv(buffers) catch |err| { + return conn.tls_client.stream.readv(buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; + if (mem.startsWith(u8, @errorName(err), "Tls")) return error.TlsFailure; switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, @@ -306,7 +305,6 @@ pub const Connection = struct { pub const ReadError = error{ TlsFailure, - TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure, @@ -320,7 +318,11 @@ pub const Connection = struct { } pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(buffer) catch |err| switch (err) { + conn.tls_client.stream.writer().writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; + conn.tls_client.stream.flush() catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -388,7 +390,7 @@ pub const Connection = struct { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd("", true) catch {}; + conn.tls_client.stream.close(); allocator.destroy(conn.tls_client); } @@ -1383,6 +1385,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. .allow_truncation_attacks = true, + .allocator = client.allocator, }) catch return error.TlsInitializationFailed; } From 94fdb397680ecc5b1639de2cb1b00920369fef83 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Wed, 13 Mar 2024 21:15:14 -0400 Subject: [PATCH 07/17] bugfixes for two real world servers --- TODO | 9 +- lib/std/crypto/tls.zig | 294 +++++++++++++++++----------------- lib/std/crypto/tls/Client.zig | 178 ++++++++++++++------ lib/std/crypto/tls/Server.zig | 39 ++--- lib/std/http/Client.zig | 10 +- 5 files changed, 303 insertions(+), 227 deletions(-) diff --git a/TODO b/TODO index 4474c5eb46b7..36a0d2fbb2a8 100644 --- a/TODO +++ b/TODO @@ -9,8 +9,9 @@ 9. KeyShare kyber read 10. StreamInterface `readv` instead of `readAll` -1. benchmark -2. store multiple fragments in buffer for less syscalls -3. streaming encode + decode -4. store handshake_cipher somewhere temporary +1. test top 100 sites +2. benchmark +3. store multiple fragments in buffer for less syscalls +4. streaming encode + decode +5. store handshake_cipher somewhere temporary diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 34deb578fa6d..1509b237453e 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -116,7 +116,7 @@ pub const Handshake = union(HandshakeType) { certificate_status: void, /// Deprecated. supplemental_data: void, - key_update: void, + key_update: KeyUpdate, message_hash: void, // If `HandshakeCipherT.encode` accepts iovecs for the message this can be moved @@ -145,6 +145,11 @@ pub const Handshake = union(HandshakeType) { res += try stream.write(u24, @intCast(len)); res += try stream.write(T, value); }, + .Enum => |info| { + len += @bitSizeOf(info.tag_type) / 8; + res += try stream.write(u24, @intCast(len)); + res += try stream.write(T, value); + }, else => |t| @compileError("implement writing " ++ @tagName(t)), } }, @@ -158,6 +163,12 @@ pub const Handshake = union(HandshakeType) { }; }; +pub const KeyUpdate = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + pub const Certificate = struct { context: []const u8 = "", entries: []const Entry, @@ -1080,14 +1091,6 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { return struct { stream: *StreamType, - /// > For concreteness, the transcript hash is always taken from the - /// > following sequence of handshake messages, starting at the first - /// > ClientHello and including only those messages that were sent: - /// > ClientHello, HelloRetryRequest, ClientHello, ServerHello, - /// > EncryptedExtensions, server CertificateRequest, server Certificate, - /// > server CertificateVerify, server Finished, EndOfEarlyData, client - /// > Certificate, client CertificateVerify, client Finished. - transcript_hash: MultiHash = .{}, /// Used for both reading and writing. Cannot be doing both at the same time. /// Stores plaintext or ciphertext, but not Plaintext headers. buffer: [fragment_size]u8 = undefined, @@ -1102,10 +1105,9 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// When receiving a handshake message will be expected with this type. handshake_type: ?HandshakeType = .client_hello, - /// Used to encrypt and decrypt .application_data messages until application_cipher is not null. - handshake_cipher: ?HandshakeCipher = null, - /// Used to encrypt and decrypt .application_data messages. - application_cipher: ?ApplicationCipher = null, + /// Used to decrypt .application_data messages. + /// Used to encrypt messages that aren't alert or change_cipher_spec. + cipher: Cipher = .none, /// True when we send or receive a close_notify alert. closed: bool = false, @@ -1118,27 +1120,29 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// When > 0 won't actually do anything with writes. Used to discover prefix lengths. nocommit: usize = 0, + /// Client and server implementations can set this. While set `readPlaintext` and `flush` + /// handshake messages will update the hash. + transcript_hash: ?*MultiHash, + const Self = @This(); - pub const ReadError = StreamType.ReadError || Error || error{EndOfStream}; - pub const WriteError = StreamType.WriteError || error{ - TlsEncodeError, + const Cipher = union(enum) { + none: void, + application: ApplicationCipher, + handshake: HandshakeCipher, }; + pub const ReadError = StreamType.ReadError || Error || error{EndOfStream}; + pub const WriteError = StreamType.WriteError || error{TlsEncodeError}; + fn ciphertextOverhead(self: Self) usize { - if (self.application_cipher) |a| { - switch (a) { - .empty_renegotiation_info_scsv => {}, - inline else => |c| return @TypeOf(c).AEAD.tag_length + @sizeOf(ContentType), - } - } - if (self.handshake_cipher) |a| { - switch (a) { - .empty_renegotiation_info_scsv => {}, - inline else => |c| return @TypeOf(c).AEAD.tag_length + @sizeOf(ContentType), - } - } - return 0; + return switch (self.cipher) { + inline .application, .handshake => |c| switch (c) { + .empty_renegotiation_info_scsv => 0, + inline else => |t| @TypeOf(t).AEAD.tag_length + @sizeOf(ContentType), + }, + else => 0, + }; } fn maxFragmentSize(self: Self) usize { @@ -1146,48 +1150,38 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { } const EncryptionMethod = enum { none, handshake, application }; - fn encryptionMethod(self: Self) EncryptionMethod { - switch (self.content_type) { - .change_cipher_spec => {}, - .handshake => { - if (self.handshake_cipher != null) return .handshake; - }, + fn encryptionMethod(self: Self, content_type: ContentType) EncryptionMethod { + switch (content_type) { + .alert, .change_cipher_spec => {}, else => { - if (self.application_cipher != null) return .application; + if (self.cipher == .application) return .application; + if (self.cipher == .handshake) return .handshake; }, } - - return .none; + return .none; } pub fn flush(self: *Self) WriteError!void { if (self.view.len == 0) return; + if (self.transcript_hash) |t| { + if (self.content_type == .handshake) t.update(self.view); + } + var plaintext = Plaintext{ .type = self.content_type, .version = self.version, .len = @intCast(self.view.len), }; - if (self.application_cipher == null) { - switch (self.content_type) { - .change_cipher_spec, .alert => {}, - else => { - self.transcript_hash.update(self.view); - }, - } - } - - var header: [Plaintext.size]u8 = undefined; + var header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); var aead: []const u8 = ""; - switch (self.encryptionMethod()) { - .none => { - header = Encoder.encode(Plaintext, plaintext); - }, - .handshake => { + switch (self.cipher) { + .none => {}, + inline .application, .handshake => |*cipher| { plaintext.type = .application_data; plaintext.len += @intCast(self.ciphertextOverhead()); header = Encoder.encode(Plaintext, plaintext); - if (self.handshake_cipher) |*a| switch (a.*) { + switch (cipher.*) { .empty_renegotiation_info_scsv => {}, inline else => |*c| { std.debug.assert(self.view.ptr == &self.buffer); @@ -1195,22 +1189,8 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { self.view = self.buffer[0 .. self.view.len + 1]; aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); }, - }; - }, - .application => { - plaintext.type = .application_data; - plaintext.len += @intCast(self.ciphertextOverhead()); - header = Encoder.encode(Plaintext, plaintext); - if (self.application_cipher) |*a| switch (a.*) { - .empty_renegotiation_info_scsv => {}, - inline else => |*c| { - std.debug.assert(self.view.ptr == &self.buffer); - self.buffer[self.view.len] = @intFromEnum(self.content_type); - self.view = self.buffer[0 .. self.view.len + 1]; - aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); - }, - }; - }, + } + } } var iovecs = [_]std.os.iovec_const{ @@ -1222,6 +1202,24 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { self.view = self.buffer[0..0]; } + /// Flush a change cipher spec message to the underlying stream. + pub fn changeCipherSpec(self: *Self) WriteError!void { + self.version = .tls_1_2; + + const plaintext = Plaintext{ + .type = .change_cipher_spec, + .version = self.version, + .len = 1, + }; + const msg = [_]u8{1}; + const header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); + var iovecs = [_]std.os.iovec_const{ + .{ .iov_base = &header, .iov_len = header.len }, + .{ .iov_base = &msg, .iov_len = msg.len }, + }; + try self.stream.writevAll(&iovecs); + } + /// Write an alert to stream and call `close_notify` after. Returns Zig error. pub fn writeError(self: *Self, err: Alert.Description) Error { const alert = Alert{ .level = .fatal, .description = err }; @@ -1348,7 +1346,7 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// A return value of 0 indicates EOF. pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; - return try self.readv(&buffers); + return try self.readv(&buffers); } /// Reads plaintext from `stream` into `buffer` and updates `view`. @@ -1356,43 +1354,50 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// Will decrypt according to `encryptionMethod` if receiving application_data message. pub fn readPlaintext(self: *Self) ReadError!Plaintext { std.debug.assert(self.view.len == 0); // last read should have completed - var plaintext_header_bytes: [Plaintext.size]u8 = undefined; + var plaintext_bytes: [Plaintext.size]u8 = undefined; var n_read: usize = 0; while (true) { - n_read = try self.stream.readAll(&plaintext_header_bytes); - if (n_read != plaintext_header_bytes.len) return self.writeError(.decode_error); + n_read = try self.stream.readAll(&plaintext_bytes); + if (n_read != plaintext_bytes.len) return self.writeError(.decode_error); - var res = Plaintext.init(plaintext_header_bytes); + var res = Plaintext.init(plaintext_bytes); if (res.len > Plaintext.max_length) return self.writeError(.record_overflow); self.view = self.buffer[0..res.len]; n_read = try self.stream.readAll(@constCast(self.view)); if (n_read != res.len) return self.writeError(.decode_error); - const encryption_method = if (res.type == .application_data) self.encryptionMethod() else .none; + const encryption_method = self.encryptionMethod(res.type); switch (encryption_method) { .none => {}, - inline .handshake, .application => |t| { - switch (if (comptime t == .handshake) self.handshake_cipher.? else self.application_cipher.?) { - .empty_renegotiation_info_scsv => {}, - inline else => |*p| { - const P = @TypeOf(p.*); - const tag_len = P.AEAD.tag_length; - - const ciphertext = self.view[0 .. self.view.len - tag_len]; - const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; - const out: []u8 = @constCast(self.view[0..ciphertext.len]); - p.decrypt(ciphertext, &plaintext_header_bytes, tag, self.is_client, out) catch - return self.writeError(.bad_record_mac); - const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); - if (padding_start) |s| { - res.type = @enumFromInt(self.view[s]); - self.view = self.view[0..s]; - } else { - return self.writeError(.decode_error); + .handshake, .application => { + if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); + + switch (self.cipher) { + inline .application, .handshake => |*c| { + switch (c.*) { + .empty_renegotiation_info_scsv => {}, + inline else => |*p| { + const P = @TypeOf(p.*); + const tag_len = P.AEAD.tag_length; + + const ciphertext = self.view[0 .. self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; + const out: []u8 = @constCast(self.view[0..ciphertext.len]); + p.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch + return self.writeError(.bad_record_mac); + const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); + if (padding_start) |s| { + res.type = @enumFromInt(self.view[s]); + self.view = self.view[0..s]; + } else { + return self.writeError(.decode_error); + } + }, } }, + else => unreachable, } }, } @@ -1441,10 +1446,10 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { } if (res.type == .handshake) { - self.transcript_hash.update(self.view[0..@sizeOf(HandshakeType) + @bitSizeOf(u24) / 8]); + if (self.transcript_hash) |t| t.update(self.view[0..4]); res.handshake_type = try self.read(HandshakeType); res.len = try self.read(u24); - self.transcript_hash.update(self.view[0..res.len]); + if (self.transcript_hash) |t| t.update(self.view[0..res.len]); self.handshake_type = res.handshake_type; } @@ -1459,7 +1464,7 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { ) ReadError!void { const inner_plaintext = try self.readInnerPlaintext(); if (expected_content != inner_plaintext.type) { - std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext.type }); + std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext }); return self.writeError(.unexpected_message); } if (expected_handshake) |expected| { @@ -1572,7 +1577,7 @@ pub const MultiHash = struct { sha384: sha2.Sha384 = sha2.Sha384.init(.{}), sha512: sha2.Sha512 = sha2.Sha512.init(.{}), /// Chosen during handshake. - active: enum { all, sha256, sha384, sha512 } = .all, + active: enum { all, sha256, sha384, sha512, none } = .all, const sha2 = crypto.hash.sha2; pub const max_digest_len = sha2.Sha512.digest_length; @@ -1588,6 +1593,7 @@ pub const MultiHash = struct { .sha256 => self.sha256.update(bytes), .sha384 => self.sha384.update(bytes), .sha512 => self.sha512.update(bytes), + .none => {}, } } @@ -1603,7 +1609,7 @@ pub const MultiHash = struct { pub inline fn peek(self: Self) []const u8 { return &switch (self.active) { - .all => [_]u8{}, + .all, .none => [_]u8{}, .sha256 => self.sha256.peek(), .sha384 => self.sha384.peek(), .sha512 => self.sha512.peek(), @@ -1952,19 +1958,23 @@ test "tls client and server handshake, data, and close_notify" { defer inner_stream.deinit(allocator); const host = "example.ulfheim.net"; + var client_transcript: MultiHash = .{}; var client = Client(@TypeOf(inner_stream)){ .stream = Stream(Plaintext.max_length, TestStream){ .stream = &inner_stream, .is_client = true, + .transcript_hash = &client_transcript, }, .options = .{ .host = host, .ca_bundle = null, .allocator = allocator }, }; const server_der = @embedFile("./testdata/server.der"); + var server_transcript: MultiHash = .{}; var server = Server(@TypeOf(inner_stream)){ .stream = Stream(Plaintext.max_length, TestStream){ .stream = &inner_stream, .is_client = false, + .transcript_hash = &server_transcript, }, .options = .{ // force this to use https://tls13.xargs.org/ as unit test for "server hello" onwards @@ -2066,7 +2076,7 @@ test "tls client and server handshake, data, and close_notify" { _ = try client.stream.write(Handshake, .{ .client_hello = hello }); try client.stream.flush(); - break :brk client_mod.Command{ .recv_hello = key_pairs }; + break :brk client_mod.State{ .recv_hello = key_pairs }; }; try inner_stream.expect([_]u8{ @@ -2805,63 +2815,36 @@ test "tls client and server handshake, data, and close_notify" { 0x4a, 0xec, 0xf2, 0x8c, 0xf3, 0x18, 0x2f, 0xd0, // auth tag })); - client_command = try client.advance(client_command); + client_command = try client.advance(client_command); // recv_hello try std.testing.expect(client_command == .recv_encrypted_extensions); - { - const s = server.stream.handshake_cipher.?.aes_256_gcm_sha384; - const c = client.stream.handshake_cipher.?.aes_256_gcm_sha384; - - try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); - try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); - try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); - try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); - try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); - try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); - try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - const client_iv = [_]u8{ 0x42, 0x56, 0xd2, 0xe0, 0xe8, 0x8b, 0xab, 0xdd, 0x05, 0xeb, 0x2f, 0x27 }; - try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); - } - client_command = try client.advance(client_command); + // { + // const s = server.stream.cipher.handshake.aes_256_gcm_sha384; + // const c = client.stream.cipher.handshake.aes_256_gcm_sha384; + + // try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); + // try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); + // try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); + // try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); + // try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + // try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + // try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + // try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + // const client_iv = [_]u8{ 0x42, 0x56, 0xd2, 0xe0, 0xe8, 0x8b, 0xab, 0xdd, 0x05, 0xeb, 0x2f, 0x27 }; + // try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + // } + client_command = try client.advance(client_command); // recv_encrypted_extensions try std.testing.expect(client_command == .recv_certificate_or_finished); - client_command = try client.advance(client_command); + client_command = try client.advance(client_command); // recv_certificate_or_finished (certificate) try std.testing.expect(client_command == .recv_certificate_verify); - client_command = try client.advance(client_command); + client_command = try client.advance(client_command); // recv_certificate_verify try std.testing.expect(client_command == .recv_finished); - client_command = try client.advance(client_command); + client_command = try client.advance(client_command); // recv_finished try std.testing.expect(client_command == .send_finished); - { - const s = server.stream.application_cipher.?.aes_256_gcm_sha384; - const c = client.stream.application_cipher.?.aes_256_gcm_sha384; - - try std.testing.expectEqualSlices(u8, &s.client_secret, &c.client_secret); - try std.testing.expectEqualSlices(u8, &s.server_secret, &c.server_secret); - try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); - try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); - try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - const client_iv = [_]u8{ - 0xbb, - 0x00, - 0x79, - 0x56, - 0xf4, - 0x74, - 0xb2, - 0x5d, - 0xe9, - 0x02, - 0x43, - 0x2f, - }; - try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); - } - - client_command = try client.advance(client_command); + client_command = try client.advance(client_command); // send_finished try std.testing.expect(client_command == .sent_finished); try inner_stream.expect(&([_]u8{ @@ -2904,6 +2887,21 @@ test "tls client and server handshake, data, and close_notify" { })); try server.recv_finished(); + { + const s = server.stream.cipher.application.aes_256_gcm_sha384; + const c = client.stream.cipher.application.aes_256_gcm_sha384; + + try std.testing.expectEqualSlices(u8, &s.client_secret, &c.client_secret); + try std.testing.expectEqualSlices(u8, &s.server_secret, &c.server_secret); + try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + const client_iv = [_]u8{ 0xbb, 0x00, 0x79, 0x56, 0xf4, 0x74, 0xb2, 0x5d, 0xe9, 0x02, 0x43, 0x2f }; + try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + } + + _ = try client.stream.writer().writeAll("ping"); try client.stream.flush(); try inner_stream.expect(&([_]u8{ diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 05d5cf2f7f9f..82d48d42e29a 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -12,25 +12,27 @@ pub fn Client(comptime StreamType: type) type { stream: Stream, options: Options, - const Stream = tls.Stream(tls.Plaintext.max_length, StreamType); + pub const Stream = tls.Stream(tls.Plaintext.max_length, StreamType); const Self = @This(); /// Initiates a TLS handshake and establishes a TLSv1.3 session pub fn init(stream: *StreamType, options: Options) !Self { + var transcript_hash: tls.MultiHash = .{}; const stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ .stream = stream, .is_client = true, + .transcript_hash = &transcript_hash, }; var res = Self{ .stream = stream_, .options = options }; - var state = Command{ .send_hello = KeyPairs.init() }; + var state = State{ .send_hello = KeyPairs.init() }; while (state != .sent_finished) state = try res.advance(state); return res; } /// Execute command and return next one. - pub fn advance(self: *Self, command: Command) !Command { + pub fn advance(self: *Self, command: State) !State { var stream = &self.stream; switch (command) { .send_hello => |key_pairs| { @@ -51,7 +53,7 @@ pub fn Client(comptime StreamType: type) type { return .{ .recv_certificate_or_finished = {} }; }, .recv_certificate_or_finished => { - const digest = stream.transcript_hash.peek(); + const digest = stream.transcript_hash.?.peek(); const inner_plaintext = try stream.readInnerPlaintext(); if (inner_plaintext.type != .handshake) return stream.writeError(.unexpected_message); switch (inner_plaintext.handshake_type) { @@ -74,14 +76,14 @@ pub fn Client(comptime StreamType: type) type { .recv_certificate_verify => |parsed| { defer self.options.allocator.free(parsed.certificate.buffer); - const digest = stream.transcript_hash.peek(); + const digest = stream.transcript_hash.?.peek(); try stream.expectInnerPlaintext(.handshake, .certificate_verify); try self.recv_certificate_verify(digest, parsed); return .{ .recv_finished = {} }; }, .recv_finished => { - const digest = stream.transcript_hash.peek(); + const digest = stream.transcript_hash.?.peek(); try stream.expectInnerPlaintext(.handshake, .finished); try self.recv_finished(digest); @@ -124,12 +126,12 @@ pub fn Client(comptime StreamType: type) type { pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { var stream = &self.stream; - var reader = stream.reader(); + var r = stream.reader(); // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. _ = try stream.read(tls.Version); var random: [32]u8 = undefined; - try reader.readNoEof(&random); + try r.readNoEof(&random); if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { // We already offered all our supported options and we aren't changing them. return stream.writeError(.unexpected_message); @@ -140,7 +142,7 @@ pub fn Client(comptime StreamType: type) type { if (session_id_len > tls.ClientHello.session_id_max_len) return stream.writeError(.illegal_parameter); const session_id: []u8 = session_id_buf[0..session_id_len]; - try reader.readNoEof(session_id); + try r.readNoEof(session_id); if (!mem.eql(u8, session_id, &key_pairs.session_id)) return stream.writeError(.illegal_parameter); @@ -169,7 +171,7 @@ pub fn Client(comptime StreamType: type) type { const expected_len = x25519_len + T.Kyber768.ciphertext_length; if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; - try reader.readNoEof(&server_ks); + try r.readNoEof(&server_ks); const mult = T.X25519.scalarmult( key_pairs.x25519.secret_key, @@ -185,7 +187,7 @@ pub fn Client(comptime StreamType: type) type { const expected_len = T.public_length; if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; - try reader.readNoEof(&server_ks); + try r.readNoEof(&server_ks); const mult = crypto.dh.X25519.scalarmult( key_pairs.x25519.secret_key, @@ -199,7 +201,7 @@ pub fn Client(comptime StreamType: type) type { if (key_size != expected_len) return stream.writeError(.illegal_parameter); var server_ks: [expected_len]u8 = undefined; - try reader.readNoEof(&server_ks); + try r.readNoEof(&server_ks); const pk = T.PublicKey.fromSec1(&server_ks) catch return stream.writeError(.illegal_parameter); @@ -211,12 +213,12 @@ pub fn Client(comptime StreamType: type) type { // Server sent us back unknown key. That's weird because we only request known ones, // but we can try for another. else => { - try reader.skipBytes(key_size, .{}); + try r.skipBytes(key_size, .{}); }, } }, else => { - try reader.skipBytes(ext.len, .{}); + try r.skipBytes(ext.len, .{}); }, } } @@ -224,26 +226,20 @@ pub fn Client(comptime StreamType: type) type { if (supported_version != tls.Version.tls_1_3) return stream.writeError(.protocol_version); if (shared_key == null) return stream.writeError(.missing_extension); - stream.transcript_hash.setActive(cipher_suite); - const hello_hash = stream.transcript_hash.peek(); - stream.handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash) catch return stream.writeError(.illegal_parameter); + stream.transcript_hash.?.setActive(cipher_suite); + const hello_hash = stream.transcript_hash.?.peek(); + + const handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash,) catch return stream.writeError(.illegal_parameter); + stream.cipher = .{ .handshake = handshake_cipher }; } pub fn recv_encrypted_extensions(self: *Self) !void { var stream = &self.stream; - var reader = stream.reader(); + var r = stream.reader(); var iter = try stream.extensions(); while (try iter.next()) |ext| { - switch (ext.type) { - .server_name => { - try reader.skipBytes(ext.len, .{}); - }, - else => |t| { - std.debug.print("unsupported extension {}\n", .{t}); - return stream.writeError(.unsupported_extension); - }, - } + try r.skipBytes(ext.len, .{}); } } @@ -252,7 +248,7 @@ pub fn Client(comptime StreamType: type) type { /// Caller owns allocated Certificate.Parsed.certificate. pub fn recv_certificate(self: *Self) !Certificate.Parsed { var stream = &self.stream; - var reader = stream.reader(); + var r = stream.reader(); const allocator = self.options.allocator; const ca_bundle = self.options.ca_bundle; const verify = ca_bundle != null; @@ -260,7 +256,7 @@ pub fn Client(comptime StreamType: type) type { var context: [tls.Certificate.max_context_len]u8 = undefined; const context_len = try stream.read(u8); if (context_len > tls.Certificate.max_context_len) return stream.writeError(.decode_error); - try reader.readNoEof(context[0..context_len]); + try r.readNoEof(context[0..context_len]); var first: ?crypto.Certificate.Parsed = null; var prev: Certificate.Parsed = undefined; @@ -271,14 +267,16 @@ pub fn Client(comptime StreamType: type) type { while (try certs_iter.next()) |cert_len| { const is_first = first == null; - if (!verified) { + if (verified) { + try r.skipBytes(cert_len, .{}); + } else { if (cert_len > tls.Certificate.Entry.max_data_len) return stream.writeError(.decode_error); const buf = allocator.alloc(u8, cert_len) catch return stream.writeError(.internal_error); defer if (!is_first) allocator.free(buf); errdefer allocator.free(buf); - try reader.readNoEof(buf); + try r.readNoEof(buf); const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; const cur = cert.parse() catch return stream.writeError(.bad_certificate); @@ -303,16 +301,16 @@ pub fn Client(comptime StreamType: type) type { } var ext_iter = try stream.extensions(); - while (try ext_iter.next()) |ext| try reader.skipBytes(ext.len, .{}); + while (try ext_iter.next()) |ext| try r.skipBytes(ext.len, .{}); } if (verify and !verified) return stream.writeError(.bad_certificate); - return if (first) |r| r else stream.writeError(.bad_certificate); + return if (first) |f| f else stream.writeError(.bad_certificate); } pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificate.Parsed) !void { var stream = &self.stream; - var reader = stream.reader(); + var r = stream.reader(); const allocator = self.options.allocator; const sig_content = tls.sigContent(digest); @@ -324,7 +322,7 @@ pub fn Client(comptime StreamType: type) type { const sig_bytes = allocator.alloc(u8, len) catch return stream.writeError(.internal_error); defer allocator.free(sig_bytes); - try reader.readNoEof(sig_bytes); + try r.readNoEof(sig_bytes); switch (scheme) { inline .ecdsa_secp256r1_sha256, @@ -386,8 +384,8 @@ pub fn Client(comptime StreamType: type) type { pub fn recv_finished(self: *Self, digest: []const u8) !void { var stream = &self.stream; - var reader = stream.reader(); - const cipher = stream.handshake_cipher.?; + var r = stream.reader(); + const cipher = stream.cipher.handshake; switch (cipher) { .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), @@ -396,26 +394,20 @@ pub fn Client(comptime StreamType: type) type { const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); var actual: [expected.len]u8 = undefined; - try reader.readNoEof(&actual); + try r.readNoEof(&actual); if (!mem.eql(u8, expected, &actual)) return stream.writeError(.decode_error); }, } - - stream.application_cipher = tls.ApplicationCipher.init( - stream.handshake_cipher.?, - stream.transcript_hash.peek(), - ); } pub fn send_finished(self: *Self) !void { var stream = &self.stream; - stream.version = .tls_1_2; - stream.content_type = .change_cipher_spec; - _ = try stream.write(tls.ChangeCipherSpec, .change_cipher_spec); - try stream.flush(); + const handshake_hash = stream.transcript_hash.?.peek(); - const verify_data = switch (stream.handshake_cipher.?) { + try stream.changeCipherSpec(); + + const verify_data = switch (stream.cipher.handshake) { inline .aes_128_gcm_sha256, .aes_256_gcm_sha384, .chacha20_poly1305_sha256, @@ -424,7 +416,7 @@ pub fn Client(comptime StreamType: type) type { => |v| brk: { const T = @TypeOf(v); const secret = v.client_finished_key; - const transcript_hash = stream.transcript_hash.peek(); + const transcript_hash = stream.transcript_hash.?.peek(); break :brk &tls.hmac(T.Hmac, transcript_hash, secret); }, @@ -434,7 +426,93 @@ pub fn Client(comptime StreamType: type) type { _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); try stream.flush(); + const application_cipher = tls.ApplicationCipher.init(stream.cipher.handshake, handshake_hash); + stream.cipher = .{ .application = application_cipher }; stream.content_type = .application_data; + stream.transcript_hash = null; + } + + pub const ReadError = Stream.ReadError; + pub const WriteError = Stream.WriteError; + + /// Reads next application_data message. + pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { + var stream = &self.stream; + + if (stream.eof()) return 0; + + while (stream.view.len == 0) { + const inner_plaintext = try stream.readInnerPlaintext(); + switch (inner_plaintext.type) { + .handshake => { + switch (inner_plaintext.handshake_type) { + // A multithreaded client could use these. + .new_session_ticket => { + try stream.reader().skipBytes(inner_plaintext.len, .{}); + }, + .key_update => { + switch (stream.cipher.application) { + .empty_renegotiation_info_scsv => {}, + inline else => |*p| { + const P = @TypeOf(p.*); + const server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); + p.server_secret = server_secret; + p.server_key = tls.hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.server_iv = tls.hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + p.read_seq = 0; + }, + } + const update = try stream.read(tls.KeyUpdate); + if (update == .update_requested) { + switch (stream.cipher.application) { + .empty_renegotiation_info_scsv => {}, + inline else => |*p| { + const P = @TypeOf(p.*); + const client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); + p.client_secret = client_secret; + p.client_key = tls.hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.client_iv = tls.hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.write_seq = 0; + }, + } + } + }, + else => return stream.writeError(.unexpected_message), + } + }, + .application_data => {}, + else => return stream.writeError(.unexpected_message), + } + } + return try self.stream.readv(buffers); + } + + pub fn read(self: *Self, buf: []u8) ReadError!usize { + const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; + return try self.readv(&buffers); + } + + pub fn write(self: *Self, buf: []const u8) WriteError!usize { + if (self.stream.eof()) return 0; + + const res = try self.stream.writeBytes(buf); + try self.stream.flush(); + return res; + } + + pub fn close(self: *Self) void { + self.stream.close(); + } + + pub const Reader = std.io.Reader(*Self, ReadError, read); + pub const Writer = std.io.Writer(*Self, WriteError, write); + + pub fn reader(self: *Self) Reader { + return .{ .context = self }; + } + + pub fn writer(self: *Self) Writer { + return .{ .context = self }; } }; } @@ -559,7 +637,7 @@ fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { } /// A single `send` or `recv`. Allows for testing `advance`. -pub const Command = union(enum) { +pub const State = union(enum) { send_hello: KeyPairs, recv_hello: KeyPairs, recv_encrypted_extensions: void, diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index b4aa32a5e8bc..dba2a7eb8408 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -21,9 +21,11 @@ pub fn Server(comptime StreamType: type) type { /// Initiates a TLS handshake and establishes a TLSv1.3 session pub fn init(stream: *StreamType, options: Options) !Self { + var transcript_hash: tls.MultiHash = .{}; var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ .stream = stream, .is_client = false, + .transcript_hash = &transcript_hash, }; var res = Self{ .stream = stream_, .options = options }; const client_hello = try res.recv_hello(&stream_); @@ -76,7 +78,7 @@ pub fn Server(comptime StreamType: type) type { if (res == null) return stream.writeError(.illegal_parameter); break :brk res.?; }; - stream.transcript_hash.setActive(cipher_suite); + stream.transcript_hash.?.setActive(cipher_suite); { var compression_methods: [2]u8 = undefined; @@ -162,11 +164,7 @@ pub fn Server(comptime StreamType: type) type { try stream.flush(); // > if the client sends a non-empty session ID, the server MUST send the change_cipher_spec - if (hello.session_id.len > 0) { - stream.content_type = .change_cipher_spec; - _ = try stream.write(tls.ChangeCipherSpec, .change_cipher_spec); - try stream.flush(); - } + if (hello.session_id.len > 0) try stream.changeCipherSpec(); const shared_key = switch (client_hello.key_share) { .x25519_kyber768d00 => |ks| brk: { @@ -199,8 +197,10 @@ pub fn Server(comptime StreamType: type) type { else => return stream.writeError(.illegal_parameter), }; - const hello_hash = stream.transcript_hash.peek(); - stream.handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash) catch return stream.writeError(.illegal_parameter); + const hello_hash = stream.transcript_hash.?.peek(); + const handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash,) catch + return stream.writeError(.illegal_parameter); + stream.cipher = .{ .handshake = handshake_cipher }; stream.content_type = .handshake; _ = try stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); @@ -212,12 +212,12 @@ pub fn Server(comptime StreamType: type) type { pub fn send_finished(self: *Self) !void { var stream = &self.stream; - const verify_data = switch (stream.handshake_cipher.?) { + const verify_data = switch (stream.cipher.handshake) { inline .aes_256_gcm_sha384, => |v| brk: { const T = @TypeOf(v); const secret = v.server_finished_key; - const transcript_hash = stream.transcript_hash.peek(); + const transcript_hash = stream.transcript_hash.?.peek(); break :brk tls.hmac(T.Hmac, transcript_hash, secret); }, @@ -225,23 +225,24 @@ pub fn Server(comptime StreamType: type) type { }; _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); try stream.flush(); - - stream.application_cipher = tls.ApplicationCipher.init( - stream.handshake_cipher.?, - stream.transcript_hash.peek(), - ); } pub fn recv_finished(self: *Self) !void { var stream = &self.stream; var reader = stream.reader(); - const cipher = stream.handshake_cipher.?; - const expected = switch (cipher) { + const handshake_hash = stream.transcript_hash.?.peek(); + + const application_cipher = tls.ApplicationCipher.init( + stream.cipher.handshake, + handshake_hash, + ); + + const expected = switch (stream.cipher.handshake) { .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), inline else => |p| brk: { const P = @TypeOf(p); - const digest = stream.transcript_hash.peek(); + const digest = stream.transcript_hash.?.peek(); break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); }, }; @@ -254,6 +255,8 @@ pub fn Server(comptime StreamType: type) type { stream.content_type = .application_data; stream.handshake_type = null; + stream.cipher = .{ .application = application_cipher }; + stream.transcript_hash = null; } }; } diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index e1d80cfa3f38..451832b16ed7 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -222,7 +222,7 @@ pub const Connection = struct { pub const Protocol = enum { plain, tls }; pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return conn.tls_client.stream.readv(buffers) catch |err| { + return conn.tls_client.readv(buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "Tls")) return error.TlsFailure; @@ -318,11 +318,7 @@ pub const Connection = struct { } pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - conn.tls_client.stream.writer().writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - conn.tls_client.stream.flush() catch |err| switch (err) { + conn.tls_client.writer().writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -390,7 +386,7 @@ pub const Connection = struct { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. - conn.tls_client.stream.close(); + conn.tls_client.close(); allocator.destroy(conn.tls_client); } From f2329784c5a8de13ec19a5660bdd16850755bc40 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Thu, 14 Mar 2024 21:11:53 -0400 Subject: [PATCH 08/17] start moving snapshots to files, start add server rsa key --- TODO | 17 - lib/std/crypto/Certificate.zig | 52 +- .../crypto/testdata/{server.der => cert.der} | Bin lib/std/crypto/testdata/cert.pem | 19 + lib/std/crypto/testdata/key.der | Bin 0 -> 294 bytes lib/std/crypto/testdata/key.pem | 27 + lib/std/crypto/tls.zig | 1034 ++--------------- lib/std/crypto/tls/Client.zig | 46 +- lib/std/crypto/tls/Server.zig | 147 ++- lib/std/http/Client.zig | 2 +- 10 files changed, 310 insertions(+), 1034 deletions(-) delete mode 100644 TODO rename lib/std/crypto/testdata/{server.der => cert.der} (100%) create mode 100644 lib/std/crypto/testdata/cert.pem create mode 100644 lib/std/crypto/testdata/key.der create mode 100644 lib/std/crypto/testdata/key.pem diff --git a/TODO b/TODO deleted file mode 100644 index 36a0d2fbb2a8..000000000000 --- a/TODO +++ /dev/null @@ -1,17 +0,0 @@ -[x] 1. single transcript hash type. move out of stream. -[x] 2. read backpressure, smaller stream buffer -[x] 3. Client recv_hello secp256r1 key share -[x] 4. remove @panic's -[x] 5. better errors than spammy TlsDecodeError. map new errors to TLS alerts. send alert on error. -[x] 6. client state machine union + test -[x] 7. cert formats -[x] 8. verify certs and [x] sigs -9. KeyShare kyber read -10. StreamInterface `readv` instead of `readAll` - -1. test top 100 sites -2. benchmark -3. store multiple fragments in buffer for less syscalls -4. streaming encode + decode -5. store handshake_cipher somewhere temporary - diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index a94f3cc3e4f6..853aa6bfb382 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -735,11 +735,10 @@ fn verifyRsa( pub_key: []const u8, ) !void { if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch; - const pk_components = try rsa.PublicKey.parseDer(pub_key); - const exponent = pk_components.exponent; - const modulus = pk_components.modulus; - if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; - if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; + + const public_key = rsa.PublicKey.fromDer(pub_key) catch return error.CertificateSignatureInvalid; + const modulus_len = public_key.n.bits() / 8; + if (sig.len != modulus_len) return error.CertificateSignatureInvalidLength; const hash_der = switch (Hash) { crypto.hash.Sha1 => [_]u8{ @@ -772,18 +771,17 @@ fn verifyRsa( var msg_hashed: [Hash.digest_length]u8 = undefined; Hash.hash(message, &msg_hashed, .{}); - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; - const em: [modulus_len]u8 = + switch (modulus_len) { + inline 128, 256, 512 => |mod_len| { + const ps_len = mod_len - (hash_der.len + msg_hashed.len) - 3; + const em: [mod_len]u8 = [2]u8{ 0, 1 } ++ ([1]u8{0xff} ** ps_len) ++ [1]u8{0} ++ hash_der ++ msg_hashed; - const public_key = rsa.PublicKey.fromBytes(exponent, modulus) catch return error.CertificateSignatureInvalid; - const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key) catch |err| switch (err) { + const em_dec = public_key.encrypt(mod_len, sig[0..mod_len].*) catch |err| switch (err) { error.MessageTooLong => unreachable, }; @@ -950,6 +948,7 @@ test { _ = Bundle; } +/// RFC8017 pub const rsa = struct { const max_modulus_bits = 4096; const Uint = std.crypto.ff.Uint(max_modulus_bits); @@ -965,7 +964,7 @@ pub const rsa = struct { pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type) !void { const mod_bits = public_key.n.bits(); - const em_dec = try encrypt(modulus_len, sig, public_key); + const em_dec = try public_key.encrypt(modulus_len, sig); try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } @@ -1107,7 +1106,9 @@ pub const rsa = struct { }; pub const PublicKey = struct { + /// the RSA modulus, a positive integer n: Modulus, + /// public exponent e: Fe, pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) !PublicKey { @@ -1135,7 +1136,8 @@ pub const rsa = struct { }; } - pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { + // RFC8017 Appendix A.1.1 + pub fn fromDer(pub_key: []const u8) !PublicKey { const pub_key_seq = try der.Element.parse(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); @@ -1147,20 +1149,20 @@ pub const rsa = struct { const modulus_offset = for (modulus_raw, 0..) |byte, i| { if (byte != 0) break i; } else modulus_raw.len; - return .{ - .modulus = modulus_raw[modulus_offset..], - .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end], - }; + const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end]; + const modulus = modulus_raw[modulus_offset..]; + + return try fromBytes(exponent, modulus); } - }; - fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) ![modulus_len]u8 { - const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; - const e = public_key.n.powPublic(m, public_key.e) catch unreachable; - var res: [modulus_len]u8 = undefined; - e.toBytes(&res, .big) catch unreachable; - return res; - } + fn encrypt(public_key: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; + const e = public_key.n.powPublic(m, public_key.e) catch unreachable; + var res: [modulus_len]u8 = undefined; + e.toBytes(&res, .big) catch unreachable; + return res; + } + }; }; const use_vectors = @import("builtin").zig_backend != .stage2_x86_64; diff --git a/lib/std/crypto/testdata/server.der b/lib/std/crypto/testdata/cert.der similarity index 100% rename from lib/std/crypto/testdata/server.der rename to lib/std/crypto/testdata/cert.der diff --git a/lib/std/crypto/testdata/cert.pem b/lib/std/crypto/testdata/cert.pem new file mode 100644 index 000000000000..45f0184c7b49 --- /dev/null +++ b/lib/std/crypto/testdata/cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDITCCAgmgAwIBAgIIFVqSrcIEj5AwDQYJKoZIhvcNAQELBQAwIjELMAkGA1UE +BhMCVVMxEzARBgNVBAoTCkV4YW1wbGUgQ0EwHhcNMTgxMDA1MDEzODE3WhcNMTkx +MDA1MDEzODE3WjArMQswCQYDVQQGEwJVUzEcMBoGA1UEAxMTZXhhbXBsZS51bGZo +ZWltLm5ldDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMSANga650dr +CJQE7Ke2kQQ/95K8Ge77fXTXqA0AHntLOkrmD+jAcfxz5wJMDbz0vdEdOWu6cEZK +E+lK+D3z4QlZVHvJVftBLaN2UhHh89x3bKpTN27KOuy+w6q3OzHVbLZSnICYvMng +KBjiC/f4oDr9FwRQns55vZ858epp7EeXLoMPtcqV3pWh5gQi1e6+UnlUoee/iob2 +Rm0NnxaVGkz3oEaSWVwTUvJUnlr7Tr/XejeVAUTkwCaHTGU+QH19IwdEAfSE/9CP +eh+gUhDR9PDVznlwKTLiyr5wH9+ta0u3EQH0S61mahETD+Lugp5NAp3JHN1nFtu5 +BhiG7cG6lCECAwEAAaNSMFAwDgYDVR0PAQH/BAQDAgWgMB0GA1UdJQQWMBQGCCsG +AQUFBwMCBggrBgEFBQcDATAfBgNVHSMEGDAWgBSJT95bzGniUs8+owDfsZe4HeHB +RjANBgkqhkiG9w0BAQsFAAOCAQEAWRZFppouN3nk9t0nGrocC/1s11WZtefDblM+ +/zZZCEMkyeelBAedOeDUKYf/4+vdCcHPHZFEVYcLVx3Rm98dJPi7mhH+gP1ZK6A5 +jN4R4mUeYYzlmPqW5Tcu7z0kiv3hdGPrv6u45NGrUCpU7ABk6S94GWYNPyfPIJ5m +f85a4uSsmcfJOBj4slEHIt/tl/MuPpNJ1MZsnqY5bXREYqBrQsbVumiOrDoBe938 +jiz8rSfLadPM3KKAQURl0640jODzSrL7nGGDcTErGRBBZBwjfxGl1lyETwQEhJk4 +cSuVntaFvFxd1kXtGZCUc0ApJty0DjRpoVlB6OLMqEu2CEY2oA== +-----END CERTIFICATE----- diff --git a/lib/std/crypto/testdata/key.der b/lib/std/crypto/testdata/key.der new file mode 100644 index 0000000000000000000000000000000000000000..ce271c001235f3f95d927ba04b65e9366442ba17 GIT binary patch literal 294 zcmV+>0ondAf&n5h4F(A+hDe6@4FLfG1potr0S^E$f&mHwf&l>l#DF#iy5~o02$Tfu zr?!yY`k_xPYX{TBpK zp3Zr_pE>dBY3xUrE`txX%9Y-gq2>f4)$YDhc~qh2zlw(TMr{qB7L^)I_n=0SSzHrR z@>HH$`%b^tdN-8;MC8CGhfHNYKz)592Sfq%g#XZwdLN)t5YhDT)y{cvDKg^9zHlGk st!qoS5dri|t!8Qw6A$9CXBOMJ1{j9z!Mc { - // const expected_len = if (stream.is_client) @TypeOf(k).bytes_length else X25519Kyber768Draft.Kyber768.ciphertext_length; - // }, + .x25519_kyber768d00 => { + const T = X25519Kyber768Draft.Kyber768.PublicKey; + var res = Self{ .x25519_kyber768d00 = undefined }; + + try reader.readNoEof(&res.x25519_kyber768d00.x25519); + + var buf: [T.bytes_length]u8 = undefined; + try reader.readNoEof(&buf); + res.x25519_kyber768d00.kyber768d00 = T.fromBytes(&buf) catch return Error.TlsDecryptError; + + return res; + }, inline .secp256r1, .secp384r1 => |k| { const T = NamedGroupT(k).PublicKey; - var buf: [T.compressed_sec1_encoded_length]u8 = undefined; + var buf: [T.uncompressed_sec1_encoded_length]u8 = undefined; try reader.readNoEof(&buf); const val = T.fromSec1(&buf) catch return Error.TlsDecryptError; return @unionInit(Self, @tagName(k), val); }, .x25519 => { var res = Self{ .x25519 = undefined }; - if (res.x25519.len != len) return Error.TlsDecodeError; try reader.readNoEof(&res.x25519); return res; }, - else => {}, + else => { + try reader.skipBytes(len, .{}); + }, } return .{ .invalid = {} }; } @@ -699,7 +712,6 @@ pub const CipherSuite = enum(u16) { chacha20_poly1305_sha256 = 0x1303, aegis_256_sha512 = 0x1306, aegis_128l_sha256 = 0x1307, - empty_renegotiation_info_scsv = 0x00ff, _, pub fn Hash(comptime self: @This()) type { @@ -731,7 +743,6 @@ pub const HandshakeCipher = union(CipherSuite) { chacha20_poly1305_sha256: HandshakeCipherT(.chacha20_poly1305_sha256), aegis_256_sha512: HandshakeCipherT(.aegis_256_sha512), aegis_128l_sha256: HandshakeCipherT(.aegis_128l_sha256), - empty_renegotiation_info_scsv: void, const Self = @This(); @@ -775,14 +786,12 @@ pub const HandshakeCipher = union(CipherSuite) { return res; }, - .empty_renegotiation_info_scsv => return .{ .empty_renegotiation_info_scsv = {} }, _ => return Error.TlsIllegalParameter, } } pub fn print(self: Self) void { switch (self) { - .empty_renegotiation_info_scsv => {}, inline else => |v| v.print(), } } @@ -794,7 +803,6 @@ pub const ApplicationCipher = union(CipherSuite) { chacha20_poly1305_sha256: ApplicationCipherT(.chacha20_poly1305_sha256), aegis_256_sha512: ApplicationCipherT(.aegis_256_sha512), aegis_128l_sha256: ApplicationCipherT(.aegis_128l_sha256), - empty_renegotiation_info_scsv: void, const Self = @This(); @@ -831,13 +839,11 @@ pub const ApplicationCipher = union(CipherSuite) { return res; }, - .empty_renegotiation_info_scsv => unreachable, } } pub fn print(self: Self) void { switch (self) { - .empty_renegotiation_info_scsv => {}, inline else => |v| v.print(), } } @@ -1138,7 +1144,6 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { fn ciphertextOverhead(self: Self) usize { return switch (self.cipher) { inline .application, .handshake => |c| switch (c) { - .empty_renegotiation_info_scsv => 0, inline else => |t| @TypeOf(t).AEAD.tag_length + @sizeOf(ContentType), }, else => 0, @@ -1182,7 +1187,6 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { plaintext.len += @intCast(self.ciphertextOverhead()); header = Encoder.encode(Plaintext, plaintext); switch (cipher.*) { - .empty_renegotiation_info_scsv => {}, inline else => |*c| { std.debug.assert(self.view.ptr == &self.buffer); self.buffer[self.view.len] = @intFromEnum(self.content_type); @@ -1369,37 +1373,33 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { if (n_read != res.len) return self.writeError(.decode_error); const encryption_method = self.encryptionMethod(res.type); - switch (encryption_method) { - .none => {}, - .handshake, .application => { - if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); - - switch (self.cipher) { - inline .application, .handshake => |*c| { - switch (c.*) { - .empty_renegotiation_info_scsv => {}, - inline else => |*p| { - const P = @TypeOf(p.*); - const tag_len = P.AEAD.tag_length; - - const ciphertext = self.view[0 .. self.view.len - tag_len]; - const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; - const out: []u8 = @constCast(self.view[0..ciphertext.len]); - p.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch - return self.writeError(.bad_record_mac); - const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); - if (padding_start) |s| { - res.type = @enumFromInt(self.view[s]); - self.view = self.view[0..s]; - } else { - return self.writeError(.decode_error); - } - }, - } - }, - else => unreachable, - } - }, + if (encryption_method != .none) { + if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); + + switch (self.cipher) { + inline .handshake, .application => |*cipher| { + switch (cipher.*) { + inline else => |*c| { + const C = @TypeOf(c.*); + const tag_len = C.AEAD.tag_length; + + const ciphertext = self.view[0 .. self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; + const out: []u8 = @constCast(self.view[0..ciphertext.len]); + c.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch + return self.writeError(.bad_record_mac); + const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); + if (padding_start) |s| { + res.type = @enumFromInt(self.view[s]); + self.view = self.view[0..s]; + } else { + return self.writeError(.decode_error); + } + }, + } + }, + else => unreachable, + } } switch (res.type) { @@ -1571,7 +1571,7 @@ pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { /// `@sizeOf(MultiHash)` = 560 /// /// A nice benefit is decreased latency on hosts where one round trip takes longer than calling -/// `update` on each hashes. +/// `update` with `active == .all`. pub const MultiHash = struct { sha256: sha2.Sha256 = sha2.Sha256.init(.{}), sha384: sha2.Sha384 = sha2.Sha384.init(.{}), @@ -1602,7 +1602,6 @@ pub const MultiHash = struct { .aes_128_gcm_sha256, .chacha20_poly1305_sha256, .aegis_128l_sha256 => .sha256, .aes_256_gcm_sha384 => .sha384, .aegis_256_sha512 => .sha512, - .empty_renegotiation_info_scsv => .all, _ => .all, }; } @@ -1969,6 +1968,7 @@ test "tls client and server handshake, data, and close_notify" { }; const server_der = @embedFile("./testdata/server.der"); + const server_key = @embedFile("./testdata/server.key"); var server_transcript: MultiHash = .{}; var server = Server(@TypeOf(inner_stream)){ .stream = Stream(Plaintext.max_length, TestStream){ @@ -1982,6 +1982,7 @@ test "tls client and server handshake, data, and close_notify" { .certificate = .{ .entries = &[_]Certificate.Entry{ .{ .data = server_der }, } }, + .certificate_key = server_key, }, }; @@ -2014,879 +2015,60 @@ test "tls client and server handshake, data, and close_notify" { client_x25519_seed ++ [_]u8{0} ** (48 - 32), client_x25519_seed, ); - var client_command = brk: { - // To get the same `hello_hash` as https://tls13.xargs.org/ just for - // this test we send a mostly falsified client hello. - // It doesn't matter because the server will be TLS 1.3 and only support .x25519 - const hello = ClientHello{ - .random = key_pairs.hello_rand, - .session_id = &key_pairs.session_id, - .cipher_suites = &[_]CipherSuite{ - .aes_256_gcm_sha384, - .chacha20_poly1305_sha256, - .aes_128_gcm_sha256, - .empty_renegotiation_info_scsv, - }, - .extensions = &.{ - .{ .server_name = &[_]ServerName{.{ .host_name = client.options.host }} }, - .{ .ec_point_formats = &[_]EcPointFormat{ - .uncompressed, - .ansiX962_compressed_prime, - .ansiX962_compressed_char2, - } }, - .{ .supported_groups = &[_]NamedGroup{ - .x25519, - .secp256r1, - .x448, - .secp521r1, - .secp384r1, - .ffdhe2048, - .ffdhe3072, - .ffdhe4096, - .ffdhe6144, - .ffdhe8192, - } }, - .{ .session_ticket = {} }, - .{ .encrypt_then_mac = {} }, - .{ .extended_master_secret = {} }, - .{ .signature_algorithms = &[_]SignatureScheme{ - .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - .ecdsa_secp521r1_sha512, - .ed25519, - .ed448, - .rsa_pss_pss_sha256, - .rsa_pss_pss_sha384, - .rsa_pss_pss_sha512, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - .rsa_pkcs1_sha256, - .rsa_pkcs1_sha384, - .rsa_pkcs1_sha512, - } }, - .{ .supported_versions = &[_]Version{.tls_1_3} }, - .{ .psk_key_exchange_modes = &[_]PskKeyExchangeMode{.ke} }, - .{ .key_share = &[_]KeyShare{ - .{ .x25519 = key_pairs.x25519.public_key }, - } }, - }, - }; - - _ = try client.stream.write(Handshake, .{ .client_hello = hello }); - try client.stream.flush(); - - break :brk client_mod.State{ .recv_hello = key_pairs }; + var client_command = client_mod.Command{ .send_hello = key_pairs }; + client_command = try client.next(client_command); + try std.testing.expect(client_command == .recv_hello); + + var server_command = server_mod.Command{ .recv_hello = {} }; + server_command = try server.next(server_command); // recv_hello + try std.testing.expect(server_command == .send_hello); + server_command.send_hello.server_random = server_random; + server_command.send_hello.server_pair = .{ + .x25519 = crypto.dh.X25519.KeyPair.create(server_x25519_seed) catch unreachable, }; - try inner_stream.expect([_]u8{ - 0x16, // handshake - 0x03, 0x01, // tls 1.0 (lie for compat) - 0x00, 0xf8, // handshake len - 0x01, // client hello - 0x00, 0x00, 0xf4, // client hello len - 0x03, 0x03, // tls 1.2 (lie for compat) - } ++ client_random ++ - [_]u8{session_id.len} ++ session_id ++ - [_]u8{ - 0x00, 0x08, // cipher suite len - 0x13, 0x02, // aes_256_gcm_sha384 - 0x13, 0x03, // chacha20_poly1305_sha256 - 0x13, 0x01, // aes_128_gcm_sha256 - 0x00, 0xff, // empty_renegotiation_info_scsv - 0x01, // compression methods len - 0x00, // none - 0x00, 0xa3, // extensions len - 0x00, 0x00, // server name ext - 0x00, 0x18, // server name len - 0x00, 0x16, // list entry len - 0x00, // dns hostname - } ++ - Encoder.encode(u16, @intCast(host.len)) ++ host ++ - [_]u8{ - 0x00, 0x0b, // ec point formats - 0x00, 0x04, // ext len - 0x03, // format type len - 0x00, // uncompresed - 0x01, // ansiX962_compressed_prime - 0x02, // ansiX962_compressed_char2 - 0x00, 0x0a, // supported groups - 0x00, 0x16, // ext len - 0x00, 0x14, // supported groups len - 0x00, 0x1d, // x25519 - 0x00, 0x17, // secp256r1 - 0x00, 0x1e, // x448 - 0x00, 0x19, // secp521r1 - 0x00, 0x18, // secp384r1 - 0x01, 0x00, // ffdhe2048 - 0x01, 0x01, // ffdhe3072 - 0x01, 0x02, // ffdhe4096 - 0x01, 0x03, // ffdhe6144 - 0x01, 0x04, // ffdhe8192 - 0x00, 0x23, // session ticket - 0x00, 0x00, // ext len - 0x00, 0x16, // encrypt then mac - 0x00, 0x00, // ext len - 0x00, 0x17, // extended master secrets - 0x00, 0x00, // ext len - 0x00, 0x0d, // signature algos - 0x00, 0x1e, // ext len - 0x00, 0x1c, // algos len - 0x04, 0x03, // ecdsa_secp256r1_sha256 - 0x05, 0x03, // ecdsa_secp384r1_sha384 - 0x06, 0x03, // ecdsa_secp521r1_sha512 - 0x08, 0x07, // ed25519 - 0x08, 0x08, // ed448 - 0x08, 0x09, // rsa_pss_pss_sha256 - 0x08, 0x0a, // rsa_pss_pss_sha384 - 0x08, 0x0b, // rsa_pss_pss_sha512 - 0x08, 0x04, // rsa_pss_rsae_sha256 - 0x08, 0x05, // rsa_pss_rsae_sha384 - 0x08, 0x06, // rsa_pss_rsae_sha512 - 0x04, 0x01, // rsa_pkcs1_sha256 - 0x05, 0x01, // rsa_pkcs1_sha384 - 0x06, 0x01, // rsa_pkcs1_sha512 - 0x00, 0x2b, // supported versions - 0x00, 0x03, // ext len - 0x02, // supported versions len - 0x03, 0x04, // tls 1.3 (not lying anymore!) - 0x00, 0x2d, // psk key exchange modes - 0x00, 0x02, // ext len - 0x01, // psk key exchange modes len - 0x01, // PSK with (EC)DHE key establishment - 0x00, 0x33, // key share - 0x00, 0x26, // ext len - 0x00, 0x24, // key shares len - 0x00, 0x1d, // curve 25519 - 0x00, 0x20, // key len - } ++ key_pairs.x25519.public_key); - - const client_hello = try server.recv_hello(); - try std.testing.expectEqualSlices(u8, &client_random, &client_hello.random); - try std.testing.expectEqualSlices(u8, &session_id, &client_hello.session_id); - const server_key_pair = server_mod.KeyPair{ - .random = server_random, - .pair = .{ .x25519 = crypto.dh.X25519.KeyPair.create(server_x25519_seed) catch unreachable }, - }; + server_command = try server.next(server_command); // send_hello + try std.testing.expect(server_command == .send_change_cipher_spec); - try server.send_hello(client_hello, server_key_pair); - // hack to match xargs, need to fix server - const signature_verify = [_]u8{ 0x5c, 0xbb, 0x24, 0xc0, 0x40, 0x93, 0x32, 0xda, 0xa9, 0x20, 0xbb, 0xab, 0xbd, 0xb9, 0xbd, 0x50, 0x17, 0x0b, 0xe4, 0x9c, 0xfb, 0xe0, 0xa4, 0x10, 0x7f, 0xca, 0x6f, 0xfb, 0x10, 0x68, 0xe6, 0x5f, 0x96, 0x9e, 0x6d, 0xe7, 0xd4, 0xf9, 0xe5, 0x60, 0x38, 0xd6, 0x7c, 0x69, 0xc0, 0x31, 0x40, 0x3a, 0x7a, 0x7c, 0x0b, 0xcc, 0x86, 0x83, 0xe6, 0x57, 0x21, 0xa0, 0xc7, 0x2c, 0xc6, 0x63, 0x40, 0x19, 0xad, 0x1d, 0x3a, 0xd2, 0x65, 0xa8, 0x12, 0x61, 0x5b, 0xa3, 0x63, 0x80, 0x37, 0x20, 0x84, 0xf5, 0xda, 0xec, 0x7e, 0x63, 0xd3, 0xf4, 0x93, 0x3f, 0x27, 0x22, 0x74, 0x19, 0xa6, 0x11, 0x03, 0x46, 0x44, 0xdc, 0xdb, 0xc7, 0xbe, 0x3e, 0x74, 0xff, 0xac, 0x47, 0x3f, 0xaa, 0xad, 0xde, 0x8c, 0x2f, 0xc6, 0x5f, 0x32, 0x65, 0x77, 0x3e, 0x7e, 0x62, 0xde, 0x33, 0x86, 0x1f, 0xa7, 0x05, 0xd1, 0x9c, 0x50, 0x6e, 0x89, 0x6c, 0x8d, 0x82, 0xf5, 0xbc, 0xf3, 0x5f, 0xec, 0xe2, 0x59, 0xb7, 0x15, 0x38, 0x11, 0x5e, 0x9c, 0x8c, 0xfb, 0xa6, 0x2e, 0x49, 0xbb, 0x84, 0x74, 0xf5, 0x85, 0x87, 0xb1, 0x1b, 0x8a, 0xe3, 0x17, 0xc6, 0x33, 0xe9, 0xc7, 0x6c, 0x79, 0x1d, 0x46, 0x62, 0x84, 0xad, 0x9c, 0x4f, 0xf7, 0x35, 0xa6, 0xd2, 0xe9, 0x63, 0xb5, 0x9b, 0xbc, 0xa4, 0x40, 0xa3, 0x07, 0x09, 0x1a, 0x1b, 0x4e, 0x46, 0xbc, 0xc7, 0xa2, 0xf9, 0xfb, 0x2f, 0x1c, 0x89, 0x8e, 0xcb, 0x19, 0x91, 0x8b, 0xe4, 0x12, 0x1d, 0x7e, 0x8e, 0xd0, 0x4c, 0xd5, 0x0c, 0x9a, 0x59, 0xe9, 0x87, 0x98, 0x01, 0x07, 0xbb, 0xbf, 0x29, 0x9c, 0x23, 0x2e, 0x7f, 0xdb, 0xe1, 0x0a, 0x4c, 0xfd, 0xae, 0x5c, 0x89, 0x1c, 0x96, 0xaf, 0xdf, 0xf9, 0x4b, 0x54, 0xcc, 0xd2, 0xbc, 0x19, 0xd3, 0xcd, 0xaa, 0x66, 0x44, 0x85, 0x9c }; - _ = try server.stream.write(Handshake, Handshake{ .certificate_verify = CertificateVerify{ - .algorithm = .rsa_pss_rsae_sha256, - .signature = &signature_verify, - } }); - try server.stream.flush(); - - try server.send_finished(); - try inner_stream.expect(&([_]u8{ - 0x16, // handshake - 0x03, 0x03, // tls 1.2 - 0x00, 0x7a, // Handshake len - 0x02, // server hello - 0x00, 0x00, 0x76, // server hello len - 0x03, 0x03, // tls 1.2 - } ++ server_key_pair.random ++ [_]u8{session_id.len} ++ session_id ++ - [_]u8{ - 0x13, 0x02, // aes_256_gcm_sha384 - 0x00, // compression method - 0x00, 0x2e, // extensions len - 0x00, 0x2b, // supported versions - 0x00, 0x02, // ext len - 0x03, 0x04, // tls 1.3 - 0x00, 0x33, // key share - 0x00, 0x24, // ext len - 0x00, 0x1d, // x25519 - 0x00, 0x20, // key len - 0x9f, 0xd7, - 0xad, 0x6d, - 0xcf, 0xf4, - 0x29, 0x8d, - 0xd3, 0xf9, - 0x6d, 0x5b, - 0x1b, 0x2a, - 0xf9, 0x10, - 0xa0, 0x53, - 0x5b, 0x14, - 0x88, 0xd7, - 0xf8, 0xfa, - 0xbb, 0x34, - 0x9a, 0x98, - 0x28, 0x80, - 0xb6, 0x15, // key - } ++ - [_]u8{ - 0x14, // ChangeCipherSpec - 0x03, 0x03, // tls 1.2 - 0x00, 0x01, // len - 0x01, // .change_cipher_spec - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x00, 0x17, // application data len - 0x6b, 0xe0, 0x2f, 0x9d, 0xa7, 0xc2, // encrypted data (empty EncryptedExtensions message) - 0xdc, // encrypted data type (handshake) - 0x9d, 0xde, 0xf5, 0x6f, 0x24, 0x68, 0xb9, 0x0a, // auth tag - 0xdf, 0xa2, 0x51, 0x01, 0xab, 0x03, 0x44, 0xae, // auth tag - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x03, 0x43, // application data len - 0xba, 0xf0, - 0x0a, 0x9b, - 0xe5, 0x0f, - 0x3f, 0x23, - 0x07, 0xe7, - 0x26, 0xed, - 0xcb, 0xda, - 0xcb, 0xe4, - 0xb1, 0x86, - 0x16, 0x44, - 0x9d, 0x46, - 0xc6, 0x20, - 0x7a, 0xf6, - 0xe9, 0x95, - 0x3e, 0xe5, - 0xd2, 0x41, - 0x1b, 0xa6, - 0x5d, 0x31, - 0xfe, 0xaf, - 0x4f, 0x78, - 0x76, 0x4f, - 0x2d, 0x69, - 0x39, 0x87, - 0x18, 0x6c, - 0xc0, 0x13, - 0x29, 0xc1, - 0x87, 0xa5, - 0xe4, 0x60, - 0x8e, 0x8d, - 0x27, 0xb3, - 0x18, 0xe9, - 0x8d, 0xd9, - 0x47, 0x69, - 0xf7, 0x73, - 0x9c, 0xe6, - 0x76, 0x83, - 0x92, 0xca, - 0xca, 0x8d, - 0xcc, 0x59, - 0x7d, 0x77, - 0xec, 0x0d, - 0x12, 0x72, - 0x23, 0x37, - 0x85, 0xf6, - 0xe6, 0x9d, - 0x6f, 0x43, - 0xef, 0xfa, - 0x8e, 0x79, - 0x05, 0xed, - 0xfd, 0xc4, - 0x03, 0x7e, - 0xee, 0x59, - 0x33, 0xe9, - 0x90, 0xa7, - 0x97, 0x2f, - 0x20, 0x69, - 0x13, 0xa3, - 0x1e, 0x8d, - 0x04, 0x93, - 0x13, 0x66, - 0xd3, 0xd8, - 0xbc, 0xd6, - 0xa4, 0xa4, - 0xd6, 0x47, - 0xdd, 0x4b, - 0xd8, 0x0b, - 0x0f, 0xf8, - 0x63, 0xce, - 0x35, 0x54, - 0x83, 0x3d, - 0x74, 0x4c, - 0xf0, 0xe0, - 0xb9, 0xc0, - 0x7c, 0xae, - 0x72, 0x6d, - 0xd2, 0x3f, - 0x99, 0x53, - 0xdf, 0x1f, - 0x1c, 0xe3, - 0xac, 0xeb, - 0x3b, 0x72, - 0x30, 0x87, - 0x1e, 0x92, - 0x31, 0x0c, - 0xfb, 0x2b, - 0x09, 0x84, - 0x86, 0xf4, - 0x35, 0x38, - 0xf8, 0xe8, - 0x2d, 0x84, - 0x04, 0xe5, - 0xc6, 0xc2, - 0x5f, 0x66, - 0xa6, 0x2e, - 0xbe, 0x3c, - 0x5f, 0x26, - 0x23, 0x26, - 0x40, 0xe2, - 0x0a, 0x76, - 0x91, 0x75, - 0xef, 0x83, - 0x48, 0x3c, - 0xd8, 0x1e, - 0x6c, 0xb1, - 0x6e, 0x78, - 0xdf, 0xad, - 0x4c, 0x1b, - 0x71, 0x4b, - 0x04, 0xb4, - 0x5f, 0x6a, - 0xc8, 0xd1, - 0x06, 0x5a, - 0xd1, 0x8c, - 0x13, 0x45, - 0x1c, 0x90, - 0x55, 0xc4, - 0x7d, 0xa3, - 0x00, 0xf9, - 0x35, 0x36, - 0xea, 0x56, - 0xf5, 0x31, - 0x98, 0x6d, - 0x64, 0x92, - 0x77, 0x53, - 0x93, 0xc4, - 0xcc, 0xb0, - 0x95, 0x46, - 0x70, 0x92, - 0xa0, 0xec, - 0x0b, 0x43, - 0xed, 0x7a, - 0x06, 0x87, - 0xcb, 0x47, - 0x0c, 0xe3, - 0x50, 0x91, - 0x7b, 0x0a, - 0xc3, 0x0c, - 0x6e, 0x5c, - 0x24, 0x72, - 0x5a, 0x78, - 0xc4, 0x5f, - 0x9f, 0x5f, - 0x29, 0xb6, - 0x62, 0x68, - 0x67, 0xf6, - 0xf7, 0x9c, - 0xe0, 0x54, - 0x27, 0x35, - 0x47, 0xb3, - 0x6d, 0xf0, - 0x30, 0xbd, - 0x24, 0xaf, - 0x10, 0xd6, - 0x32, 0xdb, - 0xa5, 0x4f, - 0xc4, 0xe8, - 0x90, 0xbd, - 0x05, 0x86, - 0x92, 0x8c, - 0x02, 0x06, - 0xca, 0x2e, - 0x28, 0xe4, - 0x4e, 0x22, - 0x7a, 0x2d, - 0x50, 0x63, - 0x19, 0x59, - 0x35, 0xdf, - 0x38, 0xda, - 0x89, 0x36, - 0x09, 0x2e, - 0xef, 0x01, - 0xe8, 0x4c, - 0xad, 0x2e, - 0x49, 0xd6, - 0x2e, 0x47, - 0x0a, 0x6c, - 0x77, 0x45, - 0xf6, 0x25, - 0xec, 0x39, - 0xe4, 0xfc, - 0x23, 0x32, - 0x9c, 0x79, - 0xd1, 0x17, - 0x28, 0x76, - 0x80, 0x7c, - 0x36, 0xd7, - 0x36, 0xba, - 0x42, 0xbb, - 0x69, 0xb0, - 0x04, 0xff, - 0x55, 0xf9, - 0x38, 0x50, - 0xdc, 0x33, - 0xc1, 0xf9, - 0x8a, 0xbb, - 0x92, 0x85, - 0x83, 0x24, - 0xc7, 0x6f, - 0xf1, 0xeb, - 0x08, 0x5d, - 0xb3, 0xc1, - 0xfc, 0x50, - 0xf7, 0x4e, - 0xc0, 0x44, - 0x42, 0xe6, - 0x22, 0x97, - 0x3e, 0xa7, - 0x07, 0x43, - 0x41, 0x87, - 0x94, 0xc3, - 0x88, 0x14, - 0x0b, 0xb4, - 0x92, 0xd6, - 0x29, 0x4a, - 0x05, 0x40, - 0xe5, 0xa5, - 0x9c, 0xfa, - 0xe6, 0x0b, - 0xa0, 0xf1, - 0x48, 0x99, - 0xfc, 0xa7, - 0x13, 0x33, - 0x31, 0x5e, - 0xa0, 0x83, - 0xa6, 0x8e, - 0x1d, 0x7c, - 0x1e, 0x4c, - 0xdc, 0x2f, - 0x56, 0xbc, - 0xd6, 0x11, - 0x96, 0x81, - 0xa4, 0xad, - 0xbc, 0x1b, - 0xbf, 0x42, - 0xaf, 0xd8, - 0x06, 0xc3, - 0xcb, 0xd4, - 0x2a, 0x07, - 0x6f, 0x54, - 0x5d, 0xee, - 0x4e, 0x11, - 0x8d, 0x0b, - 0x39, 0x67, - 0x54, 0xbe, - 0x2b, 0x04, - 0x2a, 0x68, - 0x5d, 0xd4, - 0x72, 0x7e, - 0x89, 0xc0, - 0x38, 0x6a, - 0x94, 0xd3, - 0xcd, 0x6e, - 0xcb, 0x98, - 0x20, 0xe9, - 0xd4, 0x9a, - 0xfe, 0xed, - 0x66, 0xc4, - 0x7e, 0x6f, - 0xc2, 0x43, - 0xea, 0xbe, - 0xbb, 0xcb, - 0x0b, 0x02, - 0x45, 0x38, - 0x77, 0xf5, - 0xac, 0x5d, - 0xbf, 0xbd, - 0xf8, 0xdb, - 0x10, 0x52, - 0xa3, 0xc9, - 0x94, 0xb2, - 0x24, 0xcd, - 0x9a, 0xaa, - 0xf5, 0x6b, - 0x02, 0x6b, - 0xb9, 0xef, - 0xa2, 0xe0, - 0x13, 0x02, - 0xb3, 0x64, - 0x01, 0xab, - 0x64, 0x94, - 0xe7, 0x01, - 0x8d, 0x6e, - 0x5b, 0x57, - 0x3b, 0xd3, - 0x8b, 0xce, - 0xf0, 0x23, - 0xb1, 0xfc, - 0x92, 0x94, - 0x6b, 0xbc, - 0xa0, 0x20, - 0x9c, 0xa5, - 0xfa, 0x92, - 0x6b, 0x49, - 0x70, 0xb1, - 0x00, 0x91, - 0x03, 0x64, - 0x5c, 0xb1, - 0xfc, 0xfe, - 0x55, 0x23, - 0x11, 0xff, - 0x73, 0x05, - 0x58, 0x98, - 0x43, 0x70, - 0x03, 0x8f, - 0xd2, 0xcc, - 0xe2, 0xa9, - 0x1f, 0xc7, - 0x4d, 0x6f, - 0x3e, 0x3e, - 0xa9, 0xf8, - 0x43, 0xee, - 0xd3, 0x56, - 0xf6, 0xf8, - 0x2d, 0x35, - 0xd0, 0x3b, - 0xc2, 0x4b, - 0x81, 0xb5, - 0x8c, 0xeb, - 0x1a, 0x43, - 0xec, 0x94, - 0x37, 0xe6, - 0xf1, 0xe5, - 0x0e, 0xb6, - 0xf5, 0x55, - 0xe3, 0x21, - 0xfd, 0x67, - 0xc8, 0x33, - 0x2e, 0xb1, - 0xb8, 0x32, - 0xaa, 0x8d, - 0x79, 0x5a, - 0x27, 0xd4, - 0x79, 0xc6, - 0xe2, 0x7d, - 0x5a, 0x61, - 0x03, 0x46, - 0x83, 0x89, - 0x19, 0x03, - 0xf6, 0x64, - 0x21, 0xd0, - 0x94, 0xe1, - 0xb0, 0x0a, - 0x9a, 0x13, - 0x8d, 0x86, - 0x1e, 0x6f, - 0x78, 0xa2, - 0x0a, 0xd3, - 0xe1, 0x58, - 0x00, 0x54, - 0xd2, 0xe3, - 0x05, 0x25, - 0x3c, 0x71, - 0x3a, 0x02, - 0xfe, 0x1e, - 0x28, 0xde, - 0xee, 0x73, - 0x36, 0x24, - 0x6f, 0x6a, - 0xe3, 0x43, - 0x31, 0x80, - 0x6b, 0x46, - 0xb4, 0x7b, - 0x83, 0x3c, - 0x39, 0xb9, - 0xd3, 0x1c, - 0xd3, 0x00, - 0xc2, 0xa6, - 0xed, 0x83, - 0x13, 0x99, - 0x77, 0x6d, - 0x07, 0xf5, - 0x70, 0xea, - 0xf0, 0x05, - 0x9a, 0x2c, - 0x68, 0xa5, - 0xf3, 0xae, - 0x16, 0xb6, - 0x17, 0x40, - 0x4a, 0xf7, - 0xb7, 0x23, - 0x1a, 0x4d, - 0x94, 0x27, - 0x58, 0xfc, - 0x02, 0x0b, - 0x3f, 0x23, - 0xee, 0x8c, - 0x15, 0xe3, - 0x60, 0x44, - 0xcf, 0xd6, - 0x7c, 0xd6, - 0x40, 0x99, - 0x3b, 0x16, - 0x20, 0x75, - 0x97, 0xfb, - 0xf3, 0x85, - 0xea, 0x7a, - 0x4d, 0x99, - 0xe8, 0xd4, - 0x56, 0xff, - 0x83, 0xd4, - 0x1f, 0x7b, - 0x8b, 0x4f, - 0x06, 0x9b, - 0x02, 0x8a, - 0x2a, 0x63, - 0xa9, 0x19, - 0xa7, 0x0e, - 0x3a, 0x10, - 0xe3, 0x08, // encrypted cert - 0x41, // encrypted data type (Certificate) - 0x58, 0xfa, 0xa5, 0xba, 0xfa, 0x30, 0x18, 0x6c, // auth tag - 0x6b, 0x2f, 0x23, 0x8e, 0xb5, 0x30, 0xc7, 0x3e, // auth tag - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x01, 0x19, // application data len - 0x73, 0x71, - 0x9f, 0xce, - 0x07, 0xec, - 0x2f, 0x6d, - 0x3b, 0xba, - 0x02, 0x92, - 0xa0, 0xd4, - 0x0b, 0x27, - 0x70, 0xc0, - 0x6a, 0x27, - 0x17, 0x99, - 0xa5, 0x33, - 0x14, 0xf6, - 0xf7, 0x7f, - 0xc9, 0x5c, - 0x5f, 0xe7, - 0xb9, 0xa4, - 0x32, 0x9f, - 0xd9, 0x54, - 0x8c, 0x67, - 0x0e, 0xbe, - 0xea, 0x2f, - 0x2d, 0x5c, - 0x35, 0x1d, - 0xd9, 0x35, - 0x6e, 0xf2, - 0xdc, 0xd5, - 0x2e, 0xb1, - 0x37, 0xbd, - 0x3a, 0x67, - 0x65, 0x22, - 0xf8, 0xcd, - 0x0f, 0xb7, - 0x56, 0x07, - 0x89, 0xad, - 0x7b, 0x0e, - 0x3c, 0xab, - 0xa2, 0xe3, - 0x7e, 0x6b, - 0x41, 0x99, - 0xc6, 0x79, - 0x3b, 0x33, - 0x46, 0xed, - 0x46, 0xcf, - 0x74, 0x0a, - 0x9f, 0xa1, - 0xfe, 0xc4, - 0x14, 0xdc, - 0x71, 0x5c, - 0x41, 0x5c, - 0x60, 0xe5, - 0x75, 0x70, - 0x3c, 0xe6, - 0xa3, 0x4b, - 0x70, 0xb5, - 0x19, 0x1a, - 0xa6, 0xa6, - 0x1a, 0x18, - 0xfa, 0xff, - 0x21, 0x6c, - 0x68, 0x7a, - 0xd8, 0xd1, - 0x7e, 0x12, - 0xa7, 0xe9, - 0x99, 0x15, - 0xa6, 0x11, - 0xbf, 0xc1, - 0xa2, 0xbe, - 0xfc, 0x15, - 0xe6, 0xe9, - 0x4d, 0x78, - 0x46, 0x42, - 0xe6, 0x82, - 0xfd, 0x17, - 0x38, 0x2a, - 0x34, 0x8c, - 0x30, 0x10, - 0x56, 0xb9, - 0x40, 0xc9, - 0x84, 0x72, - 0x00, 0x40, - 0x8b, 0xec, - 0x56, 0xc8, - 0x1e, 0xa3, - 0xd7, 0x21, - 0x7a, 0xb8, - 0xe8, 0x5a, - 0x88, 0x71, - 0x53, 0x95, - 0x89, 0x9c, - 0x90, 0x58, - 0x7f, 0x72, - 0xe8, 0xdd, - 0xd7, 0x4b, - 0x26, 0xd8, - 0xed, 0xc1, - 0xc7, 0xc8, - 0x37, 0xd9, - 0xf2, 0xeb, - 0xbc, 0x26, - 0x09, 0x62, - 0x21, 0x90, - 0x38, 0xb0, - 0x56, 0x54, - 0xa6, 0x3a, - 0x0b, 0x12, - 0x99, 0x9b, - 0x4a, 0x83, - 0x06, 0xa3, - 0xdd, 0xcc, - 0x0e, 0x17, - 0xc5, 0x3b, - 0xa8, 0xf9, - 0xc8, 0x03, - 0x63, 0xf7, - 0x84, 0x13, - 0x54, 0xd2, - 0x91, 0xb4, - 0xac, 0xe0, - 0xc0, 0xf3, - 0x30, 0xc0, - 0xfc, 0xd5, - 0xaa, 0x9d, - 0xee, 0xf9, - 0x69, 0xae, - 0x8a, 0xb2, - 0xd9, 0x8d, - 0xa8, 0x8e, - 0xbb, 0x6e, 0xa8, 0x0a, 0x3a, 0x11, 0xf0, 0x0e, // encrypted signature_verify - 0xa2, // encrypted data type (SignatureVerify) - 0x96, 0xa3, 0x23, 0x23, 0x67, 0xff, 0x07, 0x5e, // auth tag - 0x1c, 0x66, 0xdd, 0x9c, 0xbe, 0xdc, 0x47, 0x13, // auth tag - } ++ [_]u8{ - 0x17, // application data (lie for tls 1.2 compat) - 0x03, 0x03, // tls 1.2 - 0x00, 0x45, // application data len - 0x10, 0x61, - 0xde, 0x27, - 0xe5, 0x1c, - 0x2c, 0x9f, - 0x34, 0x29, - 0x11, 0x80, - 0x6f, 0x28, - 0x2b, 0x71, - 0x0c, 0x10, - 0x63, 0x2c, - 0xa5, 0x00, - 0x67, 0x55, - 0x88, 0x0d, - 0xbf, 0x70, - 0x06, 0x00, - 0x2d, 0x0e, - 0x84, 0xfe, - 0xd9, 0xad, - 0xf2, 0x7a, - 0x43, 0xb5, - 0x19, 0x23, - 0x03, 0xe4, - 0xdf, 0x5c, - 0x28, 0x5d, - 0x58, 0xe3, - 0xc7, 0x62, - 0x24, // encrypted data type (finished) - 0x07, 0x84, 0x40, 0xc0, 0x74, 0x23, 0x74, 0x74, // auth tag - 0x4a, 0xec, 0xf2, 0x8c, 0xf3, 0x18, 0x2f, 0xd0, // auth tag - })); - - client_command = try client.advance(client_command); // recv_hello + client_command = try client.next(client_command); // recv_hello try std.testing.expect(client_command == .recv_encrypted_extensions); - // { - // const s = server.stream.cipher.handshake.aes_256_gcm_sha384; - // const c = client.stream.cipher.handshake.aes_256_gcm_sha384; - - // try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); - // try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); - // try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); - // try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); - // try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); - // try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); - // try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - // try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - // const client_iv = [_]u8{ 0x42, 0x56, 0xd2, 0xe0, 0xe8, 0x8b, 0xab, 0xdd, 0x05, 0xeb, 0x2f, 0x27 }; - // try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); - // } - client_command = try client.advance(client_command); // recv_encrypted_extensions - try std.testing.expect(client_command == .recv_certificate_or_finished); + { + const s = server.stream.cipher.handshake.aes_256_gcm_sha384; + const c = client.stream.cipher.handshake.aes_256_gcm_sha384; - client_command = try client.advance(client_command); // recv_certificate_or_finished (certificate) - try std.testing.expect(client_command == .recv_certificate_verify); + try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); + try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); + try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); + try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); + try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + const client_iv = [_]u8{ 0xE1, 0x38, 0xB9, 0xBF, 0xD6, 0xB4, 0x2D, 0x91, 0x6D, 0x81, 0xA0, 0x2D }; + try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + } - client_command = try client.advance(client_command); // recv_certificate_verify + server_command = try server.next(server_command); // send_change_cipher_spec + try std.testing.expect(server_command == .send_encrypted_extensions); + server_command = try server.next(server_command); // send_encrypted_extensions + try std.testing.expect(server_command == .send_certificate); + server_command = try server.next(server_command); // send_certificate + try std.testing.expect(server_command == .send_certificate_verify); + server_command = try server.next(server_command); // send_certificate_verify + try std.testing.expect(server_command == .send_finished); + server_command = try server.next(server_command); // send_finished + try std.testing.expect(server_command == .recv_finished); + + client_command = try client.next(client_command); // recv_encrypted_extensions + try std.testing.expect(client_command == .recv_certificate_or_finished); + client_command = try client.next(client_command); // recv_certificate_or_finished (certificate) + try std.testing.expect(client_command == .recv_certificate_verify); + client_command = try client.next(client_command); // recv_certificate_verify try std.testing.expect(client_command == .recv_finished); - - client_command = try client.advance(client_command); // recv_finished + client_command = try client.next(client_command); // recv_finished try std.testing.expect(client_command == .send_finished); - - client_command = try client.advance(client_command); // send_finished - try std.testing.expect(client_command == .sent_finished); - - try inner_stream.expect(&([_]u8{ - 0x14, // ChangeCipherSpec - 0x03, 0x03, // tls 1.2 - 0x00, 0x01, // len - 0x01, // .change_cipher_spec - } ++ [_]u8{ - 0x17, // app data (lie for TLS 1.2) - 0x03, 0x03, // tls 1.2 - 0x00, 0x45, // len - 0x9f, 0xf9, - 0xb0, 0x63, - 0x17, 0x51, - 0x77, 0x32, - 0x2a, 0x46, - 0xdd, 0x98, - 0x96, 0xf3, - 0xc3, 0xbb, - 0x82, 0x0a, - 0xb5, 0x17, - 0x43, 0xeb, - 0xc2, 0x5f, - 0xda, 0xdd, - 0x53, 0x45, - 0x4b, 0x73, - 0xde, 0xb5, - 0x4c, 0xc7, - 0x24, 0x8d, - 0x41, 0x1a, - 0x18, 0xbc, - 0xcf, 0x65, - 0x7a, 0x96, - 0x08, 0x24, - 0xe9, 0xa1, - 0x93, 0x64, 0x83, 0x7c, // encrypted data - 0x35, // handshake - 0x0a, 0x69, 0xa8, 0x8d, 0x4b, 0xf6, 0x35, 0xc8, // auth tag - 0x5e, 0xb8, 0x74, 0xae, 0xbc, 0x9d, 0xfd, 0xe8, // auth tag - })); - try server.recv_finished(); - + client_command = try client.next(client_command); // send_finished + try std.testing.expect(client_command == .none); { const s = server.stream.cipher.application.aes_256_gcm_sha384; const c = client.stream.cipher.application.aes_256_gcm_sha384; @@ -2900,19 +2082,10 @@ test "tls client and server handshake, data, and close_notify" { const client_iv = [_]u8{ 0xbb, 0x00, 0x79, 0x56, 0xf4, 0x74, 0xb2, 0x5d, 0xe9, 0x02, 0x43, 0x2f }; try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } + server_command = try server.next(server_command); // recv_finished + try std.testing.expect(server_command == .none); - - _ = try client.stream.writer().writeAll("ping"); - try client.stream.flush(); - try inner_stream.expect(&([_]u8{ - 0x17, // app data (FOR REAL THIS TIME) - 0x03, 0x03, // tls 1.2 - 0x00, 0x15, // len - 0x82, 0x81, 0x39, 0xcb, // ping - 0x7b, // app data (exciting!) - 0x73, 0xaa, 0xab, 0xf5, 0xb8, 0x2f, 0xbf, 0x9a, // auth tag - 0x29, 0x61, 0xbc, 0xde, 0x10, 0x03, 0x8a, 0x32, // auth tag - })); + try client.writer().writeAll("ping"); var recv_ping: [4]u8 = undefined; _ = try server.stream.reader().readAll(&recv_ping); @@ -2920,15 +2093,6 @@ test "tls client and server handshake, data, and close_notify" { server.stream.close(); try std.testing.expect(server.stream.closed); - try inner_stream.expect(&([_]u8{ - 0x17, // app data (lie to encrypt) - 0x03, 0x03, // tls 1.2 - 0x00, 0x13, // len - 0x3e, 0x2d, // alert - 0x99, // encrypted message type - 0x26, 0xbb, 0xfe, 0x1f, 0x46, 0xfb, 0x4e, 0xe2, // auth tag - 0x75, 0x1e, 0x53, 0xbf, 0xfc, 0x7e, 0x65, 0x16, // auth tag - })); _ = try client.stream.readPlaintext(); try std.testing.expect(client.stream.closed); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 82d48d42e29a..14a90f805abe 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -25,14 +25,14 @@ pub fn Client(comptime StreamType: type) type { }; var res = Self{ .stream = stream_, .options = options }; - var state = State{ .send_hello = KeyPairs.init() }; - while (state != .sent_finished) state = try res.advance(state); + var command = Command{ .send_hello = KeyPairs.init() }; + while (command != .none) command = try res.next(command); return res; } - /// Execute command and return next one. - pub fn advance(self: *Self, command: State) !State { + /// Executes handshake command and returns next one. + pub fn next(self: *Self, command: Command) !Command { var stream = &self.stream; switch (command) { .send_hello => |key_pairs| { @@ -87,14 +87,19 @@ pub fn Client(comptime StreamType: type) type { try stream.expectInnerPlaintext(.handshake, .finished); try self.recv_finished(digest); + return .{ .send_change_cipher_spec = {} }; + }, + .send_change_cipher_spec => { + try stream.changeCipherSpec(); + return .{ .send_finished = {} }; }, .send_finished => { try self.send_finished(); - return .{ .sent_finished = {} }; + return .{ .none = {} }; }, - .sent_finished => return .{ .sent_finished = {} }, + .none => return .{ .none = {} }, } } @@ -211,7 +216,7 @@ pub fn Client(comptime StreamType: type) type { shared_key = &mult.affineCoordinates().x.toBytes(.big); }, // Server sent us back unknown key. That's weird because we only request known ones, - // but we can try for another. + // but we can keep iterating for another. else => { try r.skipBytes(key_size, .{}); }, @@ -259,6 +264,7 @@ pub fn Client(comptime StreamType: type) type { try r.readNoEof(context[0..context_len]); var first: ?crypto.Certificate.Parsed = null; + errdefer if (first) |f| allocator.free(f.certificate.buffer); var prev: Certificate.Parsed = undefined; var verified = false; const now_sec = std.time.timestamp(); @@ -346,20 +352,16 @@ pub fn Client(comptime StreamType: type) type { const Hash = SchemeHash(comptime_scheme); const rsa = Certificate.rsa; - const components = rsa.PublicKey.parseDer(cert.pubKey()) catch - return stream.writeError(.decode_error); - const exponent = components.exponent; - const modulus = components.modulus; - switch (modulus.len) { + const key = rsa.PublicKey.fromDer(cert.pubKey()) catch + return stream.writeError(.bad_certificate); + switch (key.n.bits() / 8) { inline 128, 256, 512 => |modulus_len| { - const key = rsa.PublicKey.fromBytes(exponent, modulus) catch - return stream.writeError(.bad_certificate); const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch return stream.writeError(.decode_error); }, else => { - return error.TlsBadRsaSignatureBitCount; + return stream.writeError(.bad_certificate); }, } }, @@ -377,7 +379,7 @@ pub fn Client(comptime StreamType: type) type { sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); }, else => { - return error.TlsBadSignatureScheme; + return stream.writeError(.bad_certificate); }, } } @@ -388,7 +390,6 @@ pub fn Client(comptime StreamType: type) type { const cipher = stream.cipher.handshake; switch (cipher) { - .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), inline else => |p| { const P = @TypeOf(p); const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); @@ -405,8 +406,6 @@ pub fn Client(comptime StreamType: type) type { const handshake_hash = stream.transcript_hash.?.peek(); - try stream.changeCipherSpec(); - const verify_data = switch (stream.cipher.handshake) { inline .aes_128_gcm_sha256, .aes_256_gcm_sha384, @@ -452,7 +451,6 @@ pub fn Client(comptime StreamType: type) type { }, .key_update => { switch (stream.cipher.application) { - .empty_renegotiation_info_scsv => {}, inline else => |*p| { const P = @TypeOf(p.*); const server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); @@ -465,7 +463,6 @@ pub fn Client(comptime StreamType: type) type { const update = try stream.read(tls.KeyUpdate); if (update == .update_requested) { switch (stream.cipher.application) { - .empty_renegotiation_info_scsv => {}, inline else => |*p| { const P = @TypeOf(p.*); const client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); @@ -636,14 +633,15 @@ fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { }; } -/// A single `send` or `recv`. Allows for testing `advance`. -pub const State = union(enum) { +/// A command to send or receive a single message. Allows testing `advance` on a single thread. +pub const Command = union(enum) { send_hello: KeyPairs, recv_hello: KeyPairs, recv_encrypted_extensions: void, recv_certificate_or_finished: void, recv_certificate_verify: Certificate.Parsed, recv_finished: void, + send_change_cipher_spec: void, send_finished: void, - sent_finished: void, + none: void, }; diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index dba2a7eb8408..6690b88763e6 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -30,26 +30,62 @@ pub fn Server(comptime StreamType: type) type { var res = Self{ .stream = stream_, .options = options }; const client_hello = try res.recv_hello(&stream_); _ = client_hello; - // { - // var random_buffer: [32]u8 = undefined; - // crypto.random.bytes(&random_buffer); - // const key_pair = crypto.dh.X25519.KeyPair.create(random_buffer) catch |err| switch (err) { - // error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. - // }; - // try res.send_hello(key_pair); - // } + + var command = Command{ .recv_hello = {} }; + while (command != .none) command = try res.next(command); return res; } - const ClientHello = struct { - random: [32]u8, - session_id_len: u8, - session_id: [32]u8, - cipher_suite: tls.CipherSuite, - key_share: tls.KeyShare, - sig_scheme: ?tls.SignatureScheme, - }; + /// Executes handshake command and returns next one. + pub fn next(self: *Self, command: Command) !Command { + var stream = &self.stream; + + switch (command) { + .recv_hello => { + const client_hello = try self.recv_hello(); + + return .{ .send_hello = client_hello }; + }, + .send_hello => |client_hello| { + try self.send_hello(client_hello); + + // > if the client sends a non-empty session ID, + // > the server MUST send the change_cipher_spec + if (client_hello.session_id_len > 0) return .{ .send_change_cipher_spec = {} }; + + return .{ .send_encrypted_extensions = {} }; + }, + .send_change_cipher_spec => { + try stream.changeCipherSpec(); + + return .{ .send_encrypted_extensions = {} }; + }, + .send_encrypted_extensions => { + try self.send_encrypted_extensions(); + + return .{ .send_certificate = {} }; + }, + .send_certificate => { + try self.send_certificate(); + + return .{ .send_certificate_verify = {} }; + }, + .send_certificate_verify => { + try self.send_certificate_verify(); + return .{ .send_finished = {} }; + }, + .send_finished => { + try self.send_finished(); + return .{ .recv_finished = {} }; + }, + .recv_finished => { + try self.recv_finished(); + return .{ .none = {} }; + }, + .none => return .{ .none = {} }, + } + } pub fn recv_hello(self: *Self) !ClientHello { var stream = &self.stream; @@ -132,10 +168,17 @@ pub fn Server(comptime StreamType: type) type { } } - if (tls_version == null) return stream.writeError(.protocol_version); + if (tls_version != .tls_1_3) return stream.writeError(.protocol_version); if (key_share == null) return stream.writeError(.missing_extension); if (ec_point_format == null) return stream.writeError(.missing_extension); + var server_random: [32]u8 = undefined; + crypto.random.bytes(&server_random); + + const key_pair = .{ + .x25519 = crypto.dh.X25519.KeyPair.create(server_random) catch unreachable, + }; + return .{ .random = client_random, .session_id_len = session_id_len, @@ -143,33 +186,32 @@ pub fn Server(comptime StreamType: type) type { .cipher_suite = cipher_suite, .key_share = key_share.?, .sig_scheme = sig_scheme, + .server_random = server_random, + .server_pair = key_pair, }; } - /// `key_pair`'s active member MUST match `client_hello.key_share` - pub fn send_hello(self: *Self, client_hello: ClientHello, key_pair: KeyPair) !void { + pub fn send_hello(self: *Self, client_hello: ClientHello) !void { var stream = &self.stream; + const key_pair = client_hello.server_pair; const hello = tls.ServerHello{ - .random = key_pair.random, + .random = client_hello.server_random, .session_id = &client_hello.session_id, .cipher_suite = client_hello.cipher_suite, .extensions = &.{ .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, - .{ .key_share = &[_]tls.KeyShare{key_pair.pair.toKeyShare()} }, + .{ .key_share = &[_]tls.KeyShare{key_pair.toKeyShare()} }, }, }; stream.version = .tls_1_2; _ = try stream.write(tls.Handshake, .{ .server_hello = hello }); try stream.flush(); - // > if the client sends a non-empty session ID, the server MUST send the change_cipher_spec - if (hello.session_id.len > 0) try stream.changeCipherSpec(); - const shared_key = switch (client_hello.key_share) { .x25519_kyber768d00 => |ks| brk: { const T = tls.NamedGroupT(.x25519_kyber768d00); - const pair: tls.X25519Kyber768Draft.KeyPair = key_pair.pair.x25519_kyber768d00; + const pair: tls.X25519Kyber768Draft.KeyPair = key_pair.x25519_kyber768d00; const shared_point = T.X25519.scalarmult( ks.x25519, pair.x25519.secret_key, @@ -182,14 +224,14 @@ pub fn Server(comptime StreamType: type) type { }, .x25519 => |ks| brk: { const shared_point = tls.NamedGroupT(.x25519).scalarmult( - key_pair.pair.x25519.secret_key, + key_pair.x25519.secret_key, ks, ) catch return stream.writeError(.decrypt_error); break :brk &shared_point; }, .secp256r1 => |ks| brk: { const mul = ks.p.mulPublic( - key_pair.pair.secp256r1.secret_key.bytes, + key_pair.secp256r1.secret_key.bytes, .big, ) catch return stream.writeError(.decrypt_error); break :brk &mul.affineCoordinates().x.toBytes(.big); @@ -201,12 +243,33 @@ pub fn Server(comptime StreamType: type) type { const handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash,) catch return stream.writeError(.illegal_parameter); stream.cipher = .{ .handshake = handshake_cipher }; + } - stream.content_type = .handshake; + pub fn send_encrypted_extensions(self: *Self) !void { + var stream = &self.stream; _ = try stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); try stream.flush(); + } + + pub fn send_certificate(self: *Self) !void { + var stream = &self.stream; + _ = try self.stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + try stream.flush(); + } + + pub fn send_certificate_verify(self: *Self) !void { + var stream = &self.stream; - _ = try stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + const digest = stream.transcript_hash.?.peek(); + const sig_content = tls.sigContent(digest); + + const signature = sig_content; + // const signature = rsa.encrypt(256, sig_content, parsed.pubKey()) catch return stream.writeError(.internal_error); + + _ = try self.stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ + .algorithm = .rsa_pss_rsae_sha256, + .signature = signature, + }}); try stream.flush(); } @@ -239,7 +302,6 @@ pub fn Server(comptime StreamType: type) type { ); const expected = switch (stream.cipher.handshake) { - .empty_renegotiation_info_scsv => return stream.writeError(.decode_error), inline else => |p| brk: { const P = @TypeOf(p); const digest = stream.transcript_hash.?.peek(); @@ -265,9 +327,30 @@ pub const Options = struct { /// List of potential cipher suites in descending order of preference. cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, certificate: tls.Certificate, + certificate_key: []const u8, +}; + +/// A command to send or receive a single message. Allows testing `advance` on a single thread. +pub const Command = union(enum) { + recv_hello: void, + send_hello: ClientHello, + send_change_cipher_spec: void, + send_encrypted_extensions: void, + send_certificate: void, + send_certificate_verify: void, + send_finished: void, + recv_finished: void, + none: void, }; -pub const KeyPair = struct { +pub const ClientHello = struct { random: [32]u8, - pair: tls.KeyPair, + session_id_len: u8, + session_id: [32]u8, + cipher_suite: tls.CipherSuite, + key_share: tls.KeyShare, + sig_scheme: ?tls.SignatureScheme, + server_random: [32]u8, + /// active member MUST match `key_share` + server_pair: tls.KeyPair, }; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 451832b16ed7..00ee206c7757 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -216,7 +216,7 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, - pub const buffer_size = @sizeOf(std.crypto.tls.Plaintext) + std.crypto.tls.Plaintext.max_length; + pub const buffer_size = std.crypto.tls.Plaintext.size + std.crypto.tls.Plaintext.max_length; const BufferSize = std.math.IntFittingRange(0, buffer_size); pub const Protocol = enum { plain, tls }; From e3185ddb86b27089658ca9e6e3fe30c778cdd90f Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Fri, 15 Mar 2024 21:49:07 -0400 Subject: [PATCH 09/17] add rsa sign functions, start on std.io.GenericStream --- lib/std/crypto/Certificate.zig | 203 ++++++- lib/std/crypto/testdata/key.der | Bin 294 -> 1191 bytes lib/std/crypto/tls.zig | 722 +++---------------------- lib/std/crypto/tls/Client.zig | 921 ++++++++++++++++---------------- lib/std/crypto/tls/Server.zig | 639 ++++++++++++---------- lib/std/crypto/tls/Stream.zig | 547 +++++++++++++++++++ lib/std/http/Client.zig | 72 +-- lib/std/http/Server.zig | 27 +- lib/std/io.zig | 48 ++ lib/std/io/Stream.zig | 36 ++ lib/std/net.zig | 47 +- 11 files changed, 1740 insertions(+), 1522 deletions(-) create mode 100644 lib/std/crypto/tls/Stream.zig create mode 100644 lib/std/io/Stream.zig diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 853aa6bfb382..1b483f28acaa 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -969,7 +969,28 @@ pub const rsa = struct { try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } - fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void { + pub fn sign( + comptime modulus_len: usize, + msg: []const u8, + comptime Hash: type, + private_key: PrivateKey, + salt: [Hash.digest_length]u8, + ) ![modulus_len]u8 { + const mod_bits = modulus_len * 8; + + var out: [modulus_len]u8 = undefined; + const em = try EMSA_PSS_ENCODE(&out, msg, mod_bits - 1, Hash, salt); + + return try private_key.decrypt(modulus_len, em[0..modulus_len].*); + } + + fn EMSA_PSS_VERIFY( + msg: []const u8, + em: []const u8, + emBit: usize, + sLen: usize, + comptime Hash: type, + ) !void { // 1. If the length of M is greater than the input limitation for // the hash function (2^61 - 1 octets for SHA-1), output // "inconsistent" and stop. @@ -1103,6 +1124,99 @@ pub const rsa = struct { return out[0..len]; } + + fn EMSA_PSS_ENCODE( + out: []u8, + msg: []const u8, + emBit: usize, + comptime Hash: type, + salt: [Hash.digest_length]u8, + ) ![]u8 { + // 1. If the length of M is greater than the input limitation for + // the hash function (2^61 - 1 octets for SHA-1), output + // "inconsistent" and stop. + // All the cryptographic hash functions in the standard library have a limit of >= 2^61 - 1. + // Even then, this check is only there for paranoia. In the context of TLS certifcates, emBit cannot exceed 4096. + if (emBit >= 1 << 61) return error.InvalidSignature; + // emLen = \c2yyeil(emBits/8) + const emLen = ((emBit - 1) / 8) + 1; + + // 2. Let mHash = Hash(M), an octet string of length hLen. + var mHash: [Hash.digest_length]u8 = undefined; + Hash.hash(msg, &mHash, .{}); + + // 3. If emLen < hLen + sLen + 2, output "encoding error" and stop. + if (emLen < Hash.digest_length + salt.len + 2) { + return error.EncodingError; + } + + // 4. Generate a random octet string salt of length sLen; if sLen = + // 0, then salt is the empty string. + // 5. Let + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + // M' is an octet string of length 8 + hLen + sLen with eight + // initial zero octets. + var m_p: [8 + Hash.digest_length + salt.len]u8 = undefined; + + @memcpy(m_p[0..8], &([_]u8{0} ** 8)); + @memcpy(m_p[8 .. 8 + mHash.len], &mHash); + @memcpy(m_p[(8 + Hash.digest_length)..], &salt); + + // 6. Let H = Hash(M'), an octet string of length hLen. + var hash: [Hash.digest_length]u8 = undefined; + Hash.hash(&m_p, &hash, .{}); + + // 7. Generate an octet string PS consisting of emLen - sLen - hLen + // - 2 zero octets. The length of PS may be 0. + const ps_len = emLen - salt.len - Hash.digest_length - 2; + + // 8. Let DB = PS || 0x01 || salt; DB is an octet string of length + // emLen - hLen - 1. + const mgf_len = emLen - Hash.digest_length - 1; + var db_buf: [512]u8 = undefined; + var db = db_buf[0 .. emLen - Hash.digest_length - 1]; + var i: usize = 0; + while (i < ps_len) : (i += 1) { + db[i] = 0x00; + } + db[i] = 0x01; + i += 1; + @memcpy(db[i..], &salt); + + // 9. Let dbMask = MGF(H, emLen - hLen - 1). + var mgf_out_buf: [512]u8 = undefined; + const mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length]; + const dbMask = try MGF1(Hash, mgf_out, &hash, mgf_len); + + // 10. Let maskedDB = DB \xor dbMask. + i = 0; + while (i < db.len) : (i += 1) { + db[i] = db[i] ^ dbMask[i]; + } + + // 11. Set the leftmost 8emLen - emBits bits of the leftmost octet + // in maskedDB to zero. + const zero_bits = emLen * 8 - emBit; + var mask: u8 = 0; + i = 0; + while (i < 8 - zero_bits) : (i += 1) { + mask = mask << 1; + mask += 1; + } + db[0] = db[0] & mask; + + // 12. Let EM = maskedDB || H || 0xbc. + i = 0; + @memcpy(out[0..db.len], db); + i += db.len; + @memcpy(out[i .. i + hash.len], &hash); + i += hash.len; + out[i] = 0xbc; + i += 1; + + // 13. Output EM. + return out[0..i]; + } }; pub const PublicKey = struct { @@ -1111,11 +1225,11 @@ pub const rsa = struct { /// public exponent e: Fe, - pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) !PublicKey { + pub fn fromBytes(mod: []const u8, exp: []const u8) !PublicKey { // Reject modulus below 512 bits. // 512-bit RSA was factored in 1999, so this limit barely means anything, // but establish some limit now to ratchet in what we can. - const _n = Modulus.fromBytes(modulus_bytes, .big) catch return error.CertificatePublicKeyInvalid; + const _n = Modulus.fromBytes(mod, .big) catch return error.CertificatePublicKeyInvalid; if (_n.bits() < 512) return error.CertificatePublicKeyInvalid; // Exponent must be odd and greater than 2. @@ -1124,45 +1238,88 @@ pub const rsa = struct { // unlikely that exponents larger than 32 bits are being used for anything // Windows commonly does. // [1] https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-rsapubkey - if (pub_bytes.len > 4) return error.CertificatePublicKeyInvalid; - const _e = Fe.fromBytes(_n, pub_bytes, .big) catch return error.CertificatePublicKeyInvalid; + if (exp.len > 4) return error.CertificatePublicKeyInvalid; + const _e = Fe.fromBytes(_n, exp, .big) catch return error.CertificatePublicKeyInvalid; if (!_e.isOdd()) return error.CertificatePublicKeyInvalid; const e_v = _e.toPrimitive(u32) catch return error.CertificatePublicKeyInvalid; if (e_v < 2) return error.CertificatePublicKeyInvalid; - return .{ - .n = _n, - .e = _e, - }; + return .{ .n = _n, .e = _e }; } // RFC8017 Appendix A.1.1 - pub fn fromDer(pub_key: []const u8) !PublicKey { - const pub_key_seq = try der.Element.parse(pub_key, 0); - if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; - const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); + pub fn fromDer(bytes: []const u8) !PublicKey { + const seq = try der.Element.parse(bytes, 0); + if (seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; + const modulus_elem = try der.Element.parse(bytes, seq.slice.start); if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; - const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end); + const exponent_elem = try der.Element.parse(bytes, modulus_elem.slice.end); if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType; // Skip over meaningless zeroes in the modulus. - const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end]; - const modulus_offset = for (modulus_raw, 0..) |byte, i| { - if (byte != 0) break i; - } else modulus_raw.len; - const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end]; + const modulus_raw = bytes[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = std.mem.indexOfNone(u8, modulus_raw, &[_]u8{0}) orelse modulus_raw.len; const modulus = modulus_raw[modulus_offset..]; + const exponent = bytes[exponent_elem.slice.start..exponent_elem.slice.end]; - return try fromBytes(exponent, modulus); + return try fromBytes(modulus, exponent); } - fn encrypt(public_key: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { - const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; - const e = public_key.n.powPublic(m, public_key.e) catch unreachable; + fn encrypt(self: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + const m = Fe.fromBytes(self.n, &msg, .big) catch return error.MessageTooLong; + const e = self.n.powPublic(m, self.e) catch unreachable; var res: [modulus_len]u8 = undefined; e.toBytes(&res, .big) catch unreachable; return res; } }; + + pub const PrivateKey = struct { + public: PublicKey, + /// private exponent + d: Fe, + + pub fn fromBytes(mod: []const u8, public: []const u8, private: []const u8) !PrivateKey { + const _public = try PublicKey.fromBytes(mod, public); + + const _d = Fe.fromBytes(_public.n, private, .big) catch return error.CertificatePrivateKeyInvalid; + if (!_d.isOdd()) return error.CertificatePrivateKeyInvalid; + + return .{ .public = _public, .d = _d }; + } + + // RFC8017 Appendix A.1.2 + pub fn fromDer(bytes: []const u8) !PrivateKey { + const seq = try der.Element.parse(bytes, 0); + if (seq.identifier.tag != .sequence) return error.PrivateKeyWrongDataType; + + // We're just interested in the first 3 fields which don't vary by version + const version_elem = try der.Element.parse(bytes, seq.slice.start); + if (version_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + + const modulus_elem = try der.Element.parse(bytes, version_elem.slice.end); + if (modulus_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + const pub_exponent_elem = try der.Element.parse(bytes, modulus_elem.slice.end); + if (pub_exponent_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + const priv_exponent_elem = try der.Element.parse(bytes, pub_exponent_elem.slice.end); + if (priv_exponent_elem.identifier.tag != .integer) return error.PrivateKeyFieldWrongDataType; + // Skip over meaningless zeroes in the modulus. + const modulus_raw = bytes[modulus_elem.slice.start..modulus_elem.slice.end]; + const modulus_offset = std.mem.indexOfNone(u8, modulus_raw, &[_]u8{0}) orelse modulus_raw.len; + const modulus = modulus_raw[modulus_offset..]; + const pub_exponent = bytes[pub_exponent_elem.slice.start..pub_exponent_elem.slice.end]; + const priv_exponent = bytes[priv_exponent_elem.slice.start..priv_exponent_elem.slice.end]; + + return try fromBytes(modulus, pub_exponent, priv_exponent); + } + + fn decrypt(self: PrivateKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + const m = Fe.fromBytes(self.public.n, &msg, .big) catch return error.MessageTooLong; + const e = self.public.n.pow(m, self.d) catch unreachable; + var res: [modulus_len]u8 = undefined; + try e.toBytes(&res, .big); + return res; + } + }; }; const use_vectors = @import("builtin").zig_backend != .stage2_x86_64; diff --git a/lib/std/crypto/testdata/key.der b/lib/std/crypto/testdata/key.der index ce271c001235f3f95d927ba04b65e9366442ba17..9e4f1334d16264ca8accea1d9f7212da6a14554a 100644 GIT binary patch delta 942 zcmV;f15x~@0;dTFFoFc50s#QA90~z{0)hbmVS71J9cG^gdO7qra{M0j{KnM;d)X4| zDh$$5puK)<(|%N=kJIS5RL4bJb#zDd;>Vn{)X3fy;mX=(OjWQ=^8)r|>p{8`>I2BL z7G&f4TJ9uTTbuWvdjpoOPq1k)g+g{=Rb*1Z;Uaedc>>;gWfTSv__%efoKYfwvM1UP zsS)5?p=Q_v@2Cvhj6wIqWvV1jLVN~&No9z1#6`+uz+%`)9puvC8dRmKmh`QfdJu^8>tD?qf!FCKQ)Q*tIy&1ETk%X z-G^=e*5^?FQfW4j{`A>xwjBNfzOqrY0@H+&2KH@}!ARx(U@l~UYlR>iBKJ2H${N^s z-A13RJC!xL;Ogz)(r4rGM2{S7w$Y?VNNR+5>ZRF$oWcryh&arHHJi6s7g-`0Dx>>O0zV}G!2mkbOQ-yu{=4O|BV%p8Sr z{rQ_^#~t`>f?9e9bV^_}9}jMUVFxw}5)x&pGh02vD*60E<8!xvMzC9RyF2Wc*XeY$ zu?>g-0o>ji*#d!pK+K{bty(ecFmngOsE!*$QXIc*#4#;Dc-CvyKlBFC z9mnm+3!y6b<5 crypto.sign.ecdsa.EcdsaP256Sha256, .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, @@ -518,7 +513,7 @@ pub const SignatureScheme = enum(u16) { }; } - fn Hash(comptime self: @This()) type { + pub fn Hash(comptime self: @This()) type { return switch (self) { .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, @@ -527,7 +522,7 @@ pub const SignatureScheme = enum(u16) { }; } - fn Eddsa(comptime self: @This()) type { + pub fn Eddsa(comptime self: @This()) type { return switch (self) { .ed25519 => crypto.sign.Ed25519, else => @compileError("bad scheme"), @@ -653,7 +648,7 @@ pub const KeyShare = union(NamedGroup) { const group = try stream.read(NamedGroup); const len = try stream.read(u16); switch (group) { - .x25519_kyber768d00 => { + .x25519_kyber768d00 => { const T = X25519Kyber768Draft.Kyber768.PublicKey; var res = Self{ .x25519_kyber768d00 = undefined }; @@ -1046,524 +1041,6 @@ pub const EcPointFormat = enum(u8) { /// RFC 5246 S7.1 pub const ChangeCipherSpec = enum(u8) { change_cipher_spec = 1, _ }; -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of - /// a stream is not an error condition. - pub fn readAll(this: @This(), buffer: []u8) ReadError!usize { - _ = .{ this, buffer }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes written, which may be less than the buffer space provided. - pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!void { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } -}; - -/// Abstraction over TLS record layer (RFC 8446 S5). StreamType MUST satisfy `StreamInterface`. -/// Cannot read and write at the same time. -/// -/// Handles: -/// * Fragmentation -/// * Encryption and decryption of handshake and application data messages -/// * Reading and writing prefix length arrays -/// * Alerts -pub fn Stream(comptime fragment_size: usize, comptime StreamType: type) type { - // TODO: Support RFC 6066 MaxFragmentLength and give fragment_size option to Client+Server. - if (fragment_size > std.math.maxInt(u16)) @compileError("choose a smaller fragment_size"); - - return struct { - stream: *StreamType, - /// Used for both reading and writing. Cannot be doing both at the same time. - /// Stores plaintext or ciphertext, but not Plaintext headers. - buffer: [fragment_size]u8 = undefined, - /// Unread or unwritten view of `buffer`. May contain multiple handshakes. - view: []const u8 = "", - - /// When sending this is the record type that will be flushed. - /// When receiving this is the next fragment's expected record type. - content_type: ContentType = .handshake, - /// When sending this is the flushed version. - version: Version = .tls_1_0, - /// When receiving a handshake message will be expected with this type. - handshake_type: ?HandshakeType = .client_hello, - - /// Used to decrypt .application_data messages. - /// Used to encrypt messages that aren't alert or change_cipher_spec. - cipher: Cipher = .none, - - /// True when we send or receive a close_notify alert. - closed: bool = false, - - /// True if we're being used as a client. This changes: - /// * Certain shared struct formats (like Extension) - /// * Which ciphers are used for encoding/decoding handshake and application messages. - is_client: bool, - - /// When > 0 won't actually do anything with writes. Used to discover prefix lengths. - nocommit: usize = 0, - - /// Client and server implementations can set this. While set `readPlaintext` and `flush` - /// handshake messages will update the hash. - transcript_hash: ?*MultiHash, - - const Self = @This(); - - const Cipher = union(enum) { - none: void, - application: ApplicationCipher, - handshake: HandshakeCipher, - }; - - pub const ReadError = StreamType.ReadError || Error || error{EndOfStream}; - pub const WriteError = StreamType.WriteError || error{TlsEncodeError}; - - fn ciphertextOverhead(self: Self) usize { - return switch (self.cipher) { - inline .application, .handshake => |c| switch (c) { - inline else => |t| @TypeOf(t).AEAD.tag_length + @sizeOf(ContentType), - }, - else => 0, - }; - } - - fn maxFragmentSize(self: Self) usize { - return fragment_size - self.ciphertextOverhead(); - } - - const EncryptionMethod = enum { none, handshake, application }; - fn encryptionMethod(self: Self, content_type: ContentType) EncryptionMethod { - switch (content_type) { - .alert, .change_cipher_spec => {}, - else => { - if (self.cipher == .application) return .application; - if (self.cipher == .handshake) return .handshake; - }, - } - return .none; - } - - pub fn flush(self: *Self) WriteError!void { - if (self.view.len == 0) return; - if (self.transcript_hash) |t| { - if (self.content_type == .handshake) t.update(self.view); - } - - var plaintext = Plaintext{ - .type = self.content_type, - .version = self.version, - .len = @intCast(self.view.len), - }; - - var header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); - var aead: []const u8 = ""; - switch (self.cipher) { - .none => {}, - inline .application, .handshake => |*cipher| { - plaintext.type = .application_data; - plaintext.len += @intCast(self.ciphertextOverhead()); - header = Encoder.encode(Plaintext, plaintext); - switch (cipher.*) { - inline else => |*c| { - std.debug.assert(self.view.ptr == &self.buffer); - self.buffer[self.view.len] = @intFromEnum(self.content_type); - self.view = self.buffer[0 .. self.view.len + 1]; - aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); - }, - } - } - } - - var iovecs = [_]std.os.iovec_const{ - .{ .iov_base = &header, .iov_len = header.len }, - .{ .iov_base = self.view.ptr, .iov_len = self.view.len }, - .{ .iov_base = aead.ptr, .iov_len = aead.len }, - }; - try self.stream.writevAll(&iovecs); - self.view = self.buffer[0..0]; - } - - /// Flush a change cipher spec message to the underlying stream. - pub fn changeCipherSpec(self: *Self) WriteError!void { - self.version = .tls_1_2; - - const plaintext = Plaintext{ - .type = .change_cipher_spec, - .version = self.version, - .len = 1, - }; - const msg = [_]u8{1}; - const header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); - var iovecs = [_]std.os.iovec_const{ - .{ .iov_base = &header, .iov_len = header.len }, - .{ .iov_base = &msg, .iov_len = msg.len }, - }; - try self.stream.writevAll(&iovecs); - } - - /// Write an alert to stream and call `close_notify` after. Returns Zig error. - pub fn writeError(self: *Self, err: Alert.Description) Error { - const alert = Alert{ .level = .fatal, .description = err }; - - self.view = self.buffer[0..0]; - self.content_type = .alert; - _ = self.write(Alert, alert) catch {}; - self.flush() catch {}; - - self.close(); - @panic("writeError"); - // return err.toError(); - } - - pub fn close(self: *Self) void { - const alert = Alert{ .level = .fatal, .description = .close_notify }; - _ = self.write(Alert, alert) catch {}; - self.content_type = .alert; - self.flush() catch {}; - self.closed = true; - } - - /// Write bytes to `stream`, potentially flushing once `self.buffer` is full. - pub fn writeBytes(self: *Self, bytes: []const u8) WriteError!usize { - if (self.nocommit > 0) return bytes.len; - - const available = self.buffer.len - self.view.len; - const to_consume = bytes[0..@min(available, bytes.len)]; - - @memcpy(self.buffer[self.view.len..][0..bytes.len], to_consume); - self.view = self.buffer[0 .. self.view.len + to_consume.len]; - - if (self.view.len == self.buffer.len) try self.flush(); - - return to_consume.len; - } - - pub fn writeAll(self: *Self, bytes: []const u8) WriteError!usize { - var index: usize = 0; - while (index != bytes.len) { - index += try self.writeBytes(bytes[index..]); - } - return index; - } - - pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T) WriteError!usize { - var res: usize = 0; - for (values) |v| res += self.length(T, v); - - if (PrefixT != void) { - if (res > std.math.maxInt(PrefixT)) { - self.close(); - return error.TlsEncodeError; // Prefix length overflow - } - res += try self.write(PrefixT, @intCast(res)); - } - - for (values) |v| _ = try self.write(T, v); - - return res; - } - - pub fn write(self: *Self, comptime T: type, value: T) WriteError!usize { - switch (@typeInfo(T)) { - .Int, .Enum => { - const encoded = Encoder.encode(T, value); - return try self.writeAll(&encoded); - }, - .Struct, .Union => { - return try T.write(value, self); - }, - .Void => return 0, - else => @compileError("cannot write " ++ @typeName(T)), - } - } - - pub fn length(self: *Self, comptime T: type, value: T) usize { - if (T == void) return 0; - self.nocommit += 1; - defer self.nocommit -= 1; - return self.write(T, value) catch unreachable; - } - - pub fn arrayLength( - self: *Self, - comptime PrefixT: type, - comptime T: type, - values: []const T, - ) usize { - var res: usize = if (PrefixT == void) 0 else @divExact(@typeInfo(PrefixT).Int.bits, 8); - for (values) |v| res += self.length(T, v); - return res; - } - - /// Reads bytes from `view`, potentially reading more fragments from `stream`. - /// - /// A return value of 0 indicates EOF. - pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { - // > Any data received after a closure alert has been received MUST be ignored. - if (self.eof()) return 0; - - if (self.view.len == 0) try self.expectInnerPlaintext(self.content_type, self.handshake_type); - - var bytes_read: usize = 0; - - for (buffers) |b| { - var bytes_read_buffer: usize = 0; - while (bytes_read_buffer != b.iov_len) { - const to_read = @min(b.iov_len, self.view.len); - if (to_read == 0) return bytes_read; - - @memcpy(b.iov_base[0..to_read], self.view[0..to_read]); - - self.view = self.view[to_read..]; - bytes_read_buffer += to_read; - bytes_read += bytes_read_buffer; - } - } - - return bytes_read; - } - - /// Reads bytes from `view`, potentially reading more fragments from `stream`. - /// A return value of 0 indicates EOF. - pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { - const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; - return try self.readv(&buffers); - } - - /// Reads plaintext from `stream` into `buffer` and updates `view`. - /// Skips non-fatal alert and change_cipher_spec messages. - /// Will decrypt according to `encryptionMethod` if receiving application_data message. - pub fn readPlaintext(self: *Self) ReadError!Plaintext { - std.debug.assert(self.view.len == 0); // last read should have completed - var plaintext_bytes: [Plaintext.size]u8 = undefined; - var n_read: usize = 0; - - while (true) { - n_read = try self.stream.readAll(&plaintext_bytes); - if (n_read != plaintext_bytes.len) return self.writeError(.decode_error); - - var res = Plaintext.init(plaintext_bytes); - if (res.len > Plaintext.max_length) return self.writeError(.record_overflow); - - self.view = self.buffer[0..res.len]; - n_read = try self.stream.readAll(@constCast(self.view)); - if (n_read != res.len) return self.writeError(.decode_error); - - const encryption_method = self.encryptionMethod(res.type); - if (encryption_method != .none) { - if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); - - switch (self.cipher) { - inline .handshake, .application => |*cipher| { - switch (cipher.*) { - inline else => |*c| { - const C = @TypeOf(c.*); - const tag_len = C.AEAD.tag_length; - - const ciphertext = self.view[0 .. self.view.len - tag_len]; - const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; - const out: []u8 = @constCast(self.view[0..ciphertext.len]); - c.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch - return self.writeError(.bad_record_mac); - const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); - if (padding_start) |s| { - res.type = @enumFromInt(self.view[s]); - self.view = self.view[0..s]; - } else { - return self.writeError(.decode_error); - } - }, - } - }, - else => unreachable, - } - } - - switch (res.type) { - .alert => { - const level = try self.read(Alert.Level); - const description = try self.read(Alert.Description); - std.log.debug("TLS alert {} {}", .{ level, description }); - - if (description == .close_notify) { - self.closed = true; - return res; - } - if (level == .fatal) return self.writeError(.unexpected_message); - }, - // > An implementation may receive an unencrypted record of type - // > change_cipher_spec consisting of the single byte value 0x01 at any - // > time after the first ClientHello message has been sent or received - // > and before the peer's Finished message has been received and MUST - // > simply drop it without further processing. - .change_cipher_spec => { - if (!std.mem.eql(u8, self.view, &[_]u8{1})) { - return self.writeError(.unexpected_message); - } - }, - else => { - return res; - }, - } - } - } - - pub fn readInnerPlaintext(self: *Self) ReadError!InnerPlaintext { - var res: InnerPlaintext = .{ - .type = self.content_type, - .handshake_type = if (self.handshake_type) |h| h else undefined, - .len = undefined, - }; - if (self.view.len == 0) { - const plaintext = try self.readPlaintext(); - res.type = plaintext.type; - res.len = plaintext.len; - - self.content_type = res.type; - } - - if (res.type == .handshake) { - if (self.transcript_hash) |t| t.update(self.view[0..4]); - res.handshake_type = try self.read(HandshakeType); - res.len = try self.read(u24); - if (self.transcript_hash) |t| t.update(self.view[0..res.len]); - - self.handshake_type = res.handshake_type; - } - - return res; - } - - pub fn expectInnerPlaintext( - self: *Self, - expected_content: ContentType, - expected_handshake: ?HandshakeType, - ) ReadError!void { - const inner_plaintext = try self.readInnerPlaintext(); - if (expected_content != inner_plaintext.type) { - std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext }); - return self.writeError(.unexpected_message); - } - if (expected_handshake) |expected| { - if (expected != inner_plaintext.handshake_type) return self.writeError(.decode_error); - } - } - - pub fn read(self: *Self, comptime T: type) ReadError!T { - comptime std.debug.assert(@sizeOf(T) < fragment_size); - switch (@typeInfo(T)) { - .Int => return self.reader().readInt(T, .big) catch |err| switch (err) { - error.EndOfStream => return self.writeError(.decode_error), - else => |e| return e, - }, - .Enum => |info| { - if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - const int = try self.read(info.tag_type); - return @enumFromInt(int); - }, - else => { - return T.read(self) catch |err| switch (err) { - error.TlsUnexpectedMessage => return self.writeError(.unexpected_message), - error.TlsBadRecordMac => return self.writeError(.bad_record_mac), - error.TlsRecordOverflow => return self.writeError(.record_overflow), - error.TlsHandshakeFailure => return self.writeError(.handshake_failure), - error.TlsBadCertificate => return self.writeError(.bad_certificate), - error.TlsUnsupportedCertificate => return self.writeError(.unsupported_certificate), - error.TlsCertificateRevoked => return self.writeError(.certificate_revoked), - error.TlsCertificateExpired => return self.writeError(.certificate_expired), - error.TlsCertificateUnknown => return self.writeError(.certificate_unknown), - error.TlsIllegalParameter => return self.writeError(.illegal_parameter), - error.TlsUnknownCa => return self.writeError(.unknown_ca), - error.TlsAccessDenied => return self.writeError(.access_denied), - error.TlsDecodeError => return self.writeError(.decode_error), - error.TlsDecryptError => return self.writeError(.decrypt_error), - error.TlsProtocolVersion => return self.writeError(.protocol_version), - error.TlsInsufficientSecurity => return self.writeError(.insufficient_security), - error.TlsInternalError => return self.writeError(.internal_error), - error.TlsInappropriateFallback => return self.writeError(.inappropriate_fallback), - error.TlsMissingExtension => return self.writeError(.missing_extension), - error.TlsUnsupportedExtension => return self.writeError(.unsupported_extension), - error.TlsUnrecognizedName => return self.writeError(.unrecognized_name), - error.TlsBadCertificateStatusResponse => return self.writeError(.bad_certificate_status_response), - error.TlsUnknownPskIdentity => return self.writeError(.unknown_psk_identity), - error.TlsCertificateRequired => return self.writeError(.certificate_required), - error.TlsNoApplicationProtocol => return self.writeError(.no_application_protocol), - error.TlsUnknown => |e| { - self.close(); - return e; - }, - else => return self.writeError(.decode_error), - }; - }, - } - } - - fn Iterator(comptime T: type) type { - return struct { - stream: *Self, - end: usize, - - pub fn next(self: *@This()) ReadError!?T { - const cur_offset = self.stream.buffer.len - self.stream.view.len; - if (cur_offset > self.end) return null; - return try self.stream.read(T); - } - }; - } - - pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) ReadError!Iterator(Tag) { - const offset = self.buffer.len - self.view.len; - const len = try self.read(Len); - return Iterator(Tag){ - .stream = self, - .end = offset + len, - }; - } - - pub fn extensions(self: *Self) ReadError!Iterator(Extension.Header) { - return self.iterator(u16, Extension.Header); - } - - pub fn eof(self: Self) bool { - return self.closed and self.view.len == 0; - } - - pub const Reader = std.io.Reader(*Self, ReadError, readBytes); - pub const Writer = std.io.Writer(*Self, WriteError, writeBytes); - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - - pub fn writer(self: *Self) Writer { - return .{ .context = self }; - } - }; -} - /// One of these potential hashes will be selected after receiving the other party's hello. /// /// We init them before sending any messages to avoid having to store our first message until the @@ -1616,61 +1093,6 @@ pub const MultiHash = struct { } }; -const Encoder = struct { - fn RetType(comptime T: type) type { - switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => return [1]u8, - 16 => return [2]u8, - 24 => return [3]u8, - else => @compileError("unsupported int type: " ++ @typeName(T)), - }, - .Enum => |info| { - if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return RetType(info.tag_type); - }, - .Struct => |info| { - var len: usize = 0; - inline for (info.fields) |f| len += @typeInfo(RetType(f.type)).Array.len; - return [len]u8; - }, - else => @compileError("don't know how to encode " ++ @tagName(T)), - } - } - fn encode(comptime T: type, value: T) RetType(T) { - return switch (@typeInfo(T)) { - .Int => |info| switch (info.bits) { - 8 => .{value}, - 16 => .{ - @as(u8, @truncate(value >> 8)), - @as(u8, @truncate(value)), - }, - 24 => .{ - @as(u8, @truncate(value >> 16)), - @as(u8, @truncate(value >> 8)), - @as(u8, @truncate(value)), - }, - else => @compileError("unsupported int type: " ++ @typeName(T)), - }, - .Enum => |info| encode(info.tag_type, @intFromEnum(value)), - .Struct => |info| brk: { - const Ret = RetType(T); - - var offset: usize = 0; - var res: Ret = undefined; - inline for (info.fields) |f| { - const encoded = encode(f.type, @field(value, f.name)); - @memcpy(res[offset..][0..encoded.len], &encoded); - offset += encoded.len; - } - - break :brk res; - }, - else => @compileError("cannot encode type " ++ @typeName(T)), - }; - } -}; - fn HandshakeCipherT(comptime suite: CipherSuite) type { return struct { pub const AEAD = suite.Aead(); @@ -1691,7 +1113,7 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { const Self = @This(); - fn encrypt( + pub fn encrypt( self: *Self, data: []const u8, additional: []const u8, @@ -1707,7 +1129,7 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { return res; } - fn decrypt( + pub fn decrypt( self: *Self, data: []const u8, additional: []const u8, @@ -1746,7 +1168,7 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { const Self = @This(); - fn encrypt( + pub fn encrypt( self: *Self, data: []const u8, additional: []const u8, @@ -1762,7 +1184,7 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { return res; } - fn decrypt( + pub fn decrypt( self: *Self, data: []const u8, additional: []const u8, @@ -1906,31 +1328,26 @@ const TestStream = struct { self.buffer.deinit(allocator); } - pub fn readAll(self: *Self, buffer: []u8) ReadError!usize { + pub fn read(self: *Self, buffer: []u8) ReadError!usize { try self.buffer.readFirst(buffer, buffer.len); return buffer.len; } - pub fn writev(self: *Self, iovecs: []const std.os.iovec_const) WriteError!usize { - var res: usize = 0; - for (iovecs) |i| { - const slice = i.iov_base[0..i.iov_len]; - try self.buffer.writeSlice(slice); - res += i.iov_len; - } - return res; - } - - pub fn writevAll(self: *Self, iovecs: []std.os.iovec_const) WriteError!void { - _ = try self.writev(iovecs); + pub fn write(self: *Self, bytes: []const u8) WriteError!usize { + try self.buffer.writeSlice(bytes); + return bytes.len; } pub fn peek(self: *Self, out: []u8) ReadError!void { const read_index = self.buffer.read_index; - _ = try self.readAll(out); + _ = try self.read(out); self.buffer.read_index = read_index; } + pub fn close(self: *Self) void { + _ = self; + } + pub fn expect(self: *Self, expected: []const u8) !void { var tmp_buf: [Plaintext.max_length]u8 = undefined; const buf = tmp_buf[0..self.buffer.len()]; @@ -1938,15 +1355,11 @@ const TestStream = struct { try std.testing.expectEqualSlices(u8, expected, buf); } -}; -const TestHasher = struct { - fn update(self: *@This(), bytes: []const u8) void { - _ = .{ self, bytes }; - } - fn peek(self: @This()) []const u8 { - _ = .{self}; - return ""; + const GenericStream = std.io.GenericStream(*Self, ReadError, read, WriteError, write, close); + + pub fn stream(self: *Self) GenericStream { + return .{ .context = self }; } }; @@ -1955,59 +1368,46 @@ test "tls client and server handshake, data, and close_notify" { var inner_stream = try TestStream.init(allocator); defer inner_stream.deinit(allocator); + const stream = inner_stream.stream(); - const host = "example.ulfheim.net"; var client_transcript: MultiHash = .{}; - var client = Client(@TypeOf(inner_stream)){ - .stream = Stream(Plaintext.max_length, TestStream){ - .stream = &inner_stream, + var client = Client{ + .stream = Stream{ + .stream = stream.any(), .is_client = true, .transcript_hash = &client_transcript, }, - .options = .{ .host = host, .ca_bundle = null, .allocator = allocator }, + .options = .{ .host = "localhost", .ca_bundle = null, .allocator = allocator }, }; - const server_der = @embedFile("./testdata/server.der"); - const server_key = @embedFile("./testdata/server.key"); + const server_cert = @embedFile("./testdata/cert.der"); + const server_key = @embedFile("./testdata/key.der"); + const server_rsa = try crypto.Certificate.rsa.PrivateKey.fromDer(server_key); var server_transcript: MultiHash = .{}; - var server = Server(@TypeOf(inner_stream)){ - .stream = Stream(Plaintext.max_length, TestStream){ - .stream = &inner_stream, + var server = Server{ + .stream = Stream{ + .stream = stream.any(), .is_client = false, .transcript_hash = &server_transcript, }, .options = .{ - // force this to use https://tls13.xargs.org/ as unit test for "server hello" onwards .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, .certificate = .{ .entries = &[_]Certificate.Entry{ - .{ .data = server_der }, + .{ .data = server_cert }, } }, - .certificate_key = server_key, + .certificate_key = server_rsa, }, }; - const session_id = [_]u8{ - 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, - 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, - }; - const client_random = [_]u8{ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, - 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - }; - const server_random = [_]u8{ - 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, - 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, - }; - const client_x25519_seed = [_]u8{ - 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, - 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, - }; - const server_x25519_seed = [_]u8{ - 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, - 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, - }; + // Use these seeded values for reproducible handshake and application ciphertext. + const session_id: [32]u8 = ("session_id012345" ** 2).*; + const client_random: [32]u8 = ("client_random012" ** 2).*; + const server_random: [32]u8 = ("server_random012" ** 2).*; + const client_x25519_seed: [32]u8 = ("client_seed01234" ** 2).*; + const server_x25519_seed: [32]u8 = ("server_seed01234" ** 2).*; + const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; - const key_pairs = try client_mod.KeyPairs.initAdvanced( + const key_pairs = try Client.KeyPairs.initAdvanced( client_random, session_id, client_x25519_seed ++ client_x25519_seed, @@ -2015,11 +1415,11 @@ test "tls client and server handshake, data, and close_notify" { client_x25519_seed ++ [_]u8{0} ** (48 - 32), client_x25519_seed, ); - var client_command = client_mod.Command{ .send_hello = key_pairs }; + var client_command = Client.Command{ .send_hello = key_pairs }; client_command = try client.next(client_command); try std.testing.expect(client_command == .recv_hello); - var server_command = server_mod.Command{ .recv_hello = {} }; + var server_command = Server.Command{ .recv_hello = {} }; server_command = try server.next(server_command); // recv_hello try std.testing.expect(server_command == .send_hello); server_command.send_hello.server_random = server_random; @@ -2044,8 +1444,10 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - const client_iv = [_]u8{ 0xE1, 0x38, 0xB9, 0xBF, 0xD6, 0xB4, 0x2D, 0x91, 0x6D, 0x81, 0xA0, 0x2D }; - try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + const client_iv = [_]u8{ +0x77, 0x02, 0x2F, 0x09, 0xB2, 0x93, 0x5A, 0x5E, 0x3F, 0x2B, 0xB0, 0x32 + }; + try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } server_command = try server.next(server_command); // send_change_cipher_spec @@ -2054,6 +1456,7 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expect(server_command == .send_certificate); server_command = try server.next(server_command); // send_certificate try std.testing.expect(server_command == .send_certificate_verify); + server_command.send_certificate_verify.salt = server_sig_salt; server_command = try server.next(server_command); // send_certificate_verify try std.testing.expect(server_command == .send_finished); server_command = try server.next(server_command); // send_finished @@ -2066,9 +1469,14 @@ test "tls client and server handshake, data, and close_notify" { client_command = try client.next(client_command); // recv_certificate_verify try std.testing.expect(client_command == .recv_finished); client_command = try client.next(client_command); // recv_finished + try std.testing.expect(client_command == .send_change_cipher_spec); + client_command = try client.next(client_command); // send_change_cipher_spec try std.testing.expect(client_command == .send_finished); client_command = try client.next(client_command); // send_finished try std.testing.expect(client_command == .none); + + server_command = try server.next(server_command); // recv_finished + try std.testing.expect(server_command == .none); { const s = server.stream.cipher.application.aes_256_gcm_sha384; const c = client.stream.cipher.application.aes_256_gcm_sha384; @@ -2079,11 +1487,11 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - const client_iv = [_]u8{ 0xbb, 0x00, 0x79, 0x56, 0xf4, 0x74, 0xb2, 0x5d, 0xe9, 0x02, 0x43, 0x2f }; + const client_iv = [_]u8{ +0x54, 0xF3, 0x34, 0x20, 0xA8, 0x50, 0xF5, 0x3A, 0x22, 0x9A, 0xBB, 0x1B + }; try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } - server_command = try server.next(server_command); // recv_finished - try std.testing.expect(server_command == .none); try client.writer().writeAll("ping"); @@ -2098,10 +1506,6 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expect(client.stream.closed); } -test { - _ = StreamInterface; -} - pub fn debugPrint(name: []const u8, slice: anytype) void { std.debug.print("{s} ", .{name}); if (@typeInfo(@TypeOf(slice)) == .Int) { diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 14a90f805abe..413f2de4bfec 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,517 +1,514 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; -const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; -const Certificate = std.crypto.Certificate; - -/// `StreamType` must conform to `tls.StreamInterface`. -pub fn Client(comptime StreamType: type) type { - return struct { - stream: Stream, - options: Options, - - pub const Stream = tls.Stream(tls.Plaintext.max_length, StreamType); - const Self = @This(); - - /// Initiates a TLS handshake and establishes a TLSv1.3 session - pub fn init(stream: *StreamType, options: Options) !Self { - var transcript_hash: tls.MultiHash = .{}; - const stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ - .stream = stream, - .is_client = true, - .transcript_hash = &transcript_hash, - }; - var res = Self{ .stream = stream_, .options = options }; - - var command = Command{ .send_hello = KeyPairs.init() }; - while (command != .none) command = try res.next(command); - - return res; - } +const Certificate = crypto.Certificate; - /// Executes handshake command and returns next one. - pub fn next(self: *Self, command: Command) !Command { - var stream = &self.stream; - switch (command) { - .send_hello => |key_pairs| { - try self.send_hello(key_pairs); +stream: tls.Stream, +options: Options, - return .{ .recv_hello = key_pairs }; - }, - .recv_hello => |key_pairs| { - try stream.expectInnerPlaintext(.handshake, .server_hello); - try self.recv_hello(key_pairs); +const Self = @This(); - return .{ .recv_encrypted_extensions = {} }; - }, - .recv_encrypted_extensions => { - try stream.expectInnerPlaintext(.handshake, .encrypted_extensions); - try self.recv_encrypted_extensions(); +/// Initiates a TLS handshake and establishes a TLSv1.3 session +pub fn init(stream: std.io.AnyStream, options: Options) !Self { + var transcript_hash: tls.MultiHash = .{}; + const stream_ = tls.Stream{ + .stream = stream, + .is_client = true, + .transcript_hash = &transcript_hash, + }; + var res = Self{ .stream = stream_, .options = options }; - return .{ .recv_certificate_or_finished = {} }; - }, - .recv_certificate_or_finished => { - const digest = stream.transcript_hash.?.peek(); - const inner_plaintext = try stream.readInnerPlaintext(); - if (inner_plaintext.type != .handshake) return stream.writeError(.unexpected_message); - switch (inner_plaintext.handshake_type) { - .certificate => { - const parsed = try self.recv_certificate(); - - return .{ .recv_certificate_verify = parsed }; - }, - .finished => { - if (self.options.ca_bundle != null) - return self.stream.writeError(.certificate_required); - - try self.recv_finished(digest); - - return .{ .send_finished = {} }; - }, - else => return self.stream.writeError(.unexpected_message), - } - }, - .recv_certificate_verify => |parsed| { - defer self.options.allocator.free(parsed.certificate.buffer); + var command = Command{ .send_hello = KeyPairs.init() }; + while (command != .none) command = try res.next(command); - const digest = stream.transcript_hash.?.peek(); - try stream.expectInnerPlaintext(.handshake, .certificate_verify); - try self.recv_certificate_verify(digest, parsed); + return res; +} - return .{ .recv_finished = {} }; +/// Executes handshake command and returns next one. +pub fn next(self: *Self, command: Command) !Command { + var stream = &self.stream; + switch (command) { + .send_hello => |key_pairs| { + try self.send_hello(key_pairs); + + return .{ .recv_hello = key_pairs }; + }, + .recv_hello => |key_pairs| { + try stream.expectInnerPlaintext(.handshake, .server_hello); + try self.recv_hello(key_pairs); + + return .{ .recv_encrypted_extensions = {} }; + }, + .recv_encrypted_extensions => { + try stream.expectInnerPlaintext(.handshake, .encrypted_extensions); + try self.recv_encrypted_extensions(); + + return .{ .recv_certificate_or_finished = {} }; + }, + .recv_certificate_or_finished => { + const digest = stream.transcript_hash.?.peek(); + const inner_plaintext = try stream.readInnerPlaintext(); + if (inner_plaintext.type != .handshake) return stream.writeError(.unexpected_message); + switch (inner_plaintext.handshake_type) { + .certificate => { + const parsed = try self.recv_certificate(); + + return .{ .recv_certificate_verify = parsed }; }, - .recv_finished => { - const digest = stream.transcript_hash.?.peek(); - try stream.expectInnerPlaintext(.handshake, .finished); - try self.recv_finished(digest); + .finished => { + if (self.options.ca_bundle != null) + return self.stream.writeError(.certificate_required); - return .{ .send_change_cipher_spec = {} }; - }, - .send_change_cipher_spec => { - try stream.changeCipherSpec(); + try self.recv_finished(digest); return .{ .send_finished = {} }; }, - .send_finished => { - try self.send_finished(); - - return .{ .none = {} }; - }, - .none => return .{ .none = {} }, + else => return self.stream.writeError(.unexpected_message), } - } + }, + .recv_certificate_verify => |parsed| { + defer self.options.allocator.free(parsed.certificate.buffer); + + const digest = stream.transcript_hash.?.peek(); + try stream.expectInnerPlaintext(.handshake, .certificate_verify); + try self.recv_certificate_verify(digest, parsed); + + return .{ .recv_finished = {} }; + }, + .recv_finished => { + const digest = stream.transcript_hash.?.peek(); + try stream.expectInnerPlaintext(.handshake, .finished); + try self.recv_finished(digest); + + return .{ .send_change_cipher_spec = {} }; + }, + .send_change_cipher_spec => { + try stream.changeCipherSpec(); + + return .{ .send_finished = {} }; + }, + .send_finished => { + try self.send_finished(); + + return .{ .none = {} }; + }, + .none => return .{ .none = {} }, + } +} - pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { - const hello = tls.ClientHello{ - .random = key_pairs.hello_rand, - .session_id = &key_pairs.session_id, - .cipher_suites = self.options.cipher_suites, - .extensions = &.{ - .{ .server_name = &[_]tls.ServerName{.{ .host_name = self.options.host }} }, - .{ .ec_point_formats = &[_]tls.EcPointFormat{.uncompressed} }, - .{ .supported_groups = &tls.supported_groups }, - .{ .signature_algorithms = &tls.supported_signature_schemes }, - .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, - .{ .key_share = &[_]tls.KeyShare{ - .{ .x25519_kyber768d00 = .{ - .x25519 = key_pairs.x25519.public_key, - .kyber768d00 = key_pairs.kyber768d00.public_key, - } }, - .{ .secp256r1 = key_pairs.secp256r1.public_key }, - .{ .x25519 = key_pairs.x25519.public_key }, - } }, - }, - }; +pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { + const hello = tls.ClientHello{ + .random = key_pairs.hello_rand, + .session_id = &key_pairs.session_id, + .cipher_suites = self.options.cipher_suites, + .extensions = &.{ + .{ .server_name = &[_]tls.ServerName{.{ .host_name = self.options.host }} }, + .{ .ec_point_formats = &[_]tls.EcPointFormat{.uncompressed} }, + .{ .supported_groups = &tls.supported_groups }, + .{ .signature_algorithms = &tls.supported_signature_schemes }, + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{ + .{ .x25519_kyber768d00 = .{ + .x25519 = key_pairs.x25519.public_key, + .kyber768d00 = key_pairs.kyber768d00.public_key, + } }, + .{ .secp256r1 = key_pairs.secp256r1.public_key }, + .{ .x25519 = key_pairs.x25519.public_key }, + } }, + }, + }; - _ = try self.stream.write(tls.Handshake, .{ .client_hello = hello }); - try self.stream.flush(); - } + _ = try self.stream.write(tls.Handshake, .{ .client_hello = hello }); + try self.stream.flush(); +} - pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { - var stream = &self.stream; - var r = stream.reader(); - - // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. - _ = try stream.read(tls.Version); - var random: [32]u8 = undefined; - try r.readNoEof(&random); - if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { - // We already offered all our supported options and we aren't changing them. - return stream.writeError(.unexpected_message); - } +pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { + var stream = &self.stream; + var r = stream.reader(); + + // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. + _ = try stream.read(tls.Version); + var random: [32]u8 = undefined; + try r.readNoEof(&random); + if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { + // We already offered all our supported options and we aren't changing them. + return stream.writeError(.unexpected_message); + } - var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; - const session_id_len = try stream.read(u8); - if (session_id_len > tls.ClientHello.session_id_max_len) - return stream.writeError(.illegal_parameter); - const session_id: []u8 = session_id_buf[0..session_id_len]; - try r.readNoEof(session_id); - if (!mem.eql(u8, session_id, &key_pairs.session_id)) - return stream.writeError(.illegal_parameter); - - const cipher_suite = try stream.read(tls.CipherSuite); - const compression_method = try stream.read(u8); - if (compression_method != 0) return stream.writeError(.illegal_parameter); - - var supported_version: ?tls.Version = null; - var shared_key: ?[]const u8 = null; - - var iter = try stream.extensions(); - while (try iter.next()) |ext| { - switch (ext.type) { - .supported_versions => { - if (supported_version != null) return stream.writeError(.illegal_parameter); - supported_version = try stream.read(tls.Version); + var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try stream.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return stream.writeError(.illegal_parameter); + const session_id: []u8 = session_id_buf[0..session_id_len]; + try r.readNoEof(session_id); + if (!mem.eql(u8, session_id, &key_pairs.session_id)) + return stream.writeError(.illegal_parameter); + + const cipher_suite = try stream.read(tls.CipherSuite); + const compression_method = try stream.read(u8); + if (compression_method != 0) return stream.writeError(.illegal_parameter); + + var supported_version: ?tls.Version = null; + var shared_key: ?[]const u8 = null; + + var iter = try stream.extensions(); + while (try iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (supported_version != null) return stream.writeError(.illegal_parameter); + supported_version = try stream.read(tls.Version); + }, + .key_share => { + if (shared_key != null) return stream.writeError(.illegal_parameter); + const named_group = try stream.read(tls.NamedGroup); + const key_size = try stream.read(u16); + switch (named_group) { + .x25519_kyber768d00 => { + const T = tls.NamedGroupT(.x25519_kyber768d00); + const x25519_len = T.X25519.public_length; + const expected_len = x25519_len + T.Kyber768.ciphertext_length; + if (key_size != expected_len) return stream.writeError(.illegal_parameter); + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const mult = T.X25519.scalarmult( + key_pairs.x25519.secret_key, + server_ks[0..x25519_len].*, + ) catch return stream.writeError(.decrypt_error); + const decaps = key_pairs.kyber768d00.secret_key.decaps( + server_ks[x25519_len..expected_len], + ) catch return stream.writeError(.decrypt_error); + shared_key = &(mult ++ decaps); }, - .key_share => { - if (shared_key != null) return stream.writeError(.illegal_parameter); - const named_group = try stream.read(tls.NamedGroup); - const key_size = try stream.read(u16); - switch (named_group) { - .x25519_kyber768d00 => { - const T = tls.NamedGroupT(.x25519_kyber768d00); - const x25519_len = T.X25519.public_length; - const expected_len = x25519_len + T.Kyber768.ciphertext_length; - if (key_size != expected_len) return stream.writeError(.illegal_parameter); - var server_ks: [expected_len]u8 = undefined; - try r.readNoEof(&server_ks); - - const mult = T.X25519.scalarmult( - key_pairs.x25519.secret_key, - server_ks[0..x25519_len].*, - ) catch return stream.writeError(.decrypt_error); - const decaps = key_pairs.kyber768d00.secret_key.decaps( - server_ks[x25519_len..expected_len], - ) catch return stream.writeError(.decrypt_error); - shared_key = &(mult ++ decaps); - }, - .x25519 => { - const T = tls.NamedGroupT(.x25519); - const expected_len = T.public_length; - if (key_size != expected_len) return stream.writeError(.illegal_parameter); - var server_ks: [expected_len]u8 = undefined; - try r.readNoEof(&server_ks); - - const mult = crypto.dh.X25519.scalarmult( - key_pairs.x25519.secret_key, - server_ks[0..expected_len].*, - ) catch return stream.writeError(.illegal_parameter); - shared_key = &mult; - }, - inline .secp256r1, .secp384r1 => |t| { - const T = tls.NamedGroupT(t); - const expected_len = T.PublicKey.uncompressed_sec1_encoded_length; - if (key_size != expected_len) return stream.writeError(.illegal_parameter); - - var server_ks: [expected_len]u8 = undefined; - try r.readNoEof(&server_ks); - - const pk = T.PublicKey.fromSec1(&server_ks) catch - return stream.writeError(.illegal_parameter); - const key_pair = @field(key_pairs, @tagName(t)); - const mult = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch - return stream.writeError(.illegal_parameter); - shared_key = &mult.affineCoordinates().x.toBytes(.big); - }, - // Server sent us back unknown key. That's weird because we only request known ones, - // but we can keep iterating for another. - else => { - try r.skipBytes(key_size, .{}); - }, - } + .x25519 => { + const T = tls.NamedGroupT(.x25519); + const expected_len = T.public_length; + if (key_size != expected_len) return stream.writeError(.illegal_parameter); + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const mult = crypto.dh.X25519.scalarmult( + key_pairs.x25519.secret_key, + server_ks[0..expected_len].*, + ) catch return stream.writeError(.illegal_parameter); + shared_key = &mult; }, + inline .secp256r1, .secp384r1 => |t| { + const T = tls.NamedGroupT(t); + const expected_len = T.PublicKey.uncompressed_sec1_encoded_length; + if (key_size != expected_len) return stream.writeError(.illegal_parameter); + + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const pk = T.PublicKey.fromSec1(&server_ks) catch + return stream.writeError(.illegal_parameter); + const key_pair = @field(key_pairs, @tagName(t)); + const mult = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch + return stream.writeError(.illegal_parameter); + shared_key = &mult.affineCoordinates().x.toBytes(.big); + }, + // Server sent us back unknown key. That's weird because we only request known ones, + // but we can keep iterating for another. else => { - try r.skipBytes(ext.len, .{}); + try r.skipBytes(key_size, .{}); }, } - } + }, + else => { + try r.skipBytes(ext.len, .{}); + }, + } + } - if (supported_version != tls.Version.tls_1_3) return stream.writeError(.protocol_version); - if (shared_key == null) return stream.writeError(.missing_extension); + if (supported_version != tls.Version.tls_1_3) return stream.writeError(.protocol_version); + if (shared_key == null) return stream.writeError(.missing_extension); - stream.transcript_hash.?.setActive(cipher_suite); - const hello_hash = stream.transcript_hash.?.peek(); + stream.transcript_hash.?.setActive(cipher_suite); + const hello_hash = stream.transcript_hash.?.peek(); - const handshake_cipher = tls.HandshakeCipher.init(cipher_suite, shared_key.?, hello_hash,) catch return stream.writeError(.illegal_parameter); - stream.cipher = .{ .handshake = handshake_cipher }; - } + const handshake_cipher = tls.HandshakeCipher.init( + cipher_suite, + shared_key.?, + hello_hash, + ) catch return stream.writeError(.illegal_parameter); + stream.cipher = .{ .handshake = handshake_cipher }; +} - pub fn recv_encrypted_extensions(self: *Self) !void { - var stream = &self.stream; - var r = stream.reader(); +pub fn recv_encrypted_extensions(self: *Self) !void { + var stream = &self.stream; + var r = stream.reader(); - var iter = try stream.extensions(); - while (try iter.next()) |ext| { - try r.skipBytes(ext.len, .{}); - } - } + var iter = try stream.extensions(); + while (try iter.next()) |ext| { + try r.skipBytes(ext.len, .{}); + } +} - /// Verifies trust chain if `options.ca_bundle` is specified. - /// - /// Caller owns allocated Certificate.Parsed.certificate. - pub fn recv_certificate(self: *Self) !Certificate.Parsed { - var stream = &self.stream; - var r = stream.reader(); - const allocator = self.options.allocator; - const ca_bundle = self.options.ca_bundle; - const verify = ca_bundle != null; - - var context: [tls.Certificate.max_context_len]u8 = undefined; - const context_len = try stream.read(u8); - if (context_len > tls.Certificate.max_context_len) return stream.writeError(.decode_error); - try r.readNoEof(context[0..context_len]); - - var first: ?crypto.Certificate.Parsed = null; - errdefer if (first) |f| allocator.free(f.certificate.buffer); - var prev: Certificate.Parsed = undefined; - var verified = false; - const now_sec = std.time.timestamp(); - - var certs_iter = try stream.iterator(u24, u24); - while (try certs_iter.next()) |cert_len| { - const is_first = first == null; - - if (verified) { - try r.skipBytes(cert_len, .{}); - } else { - if (cert_len > tls.Certificate.Entry.max_data_len) - return stream.writeError(.decode_error); - const buf = allocator.alloc(u8, cert_len) catch - return stream.writeError(.internal_error); - defer if (!is_first) allocator.free(buf); - errdefer allocator.free(buf); - try r.readNoEof(buf); - - const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; - const cur = cert.parse() catch return stream.writeError(.bad_certificate); - if (first == null) { - if (verify) try cur.verifyHostName(self.options.host); - first = cur; - } else { - if (verify) try prev.verify(cur, now_sec); - } - - if (ca_bundle) |b| { - if (b.verify(cur, now_sec)) |_| { - verified = true; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - error.CertificateExpired => return stream.writeError(.certificate_expired), - else => return stream.writeError(.bad_certificate), - } - } +/// Verifies trust chain if `options.ca_bundle` is specified. +/// +/// Caller owns allocated Certificate.Parsed.certificate. +pub fn recv_certificate(self: *Self) !Certificate.Parsed { + var stream = &self.stream; + var r = stream.reader(); + const allocator = self.options.allocator; + const ca_bundle = self.options.ca_bundle; + const verify = ca_bundle != null; + + var context: [tls.Certificate.max_context_len]u8 = undefined; + const context_len = try stream.read(u8); + if (context_len > tls.Certificate.max_context_len) return stream.writeError(.decode_error); + try r.readNoEof(context[0..context_len]); + + var first: ?crypto.Certificate.Parsed = null; + errdefer if (first) |f| allocator.free(f.certificate.buffer); + var prev: Certificate.Parsed = undefined; + var verified = false; + const now_sec = std.time.timestamp(); + + var certs_iter = try stream.iterator(u24, u24); + while (try certs_iter.next()) |cert_len| { + const is_first = first == null; + + if (verified) { + try r.skipBytes(cert_len, .{}); + } else { + if (cert_len > tls.Certificate.Entry.max_data_len) + return stream.writeError(.decode_error); + const buf = allocator.alloc(u8, cert_len) catch + return stream.writeError(.internal_error); + defer if (!is_first) allocator.free(buf); + errdefer allocator.free(buf); + try r.readNoEof(buf); + + const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; + const cur = cert.parse() catch return stream.writeError(.bad_certificate); + if (first == null) { + if (verify) try cur.verifyHostName(self.options.host); + first = cur; + } else { + if (verify) try prev.verify(cur, now_sec); + } - prev = cur; + if (ca_bundle) |b| { + if (b.verify(cur, now_sec)) |_| { + verified = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + error.CertificateExpired => return stream.writeError(.certificate_expired), + else => return stream.writeError(.bad_certificate), } - - var ext_iter = try stream.extensions(); - while (try ext_iter.next()) |ext| try r.skipBytes(ext.len, .{}); } - if (verify and !verified) return stream.writeError(.bad_certificate); - return if (first) |f| f else stream.writeError(.bad_certificate); + prev = cur; } - pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificate.Parsed) !void { - var stream = &self.stream; - var r = stream.reader(); - const allocator = self.options.allocator; + var ext_iter = try stream.extensions(); + while (try ext_iter.next()) |ext| try r.skipBytes(ext.len, .{}); + } + if (verify and !verified) return stream.writeError(.bad_certificate); - const sig_content = tls.sigContent(digest); + return if (first) |f| f else stream.writeError(.bad_certificate); +} - const scheme = try stream.read(tls.SignatureScheme); - const len = try stream.read(u16); - if (len > tls.CertificateVerify.max_signature_length) +pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificate.Parsed) !void { + var stream = &self.stream; + var r = stream.reader(); + const allocator = self.options.allocator; + + const sig_content = tls.sigContent(digest); + + const scheme = try stream.read(tls.SignatureScheme); + const len = try stream.read(u16); + if (len > tls.CertificateVerify.max_signature_length) + return stream.writeError(.decode_error); + const sig_bytes = allocator.alloc(u8, len) catch + return stream.writeError(.internal_error); + defer allocator.free(sig_bytes); + try r.readNoEof(sig_bytes); + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (cert.pub_key_algo != .X9_62_id_ecPublicKey) + return stream.writeError(.bad_certificate); + const Ecdsa = comptime_scheme.Ecdsa(); + const sig = Ecdsa.Signature.fromDer(sig_bytes) catch return stream.writeError(.decode_error); - const sig_bytes = allocator.alloc(u8, len) catch - return stream.writeError(.internal_error); - defer allocator.free(sig_bytes); - try r.readNoEof(sig_bytes); - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (cert.pub_key_algo != .X9_62_id_ecPublicKey) - return stream.writeError(.bad_certificate); - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = Ecdsa.Signature.fromDer(sig_bytes) catch - return stream.writeError(.decode_error); - const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch - return stream.writeError(.decode_error); - sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (cert.pub_key_algo != .rsaEncryption) - return stream.writeError(.bad_certificate); - - const Hash = SchemeHash(comptime_scheme); - const rsa = Certificate.rsa; - const key = rsa.PublicKey.fromDer(cert.pubKey()) catch - return stream.writeError(.bad_certificate); - switch (key.n.bits() / 8) { - inline 128, 256, 512 => |modulus_len| { - const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); - rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch - return stream.writeError(.decode_error); - }, - else => { - return stream.writeError(.bad_certificate); - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (cert.pub_key_algo != .curveEd25519) - return stream.writeError(.bad_certificate); - const Eddsa = SchemeEddsa(comptime_scheme); - if (sig_content.len != Eddsa.Signature.encoded_length) - return stream.writeError(.decode_error); - const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); - if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) + const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch + return stream.writeError(.decode_error); + sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (cert.pub_key_algo != .rsaEncryption) + return stream.writeError(.bad_certificate); + + const Hash = comptime_scheme.Hash(); + const rsa = Certificate.rsa; + const key = rsa.PublicKey.fromDer(cert.pubKey()) catch + return stream.writeError(.bad_certificate); + switch (key.n.bits() / 8) { + inline 128, 256, 512 => |modulus_len| { + const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); + rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch return stream.writeError(.decode_error); - const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch - return stream.writeError(.bad_certificate); - sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); }, else => { return stream.writeError(.bad_certificate); }, } - } - - pub fn recv_finished(self: *Self, digest: []const u8) !void { - var stream = &self.stream; - var r = stream.reader(); - const cipher = stream.cipher.handshake; - - switch (cipher) { - inline else => |p| { - const P = @TypeOf(p); - const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); - - var actual: [expected.len]u8 = undefined; - try r.readNoEof(&actual); - if (!mem.eql(u8, expected, &actual)) return stream.writeError(.decode_error); - }, - } - } + }, + inline .ed25519 => |comptime_scheme| { + if (cert.pub_key_algo != .curveEd25519) + return stream.writeError(.bad_certificate); + const Eddsa = comptime_scheme.Eddsa(); + if (sig_content.len != Eddsa.Signature.encoded_length) + return stream.writeError(.decode_error); + const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); + if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) + return stream.writeError(.decode_error); + const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch + return stream.writeError(.bad_certificate); + sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); + }, + else => { + return stream.writeError(.bad_certificate); + }, + } +} - pub fn send_finished(self: *Self) !void { - var stream = &self.stream; +pub fn recv_finished(self: *Self, digest: []const u8) !void { + var stream = &self.stream; + var r = stream.reader(); + const cipher = stream.cipher.handshake; - const handshake_hash = stream.transcript_hash.?.peek(); + switch (cipher) { + inline else => |p| { + const P = @TypeOf(p); + const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); - const verify_data = switch (stream.cipher.handshake) { - inline .aes_128_gcm_sha256, - .aes_256_gcm_sha384, - .chacha20_poly1305_sha256, - .aegis_256_sha512, - .aegis_128l_sha256, - => |v| brk: { - const T = @TypeOf(v); - const secret = v.client_finished_key; - const transcript_hash = stream.transcript_hash.?.peek(); + var actual: [expected.len]u8 = undefined; + try r.readNoEof(&actual); + if (!mem.eql(u8, expected, &actual)) return stream.writeError(.decode_error); + }, + } +} - break :brk &tls.hmac(T.Hmac, transcript_hash, secret); - }, - else => return stream.writeError(.decrypt_error), - }; - stream.content_type = .handshake; - _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); - try stream.flush(); - - const application_cipher = tls.ApplicationCipher.init(stream.cipher.handshake, handshake_hash); - stream.cipher = .{ .application = application_cipher }; - stream.content_type = .application_data; - stream.transcript_hash = null; - } +pub fn send_finished(self: *Self) !void { + var stream = &self.stream; + + const handshake_hash = stream.transcript_hash.?.peek(); + + const verify_data = switch (stream.cipher.handshake) { + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.client_finished_key; + const transcript_hash = stream.transcript_hash.?.peek(); + + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); + }, + else => return stream.writeError(.decrypt_error), + }; + stream.content_type = .handshake; + _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); + try stream.flush(); + + const application_cipher = tls.ApplicationCipher.init(stream.cipher.handshake, handshake_hash); + stream.cipher = .{ .application = application_cipher }; + stream.content_type = .application_data; + stream.transcript_hash = null; +} - pub const ReadError = Stream.ReadError; - pub const WriteError = Stream.WriteError; +pub const ReadError = anyerror; +pub const WriteError = anyerror; - /// Reads next application_data message. - pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { - var stream = &self.stream; +/// Reads next application_data message. +pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { + var stream = &self.stream; - if (stream.eof()) return 0; + if (stream.eof()) return 0; - while (stream.view.len == 0) { - const inner_plaintext = try stream.readInnerPlaintext(); - switch (inner_plaintext.type) { - .handshake => { - switch (inner_plaintext.handshake_type) { - // A multithreaded client could use these. - .new_session_ticket => { - try stream.reader().skipBytes(inner_plaintext.len, .{}); - }, - .key_update => { - switch (stream.cipher.application) { - inline else => |*p| { - const P = @TypeOf(p.*); - const server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = tls.hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = tls.hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - p.read_seq = 0; - }, - } - const update = try stream.read(tls.KeyUpdate); - if (update == .update_requested) { - switch (stream.cipher.application) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = tls.hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = tls.hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.write_seq = 0; - }, - } - } + while (stream.view.len == 0) { + const inner_plaintext = try stream.readInnerPlaintext(); + switch (inner_plaintext.type) { + .handshake => { + switch (inner_plaintext.handshake_type) { + // A multithreaded client could use these. + .new_session_ticket => { + try stream.reader().skipBytes(inner_plaintext.len, .{}); + }, + .key_update => { + switch (stream.cipher.application) { + inline else => |*p| { + const P = @TypeOf(p.*); + const server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); + p.server_secret = server_secret; + p.server_key = tls.hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + p.server_iv = tls.hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + p.read_seq = 0; }, - else => return stream.writeError(.unexpected_message), + } + const update = try stream.read(tls.KeyUpdate); + if (update == .update_requested) { + switch (stream.cipher.application) { + inline else => |*p| { + const P = @TypeOf(p.*); + const client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); + p.client_secret = client_secret; + p.client_key = tls.hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + p.client_iv = tls.hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.write_seq = 0; + }, + } } }, - .application_data => {}, else => return stream.writeError(.unexpected_message), } - } - return try self.stream.readv(buffers); + }, + .application_data => {}, + else => return stream.writeError(.unexpected_message), } + } + return try self.stream.readv(buffers); +} - pub fn read(self: *Self, buf: []u8) ReadError!usize { - const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; - return try self.readv(&buffers); - } +pub fn read(self: *Self, buf: []u8) ReadError!usize { + const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; + return try self.readv(&buffers); +} - pub fn write(self: *Self, buf: []const u8) WriteError!usize { - if (self.stream.eof()) return 0; +pub fn write(self: *Self, buf: []const u8) WriteError!usize { + if (self.stream.eof()) return 0; - const res = try self.stream.writeBytes(buf); - try self.stream.flush(); - return res; - } + const res = try self.stream.writeBytes(buf); + try self.stream.flush(); + return res; +} - pub fn close(self: *Self) void { - self.stream.close(); - } +pub fn close(self: *Self) void { + self.stream.close(); +} - pub const Reader = std.io.Reader(*Self, ReadError, read); - pub const Writer = std.io.Writer(*Self, WriteError, write); +pub const Reader = std.io.Reader(*Self, ReadError, read); +pub const Writer = std.io.Writer(*Self, WriteError, write); - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } +pub fn reader(self: *Self) Reader { + return .{ .context = self }; +} - pub fn writer(self: *Self) Writer { - return .{ .context = self }; - } - }; +pub fn writer(self: *Self) Writer { + return .{ .context = self }; } pub const Options = struct { @@ -545,8 +542,6 @@ pub const KeyPairs = struct { secp384r1: Secp384r1, x25519: X25519, - const Self = @This(); - const hello_rand_length = 32; const session_id_length = 32; const X25519 = tls.NamedGroupT(.x25519).KeyPair; @@ -554,7 +549,7 @@ pub const KeyPairs = struct { const Secp384r1 = tls.NamedGroupT(.secp384r1).KeyPair; const Kyber768 = tls.NamedGroupT(.x25519_kyber768d00).Kyber768.KeyPair; - pub fn init() Self { + pub fn init() @This() { var random_buffer: [ hello_rand_length + session_id_length + @@ -591,8 +586,8 @@ pub const KeyPairs = struct { secp256r1_seed: [Secp256r1.seed_length]u8, secp384r1_seed: [Secp384r1.seed_length]u8, x25519_seed: [X25519.seed_length]u8, - ) !Self { - return Self{ + ) !@This() { + return .{ .kyber768d00 = Kyber768.create(kyber_768_seed) catch {}, .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. @@ -609,30 +604,6 @@ pub const KeyPairs = struct { } }; -fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, - .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, - else => @compileError("bad scheme"), - }; -} - -fn SchemeHash(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, - else => @compileError("bad scheme"), - }; -} - -fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ed25519 => crypto.sign.Ed25519, - else => @compileError("bad scheme"), - }; -} - /// A command to send or receive a single message. Allows testing `advance` on a single thread. pub const Command = union(enum) { send_hello: KeyPairs, diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index 6690b88763e6..ee6afcaa7810 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -8,339 +8,402 @@ const assert = std.debug.assert; const Certificate = std.crypto.Certificate; const Allocator = std.mem.Allocator; -/// `StreamType` must conform to `tls.StreamInterface`. -pub fn Server(comptime StreamType: type) type { - return struct { - stream: Stream, - options: Options, - /// Only used during handshake for messages larger than tls.Plaintext.max_length. - // allocator: Allocator, - - const Stream = tls.Stream(tls.Plaintext.max_length, StreamType); - const Self = @This(); - - /// Initiates a TLS handshake and establishes a TLSv1.3 session - pub fn init(stream: *StreamType, options: Options) !Self { - var transcript_hash: tls.MultiHash = .{}; - var stream_ = tls.Stream(tls.Plaintext.max_length, StreamType){ - .stream = stream, - .is_client = false, - .transcript_hash = &transcript_hash, - }; - var res = Self{ .stream = stream_, .options = options }; - const client_hello = try res.recv_hello(&stream_); - _ = client_hello; - - var command = Command{ .recv_hello = {} }; - while (command != .none) command = try res.next(command); - - return res; - } +stream: tls.Stream, +options: Options, + +const Self = @This(); + +/// Initiates a TLS handshake and establishes a TLSv1.3 session +pub fn init(stream: std.io.AnyStream, options: Options) !Self { + var transcript_hash: tls.MultiHash = .{}; + var stream_ = tls.Stream{ + .stream = stream, + .is_client = false, + .transcript_hash = &transcript_hash, + }; + var res = Self{ .stream = stream_, .options = options }; + const client_hello = try res.recv_hello(&stream_); + _ = client_hello; - /// Executes handshake command and returns next one. - pub fn next(self: *Self, command: Command) !Command { - var stream = &self.stream; + var command = Command{ .recv_hello = {} }; + while (command != .none) command = try res.next(command); - switch (command) { - .recv_hello => { - const client_hello = try self.recv_hello(); + return res; +} - return .{ .send_hello = client_hello }; - }, - .send_hello => |client_hello| { - try self.send_hello(client_hello); +/// Executes handshake command and returns next one. +pub fn next(self: *Self, command: Command) !Command { + var stream = &self.stream; + + switch (command) { + .recv_hello => { + const client_hello = try self.recv_hello(); + + return .{ .send_hello = client_hello }; + }, + .send_hello => |client_hello| { + try self.send_hello(client_hello); + + const scheme = client_hello.sig_scheme; + // > if the client sends a non-empty session ID, + // > the server MUST send the change_cipher_spec + if (client_hello.session_id_len > 0) return .{ .send_change_cipher_spec = scheme }; + + return .{ .send_encrypted_extensions = scheme }; + }, + .send_change_cipher_spec => |scheme| { + try stream.changeCipherSpec(); + + return .{ .send_encrypted_extensions = scheme }; + }, + .send_encrypted_extensions => |scheme| { + try self.send_encrypted_extensions(); + + return .{ .send_certificate = scheme }; + }, + .send_certificate => |scheme| { + try self.send_certificate(); + + var cert_verify = Command.CertificateVerify{ .scheme = scheme, .salt = undefined }; + crypto.random.bytes(&cert_verify.salt); + + return .{ .send_certificate_verify = cert_verify }; + }, + .send_certificate_verify => |cert_verify| { + try self.send_certificate_verify(cert_verify); + return .{ .send_finished = {} }; + }, + .send_finished => { + try self.send_finished(); + return .{ .recv_finished = {} }; + }, + .recv_finished => { + try self.recv_finished(); + return .{ .none = {} }; + }, + .none => return .{ .none = {} }, + } +} - // > if the client sends a non-empty session ID, - // > the server MUST send the change_cipher_spec - if (client_hello.session_id_len > 0) return .{ .send_change_cipher_spec = {} }; +pub fn recv_hello(self: *Self) !ClientHello { + var stream = &self.stream; + var reader = stream.reader(); - return .{ .send_encrypted_extensions = {} }; - }, - .send_change_cipher_spec => { - try stream.changeCipherSpec(); + try stream.expectInnerPlaintext(.handshake, .client_hello); - return .{ .send_encrypted_extensions = {} }; - }, - .send_encrypted_extensions => { - try self.send_encrypted_extensions(); + _ = try stream.read(tls.Version); + var client_random: [32]u8 = undefined; + try reader.readNoEof(&client_random); - return .{ .send_certificate = {} }; - }, - .send_certificate => { - try self.send_certificate(); + var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try stream.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return stream.writeError(.illegal_parameter); + try reader.readNoEof(session_id[0..session_id_len]); - return .{ .send_certificate_verify = {} }; - }, - .send_certificate_verify => { - try self.send_certificate_verify(); - return .{ .send_finished = {} }; - }, - .send_finished => { - try self.send_finished(); - return .{ .recv_finished = {} }; - }, - .recv_finished => { - try self.recv_finished(); - return .{ .none = {} }; - }, - .none => return .{ .none = {} }, + const cipher_suite: tls.CipherSuite = brk: { + var cipher_suite_iter = try stream.iterator(u16, tls.CipherSuite); + var res: ?tls.CipherSuite = null; + while (try cipher_suite_iter.next()) |suite| { + for (self.options.cipher_suites) |s| { + if (s == suite and res == null) res = s; } } - - pub fn recv_hello(self: *Self) !ClientHello { - var stream = &self.stream; - var reader = stream.reader(); - - try stream.expectInnerPlaintext(.handshake, .client_hello); - - _ = try stream.read(tls.Version); - var client_random: [32]u8 = undefined; - try reader.readNoEof(&client_random); - - var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; - const session_id_len = try stream.read(u8); - if (session_id_len > tls.ClientHello.session_id_max_len) - return stream.writeError(.illegal_parameter); - try reader.readNoEof(session_id[0..session_id_len]); - - const cipher_suite: tls.CipherSuite = brk: { - var cipher_suite_iter = try stream.iterator(u16, tls.CipherSuite); - var res: ?tls.CipherSuite = null; - while (try cipher_suite_iter.next()) |suite| { - for (self.options.cipher_suites) |s| { - if (s == suite and res == null) res = s; + if (res == null) return stream.writeError(.illegal_parameter); + break :brk res.?; + }; + stream.transcript_hash.?.setActive(cipher_suite); + + { + var compression_methods: [2]u8 = undefined; + try reader.readNoEof(&compression_methods); + if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) + return stream.writeError(.illegal_parameter); + } + + var tls_version: ?tls.Version = null; + var key_share: ?tls.KeyShare = null; + var ec_point_format: ?tls.EcPointFormat = null; + var sig_scheme: ?tls.SignatureScheme = null; + + var extension_iter = try stream.extensions(); + while (try extension_iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (tls_version != null) return stream.writeError(.illegal_parameter); + var versions_iter = try stream.iterator(u8, tls.Version); + while (try versions_iter.next()) |v| { + if (v == .tls_1_3) tls_version = v; + } + }, + // TODO: use supported_groups instead + .key_share => { + if (key_share != null) return stream.writeError(.illegal_parameter); + + var key_share_iter = try stream.iterator(u16, tls.KeyShare); + while (try key_share_iter.next()) |ks| { + switch (ks) { + .x25519 => key_share = ks, + else => {}, } } - if (res == null) return stream.writeError(.illegal_parameter); - break :brk res.?; - }; - stream.transcript_hash.?.setActive(cipher_suite); - - { - var compression_methods: [2]u8 = undefined; - try reader.readNoEof(&compression_methods); - if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) - return stream.writeError(.illegal_parameter); - } - - var tls_version: ?tls.Version = null; - var key_share: ?tls.KeyShare = null; - var ec_point_format: ?tls.EcPointFormat = null; - var sig_scheme: ?tls.SignatureScheme = null; - - var extension_iter = try stream.extensions(); - while (try extension_iter.next()) |ext| { - switch (ext.type) { - .supported_versions => { - if (tls_version != null) return stream.writeError(.illegal_parameter); - var versions_iter = try stream.iterator(u8, tls.Version); - while (try versions_iter.next()) |v| { - if (v == .tls_1_3) tls_version = v; - } - }, - // TODO: use supported_groups instead - .key_share => { - if (key_share != null) return stream.writeError(.illegal_parameter); - - var key_share_iter = try stream.iterator(u16, tls.KeyShare); - while (try key_share_iter.next()) |ks| { - switch (ks) { - .x25519 => key_share = ks, - else => {}, - } - } - }, - .ec_point_formats => { - var format_iter = try stream.iterator(u8, tls.EcPointFormat); - while (try format_iter.next()) |f| { - if (f == .uncompressed) ec_point_format = .uncompressed; - } - }, - .signature_algorithms => { - var algos_iter = try stream.iterator(u16, tls.SignatureScheme); - while (try algos_iter.next()) |algo| { - if (algo == .rsa_pss_rsae_sha256) sig_scheme = algo; - } - }, - else => { - try reader.skipBytes(ext.len, .{}); - }, + }, + .ec_point_formats => { + var format_iter = try stream.iterator(u8, tls.EcPointFormat); + while (try format_iter.next()) |f| { + if (f == .uncompressed) ec_point_format = .uncompressed; } - } - - if (tls_version != .tls_1_3) return stream.writeError(.protocol_version); - if (key_share == null) return stream.writeError(.missing_extension); - if (ec_point_format == null) return stream.writeError(.missing_extension); - - var server_random: [32]u8 = undefined; - crypto.random.bytes(&server_random); - - const key_pair = .{ - .x25519 = crypto.dh.X25519.KeyPair.create(server_random) catch unreachable, - }; - - return .{ - .random = client_random, - .session_id_len = session_id_len, - .session_id = session_id, - .cipher_suite = cipher_suite, - .key_share = key_share.?, - .sig_scheme = sig_scheme, - .server_random = server_random, - .server_pair = key_pair, - }; + }, + .signature_algorithms => { + var algos_iter = try stream.iterator(u16, tls.SignatureScheme); + while (try algos_iter.next()) |algo| { + if (algo == .rsa_pss_rsae_sha256) sig_scheme = algo; + } + }, + else => { + try reader.skipBytes(ext.len, .{}); + }, } + } - pub fn send_hello(self: *Self, client_hello: ClientHello) !void { - var stream = &self.stream; - const key_pair = client_hello.server_pair; - - const hello = tls.ServerHello{ - .random = client_hello.server_random, - .session_id = &client_hello.session_id, - .cipher_suite = client_hello.cipher_suite, - .extensions = &.{ - .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, - .{ .key_share = &[_]tls.KeyShare{key_pair.toKeyShare()} }, - }, - }; - stream.version = .tls_1_2; - _ = try stream.write(tls.Handshake, .{ .server_hello = hello }); - try stream.flush(); - - const shared_key = switch (client_hello.key_share) { - .x25519_kyber768d00 => |ks| brk: { - const T = tls.NamedGroupT(.x25519_kyber768d00); - const pair: tls.X25519Kyber768Draft.KeyPair = key_pair.x25519_kyber768d00; - const shared_point = T.X25519.scalarmult( - ks.x25519, - pair.x25519.secret_key, - ) catch return stream.writeError(.decrypt_error); - // pair.kyber768d00.secret_key - // ks.kyber768d00 - const encaps = ks.kyber768d00.encaps(null).ciphertext; - - break :brk &(shared_point ++ encaps); - }, - .x25519 => |ks| brk: { - const shared_point = tls.NamedGroupT(.x25519).scalarmult( - key_pair.x25519.secret_key, - ks, - ) catch return stream.writeError(.decrypt_error); - break :brk &shared_point; - }, - .secp256r1 => |ks| brk: { - const mul = ks.p.mulPublic( - key_pair.secp256r1.secret_key.bytes, - .big, - ) catch return stream.writeError(.decrypt_error); - break :brk &mul.affineCoordinates().x.toBytes(.big); - }, - else => return stream.writeError(.illegal_parameter), - }; + if (tls_version != .tls_1_3) return stream.writeError(.protocol_version); + if (key_share == null) return stream.writeError(.missing_extension); + if (ec_point_format == null) return stream.writeError(.missing_extension); + if (sig_scheme == null) return stream.writeError(.missing_extension); - const hello_hash = stream.transcript_hash.?.peek(); - const handshake_cipher = tls.HandshakeCipher.init(client_hello.cipher_suite, shared_key, hello_hash,) catch - return stream.writeError(.illegal_parameter); - stream.cipher = .{ .handshake = handshake_cipher }; - } + var server_random: [32]u8 = undefined; + crypto.random.bytes(&server_random); - pub fn send_encrypted_extensions(self: *Self) !void { - var stream = &self.stream; - _ = try stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); - try stream.flush(); - } - - pub fn send_certificate(self: *Self) !void { - var stream = &self.stream; - _ = try self.stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); - try stream.flush(); - } + const key_pair = .{ + .x25519 = crypto.dh.X25519.KeyPair.create(server_random) catch unreachable, + }; - pub fn send_certificate_verify(self: *Self) !void { - var stream = &self.stream; + return .{ + .random = client_random, + .session_id_len = session_id_len, + .session_id = session_id, + .cipher_suite = cipher_suite, + .key_share = key_share.?, + .sig_scheme = sig_scheme.?, + .server_random = server_random, + .server_pair = key_pair, + }; +} - const digest = stream.transcript_hash.?.peek(); - const sig_content = tls.sigContent(digest); +pub fn send_hello(self: *Self, client_hello: ClientHello) !void { + var stream = &self.stream; + const key_pair = client_hello.server_pair; + + const hello = tls.ServerHello{ + .random = client_hello.server_random, + .session_id = &client_hello.session_id, + .cipher_suite = client_hello.cipher_suite, + .extensions = &.{ + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{key_pair.toKeyShare()} }, + }, + }; + stream.version = .tls_1_2; + _ = try stream.write(tls.Handshake, .{ .server_hello = hello }); + try stream.flush(); + + const shared_key = switch (client_hello.key_share) { + .x25519_kyber768d00 => |ks| brk: { + const T = tls.NamedGroupT(.x25519_kyber768d00); + const pair: tls.X25519Kyber768Draft.KeyPair = key_pair.x25519_kyber768d00; + const shared_point = T.X25519.scalarmult( + ks.x25519, + pair.x25519.secret_key, + ) catch return stream.writeError(.decrypt_error); + // pair.kyber768d00.secret_key + // ks.kyber768d00 + const encaps = ks.kyber768d00.encaps(null).ciphertext; + + break :brk &(shared_point ++ encaps); + }, + .x25519 => |ks| brk: { + const shared_point = tls.NamedGroupT(.x25519).scalarmult( + key_pair.x25519.secret_key, + ks, + ) catch return stream.writeError(.decrypt_error); + break :brk &shared_point; + }, + .secp256r1 => |ks| brk: { + const mul = ks.p.mulPublic( + key_pair.secp256r1.secret_key.bytes, + .big, + ) catch return stream.writeError(.decrypt_error); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + else => return stream.writeError(.illegal_parameter), + }; - const signature = sig_content; - // const signature = rsa.encrypt(256, sig_content, parsed.pubKey()) catch return stream.writeError(.internal_error); + const hello_hash = stream.transcript_hash.?.peek(); + const handshake_cipher = tls.HandshakeCipher.init( + client_hello.cipher_suite, + shared_key, + hello_hash, + ) catch + return stream.writeError(.illegal_parameter); + stream.cipher = .{ .handshake = handshake_cipher }; +} - _ = try self.stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ - .algorithm = .rsa_pss_rsae_sha256, - .signature = signature, - }}); - try stream.flush(); - } +pub fn send_encrypted_extensions(self: *Self) !void { + var stream = &self.stream; + _ = try stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); + try stream.flush(); +} - pub fn send_finished(self: *Self) !void { - var stream = &self.stream; - const verify_data = switch (stream.cipher.handshake) { - inline .aes_256_gcm_sha384, - => |v| brk: { - const T = @TypeOf(v); - const secret = v.server_finished_key; - const transcript_hash = stream.transcript_hash.?.peek(); +pub fn send_certificate(self: *Self) !void { + var stream = &self.stream; + _ = try self.stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + try stream.flush(); +} - break :brk tls.hmac(T.Hmac, transcript_hash, secret); +pub fn send_certificate_verify(self: *Self, verify: Command.CertificateVerify) !void { + var stream = &self.stream; + + const digest = stream.transcript_hash.?.peek(); + const sig_content = tls.sigContent(digest); + + const key = self.options.certificate_key; + const cert_buf = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; + const cert = cert_buf.parse() catch return stream.writeError(.bad_certificate); + + const signature: []const u8 = switch (verify.scheme) { + // inline .ecdsa_secp256r1_sha256, + // .ecdsa_secp384r1_sha384, + // => |comptime_scheme| { + // if (cert.pub_key_algo != .X9_62_id_ecPublicKey) + // return stream.writeError(.bad_certificate); + // const Ecdsa = comptime_scheme.Ecdsa(); + // const sig = Ecdsa.Signature.fromDer(sig_bytes) catch + // return stream.writeError(.decode_error); + // const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch + // return stream.writeError(.decode_error); + // sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); + // }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| brk: { + if (cert.pub_key_algo != .rsaEncryption) + return stream.writeError(.bad_certificate); + + const Hash = comptime_scheme.Hash(); + const rsa = Certificate.rsa; + // if (!std.mem.eql(u8, cert.pubKey(), key.public)) + // return stream.writeError(.bad_certificate); + + switch (key.public.n.bits() / 8) { + inline 128, 256, 512 => |modulus_length| { + break :brk &(rsa.PSSSignature.sign( + modulus_length, + sig_content, + Hash, + key, + verify.salt[0..Hash.digest_length].*, + ) catch return stream.writeError(.bad_certificate)); }, - else => return stream.writeError(.illegal_parameter), - }; - _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); - try stream.flush(); - } - - pub fn recv_finished(self: *Self) !void { - var stream = &self.stream; - var reader = stream.reader(); + else => return stream.writeError(.bad_certificate), + } + }, + // inline .ed25519 => |comptime_scheme| { + // if (cert.pub_key_algo != .curveEd25519) + // return stream.writeError(.bad_certificate); + // const Eddsa = comptime_scheme.Eddsa(); + // if (sig_content.len != Eddsa.Signature.encoded_length) + // return stream.writeError(.decode_error); + // const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); + // if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) + // return stream.writeError(.decode_error); + // const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch + // return stream.writeError(.bad_certificate); + // sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); + // }, + else => { + return stream.writeError(.bad_certificate); + }, + }; - const handshake_hash = stream.transcript_hash.?.peek(); + _ = try self.stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ + .algorithm = .rsa_pss_rsae_sha256, + .signature = signature, + } }); + try stream.flush(); +} - const application_cipher = tls.ApplicationCipher.init( - stream.cipher.handshake, - handshake_hash, - ); +pub fn send_finished(self: *Self) !void { + var stream = &self.stream; + const verify_data = switch (stream.cipher.handshake) { + inline .aes_256_gcm_sha384, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.server_finished_key; + const transcript_hash = stream.transcript_hash.?.peek(); + + break :brk tls.hmac(T.Hmac, transcript_hash, secret); + }, + else => return stream.writeError(.illegal_parameter), + }; + _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); + try stream.flush(); +} - const expected = switch (stream.cipher.handshake) { - inline else => |p| brk: { - const P = @TypeOf(p); - const digest = stream.transcript_hash.?.peek(); - break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); - }, - }; +pub fn recv_finished(self: *Self) !void { + var stream = &self.stream; + var reader = stream.reader(); - try stream.expectInnerPlaintext(.handshake, .finished); - const actual = stream.view; - try reader.skipBytes(stream.view.len, .{}); + const handshake_hash = stream.transcript_hash.?.peek(); - if (!mem.eql(u8, expected, actual)) return stream.writeError(.decode_error); + const application_cipher = tls.ApplicationCipher.init( + stream.cipher.handshake, + handshake_hash, + ); - stream.content_type = .application_data; - stream.handshake_type = null; - stream.cipher = .{ .application = application_cipher }; - stream.transcript_hash = null; - } + const expected = switch (stream.cipher.handshake) { + inline else => |p| brk: { + const P = @TypeOf(p); + const digest = stream.transcript_hash.?.peek(); + break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); + }, }; + + try stream.expectInnerPlaintext(.handshake, .finished); + const actual = stream.view; + try reader.skipBytes(stream.view.len, .{}); + + if (!mem.eql(u8, expected, actual)) return stream.writeError(.decode_error); + + stream.content_type = .application_data; + stream.handshake_type = null; + stream.cipher = .{ .application = application_cipher }; + stream.transcript_hash = null; } pub const Options = struct { /// List of potential cipher suites in descending order of preference. cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, certificate: tls.Certificate, - certificate_key: []const u8, + certificate_key: Certificate.rsa.PrivateKey, }; /// A command to send or receive a single message. Allows testing `advance` on a single thread. pub const Command = union(enum) { recv_hello: void, send_hello: ClientHello, - send_change_cipher_spec: void, - send_encrypted_extensions: void, - send_certificate: void, - send_certificate_verify: void, + send_change_cipher_spec: tls.SignatureScheme, + send_encrypted_extensions: tls.SignatureScheme, + send_certificate: tls.SignatureScheme, + send_certificate_verify: CertificateVerify, send_finished: void, recv_finished: void, none: void, + + pub const CertificateVerify = struct { + scheme: tls.SignatureScheme, + salt: [tls.MultiHash.max_digest_len]u8, + }; }; pub const ClientHello = struct { @@ -349,7 +412,7 @@ pub const ClientHello = struct { session_id: [32]u8, cipher_suite: tls.CipherSuite, key_share: tls.KeyShare, - sig_scheme: ?tls.SignatureScheme, + sig_scheme: tls.SignatureScheme, server_random: [32]u8, /// active member MUST match `key_share` server_pair: tls.KeyPair, diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig new file mode 100644 index 000000000000..f31d1d8458a6 --- /dev/null +++ b/lib/std/crypto/tls/Stream.zig @@ -0,0 +1,547 @@ +//! Abstraction over TLS record layer (RFC 8446 S5). +//! +//! After writing must call `flush` before reading or contents will not be written. +//! +//! Handles: +//! * Fragmentation +//! * Encryption and decryption of handshake and application data messages +//! * Reading and writing prefix length arrays +//! * Reading and writing TLS types +//! * Alerts +const std = @import("../../std.zig"); +const tls = std.crypto.tls; + +stream: std.io.AnyStream, +/// Used for both reading and writing. +/// Stores plaintext or briefly ciphertext, but not Plaintext headers. +buffer: [fragment_size]u8 = undefined, +/// Unread or unwritten view of `buffer`. May contain multiple handshakes. +view: []const u8 = "", + +/// When sending this is the record type that will be flushed. +/// When receiving this is the next fragment's expected record type. +content_type: ContentType = .handshake, +/// When sending this is the flushed version. +version: Version = .tls_1_0, +/// When receiving a handshake message will be expected with this type. +handshake_type: ?HandshakeType = .client_hello, + +/// Used to decrypt .application_data messages. +/// Used to encrypt messages that aren't alert or change_cipher_spec. +cipher: Cipher = .none, + +/// True when we send or receive a close_notify alert. +closed: bool = false, + +/// True if we're being used as a client. This changes: +/// * Certain shared struct formats (like Extension) +/// * Which ciphers are used for encoding/decoding handshake and application messages. +is_client: bool, + +/// When > 0 won't actually do anything with writes. Used to discover prefix lengths. +nocommit: usize = 0, + +/// Client and server implementations can set this. While set sent or received handshake messages +/// will update the hash. +transcript_hash: ?*MultiHash, + +const Self = @This(); +const ContentType = tls.ContentType; +const Version = tls.Version; +const HandshakeType = tls.HandshakeType; +const MultiHash = tls.MultiHash; +const Plaintext = tls.Plaintext; +const HandshakeCipher = tls.HandshakeCipher; +const ApplicationCipher = tls.ApplicationCipher; +const Alert = tls.Alert; +const Extension = tls.Extension; + +const fragment_size = Plaintext.max_length; + +const Cipher = union(enum) { + none: void, + application: ApplicationCipher, + handshake: HandshakeCipher, +}; + +pub const ReadError = anyerror || tls.Error || error{EndOfStream}; +pub const WriteError = anyerror || error{TlsEncodeError}; + +fn ciphertextOverhead(self: Self) usize { + return switch (self.cipher) { + inline .application, .handshake => |c| switch (c) { + inline else => |t| @TypeOf(t).AEAD.tag_length + @sizeOf(ContentType), + }, + else => 0, + }; +} + +fn maxFragmentSize(self: Self) usize { + return self.buffer.len - self.ciphertextOverhead(); +} + +const EncryptionMethod = enum { none, handshake, application }; +fn encryptionMethod(self: Self, content_type: ContentType) EncryptionMethod { + switch (content_type) { + .alert, .change_cipher_spec => {}, + else => { + if (self.cipher == .application) return .application; + if (self.cipher == .handshake) return .handshake; + }, + } + return .none; +} + +pub fn flush(self: *Self) WriteError!void { + if (self.view.len == 0) return; + if (self.transcript_hash) |t| { + if (self.content_type == .handshake) t.update(self.view); + } + + var plaintext = Plaintext{ + .type = self.content_type, + .version = self.version, + .len = @intCast(self.view.len), + }; + + var header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); + var aead: []const u8 = ""; + switch (self.cipher) { + .none => {}, + inline .application, .handshake => |*cipher| { + plaintext.type = .application_data; + plaintext.len += @intCast(self.ciphertextOverhead()); + header = Encoder.encode(Plaintext, plaintext); + switch (cipher.*) { + inline else => |*c| { + std.debug.assert(self.view.ptr == &self.buffer); + self.buffer[self.view.len] = @intFromEnum(self.content_type); + self.view = self.buffer[0 .. self.view.len + 1]; + aead = &c.encrypt(self.view, &header, self.is_client, @constCast(self.view)); + }, + } + }, + } + + // TODO: contiguous buffer management + try self.stream.writer().writeAll(&header); + try self.stream.writer().writeAll(self.view); + try self.stream.writer().writeAll(aead); + self.view = self.buffer[0..0]; +} + +/// Flush a change cipher spec message to the underlying stream. +pub fn changeCipherSpec(self: *Self) WriteError!void { + self.version = .tls_1_2; + + const plaintext = Plaintext{ + .type = .change_cipher_spec, + .version = self.version, + .len = 1, + }; + const msg = [_]u8{1}; + const header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); + // TODO: contiguous buffer management + try self.stream.writer().writeAll(&header); + try self.stream.writer().writeAll(&msg); +} + +/// Write an alert to stream and call `close_notify` after. Returns Zig error. +pub fn writeError(self: *Self, err: Alert.Description) tls.Error { + const alert = Alert{ .level = .fatal, .description = err }; + + self.view = self.buffer[0..0]; + self.content_type = .alert; + _ = self.write(Alert, alert) catch {}; + self.flush() catch {}; + + self.close(); + @panic("TODO: fixme"); + // return err.toError(); +} + +pub fn close(self: *Self) void { + const alert = Alert{ .level = .fatal, .description = .close_notify }; + _ = self.write(Alert, alert) catch {}; + self.content_type = .alert; + self.flush() catch {}; + self.closed = true; +} + +/// Write bytes to `stream`, potentially flushing once `self.buffer` is full. +pub fn writeBytes(self: *Self, bytes: []const u8) WriteError!usize { + if (self.nocommit > 0) return bytes.len; + + const available = self.buffer.len - self.view.len; + const to_consume = bytes[0..@min(available, bytes.len)]; + + @memcpy(self.buffer[self.view.len..][0..bytes.len], to_consume); + self.view = self.buffer[0 .. self.view.len + to_consume.len]; + + if (self.view.len == self.buffer.len) try self.flush(); + + return to_consume.len; +} + +pub fn writeAll(self: *Self, bytes: []const u8) WriteError!usize { + var index: usize = 0; + while (index != bytes.len) { + index += try self.writeBytes(bytes[index..]); + } + return index; +} + +pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T) WriteError!usize { + var res: usize = 0; + for (values) |v| res += self.length(T, v); + + if (PrefixT != void) { + if (res > std.math.maxInt(PrefixT)) { + self.close(); + return error.TlsEncodeError; // Prefix length overflow + } + res += try self.write(PrefixT, @intCast(res)); + } + + for (values) |v| _ = try self.write(T, v); + + return res; +} + +pub fn write(self: *Self, comptime T: type, value: T) WriteError!usize { + switch (@typeInfo(T)) { + .Int, .Enum => { + const encoded = Encoder.encode(T, value); + return try self.writeAll(&encoded); + }, + .Struct, .Union => { + return try T.write(value, self); + }, + .Void => return 0, + else => @compileError("cannot write " ++ @typeName(T)), + } +} + +pub fn length(self: *Self, comptime T: type, value: T) usize { + if (T == void) return 0; + self.nocommit += 1; + defer self.nocommit -= 1; + return self.write(T, value) catch unreachable; +} + +pub fn arrayLength( + self: *Self, + comptime PrefixT: type, + comptime T: type, + values: []const T, +) usize { + var res: usize = if (PrefixT == void) 0 else @divExact(@typeInfo(PrefixT).Int.bits, 8); + for (values) |v| res += self.length(T, v); + return res; +} + +/// Reads bytes from `view`, potentially reading more fragments from `stream`. +/// +/// A return value of 0 indicates EOF. +pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { + // > Any data received after a closure alert has been received MUST be ignored. + if (self.eof()) return 0; + + if (self.view.len == 0) try self.expectInnerPlaintext(self.content_type, self.handshake_type); + + var bytes_read: usize = 0; + + for (buffers) |b| { + var bytes_read_buffer: usize = 0; + while (bytes_read_buffer != b.iov_len) { + const to_read = @min(b.iov_len, self.view.len); + if (to_read == 0) return bytes_read; + + @memcpy(b.iov_base[0..to_read], self.view[0..to_read]); + + self.view = self.view[to_read..]; + bytes_read_buffer += to_read; + bytes_read += bytes_read_buffer; + } + } + + return bytes_read; +} + +/// Reads bytes from `view`, potentially reading more fragments from `stream`. +/// A return value of 0 indicates EOF. +pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { + const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; + return try self.readv(&buffers); +} + +/// Reads plaintext from `stream` into `buffer` and updates `view`. +/// Skips non-fatal alert and change_cipher_spec messages. +/// Will decrypt according to `encryptionMethod` if receiving application_data message. +pub fn readPlaintext(self: *Self) ReadError!Plaintext { + std.debug.assert(self.view.len == 0); // last read should have completed + var plaintext_bytes: [Plaintext.size]u8 = undefined; + var n_read: usize = 0; + + while (true) { + n_read = try self.stream.reader().readAll(&plaintext_bytes); + if (n_read != plaintext_bytes.len) return self.writeError(.decode_error); + + var res = Plaintext.init(plaintext_bytes); + if (res.len > Plaintext.max_length) return self.writeError(.record_overflow); + + self.view = self.buffer[0..res.len]; + n_read = try self.stream.reader().readAll(@constCast(self.view)); + if (n_read != res.len) return self.writeError(.decode_error); + + const encryption_method = self.encryptionMethod(res.type); + if (encryption_method != .none) { + if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); + + switch (self.cipher) { + inline .handshake, .application => |*cipher| { + switch (cipher.*) { + inline else => |*c| { + const C = @TypeOf(c.*); + const tag_len = C.AEAD.tag_length; + + const ciphertext = self.view[0 .. self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; + const out: []u8 = @constCast(self.view[0..ciphertext.len]); + c.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch + return self.writeError(.bad_record_mac); + const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); + if (padding_start) |s| { + res.type = @enumFromInt(self.view[s]); + self.view = self.view[0..s]; + } else { + return self.writeError(.decode_error); + } + }, + } + }, + else => unreachable, + } + } + + switch (res.type) { + .alert => { + const level = try self.read(Alert.Level); + const description = try self.read(Alert.Description); + std.log.debug("TLS alert {} {}", .{ level, description }); + + if (description == .close_notify) { + self.closed = true; + return res; + } + if (level == .fatal) return self.writeError(.unexpected_message); + }, + // > An implementation may receive an unencrypted record of type + // > change_cipher_spec consisting of the single byte value 0x01 at any + // > time after the first ClientHello message has been sent or received + // > and before the peer's Finished message has been received and MUST + // > simply drop it without further processing. + .change_cipher_spec => { + if (!std.mem.eql(u8, self.view, &[_]u8{1})) { + return self.writeError(.unexpected_message); + } + }, + else => { + return res; + }, + } + } +} + +pub fn readInnerPlaintext(self: *Self) ReadError!InnerPlaintext { + var res: InnerPlaintext = .{ + .type = self.content_type, + .handshake_type = if (self.handshake_type) |h| h else undefined, + .len = undefined, + }; + if (self.view.len == 0) { + const plaintext = try self.readPlaintext(); + res.type = plaintext.type; + res.len = plaintext.len; + + self.content_type = res.type; + } + + if (res.type == .handshake) { + if (self.transcript_hash) |t| t.update(self.view[0..4]); + res.handshake_type = try self.read(HandshakeType); + res.len = try self.read(u24); + if (self.transcript_hash) |t| t.update(self.view[0..res.len]); + + self.handshake_type = res.handshake_type; + } + + return res; +} + +pub fn expectInnerPlaintext( + self: *Self, + expected_content: ContentType, + expected_handshake: ?HandshakeType, +) ReadError!void { + const inner_plaintext = try self.readInnerPlaintext(); + if (expected_content != inner_plaintext.type) { + std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext }); + return self.writeError(.unexpected_message); + } + if (expected_handshake) |expected| { + if (expected != inner_plaintext.handshake_type) return self.writeError(.decode_error); + } +} + +pub fn read(self: *Self, comptime T: type) ReadError!T { + comptime std.debug.assert(@sizeOf(T) < fragment_size); + switch (@typeInfo(T)) { + .Int => return self.reader().readInt(T, .big) catch |err| switch (err) { + error.EndOfStream => return self.writeError(.decode_error), + else => |e| return e, + }, + .Enum => |info| { + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + const int = try self.read(info.tag_type); + return @enumFromInt(int); + }, + else => { + return T.read(self) catch |err| switch (err) { + error.TlsUnexpectedMessage => return self.writeError(.unexpected_message), + error.TlsBadRecordMac => return self.writeError(.bad_record_mac), + error.TlsRecordOverflow => return self.writeError(.record_overflow), + error.TlsHandshakeFailure => return self.writeError(.handshake_failure), + error.TlsBadCertificate => return self.writeError(.bad_certificate), + error.TlsUnsupportedCertificate => return self.writeError(.unsupported_certificate), + error.TlsCertificateRevoked => return self.writeError(.certificate_revoked), + error.TlsCertificateExpired => return self.writeError(.certificate_expired), + error.TlsCertificateUnknown => return self.writeError(.certificate_unknown), + error.TlsIllegalParameter => return self.writeError(.illegal_parameter), + error.TlsUnknownCa => return self.writeError(.unknown_ca), + error.TlsAccessDenied => return self.writeError(.access_denied), + error.TlsDecodeError => return self.writeError(.decode_error), + error.TlsDecryptError => return self.writeError(.decrypt_error), + error.TlsProtocolVersion => return self.writeError(.protocol_version), + error.TlsInsufficientSecurity => return self.writeError(.insufficient_security), + error.TlsInternalError => return self.writeError(.internal_error), + error.TlsInappropriateFallback => return self.writeError(.inappropriate_fallback), + error.TlsMissingExtension => return self.writeError(.missing_extension), + error.TlsUnsupportedExtension => return self.writeError(.unsupported_extension), + error.TlsUnrecognizedName => return self.writeError(.unrecognized_name), + error.TlsBadCertificateStatusResponse => return self.writeError(.bad_certificate_status_response), + error.TlsUnknownPskIdentity => return self.writeError(.unknown_psk_identity), + error.TlsCertificateRequired => return self.writeError(.certificate_required), + error.TlsNoApplicationProtocol => return self.writeError(.no_application_protocol), + error.TlsUnknown => |e| { + self.close(); + return e; + }, + else => return self.writeError(.decode_error), + }; + }, + } +} + +fn Iterator(comptime T: type) type { + return struct { + stream: *Self, + end: usize, + + pub fn next(self: *@This()) ReadError!?T { + const cur_offset = self.stream.buffer.len - self.stream.view.len; + if (cur_offset > self.end) return null; + return try self.stream.read(T); + } + }; +} + +pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) ReadError!Iterator(Tag) { + const offset = self.buffer.len - self.view.len; + const len = try self.read(Len); + return Iterator(Tag){ + .stream = self, + .end = offset + len, + }; +} + +pub fn extensions(self: *Self) ReadError!Iterator(Extension.Header) { + return self.iterator(u16, Extension.Header); +} + +pub fn eof(self: Self) bool { + return self.closed and self.view.len == 0; +} + +pub const Reader = std.io.Reader(*Self, ReadError, readBytes); +pub const Writer = std.io.Writer(*Self, WriteError, writeBytes); + +pub fn reader(self: *Self) Reader { + return .{ .context = self }; +} + +pub fn writer(self: *Self) Writer { + return .{ .context = self }; +} + +const Encoder = struct { + fn RetType(comptime T: type) type { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => return [1]u8, + 16 => return [2]u8, + 24 => return [3]u8, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return RetType(info.tag_type); + }, + .Struct => |info| { + var len: usize = 0; + inline for (info.fields) |f| len += @typeInfo(RetType(f.type)).Array.len; + return [len]u8; + }, + else => @compileError("don't know how to encode " ++ @tagName(T)), + } + } + fn encode(comptime T: type, value: T) RetType(T) { + return switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => .{value}, + 16 => .{ + @as(u8, @truncate(value >> 8)), + @as(u8, @truncate(value)), + }, + 24 => .{ + @as(u8, @truncate(value >> 16)), + @as(u8, @truncate(value >> 8)), + @as(u8, @truncate(value)), + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| encode(info.tag_type, @intFromEnum(value)), + .Struct => |info| brk: { + const Ret = RetType(T); + + var offset: usize = 0; + var res: Ret = undefined; + inline for (info.fields) |f| { + const encoded = encode(f.type, @field(value, f.name)); + @memcpy(res[offset..][0..encoded.len], &encoded); + offset += encoded.len; + } + + break :brk res; + }, + else => @compileError("cannot encode type " ++ @typeName(T)), + }; + } +}; + +const InnerPlaintext = struct { + type: ContentType, + handshake_type: HandshakeType, + len: u24, +}; + diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 00ee206c7757..5f45bbaae390 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,4 +1,4 @@ -//! HTTP(S) Client implementation. +//! Blocking HTTP(S) client //! //! Connections are opened in a thread-safe manner, but individual Requests are not. //! @@ -17,19 +17,14 @@ const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); const proto = @import("protocol.zig"); - -pub const disable_tls = std.options.http_disable_tls; -const TlsClient = if (disable_tls) void else std.crypto.tls.Client(net.Stream); +const disable_tls = std.options.http_disable_tls; +const tls = std.crypto.tls; +const TlsClient = if (disable_tls) void else std.crypto.tls.Client; /// Used for all client allocations. Must be thread-safe. allocator: Allocator, -ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, -ca_bundle_mutex: std.Thread.Mutex = .{}, - -/// When this is `true`, the next time this client performs an HTTPS request, -/// it will first rescan the system for root certificates. -next_https_rescan_certs: bool = true, +tls_options: (if (disable_tls) void else TlsOptions) = if (disable_tls) {} else .{}, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, @@ -43,6 +38,17 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, +/// tls.ClientOptions minus ones that we set +pub const TlsOptions = struct { + /// Client takes ownership of this field. If empty, will rescan on init. + /// + /// Trusted certificate authority bundle used to authenticate server certificates. + /// When null, server certificate and certificate_verify messages will be skipped (INSECURE). + ca_bundle: ?std.crypto.Certificate.Bundle = .{}, + /// List of cipher suites to advertise in order of descending preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, +}; + /// A set of linked lists of connections that can be reused. pub const ConnectionPool = struct { mutex: std.Thread.Mutex = .{}, @@ -150,8 +156,6 @@ pub const ConnectionPool = struct { pool.mutex.lock(); defer pool.mutex.unlock(); - const next = pool.free.first; - _ = next; while (pool.free_len > new_size) { const popped = pool.free.popFirst() orelse unreachable; pool.free_len -= 1; @@ -191,7 +195,7 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { - stream: net.Stream, + stream: std.io.AnyStream, /// undefined unless protocol is tls. tls_client: *TlsClient, @@ -331,7 +335,7 @@ pub const Connection = struct { return conn.writeAllDirectTls(buffer); } - return conn.stream.writeAll(buffer) catch |err| switch (err) { + return conn.stream.writer().writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -1216,6 +1220,19 @@ pub const Proxy = struct { supports_connect: bool, }; +/// Initializes ca_bundle if it's not null and empty. +pub fn init(client: Client) !Client { + var copy = client; + + if (!disable_tls) { + if (copy.tls_options.ca_bundle) |*bundle| { + if (bundle.bytes.items.len == 0) try bundle.rescan(copy.allocator); + } + } + + return copy; +} + /// Release all associated resources with the client. /// /// All pending requests must be de-initialized and all active connections released @@ -1225,8 +1242,9 @@ pub fn deinit(client: *Client) void { client.connection_pool.deinit(client.allocator); - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); + if (!disable_tls) { + if (client.tls_options.ca_bundle) |*bundle| bundle.deinit(client.allocator); + } client.* = undefined; } @@ -1346,7 +1364,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; - const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + const tcp_stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, error.NetworkUnreachable => return error.NetworkUnreachable, error.ConnectionTimedOut => return error.ConnectionTimedOut, @@ -1357,10 +1375,10 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, else => return error.UnexpectedConnectFailure, }; - errdefer stream.close(); + errdefer tcp_stream.close(); conn.data = .{ - .stream = stream, + .stream = tcp_stream, .tls_client = undefined, .protocol = protocol, @@ -1376,7 +1394,8 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.destroy(conn.data.tls_client); conn.data.tls_client.* = TlsClient.init(&conn.data.stream, .{ - .ca_bundle = client.ca_bundle, + .ca_bundle = client.tls_options.ca_bundle, + .cipher_suites = client.tls_options.cipher_suites, .host = host, // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. @@ -1554,7 +1573,6 @@ pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendE UnsupportedUrlScheme, UriMissingHost, - CertificateBundleLoadFailure, UnsupportedTransferEncoding, }; @@ -1645,18 +1663,6 @@ pub fn open( const host = uri.host orelse return error.UriMissingHost; - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { - if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); - } - } - const conn = options.connection orelse try client.connect(host, port, protocol); var req: Request = .{ diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 5290241b6e04..2911d8591b2c 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,5 +1,21 @@ -//! Blocking HTTP server implementation. +//! Blocking HTTP(s) server +//! //! Handles a single connection's lifecycle. +//! +//! TLS support may be disabled via `std.options.http_disable_tls`. + +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const net = std.net; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; + +const Server = @This(); + +pub const disable_tls = std.options.http_disable_tls; +const TlsServer = if (disable_tls) void else std.crypto.tls.Server(net.Stream); connection: net.Server.Connection, /// Keeps track of whether the Server is ready to accept a new request on the @@ -1137,12 +1153,3 @@ fn rebase(s: *Server, index: usize) void { s.read_buffer_len = index + leftover.len; } -const std = @import("../std.zig"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/lib/std/io.zig b/lib/std/io.zig index 64560c6fe3bf..b402b8185e4f 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -356,6 +356,53 @@ pub fn GenericWriter( }; } +pub fn GenericStream( + comptime Context: type, + comptime ReadError: type, + /// Returns the number of bytes read. It may be less than buffer.len. + /// If the number of bytes read is 0, it means end of stream. + /// End of stream is not an error condition. + comptime readFn: fn (context: Context, buffer: []u8) ReadError!usize, + comptime WriteError: type, + comptime writeFn: fn (context: Context, bytes: []const u8) WriteError!usize, + comptime closeFn: fn (context: Context) void, +) type { + return struct { + context: Context, + + const ReaderType = GenericReader(Context, ReadError, readFn); + const WriterType = GenericWriter(Context, WriteError, writeFn); + + const Self = @This(); + + pub inline fn reader(self: *const Self) ReaderType { + return .{ .context = self.context }; + } + + pub inline fn writer(self: *const Self) WriterType { + return .{ .context = self.context }; + } + + pub inline fn close(self: *const Self) void { + closeFn(self.context); + } + + pub inline fn any(self: *const Self) AnyStream { + return .{ + .context = @ptrCast(&self.context), + .readFn = self.reader().any().readFn, + .writeFn = self.writer().any().writeFn, + .closeFn = typeErasedCloseFn, + }; + } + + fn typeErasedCloseFn(context: *const anyopaque) void { + const ptr: *const Context = @alignCast(@ptrCast(context)); + return closeFn(ptr.*); + } + }; +} + /// Deprecated; consider switching to `AnyReader` or use `GenericReader` /// to use previous API. pub const Reader = GenericReader; @@ -365,6 +412,7 @@ pub const Writer = GenericWriter; pub const AnyReader = @import("io/Reader.zig"); pub const AnyWriter = @import("io/Writer.zig"); +pub const AnyStream = @import("io/Stream.zig"); pub const SeekableStream = @import("io/seekable_stream.zig").SeekableStream; diff --git a/lib/std/io/Stream.zig b/lib/std/io/Stream.zig new file mode 100644 index 000000000000..37be7ae0b1ab --- /dev/null +++ b/lib/std/io/Stream.zig @@ -0,0 +1,36 @@ +const std = @import("../std.zig"); +const assert = std.debug.assert; +const mem = std.mem; +const os = std.os; + +context: *const anyopaque, +writeFn: *const fn (context: *const anyopaque, bytes: []const u8) anyerror!usize, +readFn: *const fn (context: *const anyopaque, buffer: []u8) anyerror!usize, +closeFn: *const fn (context: *const anyopaque) void, + +const Self = @This(); +pub const Error = anyerror; + +pub fn write(self: Self, bytes: []const u8) anyerror!usize { + return self.writeFn(self.context, bytes); +} + +/// Returns the number of bytes read. It may be less than buffer.len. +/// If the number of bytes read is 0, it means end of stream. +/// End of stream is not an error condition. +pub fn read(self: Self, buffer: []u8) anyerror!usize { + return self.readFn(self.context, buffer); +} + +pub fn close(self: Self) void { + return self.closeFn(self.context); +} + +pub fn reader(self: Self) std.io.AnyReader { + return .{ .context = self.context, .readFn = self.readFn }; +} + +pub fn writer(self: Self) std.io.AnyWriter { + return .{ .context = self.context, .writeFn = self.writeFn }; +} + diff --git a/lib/std/net.zig b/lib/std/net.zig index e68adc4207ff..8b6a805aa1b4 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -264,6 +264,10 @@ pub const Address = extern union { try posix.getsockname(sockfd, &s.listen_address.any, &socklen); return s; } + + // The returned `Server` has an open `stream`. + // pub fn listenTls(address: Address, options: ListenOptions) ListenError!Server { + // } }; pub const Ip4Address = extern struct { @@ -1801,8 +1805,9 @@ pub const Stream = struct { pub const ReadError = os.ReadError; pub const WriteError = os.WriteError; - pub const Reader = io.Reader(Stream, ReadError, read); - pub const Writer = io.Writer(Stream, WriteError, write); + pub const Reader = io.GenericReader(Stream, ReadError, read); + pub const Writer = io.GenericWriter(Stream, WriteError, write); + pub const GenericStream = io.GenericStream(Stream, ReadError, read, WriteError, write, close); pub fn reader(self: Stream) Reader { return .{ .context = self }; @@ -1812,6 +1817,10 @@ pub const Stream = struct { return .{ .context = self }; } + pub fn stream(self: Stream) GenericStream { + return .{ .context = self }; + } + pub fn read(self: Stream, buffer: []u8) ReadError!usize { if (builtin.os.tag == .windows) { return os.windows.ReadFile(self.handle, buffer, null); @@ -1831,29 +1840,6 @@ pub const Stream = struct { return os.readv(s.handle, iovecs); } - /// Returns the number of bytes read. If the number read is smaller than - /// `buffer.len`, it means the stream reached the end. Reaching the end of - /// a stream is not an error condition. - pub fn readAll(s: Stream, buffer: []u8) ReadError!usize { - return readAtLeast(s, buffer, buffer.len); - } - - /// Returns the number of bytes read, calling the underlying read function - /// the minimal number of times until the buffer has at least `len` bytes - /// filled. If the number read is less than `len` it means the stream - /// reached the end. Reaching the end of the stream is not an error - /// condition. - pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize { - assert(len <= buffer.len); - var index: usize = 0; - while (index < len) { - const amt = try s.read(buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's /// file system thread instead of non-blocking. It needs to be reworked to properly /// use non-blocking I/O. @@ -1865,13 +1851,6 @@ pub const Stream = struct { return os.write(self.handle, buffer); } - pub fn writeAll(self: Stream, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try self.write(bytes[index..]); - } - } - /// See https://github.com/ziglang/zig/issues/7699 /// See equivalent function: `std.fs.File.writev`. pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { @@ -1901,10 +1880,10 @@ pub const Stream = struct { pub const Server = struct { listen_address: Address, - stream: std.net.Stream, + stream: net.Stream, pub const Connection = struct { - stream: std.net.Stream, + stream: net.Stream, address: Address, }; From 24891eb8ac7e2a1df63d13a7f01d20d267bb45d8 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Mon, 18 Mar 2024 18:32:28 -0400 Subject: [PATCH 10/17] replace Reader `read` and Writer `write` with `readv` and `writev`, respectively. add std.io.AnyStream. add tls server signature suites. temporarily remove std.options.http_disable_tls --- lib/std/array_list.zig | 45 ++++-- lib/std/bounded_array.zig | 13 +- lib/std/compress.zig | 30 ++-- lib/std/compress/flate/deflate.zig | 20 ++- lib/std/compress/flate/inflate.zig | 17 ++- lib/std/compress/lzma.zig | 35 +++-- lib/std/compress/xz.zig | 6 +- lib/std/compress/zstandard.zig | 6 +- lib/std/compress/zstandard/readers.zig | 7 +- lib/std/crypto/25519/ed25519.zig | 6 +- lib/std/crypto/Certificate.zig | 14 +- lib/std/crypto/ecdsa.zig | 4 +- lib/std/crypto/sha2.zig | 12 +- lib/std/crypto/tls.zig | 82 +++++----- lib/std/crypto/tls/Client.zig | 38 ++--- lib/std/crypto/tls/Server.zig | 201 +++++++++++++++++------- lib/std/crypto/tls/Stream.zig | 132 ++++++++-------- lib/std/fifo.zig | 25 ++- lib/std/fs/File.zig | 4 +- lib/std/http/Client.zig | 203 ++++++++++--------------- lib/std/http/Server.zig | 77 ++++------ lib/std/http/protocol.zig | 4 +- lib/std/http/test.zig | 73 ++++----- lib/std/io.zig | 61 +++++--- lib/std/io/Reader.zig | 32 ++-- lib/std/io/Stream.zig | 26 ++-- lib/std/io/Writer.zig | 30 +++- lib/std/io/buffered_reader.zig | 40 ++--- lib/std/io/buffered_tee.zig | 52 ++++--- lib/std/io/buffered_writer.zig | 27 ++-- lib/std/io/counting_reader.zig | 10 +- lib/std/io/counting_writer.zig | 6 +- lib/std/io/fixed_buffer_stream.zig | 44 +++--- lib/std/io/limited_reader.zig | 13 +- lib/std/io/peek_stream.zig | 20 ++- lib/std/io/stream_source.zig | 18 +-- lib/std/json/stringify_test.zig | 6 +- lib/std/net.zig | 93 ++++------- lib/std/net/test.zig | 22 +-- lib/std/os.zig | 2 +- lib/std/std.zig | 7 - lib/std/tar.zig | 19 ++- lib/std/zig/render.zig | 15 +- 43 files changed, 873 insertions(+), 724 deletions(-) diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index ff2307e8124d..ce2df2098f81 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -344,7 +344,7 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?u29) type { @compileError("The Writer interface is only defined for ArrayList(u8) " ++ "but the given type is ArrayList(" ++ @typeName(T) ++ ")") else - std.io.Writer(*Self, Allocator.Error, appendWrite); + std.io.Writer(*Self, Allocator.Error, appendWritev); /// Initializes a Writer which will append to the list. pub fn writer(self: *Self) Writer { @@ -354,9 +354,13 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?u29) type { /// Same as `append` except it returns the number of bytes written, which is always the same /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. - fn appendWrite(self: *Self, m: []const u8) Allocator.Error!usize { - try self.appendSlice(m); - return m.len; + fn appendWritev(self: *Self, iov: []std.os.iovec_const) Allocator.Error!usize { + var written: usize = 0; + for (iov) |v| { + try self.appendSlice(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } /// Append a value to the list `n` times. @@ -930,7 +934,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ @compileError("The Writer interface is only defined for ArrayList(u8) " ++ "but the given type is ArrayList(" ++ @typeName(T) ++ ")") else - std.io.Writer(WriterContext, Allocator.Error, appendWrite); + std.io.Writer(WriterContext, Allocator.Error, appendWritev); /// Initializes a Writer which will append to the list. pub fn writer(self: *Self, allocator: Allocator) Writer { @@ -941,12 +945,16 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ /// which is always the same as `m.len`. The purpose of this function /// existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. - fn appendWrite(context: WriterContext, m: []const u8) Allocator.Error!usize { - try context.self.appendSlice(context.allocator, m); - return m.len; + fn appendWritev(context: WriterContext, iov: []std.os.iovec_const) Allocator.Error!usize { + var written: usize = 0; + for (iov) |v| { + try context.self.appendSlice(context.allocator, v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } - pub const FixedWriter = std.io.Writer(*Self, Allocator.Error, appendWriteFixed); + pub const FixedWriter = std.io.Writer(*Self, Allocator.Error, appendWritevFixed); /// Initializes a Writer which will append to the list but will return /// `error.OutOfMemory` rather than increasing capacity. @@ -955,13 +963,18 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ } /// The purpose of this function existing is to match `std.io.Writer` API. - fn appendWriteFixed(self: *Self, m: []const u8) error{OutOfMemory}!usize { - const available_capacity = self.capacity - self.items.len; - if (m.len > available_capacity) - return error.OutOfMemory; - - self.appendSliceAssumeCapacity(m); - return m.len; + fn appendWritevFixed(self: *Self, iov: []std.os.iovec_const) error{OutOfMemory}!usize { + var written: usize = 0; + for (iov) |v| { + const m = v.iov_base[0..v.iov_len]; + const available_capacity = self.capacity - self.items.len; + if (m.len > available_capacity) + return error.OutOfMemory; + + self.appendSliceAssumeCapacity(m); + written += m.len; + } + return written; } /// Append a value to the list `n` times. diff --git a/lib/std/bounded_array.zig b/lib/std/bounded_array.zig index 9867754dcd0a..b5d04cac9d33 100644 --- a/lib/std/bounded_array.zig +++ b/lib/std/bounded_array.zig @@ -271,7 +271,7 @@ pub fn BoundedArrayAligned( @compileError("The Writer interface is only defined for BoundedArray(u8, ...) " ++ "but the given type is BoundedArray(" ++ @typeName(T) ++ ", ...)") else - std.io.Writer(*Self, error{Overflow}, appendWrite); + std.io.Writer(*Self, error{Overflow}, appendWritev); /// Initializes a writer which will write into the array. pub fn writer(self: *Self) Writer { @@ -280,9 +280,14 @@ pub fn BoundedArrayAligned( /// Same as `appendSlice` except it returns the number of bytes written, which is always the same /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. - fn appendWrite(self: *Self, m: []const u8) error{Overflow}!usize { - try self.appendSlice(m); - return m.len; + fn appendWritev(self: *Self, iov: []std.os.iovec_const) error{Overflow}!usize { + var written: usize = 0; + for (iov) |v| { + const m = v.iov_base[0..v.iov_len]; + try self.appendSlice(m); + written += m.len; + } + return written; } }; } diff --git a/lib/std/compress.zig b/lib/std/compress.zig index 200489c18ab5..0ac8211f6907 100644 --- a/lib/std/compress.zig +++ b/lib/std/compress.zig @@ -19,12 +19,18 @@ pub fn HashedReader( hasher: HasherType, pub const Error = ReaderType.Error; - pub const Reader = std.io.Reader(*@This(), Error, read); + pub const Reader = std.io.Reader(*@This(), Error, readv); - pub fn read(self: *@This(), buf: []u8) Error!usize { - const amt = try self.child_reader.read(buf); - self.hasher.update(buf[0..amt]); - return amt; + pub fn readv(self: *@This(), iov: []std.os.iovec) Error!usize { + const n_read = try self.child_reader.readv(iov); + var hashed_amt: usize = 0; + for (iov) |v| { + const to_hash = @min(n_read - hashed_amt, v.iov_len); + if (to_hash == 0) break; + self.hasher.update(v.iov_base[0..to_hash]); + hashed_amt += to_hash; + } + return n_read; } pub fn reader(self: *@This()) Reader { @@ -51,10 +57,16 @@ pub fn HashedWriter( pub const Error = WriterType.Error; pub const Writer = std.io.Writer(*@This(), Error, write); - pub fn write(self: *@This(), buf: []const u8) Error!usize { - const amt = try self.child_writer.write(buf); - self.hasher.update(buf[0..amt]); - return amt; + pub fn write(self: *@This(), iov: []std.os.iovec_const) Error!usize { + const n_written = try self.child_writer.writev(iov); + var hashed_amt: usize = 0; + for (iov) |v| { + const to_hash = @min(n_written - hashed_amt, v.iov_len); + if (to_hash == 0) break; + self.hasher.update(v.iov_base[0..to_hash]); + hashed_amt += to_hash; + } + return n_written; } pub fn writer(self: *@This()) Writer { diff --git a/lib/std/compress/flate/deflate.zig b/lib/std/compress/flate/deflate.zig index 794ab02247c3..6f1c744b3fd8 100644 --- a/lib/std/compress/flate/deflate.zig +++ b/lib/std/compress/flate/deflate.zig @@ -354,16 +354,20 @@ fn Deflate(comptime container: Container, comptime WriterType: type, comptime Bl } // Writer interface - - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); pub const Error = BlockWriterType.Error; /// Write `input` of uncompressed data. /// See compress. - pub fn write(self: *Self, input: []const u8) !usize { - var fbs = io.fixedBufferStream(input); - try self.compress(fbs.reader()); - return input.len; + pub fn writev(self: *Self, iov: []std.os.iovec_const) !usize { + var written: usize = 0; + for (iov) |v| { + const input = v.iov_base[0..v.iov_len]; + var fbs = io.fixedBufferStream(input); + try self.compress(fbs.reader()); + written += input.len; + } + return written; } pub fn writer(self: *Self) Writer { @@ -558,7 +562,7 @@ test "tokenization" { const cww = cw.writer(); var df = try Deflate(container, @TypeOf(cww), TestTokenWriter).init(cww, .{}); - _ = try df.write(c.data); + _ = try df.writer().write(c.data); try df.flush(); // df.token_writer.show(); @@ -579,6 +583,8 @@ const TestTokenWriter = struct { pos: usize = 0, actual: [128]Token = undefined, + pub const Error = error{}; + pub fn init(_: anytype) Self { return .{}; } diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig index cf23961b2132..93f3a30cea0d 100644 --- a/lib/std/compress/flate/inflate.zig +++ b/lib/std/compress/flate/inflate.zig @@ -325,7 +325,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comp /// Returns decompressed data from internal sliding window buffer. /// Returned buffer can be any length between 0 and `limit` bytes. 0 /// returned bytes means end of stream reached. With limit=0 returns as - /// much data it can. It newer will be more than 65536 bytes, which is + /// much data it can. It never will be more than 65536 bytes, which is /// size of internal buffer. pub fn get(self: *Self, limit: usize) Error![]const u8 { while (true) { @@ -340,16 +340,19 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comp } // Reader interface - - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - pub fn read(self: *Self, buffer: []u8) Error!usize { - const out = try self.get(buffer.len); - @memcpy(buffer[0..out.len], out); - return out.len; + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + var read: usize = 0; + for (iov) |v| { + const out = try self.get(v.iov_len); + @memcpy(v.iov_base[0..out.len], out); + read += out.len; + } + return read; } pub fn reader(self: *Self) Reader { diff --git a/lib/std/compress/lzma.zig b/lib/std/compress/lzma.zig index ff05bc1c8beb..2bd31b636ebe 100644 --- a/lib/std/compress/lzma.zig +++ b/lib/std/compress/lzma.zig @@ -30,7 +30,7 @@ pub fn Decompress(comptime ReaderType: type) type { Allocator.Error || error{ CorruptInput, EndOfStream, Overflow }; - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); allocator: Allocator, in_reader: ReaderType, @@ -63,23 +63,28 @@ pub fn Decompress(comptime ReaderType: type) type { self.* = undefined; } - pub fn read(self: *Self, output: []u8) Error!usize { + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { const writer = self.to_read.writer(self.allocator); - while (self.to_read.items.len < output.len) { - switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { - .continue_ => {}, - .finished => { - try self.buffer.finish(writer); - break; - }, + var n_read: usize = 0; + for (iov) |v| { + const output = v.iov_base[0..v.iov_len]; + while (self.to_read.items.len < output.len) { + switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) { + .continue_ => {}, + .finished => { + try self.buffer.finish(writer); + break; + }, + } } + const input = self.to_read.items; + const n = @min(input.len, output.len); + @memcpy(output[0..n], input[0..n]); + @memcpy(input[0 .. input.len - n], input[n..]); + self.to_read.shrinkRetainingCapacity(input.len - n); + n_read += n; } - const input = self.to_read.items; - const n = @min(input.len, output.len); - @memcpy(output[0..n], input[0..n]); - @memcpy(input[0 .. input.len - n], input[n..]); - self.to_read.shrinkRetainingCapacity(input.len - n); - return n; + return n_read; } }; } diff --git a/lib/std/compress/xz.zig b/lib/std/compress/xz.zig index e844c234ffc8..6458a7e7ca2a 100644 --- a/lib/std/compress/xz.zig +++ b/lib/std/compress/xz.zig @@ -34,7 +34,7 @@ pub fn Decompress(comptime ReaderType: type) type { const Self = @This(); pub const Error = ReaderType.Error || block.Decoder(ReaderType).Error; - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); allocator: Allocator, block_decoder: block.Decoder(ReaderType), @@ -71,7 +71,9 @@ pub fn Decompress(comptime ReaderType: type) type { return .{ .context = self }; } - pub fn read(self: *Self, buffer: []u8) Error!usize { + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; if (buffer.len == 0) return 0; diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index 9092a2d13083..b1b283e5afd4 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -50,7 +50,7 @@ pub fn Decompressor(comptime ReaderType: type) type { OutOfMemory, }; - pub const Reader = std.io.Reader(*Self, Error, read); + pub const Reader = std.io.Reader(*Self, Error, readv); pub fn init(source: ReaderType, options: DecompressorOptions) Self { return .{ @@ -105,7 +105,9 @@ pub fn Decompressor(comptime ReaderType: type) type { return .{ .context = self }; } - pub fn read(self: *Self, buffer: []u8) Error!usize { + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; if (buffer.len == 0) return 0; var size: usize = 0; diff --git a/lib/std/compress/zstandard/readers.zig b/lib/std/compress/zstandard/readers.zig index f95573f77bbf..27590cef9300 100644 --- a/lib/std/compress/zstandard/readers.zig +++ b/lib/std/compress/zstandard/readers.zig @@ -4,7 +4,7 @@ pub const ReversedByteReader = struct { remaining_bytes: usize, bytes: []const u8, - const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn); + const Reader = std.io.Reader(*ReversedByteReader, error{}, readvFn); pub fn init(bytes: []const u8) ReversedByteReader { return .{ @@ -17,7 +17,10 @@ pub const ReversedByteReader = struct { return .{ .context = self }; } - fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize { + fn readvFn(ctx: *ReversedByteReader, iov: []std.os.iovec) !usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; + std.debug.assert(buffer.len > 0); if (ctx.remaining_bytes == 0) return 0; const byte_index = ctx.remaining_bytes - 1; buffer[0] = ctx.bytes[byte_index]; diff --git a/lib/std/crypto/25519/ed25519.zig b/lib/std/crypto/25519/ed25519.zig index d7b51271d29d..025a980a8e1d 100644 --- a/lib/std/crypto/25519/ed25519.zig +++ b/lib/std/crypto/25519/ed25519.zig @@ -21,8 +21,8 @@ pub const Ed25519 = struct { /// Length (in bytes) of optional random bytes, for non-deterministic signatures. pub const noise_length = 32; - const CompressedScalar = Curve.scalar.CompressedScalar; - const Scalar = Curve.scalar.Scalar; + pub const CompressedScalar = Curve.scalar.CompressedScalar; + pub const Scalar = Curve.scalar.Scalar; /// An Ed25519 secret key. pub const SecretKey = struct { @@ -73,7 +73,7 @@ pub const Ed25519 = struct { nonce: CompressedScalar, r_bytes: [Curve.encoded_length]u8, - fn init(scalar: CompressedScalar, nonce: CompressedScalar, public_key: PublicKey) (IdentityElementError || KeyMismatchError || NonCanonicalError || WeakPublicKeyError)!Signer { + pub fn init(scalar: CompressedScalar, nonce: CompressedScalar, public_key: PublicKey) (IdentityElementError || WeakPublicKeyError)!Signer { const r = try Curve.basePoint.mul(nonce); const r_bytes = r.toBytes(); diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 1b483f28acaa..80aa355fce55 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -220,10 +220,6 @@ pub const Parsed = struct { return p.slice(p.pub_key_slice); } - pub fn pubKeySigAlgo(p: Parsed) []const u8 { - return p.slice(p.pub_key_signature_algorithm_slice); - } - pub fn message(p: Parsed) []const u8 { return p.slice(p.message_slice); } @@ -973,7 +969,7 @@ pub const rsa = struct { comptime modulus_len: usize, msg: []const u8, comptime Hash: type, - private_key: PrivateKey, + private_key: SecretKey, salt: [Hash.digest_length]u8, ) ![modulus_len]u8 { const mod_bits = modulus_len * 8; @@ -1273,12 +1269,12 @@ pub const rsa = struct { } }; - pub const PrivateKey = struct { + pub const SecretKey = struct { public: PublicKey, /// private exponent d: Fe, - pub fn fromBytes(mod: []const u8, public: []const u8, private: []const u8) !PrivateKey { + pub fn fromBytes(mod: []const u8, public: []const u8, private: []const u8) !SecretKey { const _public = try PublicKey.fromBytes(mod, public); const _d = Fe.fromBytes(_public.n, private, .big) catch return error.CertificatePrivateKeyInvalid; @@ -1288,7 +1284,7 @@ pub const rsa = struct { } // RFC8017 Appendix A.1.2 - pub fn fromDer(bytes: []const u8) !PrivateKey { + pub fn fromDer(bytes: []const u8) !SecretKey { const seq = try der.Element.parse(bytes, 0); if (seq.identifier.tag != .sequence) return error.PrivateKeyWrongDataType; @@ -1312,7 +1308,7 @@ pub const rsa = struct { return try fromBytes(modulus, pub_exponent, priv_exponent); } - fn decrypt(self: PrivateKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + fn decrypt(self: SecretKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { const m = Fe.fromBytes(self.public.n, &msg, .big) catch return error.MessageTooLong; const e = self.public.n.pow(m, self.d) catch unreachable; var res: [modulus_len]u8 = undefined; diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 70362470c3ef..16ac99bba9cf 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -17,7 +17,7 @@ pub const EcdsaP256Sha256 = Ecdsa(crypto.ecc.P256, crypto.hash.sha2.Sha256); pub const EcdsaP256Sha3_256 = Ecdsa(crypto.ecc.P256, crypto.hash.sha3.Sha3_256); /// ECDSA over P-384 with SHA-384. pub const EcdsaP384Sha384 = Ecdsa(crypto.ecc.P384, crypto.hash.sha2.Sha384); -/// ECDSA over P-384 with SHA3-384. +/// ECDSA over P-256 with SHA3-384. pub const EcdsaP256Sha3_384 = Ecdsa(crypto.ecc.P384, crypto.hash.sha3.Sha3_384); /// ECDSA over Secp256k1 with SHA-256. pub const EcdsaSecp256k1Sha256 = Ecdsa(crypto.ecc.Secp256k1, crypto.hash.sha2.Sha256); @@ -183,7 +183,7 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { secret_key: SecretKey, noise: ?[noise_length]u8, - fn init(secret_key: SecretKey, noise: ?[noise_length]u8) !Signer { + pub fn init(secret_key: SecretKey, noise: ?[noise_length]u8) Signer { return Signer{ .h = Hash.init(.{}), .secret_key = secret_key, diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig index 31884c73818a..0287de90ed6c 100644 --- a/lib/std/crypto/sha2.zig +++ b/lib/std/crypto/sha2.zig @@ -392,11 +392,15 @@ fn Sha2x32(comptime params: Sha2Params32) type { } pub const Error = error{}; - pub const Writer = std.io.Writer(*Self, Error, write); + pub const Writer = std.io.Writer(*Self, Error, writev); - fn write(self: *Self, bytes: []const u8) Error!usize { - self.update(bytes); - return bytes.len; + fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + self.update(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn writer(self: *Self) Writer { diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 334e69eab07d..9d450132fa17 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -116,7 +116,7 @@ pub const Handshake = union(HandshakeType) { // If `HandshakeCipherT.encode` accepts iovecs for the message this can be moved // to `Stream.writeFragment` and this type can be deleted. - pub fn write(self: @This(), stream: anytype) !usize { + pub fn write(self: @This(), stream: *Stream) !usize { var res: usize = 0; res += try stream.write(HandshakeType, self); switch (self) { @@ -164,6 +164,7 @@ pub const KeyUpdate = enum(u8) { _, }; +/// A DER encoded certificate chain with the first entry being for this domain. pub const Certificate = struct { context: []const u8 = "", entries: []const Entry, @@ -171,13 +172,13 @@ pub const Certificate = struct { pub const max_context_len = 255; pub const Entry = struct { - /// Either ASN1_subjectPublicKeyInfo or cert_data based on CertificateType. + /// DER encoded data: []const u8, extensions: []const Extension = &.{}, pub const max_data_len = 1 << 24 - 1; - pub fn write(self: @This(), stream: anytype) !usize { + pub fn write(self: @This(), stream: *Stream) !usize { var res: usize = 0; res += try stream.writeArray(u24, u8, self.data); res += try stream.writeArray(u16, Extension, self.extensions); @@ -187,7 +188,7 @@ pub const Certificate = struct { const Self = @This(); - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { var res: usize = 0; res += try stream.writeArray(u8, u8, self.context); res += try stream.writeArray(u24, Entry, self.entries); @@ -201,7 +202,7 @@ pub const CertificateVerify = struct { pub const max_signature_length = 1 << 16 - 1; - pub fn write(self: @This(), stream: anytype) !usize { + pub fn write(self: @This(), stream: *Stream) !usize { var res: usize = 0; res += try stream.write(SignatureScheme, self.algorithm); res += try stream.writeArray(u16, u8, self.signature); @@ -456,13 +457,13 @@ pub const Alert = struct { const Self = @This(); - pub fn read(stream: anytype) Self { + pub fn read(stream: *Stream) Self { const level = try stream.read(Level); const description = try stream.read(Description); return .{ .level = level, .description = description }; } - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { var res: usize = 0; res += try stream.write(Level, self.level); res += try stream.write(Description, self.description); @@ -475,26 +476,21 @@ pub const Alert = struct { /// Note: This enum is named `SignatureScheme` because there is already a /// `SignatureAlgorithm` type in TLS 1.2, which this replaces. pub const SignatureScheme = enum(u16) { - // RSASSA-PKCS1-v1_5 algorithms rsa_pkcs1_sha256 = 0x0401, rsa_pkcs1_sha384 = 0x0501, rsa_pkcs1_sha512 = 0x0601, - // ECDSA algorithms ecdsa_secp256r1_sha256 = 0x0403, ecdsa_secp384r1_sha384 = 0x0503, ecdsa_secp521r1_sha512 = 0x0603, - // RSASSA-PSS algorithms with public key OID rsaEncryption rsa_pss_rsae_sha256 = 0x0804, rsa_pss_rsae_sha384 = 0x0805, rsa_pss_rsae_sha512 = 0x0806, - // EdDSA algorithms ed25519 = 0x0807, ed448 = 0x0808, - // RSASSA-PSS algorithms with public key OID RSASSA-PSS rsa_pss_pss_sha256 = 0x0809, rsa_pss_pss_sha384 = 0x080a, rsa_pss_pss_sha512 = 0x080b, @@ -515,9 +511,9 @@ pub const SignatureScheme = enum(u16) { pub fn Hash(comptime self: @This()) type { return switch (self) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, + .ecdsa_secp256r1_sha256, .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, + .ecdsa_secp384r1_sha384, .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, + .ecdsa_secp521r1_sha512, .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, else => @compileError("bad scheme"), }; } @@ -576,6 +572,19 @@ pub const X25519Kyber768Draft = struct { pub const KeyPair = struct { x25519: X25519.KeyPair, kyber768d00: Kyber768.KeyPair, + + pub const seed_length = X25519.KeyPair.seed_length + Kyber768.KeyPair.seed_length; + + pub fn create(seed: ?[seed_length]u8) !@This() { + var seed_: [seed_length]u8 = seed orelse undefined; + if (seed == null) { + crypto.random.bytes(&seed_); + } + return .{ + .x25519 = try X25519.KeyPair.create(seed_[0..X25519.KeyPair.seed_length].*), + .kyber768d00 = try Kyber768.KeyPair.create(seed_[X25519.KeyPair.seed_length..].*), + }; + } }; pub const PublicKey = struct { x25519: X25519.PublicKey, @@ -641,10 +650,10 @@ pub const KeyShare = union(NamedGroup) { const Self = @This(); - pub fn read(stream: anytype) !Self { + pub fn read(stream: *Stream) !Self { std.debug.assert(!stream.is_client); - var reader = stream.reader(); + var reader = stream.any().reader(); const group = try stream.read(NamedGroup); const len = try stream.read(u16); switch (group) { @@ -679,7 +688,7 @@ pub const KeyShare = union(NamedGroup) { return .{ .invalid = {} }; } - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { var res: usize = 0; res += try stream.write(NamedGroup, self); const public = switch (self) { @@ -862,7 +871,7 @@ pub const ClientHello = struct { const Self = @This(); - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { var res: usize = 0; res += try stream.write(Version, self.version); res += try stream.writeAll(&self.random); @@ -894,7 +903,7 @@ pub const ServerHello = struct { const Self = @This(); - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { var res: usize = 0; res += try stream.write(Version, self.version); res += try stream.writeAll(&self.random); @@ -911,7 +920,7 @@ pub const EncryptedExtensions = struct { const Self = @This(); - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { return try stream.writeArray(u16, Extension, self.extensions); } }; @@ -953,7 +962,7 @@ pub const Extension = union(ExtensionType) { const Self = @This(); - pub fn write(self: Self, stream: anytype) !usize { + pub fn write(self: Self, stream: *Stream) !usize { const PrefixLen = enum { zero, one, two }; const prefix_len: PrefixLen = if (stream.is_client) switch (self) { .supported_versions, .ec_point_formats, .psk_key_exchange_modes => .one, @@ -996,7 +1005,7 @@ pub const Extension = union(ExtensionType) { type: ExtensionType, len: u16, - pub fn read(stream: anytype) @TypeOf(stream.*).ReadError!@This() { + pub fn read(stream: *Stream) @TypeOf(stream.*).ReadError!@This() { const ty = try stream.read(ExtensionType); const length = try stream.read(u16); return .{ .type = ty, .len = length }; @@ -1023,7 +1032,7 @@ pub const ServerName = struct { pub const NameType = enum(u8) { host_name = 0, _ }; - pub fn write(self: @This(), stream: anytype) !usize { + pub fn write(self: @This(), stream: *Stream) !usize { var res: usize = 0; res += try stream.write(NameType, self.type); res += try stream.writeArray(u16, u8, self.host_name); @@ -1328,14 +1337,19 @@ const TestStream = struct { self.buffer.deinit(allocator); } - pub fn read(self: *Self, buffer: []u8) ReadError!usize { - try self.buffer.readFirst(buffer, buffer.len); - return buffer.len; + pub fn readv(self: *Self, iov: []std.os.iovec) ReadError!usize { + const first = iov[0]; + try self.buffer.readFirst(first.iov_base[0..first.iov_len], first.iov_len); + return first.iov_len; } - pub fn write(self: *Self, bytes: []const u8) WriteError!usize { - try self.buffer.writeSlice(bytes); - return bytes.len; + pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + var written: usize = 0; + for (iov) |v| { + try self.buffer.writeSlice(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn peek(self: *Self, out: []u8) ReadError!void { @@ -1356,7 +1370,7 @@ const TestStream = struct { try std.testing.expectEqualSlices(u8, expected, buf); } - const GenericStream = std.io.GenericStream(*Self, ReadError, read, WriteError, write, close); + const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); pub fn stream(self: *Self) GenericStream { return .{ .context = self }; @@ -1382,7 +1396,7 @@ test "tls client and server handshake, data, and close_notify" { const server_cert = @embedFile("./testdata/cert.der"); const server_key = @embedFile("./testdata/key.der"); - const server_rsa = try crypto.Certificate.rsa.PrivateKey.fromDer(server_key); + const server_rsa = try crypto.Certificate.rsa.SecretKey.fromDer(server_key); var server_transcript: MultiHash = .{}; var server = Server{ .stream = Stream{ @@ -1493,10 +1507,10 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } - try client.writer().writeAll("ping"); + try client.any().writer().writeAll("ping"); var recv_ping: [4]u8 = undefined; - _ = try server.stream.reader().readAll(&recv_ping); + _ = try server.stream.any().reader().readAll(&recv_ping); try std.testing.expectEqualStrings("ping", &recv_ping); server.stream.close(); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 413f2de4bfec..3dc730ecc083 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -126,7 +126,7 @@ pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { var stream = &self.stream; - var r = stream.reader(); + var r = stream.any().reader(); // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. _ = try stream.read(tls.Version); @@ -239,7 +239,7 @@ pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { pub fn recv_encrypted_extensions(self: *Self) !void { var stream = &self.stream; - var r = stream.reader(); + var r = stream.any().reader(); var iter = try stream.extensions(); while (try iter.next()) |ext| { @@ -252,7 +252,7 @@ pub fn recv_encrypted_extensions(self: *Self) !void { /// Caller owns allocated Certificate.Parsed.certificate. pub fn recv_certificate(self: *Self) !Certificate.Parsed { var stream = &self.stream; - var r = stream.reader(); + var r = stream.any().reader(); const allocator = self.options.allocator; const ca_bundle = self.options.ca_bundle; const verify = ca_bundle != null; @@ -315,7 +315,7 @@ pub fn recv_certificate(self: *Self) !Certificate.Parsed { pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificate.Parsed) !void { var stream = &self.stream; - var r = stream.reader(); + var r = stream.any().reader(); const allocator = self.options.allocator; const sig_content = tls.sigContent(digest); @@ -385,7 +385,7 @@ pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificat pub fn recv_finished(self: *Self, digest: []const u8) !void { var stream = &self.stream; - var r = stream.reader(); + var r = stream.any().reader(); const cipher = stream.cipher.handshake; switch (cipher) { @@ -434,7 +434,7 @@ pub const ReadError = anyerror; pub const WriteError = anyerror; /// Reads next application_data message. -pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { +pub fn readv(self: *Self, buffers: []std.os.iovec) ReadError!usize { var stream = &self.stream; if (stream.eof()) return 0; @@ -446,7 +446,7 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { switch (inner_plaintext.handshake_type) { // A multithreaded client could use these. .new_session_ticket => { - try stream.reader().skipBytes(inner_plaintext.len, .{}); + try stream.any().reader().skipBytes(inner_plaintext.len, .{}); }, .key_update => { switch (stream.cipher.application) { @@ -476,22 +476,18 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { else => return stream.writeError(.unexpected_message), } }, + .alert => {}, .application_data => {}, else => return stream.writeError(.unexpected_message), } } - return try self.stream.readv(buffers); + return try stream.readv(buffers); } -pub fn read(self: *Self, buf: []u8) ReadError!usize { - const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; - return try self.readv(&buffers); -} - -pub fn write(self: *Self, buf: []const u8) WriteError!usize { +pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { if (self.stream.eof()) return 0; - const res = try self.stream.writeBytes(buf); + const res = try self.stream.writev(iov); try self.stream.flush(); return res; } @@ -500,14 +496,9 @@ pub fn close(self: *Self) void { self.stream.close(); } -pub const Reader = std.io.Reader(*Self, ReadError, read); -pub const Writer = std.io.Writer(*Self, WriteError, write); - -pub fn reader(self: *Self) Reader { - return .{ .context = self }; -} +pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); -pub fn writer(self: *Self) Writer { +pub fn any(self: *Self) GenericStream { return .{ .context = self }; } @@ -604,7 +595,8 @@ pub const KeyPairs = struct { } }; -/// A command to send or receive a single message. Allows testing `advance` on a single thread. +/// A command to send or receive a single message. Allows deterministically +/// testing `advance` on a single thread. pub const Command = union(enum) { send_hello: KeyPairs, recv_hello: KeyPairs, diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index ee6afcaa7810..6994e14898f3 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -16,17 +16,37 @@ const Self = @This(); /// Initiates a TLS handshake and establishes a TLSv1.3 session pub fn init(stream: std.io.AnyStream, options: Options) !Self { var transcript_hash: tls.MultiHash = .{}; - var stream_ = tls.Stream{ + const stream_ = tls.Stream{ .stream = stream, .is_client = false, .transcript_hash = &transcript_hash, }; var res = Self{ .stream = stream_, .options = options }; - const client_hello = try res.recv_hello(&stream_); - _ = client_hello; - var command = Command{ .recv_hello = {} }; - while (command != .none) command = try res.next(command); + // Verify that the certificate key matches the certificate. + const cert_buf = Certificate{ .buffer = options.certificate.entries[0].data, .index = 0 }; + // TODO: don't reparse cert in send_certificate_verify + const cert = try cert_buf.parse(); + const expected: std.meta.Tag(Options.CertificateKey) = switch (cert.pub_key_algo) { + .rsaEncryption => .rsa, + .X9_62_id_ecPublicKey => |curve| switch (curve) { + .X9_62_prime256v1 => .ecdsa256, + .secp384r1 => .ecdsa384, + else => return error.UnsupportedCertificateSignature, + }, + .curveEd25519 => .ed25519, + }; + if (expected != options.certificate_key) return error.CertificateKeyMismatch; + // TODO: verify private key corresponds to public key + + const cmd_init = Command{ .recv_hello = {} }; + var command = cmd_init; + while (command != .none) { + command = res.next(command) catch |err| switch (err) { + error.ConnectionResetByPeer => cmd_init, + else => return err, + }; + } return res; } @@ -87,7 +107,7 @@ pub fn next(self: *Self, command: Command) !Command { pub fn recv_hello(self: *Self) !ClientHello { var stream = &self.stream; - var reader = stream.reader(); + var reader = stream.any().reader(); try stream.expectInnerPlaintext(.handshake, .client_hello); @@ -142,9 +162,8 @@ pub fn recv_hello(self: *Self) !ClientHello { var key_share_iter = try stream.iterator(u16, tls.KeyShare); while (try key_share_iter.next()) |ks| { - switch (ks) { - .x25519 => key_share = ks, - else => {}, + for (self.options.key_shares) |s| { + if (ks == s and key_share == null) key_share = ks; } } }, @@ -155,9 +174,20 @@ pub fn recv_hello(self: *Self) !ClientHello { } }, .signature_algorithms => { + const acceptable = switch (self.options.certificate_key) { + .rsa => &[_]tls.SignatureScheme{ + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha256, + }, + .ecdsa256 => &[_]tls.SignatureScheme{ .ecdsa_secp256r1_sha256 }, + .ecdsa384 => &[_]tls.SignatureScheme{ .ecdsa_secp384r1_sha384 }, + .ed25519 => &[_]tls.SignatureScheme{ .ed25519 }, + }; var algos_iter = try stream.iterator(u16, tls.SignatureScheme); while (try algos_iter.next()) |algo| { - if (algo == .rsa_pss_rsae_sha256) sig_scheme = algo; + for (acceptable) |a| { + if (algo == a and sig_scheme == null) sig_scheme = algo; + } } }, else => { @@ -174,8 +204,17 @@ pub fn recv_hello(self: *Self) !ClientHello { var server_random: [32]u8 = undefined; crypto.random.bytes(&server_random); - const key_pair = .{ - .x25519 = crypto.dh.X25519.KeyPair.create(server_random) catch unreachable, + const key_pair: tls.KeyPair = switch (key_share.?) { + inline + .secp256r1, + .secp384r1, + .x25519, + .x25519_kyber768d00, + => |_, tag| brk: { + const pair = tls.NamedGroupT(tag).KeyPair.create(null) catch unreachable; + break :brk @unionInit(tls.KeyPair, @tagName(tag), pair); + }, + else => return stream.writeError(.decode_error), }; return .{ @@ -266,68 +305,66 @@ pub fn send_certificate_verify(self: *Self, verify: Command.CertificateVerify) ! const digest = stream.transcript_hash.?.peek(); const sig_content = tls.sigContent(digest); - const key = self.options.certificate_key; - const cert_buf = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; - const cert = cert_buf.parse() catch return stream.writeError(.bad_certificate); - const signature: []const u8 = switch (verify.scheme) { - // inline .ecdsa_secp256r1_sha256, - // .ecdsa_secp384r1_sha384, - // => |comptime_scheme| { - // if (cert.pub_key_algo != .X9_62_id_ecPublicKey) - // return stream.writeError(.bad_certificate); - // const Ecdsa = comptime_scheme.Ecdsa(); - // const sig = Ecdsa.Signature.fromDer(sig_bytes) catch - // return stream.writeError(.decode_error); - // const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch - // return stream.writeError(.decode_error); - // sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); - // }, + inline .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384 => |comptime_scheme| brk: { + const Ecdsa = comptime_scheme.Ecdsa(); + const key = switch (comptime_scheme) { + .ecdsa_secp256r1_sha256 => self.options.certificate_key.ecdsa256, + .ecdsa_secp384r1_sha384 => self.options.certificate_key.ecdsa384, + else => unreachable, + }; + + var signer = Ecdsa.Signer.init(key, verify.salt[0..Ecdsa.noise_length].*); + signer.update(sig_content); + const sig = signer.finalize() catch return stream.writeError(.internal_error); + break :brk &sig.toBytes(); + }, inline .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, => |comptime_scheme| brk: { - if (cert.pub_key_algo != .rsaEncryption) - return stream.writeError(.bad_certificate); - const Hash = comptime_scheme.Hash(); - const rsa = Certificate.rsa; - // if (!std.mem.eql(u8, cert.pubKey(), key.public)) - // return stream.writeError(.bad_certificate); + const key = self.options.certificate_key.rsa; switch (key.public.n.bits() / 8) { inline 128, 256, 512 => |modulus_length| { - break :brk &(rsa.PSSSignature.sign( + const sig = Certificate.rsa.PSSSignature.sign( modulus_length, sig_content, Hash, key, verify.salt[0..Hash.digest_length].*, - ) catch return stream.writeError(.bad_certificate)); + ) catch return stream.writeError(.bad_certificate); + break :brk &sig; }, else => return stream.writeError(.bad_certificate), } }, - // inline .ed25519 => |comptime_scheme| { - // if (cert.pub_key_algo != .curveEd25519) - // return stream.writeError(.bad_certificate); - // const Eddsa = comptime_scheme.Eddsa(); - // if (sig_content.len != Eddsa.Signature.encoded_length) - // return stream.writeError(.decode_error); - // const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); - // if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) - // return stream.writeError(.decode_error); - // const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch - // return stream.writeError(.bad_certificate); - // sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); - // }, + .ed25519 => brk: { + const Ed25519 = crypto.sign.Ed25519; + const key = self.options.certificate_key.ed25519; + + const pub_key = brk2: { + const cert_buf = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; + const cert = try cert_buf.parse(); + const expected_len = Ed25519.PublicKey.encoded_length; + if (cert.pubKey().len != expected_len) return stream.writeError(.bad_certificate); + break :brk2 Ed25519.PublicKey.fromBytes(cert.pubKey()[0..expected_len].*) catch + return stream.writeError(.bad_certificate); + }; + const nonce: Ed25519.CompressedScalar = verify.salt[0..Ed25519.noise_length].*; + + const key_pair = Ed25519.KeyPair{ .public_key = pub_key, .secret_key = key }; + const sig = key_pair.sign(sig_content, nonce) catch return stream.writeError(.internal_error); + break :brk &sig.toBytes(); + }, else => { return stream.writeError(.bad_certificate); }, }; _ = try self.stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ - .algorithm = .rsa_pss_rsae_sha256, + .algorithm = verify.scheme, .signature = signature, } }); try stream.flush(); @@ -336,23 +373,21 @@ pub fn send_certificate_verify(self: *Self, verify: Command.CertificateVerify) ! pub fn send_finished(self: *Self) !void { var stream = &self.stream; const verify_data = switch (stream.cipher.handshake) { - inline .aes_256_gcm_sha384, - => |v| brk: { + inline else => |v| brk: { const T = @TypeOf(v); const secret = v.server_finished_key; const transcript_hash = stream.transcript_hash.?.peek(); - break :brk tls.hmac(T.Hmac, transcript_hash, secret); + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); }, - else => return stream.writeError(.illegal_parameter), }; - _ = try stream.write(tls.Handshake, .{ .finished = &verify_data }); + _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); try stream.flush(); } pub fn recv_finished(self: *Self) !void { var stream = &self.stream; - var reader = stream.reader(); + var reader = stream.any().reader(); const handshake_hash = stream.transcript_hash.?.peek(); @@ -381,14 +416,64 @@ pub fn recv_finished(self: *Self) !void { stream.transcript_hash = null; } +pub const ReadError = anyerror; +pub const WriteError = anyerror; + +/// Reads next application_data message. +pub fn readv(self: *Self, buffers: []std.os.iovec) ReadError!usize { + var stream = &self.stream; + + if (stream.eof()) return 0; + + while (stream.view.len == 0) { + const inner_plaintext = try stream.readInnerPlaintext(); + switch (inner_plaintext.type) { + .application_data => {}, + .alert => {}, + else => return stream.writeError(.unexpected_message), + } + } + return try self.stream.readv(buffers); +} + +pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + if (self.stream.eof()) return 0; + + const res = try self.stream.writev(iov); + try self.stream.flush(); + return res; +} + +pub fn close(self: *Self) void { + self.stream.close(); +} + +pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); + +pub fn any(self: *Self) GenericStream { + return .{ .context = self }; +} + pub const Options = struct { /// List of potential cipher suites in descending order of preference. cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + /// Types of shared keys to accept from client. + key_shares: []const tls.NamedGroup = &tls.supported_groups, + /// Certificate(s) to send in `send_certificate` messages. certificate: tls.Certificate, - certificate_key: Certificate.rsa.PrivateKey, + /// Key to use in `send_certificate_verify`. Must match `certificate.parse().pub_key_algo`. + certificate_key: CertificateKey, + + pub const CertificateKey = union(enum) { + rsa: crypto.Certificate.rsa.SecretKey, + ecdsa256: tls.NamedGroupT(.secp256r1).SecretKey, + ecdsa384: tls.NamedGroupT(.secp384r1).SecretKey, + ed25519: crypto.sign.Ed25519.SecretKey, + }; }; -/// A command to send or receive a single message. Allows testing `advance` on a single thread. +/// A command to send or receive a single message. Allows deterministically +/// testing `advance` on a single thread. pub const Command = union(enum) { recv_hello: void, send_hello: ClientHello, diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig index f31d1d8458a6..8648ecde98d9 100644 --- a/lib/std/crypto/tls/Stream.zig +++ b/lib/std/crypto/tls/Stream.zig @@ -64,6 +64,7 @@ const Cipher = union(enum) { handshake: HandshakeCipher, }; +// Useful mostly as reference or until std.io.Any* types don't type erase errors. pub const ReadError = anyerror || tls.Error || error{EndOfStream}; pub const WriteError = anyerror || error{TlsEncodeError}; @@ -131,7 +132,7 @@ pub fn flush(self: *Self) WriteError!void { } /// Flush a change cipher spec message to the underlying stream. -pub fn changeCipherSpec(self: *Self) WriteError!void { +pub fn changeCipherSpec(self: *Self) !void { self.version = .tls_1_2; const plaintext = Plaintext{ @@ -156,8 +157,7 @@ pub fn writeError(self: *Self, err: Alert.Description) tls.Error { self.flush() catch {}; self.close(); - @panic("TODO: fixme"); - // return err.toError(); + return err.toError(); } pub fn close(self: *Self) void { @@ -169,13 +169,15 @@ pub fn close(self: *Self) void { } /// Write bytes to `stream`, potentially flushing once `self.buffer` is full. -pub fn writeBytes(self: *Self, bytes: []const u8) WriteError!usize { +pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; if (self.nocommit > 0) return bytes.len; const available = self.buffer.len - self.view.len; const to_consume = bytes[0..@min(available, bytes.len)]; - @memcpy(self.buffer[self.view.len..][0..bytes.len], to_consume); + @memcpy(self.buffer[self.view.len..][0..to_consume.len], to_consume); self.view = self.buffer[0 .. self.view.len + to_consume.len]; if (self.view.len == self.buffer.len) try self.flush(); @@ -183,15 +185,7 @@ pub fn writeBytes(self: *Self, bytes: []const u8) WriteError!usize { return to_consume.len; } -pub fn writeAll(self: *Self, bytes: []const u8) WriteError!usize { - var index: usize = 0; - while (index != bytes.len) { - index += try self.writeBytes(bytes[index..]); - } - return index; -} - -pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T) WriteError!usize { +pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: []const T) !usize { var res: usize = 0; for (values) |v| res += self.length(T, v); @@ -208,11 +202,18 @@ pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: return res; } -pub fn write(self: *Self, comptime T: type, value: T) WriteError!usize { +/// Returns number of bytes written. Convienent for encoding struct types in tls.zig . +pub fn writeAll(self: *Self, bytes: []const u8) !usize { + try self.any().writer().writeAll(bytes); + return bytes.len; +} + +pub fn write(self: *Self, comptime T: type, value: T) !usize { switch (@typeInfo(T)) { .Int, .Enum => { const encoded = Encoder.encode(T, value); - return try self.writeAll(&encoded); + try self.any().writer().writeAll(&encoded); + return encoded.len; }, .Struct, .Union => { return try T.write(value, self); @@ -240,10 +241,10 @@ pub fn arrayLength( return res; } -/// Reads bytes from `view`, potentially reading more fragments from `stream`. +/// Reads bytes from `view`, potentially reading more fragments from underlying `stream`. /// /// A return value of 0 indicates EOF. -pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { +pub fn readv(self: *Self, iov: []std.os.iovec) ReadError!usize { // > Any data received after a closure alert has been received MUST be ignored. if (self.eof()) return 0; @@ -251,7 +252,7 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { var bytes_read: usize = 0; - for (buffers) |b| { + for (iov) |b| { var bytes_read_buffer: usize = 0; while (bytes_read_buffer != b.iov_len) { const to_read = @min(b.iov_len, self.view.len); @@ -268,17 +269,10 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { return bytes_read; } -/// Reads bytes from `view`, potentially reading more fragments from `stream`. -/// A return value of 0 indicates EOF. -pub fn readBytes(self: *Self, buf: []u8) ReadError!usize { - const buffers = [_]std.os.iovec{.{ .iov_base = buf.ptr, .iov_len = buf.len }}; - return try self.readv(&buffers); -} - /// Reads plaintext from `stream` into `buffer` and updates `view`. /// Skips non-fatal alert and change_cipher_spec messages. /// Will decrypt according to `encryptionMethod` if receiving application_data message. -pub fn readPlaintext(self: *Self) ReadError!Plaintext { +pub fn readPlaintext(self: *Self) !Plaintext { std.debug.assert(self.view.len == 0); // last read should have completed var plaintext_bytes: [Plaintext.size]u8 = undefined; var n_read: usize = 0; @@ -299,26 +293,25 @@ pub fn readPlaintext(self: *Self) ReadError!Plaintext { if (res.len < self.ciphertextOverhead()) return self.writeError(.decode_error); switch (self.cipher) { - inline .handshake, .application => |*cipher| { - switch (cipher.*) { - inline else => |*c| { - const C = @TypeOf(c.*); - const tag_len = C.AEAD.tag_length; - - const ciphertext = self.view[0 .. self.view.len - tag_len]; - const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; - const out: []u8 = @constCast(self.view[0..ciphertext.len]); - c.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch - return self.writeError(.bad_record_mac); - const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); - if (padding_start) |s| { - res.type = @enumFromInt(self.view[s]); - self.view = self.view[0..s]; - } else { - return self.writeError(.decode_error); - } - }, - } + inline .handshake, .application => |*cipher| switch (cipher.*) { + inline else => |*c| { + const C = @TypeOf(c.*); + const tag_len = C.AEAD.tag_length; + + const ciphertext = self.view[0 .. self.view.len - tag_len]; + const tag = self.view[self.view.len - tag_len ..][0..tag_len].*; + const out: []u8 = @constCast(self.view[0..ciphertext.len]); + c.decrypt(ciphertext, &plaintext_bytes, tag, self.is_client, out) catch + return self.writeError(.bad_record_mac); + + const padding_start = std.mem.lastIndexOfNone(u8, out, &[_]u8{0}); + if (padding_start) |s| { + res.type = @enumFromInt(self.view[s]); + self.view = self.view[0..s]; + } else { + return self.writeError(.decode_error); + } + }, }, else => unreachable, } @@ -330,11 +323,19 @@ pub fn readPlaintext(self: *Self) ReadError!Plaintext { const description = try self.read(Alert.Description); std.log.debug("TLS alert {} {}", .{ level, description }); - if (description == .close_notify) { - self.closed = true; - return res; + switch (description) { + .close_notify => { + self.closed = true; + return res; + }, + .certificate_revoked, + .certificate_unknown, + .certificate_expired, + .certificate_required => {}, + else => { + return self.writeError(.unexpected_message); + } } - if (level == .fatal) return self.writeError(.unexpected_message); }, // > An implementation may receive an unencrypted record of type // > change_cipher_spec consisting of the single byte value 0x01 at any @@ -353,14 +354,18 @@ pub fn readPlaintext(self: *Self) ReadError!Plaintext { } } -pub fn readInnerPlaintext(self: *Self) ReadError!InnerPlaintext { +pub fn readInnerPlaintext(self: *Self) !InnerPlaintext { var res: InnerPlaintext = .{ .type = self.content_type, .handshake_type = if (self.handshake_type) |h| h else undefined, - .len = undefined, + .len = 0, }; + if (self.closed) return res; + if (self.view.len == 0) { const plaintext = try self.readPlaintext(); + if (self.closed) return res; + res.type = plaintext.type; res.len = plaintext.len; @@ -383,7 +388,7 @@ pub fn expectInnerPlaintext( self: *Self, expected_content: ContentType, expected_handshake: ?HandshakeType, -) ReadError!void { +) !void { const inner_plaintext = try self.readInnerPlaintext(); if (expected_content != inner_plaintext.type) { std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext }); @@ -394,10 +399,10 @@ pub fn expectInnerPlaintext( } } -pub fn read(self: *Self, comptime T: type) ReadError!T { +pub fn read(self: *Self, comptime T: type) !T { comptime std.debug.assert(@sizeOf(T) < fragment_size); switch (@typeInfo(T)) { - .Int => return self.reader().readInt(T, .big) catch |err| switch (err) { + .Int => return self.any().reader().readInt(T, .big) catch |err| switch (err) { error.EndOfStream => return self.writeError(.decode_error), else => |e| return e, }, @@ -448,7 +453,7 @@ fn Iterator(comptime T: type) type { stream: *Self, end: usize, - pub fn next(self: *@This()) ReadError!?T { + pub fn next(self: *@This()) !?T { const cur_offset = self.stream.buffer.len - self.stream.view.len; if (cur_offset > self.end) return null; return try self.stream.read(T); @@ -456,7 +461,7 @@ fn Iterator(comptime T: type) type { }; } -pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) ReadError!Iterator(Tag) { +pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) !Iterator(Tag) { const offset = self.buffer.len - self.view.len; const len = try self.read(Len); return Iterator(Tag){ @@ -465,7 +470,7 @@ pub fn iterator(self: *Self, comptime Len: type, comptime Tag: type) ReadError!I }; } -pub fn extensions(self: *Self) ReadError!Iterator(Extension.Header) { +pub fn extensions(self: *Self) !Iterator(Extension.Header) { return self.iterator(u16, Extension.Header); } @@ -473,14 +478,9 @@ pub fn eof(self: Self) bool { return self.closed and self.view.len == 0; } -pub const Reader = std.io.Reader(*Self, ReadError, readBytes); -pub const Writer = std.io.Writer(*Self, WriteError, writeBytes); - -pub fn reader(self: *Self) Reader { - return .{ .context = self }; -} +pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); -pub fn writer(self: *Self) Writer { +pub fn any(self: *Self) GenericStream { return .{ .context = self }; } diff --git a/lib/std/fifo.zig b/lib/std/fifo.zig index a26086700258..e02833a8cc63 100644 --- a/lib/std/fifo.zig +++ b/lib/std/fifo.zig @@ -38,8 +38,8 @@ pub fn LinearFifo( count: usize, const Self = @This(); - pub const Reader = std.io.Reader(*Self, error{}, readFn); - pub const Writer = std.io.Writer(*Self, error{OutOfMemory}, appendWrite); + pub const Reader = std.io.Reader(*Self, error{}, readvFn); + pub const Writer = std.io.Writer(*Self, error{OutOfMemory}, appendWritev); // Type of Self argument for slice operations. // If buffer is inline (Static) then we need to ensure we haven't @@ -232,8 +232,15 @@ pub fn LinearFifo( /// Same as `read` except it returns an error union /// The purpose of this function existing is to match `std.io.Reader` API. - fn readFn(self: *Self, dest: []u8) error{}!usize { - return self.read(dest); + fn readvFn(self: *Self, iov: []std.os.iovec) error{}!usize { + var n_read: usize = 0; + for (iov) |v| { + const n = self.read(v.iov_base[0..v.iov_len]); + if (n == 0) return n_read; + + n_read += n; + } + return n_read; } pub fn reader(self: *Self) Reader { @@ -321,9 +328,13 @@ pub fn LinearFifo( /// Same as `write` except it returns the number of bytes written, which is always the same /// as `bytes.len`. The purpose of this function existing is to match `std.io.Writer` API. - fn appendWrite(self: *Self, bytes: []const u8) error{OutOfMemory}!usize { - try self.write(bytes); - return bytes.len; + fn appendWritev(self: *Self, iov: []std.os.iovec_const) error{OutOfMemory}!usize { + var written: usize = 0; + for (iov) |v| { + try self.write(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn writer(self: *Self) Writer { diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 669f1b72e33f..61fa536e36a0 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1441,13 +1441,13 @@ fn writeFileAllSendfile(self: File, in_file: File, args: WriteFileOptions) posix } } -pub const Reader = io.Reader(File, ReadError, read); +pub const Reader = io.Reader(File, ReadError, readv); pub fn reader(file: File) Reader { return .{ .context = file }; } -pub const Writer = io.Writer(File, WriteError, write); +pub const Writer = io.Writer(File, WriteError, writev); pub fn writer(file: File) Writer { return .{ .context = file }; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 5f45bbaae390..dc911a65ddc0 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1,8 +1,6 @@ //! Blocking HTTP(S) client //! //! Connections are opened in a thread-safe manner, but individual Requests are not. -//! -//! TLS support may be disabled via `std.options.http_disable_tls`. const std = @import("../std.zig"); const builtin = @import("builtin"); @@ -17,14 +15,12 @@ const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); const proto = @import("protocol.zig"); -const disable_tls = std.options.http_disable_tls; const tls = std.crypto.tls; -const TlsClient = if (disable_tls) void else std.crypto.tls.Client; /// Used for all client allocations. Must be thread-safe. allocator: Allocator, -tls_options: (if (disable_tls) void else TlsOptions) = if (disable_tls) {} else .{}, +tls_options: TlsOptions = .{}, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, @@ -38,7 +34,7 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, -/// tls.ClientOptions minus ones that we set +/// tls.Client.Options minus ones that we set pub const TlsOptions = struct { /// Client takes ownership of this field. If empty, will rescan on init. /// @@ -195,9 +191,10 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { - stream: std.io.AnyStream, - /// undefined unless protocol is tls. - tls_client: *TlsClient, + /// Underlying socket + socket: net.Stream, + /// TLS client. + tls: tls.Client, /// The protocol that this connection is using. protocol: Protocol, @@ -220,32 +217,21 @@ pub const Connection = struct { read_buf: [buffer_size]u8 = undefined, write_buf: [buffer_size]u8 = undefined, - pub const buffer_size = std.crypto.tls.Plaintext.size + std.crypto.tls.Plaintext.max_length; + /// Want to be greater than max HTTP headers length. + pub const buffer_size = 4096; const BufferSize = std.math.IntFittingRange(0, buffer_size); pub const Protocol = enum { plain, tls }; - pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - return conn.tls_client.readv(buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "Tls")) return error.TlsFailure; - - switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } + pub inline fn stream(conn: *Connection) std.io.AnyStream { + return switch (conn.protocol) { + .plain => conn.socket.any().any(), + .tls => conn.tls.any().any(), }; } pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); - } - - return conn.stream.readv(buffers) catch |err| switch (err) { + return conn.stream().readv(buffers) catch |err| switch (err) { error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, @@ -296,7 +282,7 @@ pub const Connection = struct { .{ .iov_base = buffer.ptr, .iov_len = buffer.len }, .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, }; - const nread = try conn.readvDirect(&iovecs); + const nread = try conn.readvDirect(&iovecs); if (nread > buffer.len) { conn.read_start = 0; @@ -307,6 +293,12 @@ pub const Connection = struct { return nread; } + pub fn readv(conn: *Connection, iov: []std.os.iovec) ReadError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; + return try conn.read(buffer); + } + pub const ReadError = error{ TlsFailure, ConnectionTimedOut, @@ -315,39 +307,22 @@ pub const Connection = struct { EndOfStream, }; - pub const Reader = std.io.Reader(*Connection, ReadError, read); + pub const Reader = std.io.Reader(*Connection, ReadError, readv); pub fn reader(conn: *Connection) Reader { return Reader{ .context = conn }; } - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - conn.tls_client.writer().writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } - - return conn.stream.writer().writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, - }; - } - /// Writes the given buffer to the connection. pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { if (conn.write_buf.len - conn.write_end < buffer.len) { try conn.flush(); if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); + conn.stream().writer().writeAll(buffer) catch |err| switch (err) { + error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, + else => return error.UnexpectedWriteFailure, + }; return buffer.len; } } @@ -358,6 +333,12 @@ pub const Connection = struct { return buffer.len; } + pub fn writev(conn: *Connection, iov: []std.os.iovec_const) WriteError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; + return try conn.write(buffer); + } + /// Returns a buffer to be filled with exactly len bytes to write to the connection. pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { if (conn.write_buf.len - conn.write_end < len) try conn.flush(); @@ -369,32 +350,26 @@ pub const Connection = struct { pub fn flush(conn: *Connection) WriteError!void { if (conn.write_end == 0) return; - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); + try conn.stream().writer().writeAll(conn.write_buf[0..conn.write_end]); conn.write_end = 0; + + if (conn.protocol == .tls) try conn.tls.stream.flush(); } - pub const WriteError = error{ + pub const WriteError = anyerror || error{ ConnectionResetByPeer, UnexpectedWriteFailure, }; - pub const Writer = std.io.Writer(*Connection, WriteError, write); + pub const Writer = std.io.Writer(*Connection, WriteError, writev); pub fn writer(conn: *Connection) Writer { return Writer{ .context = conn }; } - /// Closes the connection. + /// Closes the connection and deinitializes members. pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - conn.tls_client.close(); - allocator.destroy(conn.tls_client); - } - - conn.stream.close(); + conn.stream().close(); allocator.free(conn.host); } }; @@ -922,14 +897,16 @@ pub const Request = struct { const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead); + const TransferReader = std.io.Reader(*Request, TransferReadError, transferReadv); fn transferReader(req: *Request) TransferReader { return .{ .context = req }; } - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { + fn transferReadv(req: *Request, iov: []std.os.iovec) TransferReadError!usize { if (req.response.parser.done) return 0; + const first = iov[0]; + const buf = first.iov_base[0..first.iov_len]; var index: usize = 0; while (index == 0) { @@ -1034,7 +1011,7 @@ pub const Request = struct { // skip the body of the redirect response, this will at least // leave the connection in a known good state. req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary + assert(try req.transferReadv(&.{}) == 0); // we're skipping, no buffer is necessary if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; @@ -1118,20 +1095,20 @@ pub const Request = struct { pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers }; - pub const Reader = std.io.Reader(*Request, ReadError, read); + pub const Reader = std.io.Reader(*Request, ReadError, readv); pub fn reader(req: *Request) Reader { return .{ .context = req }; } /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { + pub fn readv(req: *Request, iov: []std.os.iovec) ReadError!usize { const out_index = switch (req.response.compression) { - .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure, - .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, + .deflate => |*deflate| deflate.readv(iov) catch return error.DecompressionFailure, + .gzip => |*gzip| gzip.readv(iov) catch return error.DecompressionFailure, // https://github.com/ziglang/zig/issues/18937 //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), + else => try req.transferReadv(iov), }; if (out_index > 0) return out_index; @@ -1145,20 +1122,9 @@ pub const Request = struct { return 0; } - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; - } - return index; - } - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - pub const Writer = std.io.Writer(*Request, WriteError, write); + pub const Writer = std.io.Writer(*Request, WriteError, writev); pub fn writer(req: *Request) Writer { return .{ .context = req }; @@ -1166,21 +1132,29 @@ pub const Request = struct { /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { + pub fn writev(req: *Request, iov: []std.os.iovec_const) WriteError!usize { + var iov_len: usize = 0; + for (iov) |v| iov_len += v.iov_len; + switch (req.transfer_encoding) { .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); + var w = req.connection.?.writer(); + + + if (iov_len > 0) { + try w.print("{x}\r\n", .{iov_len}); + try w.writevAll(iov); + try w.writeAll("\r\n"); } - return bytes.len; + return iov_len; }, .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; + const cwriter = req.connection.?.writer(); + + if (len.* < iov_len) return error.MessageTooLong; - const amt = try req.connection.?.write(bytes); + const amt = try cwriter.writev(iov); len.* -= amt; return amt; }, @@ -1188,15 +1162,6 @@ pub const Request = struct { } } - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); - } - } - pub const FinishError = WriteError || error{MessageNotCompleted}; /// Finish the body of a request. This notifies the server that you have no more data to send. @@ -1224,10 +1189,8 @@ pub const Proxy = struct { pub fn init(client: Client) !Client { var copy = client; - if (!disable_tls) { - if (copy.tls_options.ca_bundle) |*bundle| { - if (bundle.bytes.items.len == 0) try bundle.rescan(copy.allocator); - } + if (copy.tls_options.ca_bundle) |*bundle| { + if (bundle.bytes.items.len == 0) try bundle.rescan(copy.allocator); } return copy; @@ -1242,9 +1205,7 @@ pub fn deinit(client: *Client) void { client.connection_pool.deinit(client.allocator); - if (!disable_tls) { - if (client.tls_options.ca_bundle) |*bundle| bundle.deinit(client.allocator); - } + if (client.tls_options.ca_bundle) |*bundle| bundle.deinit(client.allocator); client.* = undefined; } @@ -1357,14 +1318,11 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec })) |node| return node; - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; - const conn = try client.allocator.create(ConnectionPool.Node); errdefer client.allocator.destroy(conn); conn.* = .{ .data = undefined }; - const tcp_stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { + const socket = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, error.NetworkUnreachable => return error.NetworkUnreachable, error.ConnectionTimedOut => return error.ConnectionTimedOut, @@ -1375,12 +1333,11 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses, else => return error.UnexpectedConnectFailure, }; - errdefer tcp_stream.close(); - - conn.data = .{ - .stream = tcp_stream, - .tls_client = undefined, + errdefer socket.close(); + conn.data = Connection{ + .socket = socket, + .tls = undefined, .protocol = protocol, .host = try client.allocator.dupe(u8, host), .port = port, @@ -1388,12 +1345,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.free(conn.data.host); if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.data.tls_client = try client.allocator.create(TlsClient); - errdefer client.allocator.destroy(conn.data.tls_client); - - conn.data.tls_client.* = TlsClient.init(&conn.data.stream, .{ + conn.data.tls = tls.Client.init(conn.data.socket.any().any(), .{ .ca_bundle = client.tls_options.ca_bundle, .cipher_suites = client.tls_options.cipher_suites, .host = host, @@ -1430,8 +1382,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti errdefer stream.close(); conn.data = .{ - .stream = stream, - .tls_client = undefined, + .stream = stream.any(), .protocol = .plain, .host = try client.allocator.dupe(u8, path), @@ -1762,7 +1713,7 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { try req.send(.{ .raw_uri = options.raw_uri }); - if (options.payload) |payload| try req.writeAll(payload); + if (options.payload) |payload| try req.writer().writeAll(payload); try req.finish(); try req.wait(); @@ -1772,7 +1723,7 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { // Take advantage of request internals to discard the response body // and make the connection available for another request. req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. + assert(try req.transferReadv(&.{}) == 0); // No buffer is necessary when skipping. }, .dynamic => |list| { const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 2911d8591b2c..6aa2fef4ff67 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,8 +1,6 @@ //! Blocking HTTP(s) server //! //! Handles a single connection's lifecycle. -//! -//! TLS support may be disabled via `std.options.http_disable_tls`. const std = @import("../std.zig"); const http = std.http; @@ -14,9 +12,6 @@ const testing = std.testing; const Server = @This(); -pub const disable_tls = std.options.http_disable_tls; -const TlsServer = if (disable_tls) void else std.crypto.tls.Server(net.Stream); - connection: net.Server.Connection, /// Keeps track of whether the Server is ready to accept a new request on the /// same connection, and makes invalid API usage cause assertion failures @@ -105,7 +100,7 @@ pub fn receiveHead(s: *Server) ReceiveHeadError!Request { const buf = s.read_buffer[s.read_buffer_len..]; if (buf.len == 0) return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch + const read_n = s.connection.stream().reader().read(buf) catch return error.HttpHeadersUnreadable; if (read_n == 0) { if (s.read_buffer_len > 0) { @@ -434,7 +429,7 @@ pub const Request = struct { h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); + try request.server.connection.stream().writer().writeAll(h.items); return; } h.fixedWriter().print("{s} {d} {s}\r\n", .{ @@ -540,7 +535,7 @@ pub const Request = struct { } } - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); + try request.server.connection.stream().writer().writevAll(iovecs[0..iovecs_len]); } pub const RespondStreamingOptions = struct { @@ -620,7 +615,7 @@ pub const Request = struct { }; return .{ - .stream = request.server.connection.stream, + .stream = request.server.connection.stream(), .send_buffer = options.send_buffer, .send_buffer_start = 0, .send_buffer_end = h.items.len, @@ -635,12 +630,14 @@ pub const Request = struct { }; } - pub const ReadError = net.Stream.ReadError || error{ + pub const ReadError = anyerror || net.Stream.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize, }; - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { + fn read_cl(context: *const anyopaque, iov: []std.os.iovec) ReadError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; @@ -664,11 +661,13 @@ pub const Request = struct { const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; if (available.len > 0) return available; s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); + s.read_buffer_len = head_end + try s.connection.stream().reader().read(s.read_buffer[head_end..]); return s.read_buffer[head_end..s.read_buffer_len]; } - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { + fn readv_chunked(context: *const anyopaque, iov: []std.os.iovec) ReadError!usize { + const first = iov[0]; + const buffer = first.iov_base[0..first.iov_len]; const request: *Request = @constCast(@alignCast(@ptrCast(context))); const s = request.server; @@ -726,7 +725,7 @@ pub const Request = struct { const buf = s.read_buffer[s.read_buffer_len..]; if (buf.len == 0) return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); + const read_n = try s.connection.stream().reader().read(buf); s.read_buffer_len += read_n; const bytes = buf[0..read_n]; const end = hp.feed(bytes); @@ -776,7 +775,7 @@ pub const Request = struct { if (request.head.expect) |expect| { if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + try request.server.connection.stream().writer().writeAll("HTTP/1.1 100 Continue\r\n\r\n"); request.head.expect = null; } else { return error.HttpExpectationFailed; @@ -787,7 +786,7 @@ pub const Request = struct { .chunked => { request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; return .{ - .readFn = read_chunked, + .readvFn = readv_chunked, .context = request, }; }, @@ -796,7 +795,7 @@ pub const Request = struct { .remaining_content_length = request.head.content_length orelse 0, }; return .{ - .readFn = read_cl, + .readvFn = read_cl, .context = request, }; }, @@ -837,7 +836,7 @@ pub const Request = struct { }; pub const Response = struct { - stream: net.Stream, + stream: std.io.AnyStream, send_buffer: []u8, /// Index of the first byte in `send_buffer`. /// This is 0 unless a short write happens in `write`. @@ -899,20 +898,12 @@ pub const Response = struct { r.* = undefined; } - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } - - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { + fn write_cl(context: *const anyopaque, iov: []std.os.iovec_const) WriteError!usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; + var trash: u64 = std.math.maxInt(u64); const len = switch (r.transfer_encoding) { .content_length => |*len| len, @@ -926,7 +917,7 @@ pub const Response = struct { if (bytes.len + r.send_buffer_end > r.send_buffer.len) { const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ + var iovecs: [2]std.os.iovec_const = .{ .{ .iov_base = r.send_buffer.ptr + r.send_buffer_start, .iov_len = send_buffer_len, @@ -960,10 +951,13 @@ pub const Response = struct { return bytes.len; } - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { + fn write_chunked(context: *const anyopaque, iov: []std.os.iovec_const) WriteError!usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); assert(r.transfer_encoding == .chunked); + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; + if (r.elide_body) return bytes.len; @@ -997,7 +991,7 @@ pub const Response = struct { }; // TODO make this writev instead of writevAll, which involves // complicating the logic of this function. - try r.stream.writevAll(&iovecs); + try r.stream.writer().writevAll(&iovecs); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1011,15 +1005,6 @@ pub const Response = struct { return bytes.len; } - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); - } - } - /// Sends all buffered data to the client. /// This is redundant after calling `end`. /// Respects the value of `elide_body` to omit all data after the headers. @@ -1031,7 +1016,7 @@ pub const Response = struct { } fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + try r.stream.writer().writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); r.send_buffer_start = 0; r.send_buffer_end = 0; } @@ -1044,7 +1029,7 @@ pub const Response = struct { const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; if (r.elide_body) { - try r.stream.writeAll(http_headers); + try r.stream.writer().writeAll(http_headers); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1125,7 +1110,7 @@ pub const Response = struct { iovecs_len += 1; } - try r.stream.writevAll(iovecs[0..iovecs_len]); + try r.stream.writer().writevAll(iovecs[0..iovecs_len]); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1133,7 +1118,7 @@ pub const Response = struct { pub fn writer(r: *Response) std.io.AnyWriter { return .{ - .writeFn = switch (r.transfer_encoding) { + .writevFn = switch (r.transfer_encoding) { .none, .content_length => write_cl, .chunked => write_chunked, }, diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 78511f435d67..e1f70b82ab93 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -281,7 +281,7 @@ const MockBufferedConnection = struct { pub fn fill(conn: *MockBufferedConnection) ReadError!void { if (conn.end != conn.start) return; - const nread = try conn.conn.read(conn.buf[0..]); + const nread = try conn.conn.reader().read(conn.buf[0..]); if (nread == 0) return error.EndOfStream; conn.start = 0; conn.end = @as(u16, @truncate(nread)); @@ -313,7 +313,7 @@ const MockBufferedConnection = struct { if (left > conn.buf.len) { // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); + return conn.conn.reader().read(buffer[out_index..]); } try conn.fill(); diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index e2aa810d580d..a54f8bdb2077 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -14,8 +14,8 @@ test "trailers" { var header_buffer: [1024]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); @@ -33,9 +33,9 @@ test "trailers" { var response = request.respondStreaming(.{ .send_buffer = &send_buffer, }); - try response.writeAll("Hello, "); + try response.writer().writeAll("Hello, "); try response.flush(); - try response.writeAll("World!\n"); + try response.writer().writeAll("World!\n"); try response.flush(); try response.endChunked(.{ .trailers = &.{ @@ -96,8 +96,8 @@ test "HTTP server handles a chunked transfer coding request" { const test_server = try createTestServer(struct { fn run(net_server: *std.net.Server) !void { var header_buffer: [8192]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); var request = try server.receiveHead(); @@ -133,9 +133,10 @@ test "HTTP server handles a chunked transfer coding request" { "\r\n"; const gpa = std.testing.allocator; - const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const tcp_stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = tcp_stream.any(); defer stream.close(); - try stream.writeAll(request_bytes); + try stream.writer().writeAll(request_bytes); const expected_response = "HTTP/1.1 200 OK\r\n" ++ @@ -155,8 +156,8 @@ test "echo content server" { var read_buffer: [1024]u8 = undefined; accept: while (true) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var http_server = http.Server.init(conn, &read_buffer); @@ -237,8 +238,8 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { var header_buffer: [1000]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); @@ -256,7 +257,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { for (0..500) |i| { var buf: [30]u8 = undefined; const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); - try response.writeAll(line); + try response.writer().writeAll(line); total += line.len; } try expectEqual(7390, total); @@ -269,9 +270,10 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { const request_bytes = "GET /foo HTTP/1.1\r\n\r\n"; const gpa = std.testing.allocator; - const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const tcp_stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = tcp_stream.any(); defer stream.close(); - try stream.writeAll(request_bytes); + try stream.writer().writeAll(request_bytes); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -301,8 +303,8 @@ test "receiving arbitrary http headers from the client" { var read_buffer: [666]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &read_buffer); try expectEqual(.ready, server.state); @@ -332,9 +334,10 @@ test "receiving arbitrary http headers from the client" { "aoeu: asdf \r\n" ++ "\r\n"; const gpa = std.testing.allocator; - const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const tcp_stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = tcp_stream.any(); defer stream.close(); - try stream.writeAll(request_bytes); + try stream.writer().writeAll(request_bytes); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -361,8 +364,8 @@ test "general client/server API coverage" { fn run(net_server: *std.net.Server) anyerror!void { var client_header_buffer: [1024]u8 = undefined; outer: while (global.handle_new_requests) { - var connection = try net_server.accept(); - defer connection.stream.close(); + var connection = try net_server.accept(null); + defer connection.stream().close(); var http_server = http.Server.init(connection, &client_header_buffer); @@ -877,8 +880,8 @@ test "Server streams both reading and writing" { const test_server = try createTestServer(struct { fn run(net_server: *std.net.Server) anyerror!void { var header_buffer: [1024]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); var request = try server.receiveHead(); @@ -925,8 +928,8 @@ test "Server streams both reading and writing" { try req.send(.{}); try req.wait(); - try req.writeAll("one "); - try req.writeAll("fish"); + try req.writer().writeAll("one "); + try req.writer().writeAll("fish"); try req.finish(); @@ -957,8 +960,8 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .{ .content_length = 14 }; try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + try req.writer().writeAll("Hello, "); + try req.writer().writeAll("World!\n"); try req.finish(); try req.wait(); @@ -991,8 +994,8 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + try req.writer().writeAll("Hello, "); + try req.writer().writeAll("World!\n"); try req.finish(); try req.wait(); @@ -1045,8 +1048,8 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; try req.send(.{}); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); + try req.writer().writeAll("Hello, "); + try req.writer().writeAll("World!\n"); try req.finish(); try req.wait(); @@ -1121,8 +1124,8 @@ test "redirect to different connection" { fn run(net_server: *std.net.Server) anyerror!void { var header_buffer: [888]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); var server = http.Server.init(conn, &header_buffer); var request = try server.receiveHead(); @@ -1142,8 +1145,8 @@ test "redirect to different connection" { var header_buffer: [999]u8 = undefined; var send_buffer: [100]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + var conn = try net_server.accept(null); + defer conn.stream().close(); const new_loc = try std.fmt.bufPrint(&send_buffer, "http://127.0.0.1:{d}/ok", .{ global.other_port.?, diff --git a/lib/std/io.zig b/lib/std/io.zig index b402b8185e4f..c886b1d47687 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -11,6 +11,8 @@ const mem = std.mem; const meta = std.meta; const File = std.fs.File; const Allocator = std.mem.Allocator; +const iovec = std.os.iovec; +const iovec_const = std.os.iovec_const; fn getStdOutHandle() os.fd_t { if (builtin.os.tag == .windows) { @@ -78,7 +80,7 @@ pub fn GenericReader( /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - comptime readFn: fn (context: Context, buffer: []u8) ReadError!usize, + comptime readvFn: fn (context: Context, iov: []iovec) ReadError!usize, ) type { return struct { context: Context, @@ -88,8 +90,12 @@ pub fn GenericReader( EndOfStream, }; + pub inline fn readv(self: Self, iov: []iovec) Error!usize { + return readvFn(self.context, iov); + } + pub inline fn read(self: Self, buffer: []u8) Error!usize { - return readFn(self.context, buffer); + return @errorCast(self.any().read(buffer)); } pub inline fn readAll(self: Self, buffer: []u8) Error!usize { @@ -283,18 +289,24 @@ pub fn GenericReader( return @errorCast(self.any().readEnum(Enum, endian)); } + /// Reads the stream until the end, ignoring all the data. + /// Returns the number of bytes discarded. + pub inline fn discard(self: Self) anyerror!u64 { + return @errorCast(self.any().discard()); + } + pub inline fn any(self: *const Self) AnyReader { return .{ .context = @ptrCast(&self.context), - .readFn = typeErasedReadFn, + .readvFn = typeErasedReadvFn, }; } const Self = @This(); - fn typeErasedReadFn(context: *const anyopaque, buffer: []u8) anyerror!usize { + fn typeErasedReadvFn(context: *const anyopaque, iov: []iovec) anyerror!usize { const ptr: *const Context = @alignCast(@ptrCast(context)); - return readFn(ptr.*, buffer); + return readvFn(ptr.*, iov); } }; } @@ -302,7 +314,7 @@ pub fn GenericReader( pub fn GenericWriter( comptime Context: type, comptime WriteError: type, - comptime writeFn: fn (context: Context, bytes: []const u8) WriteError!usize, + comptime writevFn: fn (context: Context, iov: []iovec_const) WriteError!usize, ) type { return struct { context: Context, @@ -310,8 +322,16 @@ pub fn GenericWriter( const Self = @This(); pub const Error = WriteError; + pub inline fn writev(self: Self, iov: []iovec_const) Error!usize { + return writevFn(self.context, iov); + } + + pub inline fn writevAll(self: Self, iov: []iovec_const) Error!void { + return @errorCast(self.any().writevAll(iov)); + } + pub inline fn write(self: Self, bytes: []const u8) Error!usize { - return writeFn(self.context, bytes); + return @errorCast(self.any().write(bytes)); } pub inline fn writeAll(self: Self, bytes: []const u8) Error!void { @@ -345,13 +365,13 @@ pub fn GenericWriter( pub inline fn any(self: *const Self) AnyWriter { return .{ .context = @ptrCast(&self.context), - .writeFn = typeErasedWriteFn, + .writevFn = typeErasedWritevFn, }; } - fn typeErasedWriteFn(context: *const anyopaque, bytes: []const u8) anyerror!usize { + fn typeErasedWritevFn(context: *const anyopaque, iov: []iovec_const) anyerror!usize { const ptr: *const Context = @alignCast(@ptrCast(context)); - return writeFn(ptr.*, bytes); + return writevFn(ptr.*, iov); } }; } @@ -362,16 +382,16 @@ pub fn GenericStream( /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - comptime readFn: fn (context: Context, buffer: []u8) ReadError!usize, + comptime readvFn: fn (context: Context, iov: []iovec) ReadError!usize, comptime WriteError: type, - comptime writeFn: fn (context: Context, bytes: []const u8) WriteError!usize, + comptime writevFn: fn (context: Context, iov: []iovec_const) WriteError!usize, comptime closeFn: fn (context: Context) void, ) type { return struct { context: Context, - const ReaderType = GenericReader(Context, ReadError, readFn); - const WriterType = GenericWriter(Context, WriteError, writeFn); + const ReaderType = GenericReader(Context, ReadError, readvFn); + const WriterType = GenericWriter(Context, WriteError, writevFn); const Self = @This(); @@ -390,8 +410,8 @@ pub fn GenericStream( pub inline fn any(self: *const Self) AnyStream { return .{ .context = @ptrCast(&self.context), - .readFn = self.reader().any().readFn, - .writeFn = self.writer().any().writeFn, + .readvFn = self.reader().any().readvFn, + .writevFn = self.writer().any().writevFn, .closeFn = typeErasedCloseFn, }; } @@ -467,10 +487,12 @@ pub const tty = @import("io/tty.zig"); /// A Writer that doesn't write to anything. pub const null_writer = @as(NullWriter, .{ .context = {} }); -const NullWriter = Writer(void, error{}, dummyWrite); -fn dummyWrite(context: void, data: []const u8) error{}!usize { +const NullWriter = Writer(void, error{}, dummyWritev); +fn dummyWritev(context: void, iov: []std.os.iovec_const) error{}!usize { _ = context; - return data.len; + var written: usize = 0; + for (iov) |v| written += v.iov_len; + return written; } test "null_writer" { @@ -746,6 +768,7 @@ pub fn PollFiles(comptime StreamEnum: type) type { test { _ = AnyReader; _ = AnyWriter; + _ = AnyStream; _ = @import("io/bit_reader.zig"); _ = @import("io/bit_writer.zig"); _ = @import("io/buffered_atomic_file.zig"); diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index a769fe4c0421..e8b76533d887 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -1,20 +1,38 @@ +const std = @import("../std.zig"); +const Self = @This(); +const math = std.math; +const assert = std.debug.assert; +const mem = std.mem; +const testing = std.testing; +const native_endian = @import("builtin").target.cpu.arch.endian(); +const iovec = std.os.iovec; + context: *const anyopaque, -readFn: *const fn (context: *const anyopaque, buffer: []u8) anyerror!usize, +readvFn: *const fn (context: *const anyopaque, iov: []iovec) anyerror!usize, pub const Error = anyerror; + +/// Returns the number of bytes read. It may be less than buffer.len. +/// If the number of bytes read is 0, it means end of stream. +/// End of stream is not an error condition. +pub fn readv(self: Self, iov: []iovec) anyerror!usize { + return self.readvFn(self.context, iov); +} + /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. pub fn read(self: Self, buffer: []u8) anyerror!usize { - return self.readFn(self.context, buffer); + var iov = [_]iovec{ .{ .iov_base = buffer.ptr, .iov_len = buffer.len } }; + return self.readv(&iov); } /// Returns the number of bytes read. If the number read is smaller than `buffer.len`, it /// means the stream reached the end. Reaching the end of a stream is not an error /// condition. pub fn readAll(self: Self, buffer: []u8) anyerror!usize { - return readAtLeast(self, buffer, buffer.len); + return self.readAtLeast(buffer, buffer.len); } /// Returns the number of bytes read, calling the underlying read @@ -372,14 +390,6 @@ pub fn discard(self: Self) anyerror!u64 { } } -const std = @import("../std.zig"); -const Self = @This(); -const math = std.math; -const assert = std.debug.assert; -const mem = std.mem; -const testing = std.testing; -const native_endian = @import("builtin").target.cpu.arch.endian(); - test { _ = @import("Reader/test.zig"); } diff --git a/lib/std/io/Stream.zig b/lib/std/io/Stream.zig index 37be7ae0b1ab..f7941c70405f 100644 --- a/lib/std/io/Stream.zig +++ b/lib/std/io/Stream.zig @@ -2,35 +2,37 @@ const std = @import("../std.zig"); const assert = std.debug.assert; const mem = std.mem; const os = std.os; +const iovec = os.iovec; +const iovec_const = os.iovec_const; context: *const anyopaque, -writeFn: *const fn (context: *const anyopaque, bytes: []const u8) anyerror!usize, -readFn: *const fn (context: *const anyopaque, buffer: []u8) anyerror!usize, +readvFn: *const fn (context: *const anyopaque, iov: []iovec) anyerror!usize, +writevFn: *const fn (context: *const anyopaque, iov: []iovec_const) anyerror!usize, closeFn: *const fn (context: *const anyopaque) void, const Self = @This(); pub const Error = anyerror; -pub fn write(self: Self, bytes: []const u8) anyerror!usize { - return self.writeFn(self.context, bytes); +pub fn writev(self: Self, iov: []iovec_const) anyerror!usize { + return self.writevFn(self.context, iov); } /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. -pub fn read(self: Self, buffer: []u8) anyerror!usize { - return self.readFn(self.context, buffer); -} - -pub fn close(self: Self) void { - return self.closeFn(self.context); +pub fn readv(self: Self, iov: []iovec) anyerror!usize { + return self.readvFn(self.context, iov); } pub fn reader(self: Self) std.io.AnyReader { - return .{ .context = self.context, .readFn = self.readFn }; + return .{ .context = self.context, .readvFn = self.readvFn }; } pub fn writer(self: Self) std.io.AnyWriter { - return .{ .context = self.context, .writeFn = self.writeFn }; + return .{ .context = self.context, .writevFn = self.writevFn }; +} + +pub fn close(self: Self) void { + return self.closeFn(self.context); } diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index dfcae48b1eb8..c5e92bf8da1e 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -1,15 +1,41 @@ const std = @import("../std.zig"); const assert = std.debug.assert; const mem = std.mem; +const iovec_const = std.os.iovec_const; context: *const anyopaque, -writeFn: *const fn (context: *const anyopaque, bytes: []const u8) anyerror!usize, +writevFn: *const fn (context: *const anyopaque, iov: []iovec_const) anyerror!usize, const Self = @This(); pub const Error = anyerror; +pub fn writev(self: Self, iov: []iovec_const) anyerror!usize { + return self.writevFn(self.context, iov); +} + +/// The `iovecs` parameter is mutable because this function needs to mutate the fields in +/// order to handle partial writes from the underlying OS layer. +/// See https://github.com/ziglang/zig/issues/7699 +/// See equivalent function: `std.fs.File.writevAll`. +pub fn writevAll(self: Self, iovecs: []iovec_const) anyerror!void { + if (iovecs.len == 0) return; + + var i: usize = 0; + while (true) { + var amt = try self.writev(iovecs[i..]); + while (amt >= iovecs[i].iov_len) { + amt -= iovecs[i].iov_len; + i += 1; + if (i >= iovecs.len) return; + } + iovecs[i].iov_base += amt; + iovecs[i].iov_len -= amt; + } +} + pub fn write(self: Self, bytes: []const u8) anyerror!usize { - return self.writeFn(self.context, bytes); + var iov = [_]iovec_const{ .{ .iov_base = bytes.ptr, .iov_len = bytes.len } }; + return self.writev(&iov); } pub fn writeAll(self: Self, bytes: []const u8) anyerror!void { diff --git a/lib/std/io/buffered_reader.zig b/lib/std/io/buffered_reader.zig index ca132202a7df..a8e5193f7b78 100644 --- a/lib/std/io/buffered_reader.zig +++ b/lib/std/io/buffered_reader.zig @@ -12,11 +12,13 @@ pub fn BufferedReader(comptime buffer_size: usize, comptime ReaderType: type) ty end: usize = 0, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); - pub fn read(self: *Self, dest: []u8) Error!usize { + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; var dest_index: usize = 0; while (dest_index < dest.len) { @@ -60,7 +62,7 @@ test "OneByte" { const Error = error{NoError}; const Self = @This(); - const Reader = io.Reader(*Self, Error, read); + const Reader = io.Reader(*Self, Error, readv); fn init(str: []const u8) Self { return Self{ @@ -69,7 +71,9 @@ test "OneByte" { }; } - fn read(self: *Self, dest: []u8) Error!usize { + fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; if (self.str.len <= self.curr or dest.len == 0) return 0; @@ -135,11 +139,11 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out < block @@ -148,13 +152,13 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [3]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "012"); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "301"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "23"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out > block @@ -163,11 +167,11 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [5]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "01230"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "123"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out == 0 @@ -176,7 +180,7 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [0]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, ""); } @@ -186,10 +190,10 @@ test "Block" { .unbuffered_reader = BlockReader.init(block, 2), }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } } diff --git a/lib/std/io/buffered_tee.zig b/lib/std/io/buffered_tee.zig index d5748c3a5276..bc817ddb69b5 100644 --- a/lib/std/io/buffered_tee.zig +++ b/lib/std/io/buffered_tee.zig @@ -36,11 +36,13 @@ pub fn BufferedTee( wp: usize = 0, // writer pointer; data is sent to the output up to this position pub const Error = InputReaderType.Error || OutputWriterType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); - pub fn read(self: *Self, dest: []u8) Error!usize { + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; var dest_index: usize = 0; while (dest_index < dest.len) { @@ -153,7 +155,7 @@ test "OneByte" { const Error = error{NoError}; const Self = @This(); - const Reader = io.Reader(*Self, Error, read); + const Reader = io.Reader(*Self, Error, readv); fn init(str: []const u8) Self { return Self{ @@ -162,7 +164,9 @@ test "OneByte" { }; } - fn read(self: *Self, dest: []u8) Error!usize { + fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const first = iov[0]; + const dest = first.iov_base[0..first.iov_len]; if (self.str.len <= self.curr or dest.len == 0) return 0; @@ -226,11 +230,11 @@ test "Block" { .output = io.null_writer, }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out < block @@ -240,13 +244,13 @@ test "Block" { .output = io.null_writer, }; var out_buf: [3]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "012"); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "301"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "23"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out > block @@ -256,11 +260,11 @@ test "Block" { .output = io.null_writer, }; var out_buf: [5]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, "01230"); - const n = try test_buf_reader.read(&out_buf); + const n = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, out_buf[0..n], "123"); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } // len out == 0 @@ -270,7 +274,7 @@ test "Block" { .output = io.null_writer, }; var out_buf: [0]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, ""); } @@ -281,11 +285,11 @@ test "Block" { .output = io.null_writer, }; var out_buf: [4]u8 = undefined; - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - _ = try test_buf_reader.read(&out_buf); + _ = try test_buf_reader.reader().read(&out_buf); try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try test_buf_reader.read(&out_buf), 0); + try testing.expectEqual(try test_buf_reader.reader().read(&out_buf), 0); } } @@ -301,7 +305,7 @@ test "with zero lookahead" { var buf: [16]u8 = undefined; var read_len: usize = 0; for (0..buf.len) |i| { - const n = try bt.read(buf[0..i]); + const n = try bt.reader().read(buf[0..i]); try testing.expectEqual(i, n); read_len += i; try testing.expectEqual(read_len, out.items.len); @@ -321,7 +325,7 @@ test "with lookahead" { var read_len: usize = 0; for (1..buf.len) |i| { - const n = try bt.read(buf[0..i]); + const n = try bt.reader().read(buf[0..i]); try testing.expectEqual(i, n); read_len += i; const out_len = if (read_len < lookahead) 0 else read_len - lookahead; @@ -342,14 +346,14 @@ test "internal state" { var bt = bufferedTee(8, 4, in.reader(), out.writer()); var buf: [16]u8 = undefined; - var n = try bt.read(buf[0..3]); + var n = try bt.reader().read(buf[0..3]); try testing.expectEqual(3, n); try testing.expectEqualSlices(u8, data[0..3], buf[0..n]); try testing.expectEqual(8, bt.tail); try testing.expectEqual(3, bt.rp); try testing.expectEqual(0, out.items.len); - n = try bt.read(buf[0..6]); + n = try bt.reader().read(buf[0..6]); try testing.expectEqual(6, n); try testing.expectEqualSlices(u8, data[3..9], buf[0..n]); try testing.expectEqual(8, bt.tail); @@ -357,7 +361,7 @@ test "internal state" { try testing.expectEqualSlices(u8, data[4..12], &bt.buf); try testing.expectEqual(5, out.items.len); - n = try bt.read(buf[0..9]); + n = try bt.reader().read(buf[0..9]); try testing.expectEqual(9, n); try testing.expectEqualSlices(u8, data[9..18], buf[0..n]); try testing.expectEqual(8, bt.tail); @@ -369,7 +373,7 @@ test "internal state" { try testing.expectEqual(18, out.items.len); bt.putBack(4); - n = try bt.read(buf[0..4]); + n = try bt.reader().read(buf[0..4]); try testing.expectEqual(4, n); try testing.expectEqualSlices(u8, data[14..18], buf[0..n]); diff --git a/lib/std/io/buffered_writer.zig b/lib/std/io/buffered_writer.zig index 906d6cce4926..b3518d69f553 100644 --- a/lib/std/io/buffered_writer.zig +++ b/lib/std/io/buffered_writer.zig @@ -10,7 +10,7 @@ pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) ty end: usize = 0, pub const Error = WriterType.Error; - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); const Self = @This(); @@ -23,17 +23,22 @@ pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) ty return .{ .context = self }; } - pub fn write(self: *Self, bytes: []const u8) Error!usize { - if (self.end + bytes.len > self.buf.len) { - try self.flush(); - if (bytes.len > self.buf.len) - return self.unbuffered_writer.write(bytes); + pub fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + const bytes = v.iov_base[0..v.iov_len]; + if (self.end + bytes.len > self.buf.len) { + try self.flush(); + if (bytes.len > self.buf.len) + return self.unbuffered_writer.write(bytes); + } + + const new_end = self.end + bytes.len; + @memcpy(self.buf[self.end..new_end], bytes); + self.end = new_end; + written += bytes.len; } - - const new_end = self.end + bytes.len; - @memcpy(self.buf[self.end..new_end], bytes); - self.end = new_end; - return bytes.len; + return written; } }; } diff --git a/lib/std/io/counting_reader.zig b/lib/std/io/counting_reader.zig index 2ff9b8a08fe3..72cd10f1ea3e 100644 --- a/lib/std/io/counting_reader.zig +++ b/lib/std/io/counting_reader.zig @@ -9,15 +9,17 @@ pub fn CountingReader(comptime ReaderType: anytype) type { bytes_read: u64 = 0, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*@This(), Error, read); + pub const Reader = io.Reader(*@This(), Error, readv); - pub fn read(self: *@This(), buf: []u8) Error!usize { - const amt = try self.child_reader.read(buf); + const Self = @This(); + + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + const amt = try self.child_reader.readv(iov); self.bytes_read += amt; return amt; } - pub fn reader(self: *@This()) Reader { + pub fn reader(self: *Self) Reader { return .{ .context = self }; } }; diff --git a/lib/std/io/counting_writer.zig b/lib/std/io/counting_writer.zig index 9043e1a47c17..989fbcbab614 100644 --- a/lib/std/io/counting_writer.zig +++ b/lib/std/io/counting_writer.zig @@ -9,12 +9,12 @@ pub fn CountingWriter(comptime WriterType: type) type { child_stream: WriterType, pub const Error = WriterType.Error; - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); const Self = @This(); - pub fn write(self: *Self, bytes: []const u8) Error!usize { - const amt = try self.child_stream.write(bytes); + pub fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + const amt = try self.child_stream.writev(iov); self.bytes_written += amt; return amt; } diff --git a/lib/std/io/fixed_buffer_stream.zig b/lib/std/io/fixed_buffer_stream.zig index 14e5e5de43eb..e455e591c454 100644 --- a/lib/std/io/fixed_buffer_stream.zig +++ b/lib/std/io/fixed_buffer_stream.zig @@ -17,8 +17,8 @@ pub fn FixedBufferStream(comptime Buffer: type) type { pub const SeekError = error{}; pub const GetSeekPosError = error{}; - pub const Reader = io.Reader(*Self, ReadError, read); - pub const Writer = io.Writer(*Self, WriteError, write); + pub const Reader = io.Reader(*Self, ReadError, readv); + pub const Writer = io.Writer(*Self, WriteError, writev); pub const SeekableStream = io.SeekableStream( *Self, @@ -44,31 +44,39 @@ pub fn FixedBufferStream(comptime Buffer: type) type { return .{ .context = self }; } - pub fn read(self: *Self, dest: []u8) ReadError!usize { - const size = @min(dest.len, self.buffer.len - self.pos); - const end = self.pos + size; + pub fn readv(self: *Self, iov: []std.os.iovec) ReadError!usize { + var read: usize = 0; + for (iov) |v| { + const size = @min(v.iov_len, self.buffer.len - self.pos); + const end = self.pos + size; - @memcpy(dest[0..size], self.buffer[self.pos..end]); - self.pos = end; + @memcpy(v.iov_base[0..size], self.buffer[self.pos..end]); + self.pos = end; + read += size; + } - return size; + return read; } /// If the returned number of bytes written is less than requested, the /// buffer is full. Returns `error.NoSpaceLeft` when no bytes would be written. /// Note: `error.NoSpaceLeft` matches the corresponding error from /// `std.fs.File.WriteError`. - pub fn write(self: *Self, bytes: []const u8) WriteError!usize { - if (bytes.len == 0) return 0; - if (self.pos >= self.buffer.len) return error.NoSpaceLeft; - - const n = @min(self.buffer.len - self.pos, bytes.len); - @memcpy(self.buffer[self.pos..][0..n], bytes[0..n]); - self.pos += n; - - if (n == 0) return error.NoSpaceLeft; + pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + var written: usize = 0; + for (iov) |v| { + if (v.iov_len == 0) continue; + if (self.pos >= self.buffer.len) return error.NoSpaceLeft; + + const n = @min(self.buffer.len - self.pos, v.iov_len); + @memcpy(self.buffer[self.pos..][0..n], v.iov_base[0..n]); + self.pos += n; + + if (n == 0) return error.NoSpaceLeft; + written += n; + } - return n; + return written; } pub fn seekTo(self: *Self, pos: u64) SeekError!void { diff --git a/lib/std/io/limited_reader.zig b/lib/std/io/limited_reader.zig index d7e250388139..b20018e6e717 100644 --- a/lib/std/io/limited_reader.zig +++ b/lib/std/io/limited_reader.zig @@ -9,15 +9,16 @@ pub fn LimitedReader(comptime ReaderType: type) type { bytes_left: u64, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); - pub fn read(self: *Self, dest: []u8) Error!usize { - const max_read = @min(self.bytes_left, dest.len); - const n = try self.inner_reader.read(dest[0..max_read]); - self.bytes_left -= n; - return n; + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + for (iov) |*v| { + v.iov_len = @min(self.bytes_left, v.iov_len); + self.bytes_left -= v.iov_len; + } + return try self.inner_reader.readv(iov); } pub fn reader(self: *Self) Reader { diff --git a/lib/std/io/peek_stream.zig b/lib/std/io/peek_stream.zig index 9c28a80cef9d..7bfd3b1e8962 100644 --- a/lib/std/io/peek_stream.zig +++ b/lib/std/io/peek_stream.zig @@ -16,7 +16,7 @@ pub fn PeekStream( fifo: FifoType, pub const Error = ReaderType.Error; - pub const Reader = io.Reader(*Self, Error, read); + pub const Reader = io.Reader(*Self, Error, readv); const Self = @This(); const FifoType = std.fifo.LinearFifo(u8, buffer_type); @@ -59,13 +59,17 @@ pub fn PeekStream( try self.fifo.unget(bytes); } - pub fn read(self: *Self, dest: []u8) Error!usize { - // copy over anything putBack()'d - var dest_index = self.fifo.read(dest); - if (dest_index == dest.len) return dest_index; - - // ask the backing stream for more - dest_index += try self.unbuffered_reader.read(dest[dest_index..]); + pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + var dest_index: usize = 0; + for (iov) |v| { + const dest = v.iov_base[0..v.iov_len]; + // copy over anything putBack()'d + dest_index = self.fifo.read(dest); + if (dest_index == dest.len) return dest_index; + + // ask the backing stream for more + dest_index += try self.unbuffered_reader.read(dest[dest_index..]); + } return dest_index; } diff --git a/lib/std/io/stream_source.zig b/lib/std/io/stream_source.zig index 6e06af8204e0..f4fcc5ad31c2 100644 --- a/lib/std/io/stream_source.zig +++ b/lib/std/io/stream_source.zig @@ -26,8 +26,8 @@ pub const StreamSource = union(enum) { pub const SeekError = io.FixedBufferStream([]u8).SeekError || (if (has_file) std.fs.File.SeekError else error{}); pub const GetSeekPosError = io.FixedBufferStream([]u8).GetSeekPosError || (if (has_file) std.fs.File.GetSeekPosError else error{}); - pub const Reader = io.Reader(*StreamSource, ReadError, read); - pub const Writer = io.Writer(*StreamSource, WriteError, write); + pub const Reader = io.Reader(*StreamSource, ReadError, readv); + pub const Writer = io.Writer(*StreamSource, WriteError, writev); pub const SeekableStream = io.SeekableStream( *StreamSource, SeekError, @@ -38,19 +38,19 @@ pub const StreamSource = union(enum) { getEndPos, ); - pub fn read(self: *StreamSource, dest: []u8) ReadError!usize { + pub fn readv(self: *StreamSource, iov: []std.os.iovec) ReadError!usize { switch (self.*) { - .buffer => |*x| return x.read(dest), - .const_buffer => |*x| return x.read(dest), - .file => |x| if (!has_file) unreachable else return x.read(dest), + .buffer => |*x| return x.readv(iov), + .const_buffer => |*x| return x.readv(iov), + .file => |x| if (!has_file) unreachable else return x.readv(iov), } } - pub fn write(self: *StreamSource, bytes: []const u8) WriteError!usize { + pub fn writev(self: *StreamSource, iov: []std.os.iovec_const) WriteError!usize { switch (self.*) { - .buffer => |*x| return x.write(bytes), + .buffer => |*x| return x.writev(iov), .const_buffer => return error.AccessDenied, - .file => |x| if (!has_file) unreachable else return x.write(bytes), + .file => |x| if (!has_file) unreachable else return x.writev(iov), } } diff --git a/lib/std/json/stringify_test.zig b/lib/std/json/stringify_test.zig index 9baeae9389c4..11ae4b3ea1e8 100644 --- a/lib/std/json/stringify_test.zig +++ b/lib/std/json/stringify_test.zig @@ -298,7 +298,7 @@ test "stringify tuple" { fn testStringify(expected: []const u8, value: anytype, options: StringifyOptions) !void { const ValidationWriter = struct { const Self = @This(); - pub const Writer = std.io.Writer(*Self, Error, write); + pub const Writer = std.io.Writer(*Self, Error, writev); pub const Error = error{ TooMuchData, DifferentData, @@ -314,7 +314,9 @@ fn testStringify(expected: []const u8, value: anytype, options: StringifyOptions return .{ .context = self }; } - fn write(self: *Self, bytes: []const u8) Error!usize { + fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + const first = iov[0]; + const bytes = first.iov_base[0..first.iov_len]; if (self.expected_remaining.len < bytes.len) { std.debug.print( \\====== expected this output: ========= diff --git a/lib/std/net.zig b/lib/std/net.zig index 8b6a805aa1b4..31c12b0a07ba 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -10,6 +10,7 @@ const posix = std.posix; const fs = std.fs; const io = std.io; const native_endian = builtin.target.cpu.arch.endian(); +const tls = std.crypto.tls; // Windows 10 added support for unix sockets in build 17063, redstone 4 is the // first release to support them. @@ -264,10 +265,6 @@ pub const Address = extern union { try posix.getsockname(sockfd, &s.listen_address.any, &socklen); return s; } - - // The returned `Server` has an open `stream`. - // pub fn listenTls(address: Address, options: ListenOptions) ListenError!Server { - // } }; pub const Ip4Address = extern struct { @@ -1804,31 +1801,12 @@ pub const Stream = struct { pub const ReadError = os.ReadError; pub const WriteError = os.WriteError; + pub const GenericStream = io.GenericStream(Stream, ReadError, readv, WriteError, writev, close); - pub const Reader = io.GenericReader(Stream, ReadError, read); - pub const Writer = io.GenericWriter(Stream, WriteError, write); - pub const GenericStream = io.GenericStream(Stream, ReadError, read, WriteError, write, close); - - pub fn reader(self: Stream) Reader { - return .{ .context = self }; - } - - pub fn writer(self: Stream) Writer { - return .{ .context = self }; - } - - pub fn stream(self: Stream) GenericStream { + pub fn any(self: Stream) GenericStream { return .{ .context = self }; } - pub fn read(self: Stream, buffer: []u8) ReadError!usize { - if (builtin.os.tag == .windows) { - return os.windows.ReadFile(self.handle, buffer, null); - } - - return os.read(self.handle, buffer); - } - pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { if (builtin.os.tag == .windows) { // TODO improve this to use ReadFileScatter @@ -1840,51 +1818,31 @@ pub const Stream = struct { return os.readv(s.handle, iovecs); } - /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's - /// file system thread instead of non-blocking. It needs to be reworked to properly - /// use non-blocking I/O. - pub fn write(self: Stream, buffer: []const u8) WriteError!usize { - if (builtin.os.tag == .windows) { - return os.windows.WriteFile(self.handle, buffer, null); - } - - return os.write(self.handle, buffer); - } - /// See https://github.com/ziglang/zig/issues/7699 /// See equivalent function: `std.fs.File.writev`. pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { return os.writev(self.handle, iovecs); } - - /// The `iovecs` parameter is mutable because this function needs to mutate the fields in - /// order to handle partial writes from the underlying OS layer. - /// See https://github.com/ziglang/zig/issues/7699 - /// See equivalent function: `std.fs.File.writevAll`. - pub fn writevAll(self: Stream, iovecs: []os.iovec_const) WriteError!void { - if (iovecs.len == 0) return; - - var i: usize = 0; - while (true) { - var amt = try self.writev(iovecs[i..]); - while (amt >= iovecs[i].iov_len) { - amt -= iovecs[i].iov_len; - i += 1; - if (i >= iovecs.len) return; - } - iovecs[i].iov_base += amt; - iovecs[i].iov_len -= amt; - } - } }; pub const Server = struct { listen_address: Address, - stream: net.Stream, + stream: Stream, pub const Connection = struct { - stream: net.Stream, address: Address, + protocol: Protocol, + socket: Stream, + tls: tls.Server, + + pub const Protocol = enum { plain, tls }; + + pub inline fn stream(conn: *Connection) std.io.AnyStream { + return switch (conn.protocol) { + .plain => conn.socket.any().any(), + .tls => conn.tls.any().any(), + }; + } }; pub fn deinit(s: *Server) void { @@ -1892,17 +1850,24 @@ pub const Server = struct { s.* = undefined; } - pub const AcceptError = posix.AcceptError; - - /// Blocks until a client connects to the server. The returned `Connection` has - /// an open stream. - pub fn accept(s: *Server) AcceptError!Connection { + /// Blocks until a client connects to the server. + /// If tls_options are supplied, will await a client handshake. + /// The returned `Connection` has an open stream. + pub fn accept(s: *Server, tls_options: ?tls.Server.Options) !Connection { var accepted_addr: Address = undefined; var addr_len: posix.socklen_t = @sizeOf(Address); const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); + const socket = Stream{ .handle = fd }; + const protocol: Connection.Protocol = if (tls_options == null) .plain else .tls; + const _tls: tls.Server = if (tls_options) |options| + try tls.Server.init(socket.any().any(), options) + else + undefined; return .{ - .stream = .{ .handle = fd }, .address = accepted_addr, + .protocol = protocol, + .socket = socket, + .tls = _tls, }; } }; diff --git a/lib/std/net/test.zig b/lib/std/net/test.zig index 3e316c545643..ccbde4e93627 100644 --- a/lib/std/net/test.zig +++ b/lib/std/net/test.zig @@ -189,17 +189,17 @@ test "listen on a port, send bytes, receive bytes" { const socket = try net.tcpConnectToAddress(server_address); defer socket.close(); - _ = try socket.writer().writeAll("Hello world!"); + _ = try socket.any().writer().writeAll("Hello world!"); } }; const t = try std.Thread.spawn(.{}, S.clientFn, .{server.listen_address}); defer t.join(); - var client = try server.accept(); - defer client.stream.close(); + var client = try server.accept(null); + defer client.stream().close(); var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); + const n = try client.stream().reader().read(&buf); try testing.expectEqual(@as(usize, 12), n); try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); @@ -247,7 +247,7 @@ fn testClient(addr: net.Address) anyerror!void { fn testServer(server: *net.Server) anyerror!void { if (builtin.os.tag == .wasi) return error.SkipZigTest; - var client = try server.accept(); + var client = try server.accept(null); const stream = client.stream.writer(); try stream.print("hello from server\n", .{}); @@ -280,17 +280,17 @@ test "listen on a unix socket, send bytes, receive bytes" { const socket = try net.connectUnixSocket(path); defer socket.close(); - _ = try socket.writer().writeAll("Hello world!"); + _ = try socket.any().writer().writeAll("Hello world!"); } }; const t = try std.Thread.spawn(.{}, S.clientFn, .{socket_path}); defer t.join(); - var client = try server.accept(); - defer client.stream.close(); + var client = try server.accept(null); + defer client.stream().close(); var buf: [16]u8 = undefined; - const n = try client.stream.reader().read(&buf); + const n = try client.stream().reader().read(&buf); try testing.expectEqual(@as(usize, 12), n); try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]); @@ -317,13 +317,13 @@ test "non-blocking tcp server" { var server = localhost.listen(.{ .force_nonblocking = true }); defer server.deinit(); - const accept_err = server.accept(); + const accept_err = server.accept(null); try testing.expectError(error.WouldBlock, accept_err); const socket_file = try net.tcpConnectToAddress(server.listen_address); defer socket_file.close(); - var client = try server.accept(); + var client = try server.accept(null); defer client.stream.close(); const stream = client.stream.writer(); try stream.print("hello from server\n", .{}); diff --git a/lib/std/os.zig b/lib/std/os.zig index 417e5711455f..ebef44848af9 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -1195,7 +1195,7 @@ pub fn preadv(fd: fd_t, iov: []const iovec, offset: u64) PReadError!usize { } } -pub const WriteError = error{ +pub const WriteError = anyerror || error{ DiskQuota, FileTooBig, InputOutput, diff --git a/lib/std/std.zig b/lib/std/std.zig index 557b320c244e..27920085177a 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -151,13 +151,6 @@ pub const Options = struct { /// it like any other error. keep_sigpipe: bool = false, - /// By default, std.http.Client will support HTTPS connections. Set this option to `true` to - /// disable TLS support. - /// - /// This will likely reduce the size of the binary, but it will also make it impossible to - /// make a HTTPS connection. - http_disable_tls: bool = false, - side_channels_mitigations: crypto.SideChannelsMitigations = crypto.default_side_channels_mitigations, }; diff --git a/lib/std/tar.zig b/lib/std/tar.zig index 13da27ca846d..2782fc1303aa 100644 --- a/lib/std/tar.zig +++ b/lib/std/tar.zig @@ -269,7 +269,7 @@ pub const FileKind = enum { file, }; -/// Iteartor over entries in the tar file represented by reader. +/// Iterator over tar entries pub fn Iterator(comptime ReaderType: type) type { return struct { reader: ReaderType, @@ -295,17 +295,22 @@ pub fn Iterator(comptime ReaderType: type) type { unread_bytes: *u64, parent_reader: ReaderType, - pub const Reader = std.io.Reader(File, ReaderType.Error, File.read); + pub const Reader = std.io.Reader(File, ReaderType.Error, File.readv); pub fn reader(self: File) Reader { return .{ .context = self }; } - pub fn read(self: File, dest: []u8) ReaderType.Error!usize { - const buf = dest[0..@min(dest.len, self.unread_bytes.*)]; - const n = try self.parent_reader.read(buf); - self.unread_bytes.* -= n; - return n; + pub fn readv(self: File, iov: []std.os.iovec) ReaderType.Error!usize { + var n_read: usize = 0; + for (iov) |v| { + const dest = v.iov_base[0..v.iov_len]; + const buf = dest[0..@min(dest.len, self.unread_bytes.*)]; + const n = try self.parent_reader.read(buf); + self.unread_bytes.* -= n; + n_read += n; + } + return n_read; } // Writes file content to writer. diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index a87f28964242..8b8c028bdbeb 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -3335,7 +3335,7 @@ fn AutoIndentingStream(comptime UnderlyingWriter: type) type { return struct { const Self = @This(); pub const WriteError = UnderlyingWriter.Error; - pub const Writer = std.io.Writer(*Self, WriteError, write); + pub const Writer = std.io.Writer(*Self, WriteError, writev); underlying_writer: UnderlyingWriter, @@ -3361,12 +3361,15 @@ fn AutoIndentingStream(comptime UnderlyingWriter: type) type { return .{ .context = self }; } - pub fn write(self: *Self, bytes: []const u8) WriteError!usize { - if (bytes.len == 0) - return @as(usize, 0); + pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + var n_written: usize = 0; + for (iov) |v| { + if (v.iov_len == 0) return n_written; - try self.applyIndent(); - return self.writeNoIndent(bytes); + try self.applyIndent(); + n_written += try self.writeNoIndent(v.iov_base[0..v.iov_len]); + } + return n_written; } // Change the indent delta without changing the final indentation level From de67fd94443fa4842c740b9e290ab5019e4eda91 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Mon, 18 Mar 2024 18:42:33 -0400 Subject: [PATCH 11/17] zig fmt --- lib/compiler/fmt.zig | 5 ++-- lib/std/crypto/tls.zig | 10 +++---- lib/std/crypto/tls/Server.zig | 50 ++++++++++++++++------------------- lib/std/crypto/tls/Stream.zig | 8 ++---- lib/std/http/Client.zig | 3 +-- lib/std/http/Server.zig | 1 - lib/std/io/Reader.zig | 3 +-- lib/std/io/Stream.zig | 1 - lib/std/io/Writer.zig | 2 +- lib/std/net.zig | 2 +- lib/std/os.zig | 2 +- 11 files changed, 36 insertions(+), 51 deletions(-) diff --git a/lib/compiler/fmt.zig b/lib/compiler/fmt.zig index 2fc04b7935a7..7f58a133a3da 100644 --- a/lib/compiler/fmt.zig +++ b/lib/compiler/fmt.zig @@ -322,8 +322,9 @@ fn fmtPathFile( return; if (check_mode) { - const stdout = std.io.getStdOut().writer(); - try stdout.print("{s}\n", .{file_path}); + const stdout = std.io.getStdOut(); + try stdout.writeAll(file_path); + try stdout.writeAll("\n"); fmt.any_error = true; } else { var af = try dir.atomicFile(sub_path, .{ .mode = stat.mode }); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 9d450132fa17..5570dbf38f02 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -1458,10 +1458,8 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - const client_iv = [_]u8{ -0x77, 0x02, 0x2F, 0x09, 0xB2, 0x93, 0x5A, 0x5E, 0x3F, 0x2B, 0xB0, 0x32 - }; - try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); + const client_iv = [_]u8{ 0x77, 0x02, 0x2F, 0x09, 0xB2, 0x93, 0x5A, 0x5E, 0x3F, 0x2B, 0xB0, 0x32 }; + try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } server_command = try server.next(server_command); // send_change_cipher_spec @@ -1501,9 +1499,7 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - const client_iv = [_]u8{ -0x54, 0xF3, 0x34, 0x20, 0xA8, 0x50, 0xF5, 0x3A, 0x22, 0x9A, 0xBB, 0x1B - }; + const client_iv = [_]u8{ 0x54, 0xF3, 0x34, 0x20, 0xA8, 0x50, 0xF5, 0x3A, 0x22, 0x9A, 0xBB, 0x1B }; try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index 6994e14898f3..8eba2bd730ae 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -179,9 +179,9 @@ pub fn recv_hello(self: *Self) !ClientHello { .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha256, }, - .ecdsa256 => &[_]tls.SignatureScheme{ .ecdsa_secp256r1_sha256 }, - .ecdsa384 => &[_]tls.SignatureScheme{ .ecdsa_secp384r1_sha384 }, - .ed25519 => &[_]tls.SignatureScheme{ .ed25519 }, + .ecdsa256 => &[_]tls.SignatureScheme{.ecdsa_secp256r1_sha256}, + .ecdsa384 => &[_]tls.SignatureScheme{.ecdsa_secp384r1_sha384}, + .ed25519 => &[_]tls.SignatureScheme{.ed25519}, }; var algos_iter = try stream.iterator(u16, tls.SignatureScheme); while (try algos_iter.next()) |algo| { @@ -205,12 +205,11 @@ pub fn recv_hello(self: *Self) !ClientHello { crypto.random.bytes(&server_random); const key_pair: tls.KeyPair = switch (key_share.?) { - inline - .secp256r1, - .secp384r1, - .x25519, - .x25519_kyber768d00, - => |_, tag| brk: { + inline .secp256r1, + .secp384r1, + .x25519, + .x25519_kyber768d00, + => |_, tag| brk: { const pair = tls.NamedGroupT(tag).KeyPair.create(null) catch unreachable; break :brk @unionInit(tls.KeyPair, @tagName(tag), pair); }, @@ -307,17 +306,17 @@ pub fn send_certificate_verify(self: *Self, verify: Command.CertificateVerify) ! const signature: []const u8 = switch (verify.scheme) { inline .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384 => |comptime_scheme| brk: { - const Ecdsa = comptime_scheme.Ecdsa(); - const key = switch (comptime_scheme) { - .ecdsa_secp256r1_sha256 => self.options.certificate_key.ecdsa256, - .ecdsa_secp384r1_sha384 => self.options.certificate_key.ecdsa384, - else => unreachable, - }; - - var signer = Ecdsa.Signer.init(key, verify.salt[0..Ecdsa.noise_length].*); - signer.update(sig_content); - const sig = signer.finalize() catch return stream.writeError(.internal_error); - break :brk &sig.toBytes(); + const Ecdsa = comptime_scheme.Ecdsa(); + const key = switch (comptime_scheme) { + .ecdsa_secp256r1_sha256 => self.options.certificate_key.ecdsa256, + .ecdsa_secp384r1_sha384 => self.options.certificate_key.ecdsa384, + else => unreachable, + }; + + var signer = Ecdsa.Signer.init(key, verify.salt[0..Ecdsa.noise_length].*); + signer.update(sig_content); + const sig = signer.finalize() catch return stream.writeError(.internal_error); + break :brk &sig.toBytes(); }, inline .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, @@ -328,12 +327,12 @@ pub fn send_certificate_verify(self: *Self, verify: Command.CertificateVerify) ! switch (key.public.n.bits() / 8) { inline 128, 256, 512 => |modulus_length| { - const sig = Certificate.rsa.PSSSignature.sign( + const sig = Certificate.rsa.PSSSignature.sign( modulus_length, sig_content, Hash, key, - verify.salt[0..Hash.digest_length].*, + verify.salt[0..Hash.digest_length].*, ) catch return stream.writeError(.bad_certificate); break :brk &sig; }, @@ -391,10 +390,7 @@ pub fn recv_finished(self: *Self) !void { const handshake_hash = stream.transcript_hash.?.peek(); - const application_cipher = tls.ApplicationCipher.init( - stream.cipher.handshake, - handshake_hash, - ); + const application_cipher = tls.ApplicationCipher.init(stream.cipher.handshake, handshake_hash); const expected = switch (stream.cipher.handshake) { inline else => |p| brk: { @@ -461,7 +457,7 @@ pub const Options = struct { key_shares: []const tls.NamedGroup = &tls.supported_groups, /// Certificate(s) to send in `send_certificate` messages. certificate: tls.Certificate, - /// Key to use in `send_certificate_verify`. Must match `certificate.parse().pub_key_algo`. + /// Key to use in `send_certificate_verify`. Must match `certificate.parse().pub_key_algo`. certificate_key: CertificateKey, pub const CertificateKey = union(enum) { diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig index 8648ecde98d9..f006d5b3a40b 100644 --- a/lib/std/crypto/tls/Stream.zig +++ b/lib/std/crypto/tls/Stream.zig @@ -328,13 +328,10 @@ pub fn readPlaintext(self: *Self) !Plaintext { self.closed = true; return res; }, - .certificate_revoked, - .certificate_unknown, - .certificate_expired, - .certificate_required => {}, + .certificate_revoked, .certificate_unknown, .certificate_expired, .certificate_required => {}, else => { return self.writeError(.unexpected_message); - } + }, } }, // > An implementation may receive an unencrypted record of type @@ -544,4 +541,3 @@ const InnerPlaintext = struct { handshake_type: HandshakeType, len: u24, }; - diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index dc911a65ddc0..a239ddefbfd7 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -282,7 +282,7 @@ pub const Connection = struct { .{ .iov_base = buffer.ptr, .iov_len = buffer.len }, .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len }, }; - const nread = try conn.readvDirect(&iovecs); + const nread = try conn.readvDirect(&iovecs); if (nread > buffer.len) { conn.read_start = 0; @@ -1140,7 +1140,6 @@ pub const Request = struct { .chunked => { var w = req.connection.?.writer(); - if (iov_len > 0) { try w.print("{x}\r\n", .{iov_len}); try w.writevAll(iov); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 6aa2fef4ff67..cbd7596f74ca 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1137,4 +1137,3 @@ fn rebase(s: *Server, index: usize) void { } s.read_buffer_len = index + leftover.len; } - diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index e8b76533d887..f06a7be4a256 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -12,7 +12,6 @@ readvFn: *const fn (context: *const anyopaque, iov: []iovec) anyerror!usize, pub const Error = anyerror; - /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. @@ -24,7 +23,7 @@ pub fn readv(self: Self, iov: []iovec) anyerror!usize { /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. pub fn read(self: Self, buffer: []u8) anyerror!usize { - var iov = [_]iovec{ .{ .iov_base = buffer.ptr, .iov_len = buffer.len } }; + var iov = [_]iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; return self.readv(&iov); } diff --git a/lib/std/io/Stream.zig b/lib/std/io/Stream.zig index f7941c70405f..a8658fd7d7dd 100644 --- a/lib/std/io/Stream.zig +++ b/lib/std/io/Stream.zig @@ -35,4 +35,3 @@ pub fn writer(self: Self) std.io.AnyWriter { pub fn close(self: Self) void { return self.closeFn(self.context); } - diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index c5e92bf8da1e..383120c4a16b 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -34,7 +34,7 @@ pub fn writevAll(self: Self, iovecs: []iovec_const) anyerror!void { } pub fn write(self: Self, bytes: []const u8) anyerror!usize { - var iov = [_]iovec_const{ .{ .iov_base = bytes.ptr, .iov_len = bytes.len } }; + var iov = [_]iovec_const{.{ .iov_base = bytes.ptr, .iov_len = bytes.len }}; return self.writev(&iov); } diff --git a/lib/std/net.zig b/lib/std/net.zig index 31c12b0a07ba..c209c5ac4096 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1840,7 +1840,7 @@ pub const Server = struct { pub inline fn stream(conn: *Connection) std.io.AnyStream { return switch (conn.protocol) { .plain => conn.socket.any().any(), - .tls => conn.tls.any().any(), + .tls => conn.tls.any().any(), }; } }; diff --git a/lib/std/os.zig b/lib/std/os.zig index ebef44848af9..417e5711455f 100644 --- a/lib/std/os.zig +++ b/lib/std/os.zig @@ -1195,7 +1195,7 @@ pub fn preadv(fd: fd_t, iov: []const iovec, offset: u64) PReadError!usize { } } -pub const WriteError = anyerror || error{ +pub const WriteError = error{ DiskQuota, FileTooBig, InputOutput, From d39cdcf751f25f75a07e7105cd0aea1b34a78d0f Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Mon, 18 Mar 2024 19:01:26 -0400 Subject: [PATCH 12/17] fix server reader/writer anyerror --- lib/std/http/Server.zig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index cbd7596f74ca..a540274ad7ba 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -402,7 +402,7 @@ pub const Request = struct { request: *Request, content: []const u8, options: RespondOptions, - ) Response.WriteError!void { + ) !void { const max_extra_headers = 25; assert(options.status != .@"continue"); assert(options.extra_headers.len <= max_extra_headers); @@ -767,7 +767,7 @@ pub const Request = struct { /// request's expect field to `null`. /// /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { + pub fn reader(request: *Request) !std.io.AnyReader { const s = request.server; assert(s.state == .received_head); s.state = .receiving_body; From 8c5422b2e36f3760eded16e7b46113be7fc1025d Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Tue, 19 Mar 2024 15:47:36 -0400 Subject: [PATCH 13/17] Fixup new writev and readv --- lib/std/array_list.zig | 6 +- lib/std/bounded_array.zig | 2 +- lib/std/compress.zig | 6 +- lib/std/compress/flate/deflate.zig | 2 +- lib/std/compress/flate/inflate.zig | 2 +- lib/std/compress/lzma.zig | 2 +- lib/std/compress/xz.zig | 2 +- lib/std/compress/zstandard.zig | 2 +- lib/std/compress/zstandard/readers.zig | 2 +- lib/std/crypto/sha2.zig | 2 +- lib/std/crypto/siphash.zig | 12 ++- lib/std/crypto/tls.zig | 118 ++++--------------------- lib/std/crypto/tls/Client.zig | 39 ++------ lib/std/crypto/tls/Server.zig | 71 +++++++-------- lib/std/crypto/tls/Stream.zig | 4 +- lib/std/fifo.zig | 4 +- lib/std/http/Client.zig | 20 ++--- lib/std/http/Server.zig | 36 ++++---- lib/std/http/test.zig | 13 +-- lib/std/io.zig | 22 +++-- lib/std/io/Reader.zig | 4 +- lib/std/io/Stream.zig | 8 +- lib/std/io/Writer.zig | 4 +- lib/std/io/buffered_reader.zig | 4 +- lib/std/io/buffered_tee.zig | 4 +- lib/std/io/buffered_writer.zig | 2 +- lib/std/io/counting_reader.zig | 2 +- lib/std/io/counting_writer.zig | 2 +- lib/std/io/fixed_buffer_stream.zig | 4 +- lib/std/io/limited_reader.zig | 17 ++-- lib/std/io/multi_writer.zig | 10 +-- lib/std/io/stream_source.zig | 4 +- lib/std/json/stringify_test.zig | 2 +- lib/std/net.zig | 44 ++++----- lib/std/tar.zig | 4 +- lib/std/zig/render.zig | 2 +- src/Package/Fetch.zig | 16 ++-- src/Package/Fetch/git.zig | 48 +++------- src/codegen/c.zig | 39 ++++---- src/main.zig | 8 +- 40 files changed, 237 insertions(+), 358 deletions(-) diff --git a/lib/std/array_list.zig b/lib/std/array_list.zig index ce2df2098f81..05779c3a8f5a 100644 --- a/lib/std/array_list.zig +++ b/lib/std/array_list.zig @@ -354,7 +354,7 @@ pub fn ArrayListAligned(comptime T: type, comptime alignment: ?u29) type { /// Same as `append` except it returns the number of bytes written, which is always the same /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. - fn appendWritev(self: *Self, iov: []std.os.iovec_const) Allocator.Error!usize { + fn appendWritev(self: *Self, iov: []const std.os.iovec_const) Allocator.Error!usize { var written: usize = 0; for (iov) |v| { try self.appendSlice(v.iov_base[0..v.iov_len]); @@ -945,7 +945,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ /// which is always the same as `m.len`. The purpose of this function /// existing is to match `std.io.Writer` API. /// Invalidates element pointers if additional memory is needed. - fn appendWritev(context: WriterContext, iov: []std.os.iovec_const) Allocator.Error!usize { + fn appendWritev(context: WriterContext, iov: []const std.os.iovec_const) Allocator.Error!usize { var written: usize = 0; for (iov) |v| { try context.self.appendSlice(context.allocator, v.iov_base[0..v.iov_len]); @@ -963,7 +963,7 @@ pub fn ArrayListAlignedUnmanaged(comptime T: type, comptime alignment: ?u29) typ } /// The purpose of this function existing is to match `std.io.Writer` API. - fn appendWritevFixed(self: *Self, iov: []std.os.iovec_const) error{OutOfMemory}!usize { + fn appendWritevFixed(self: *Self, iov: []const std.os.iovec_const) error{OutOfMemory}!usize { var written: usize = 0; for (iov) |v| { const m = v.iov_base[0..v.iov_len]; diff --git a/lib/std/bounded_array.zig b/lib/std/bounded_array.zig index b5d04cac9d33..ac8cea8576dd 100644 --- a/lib/std/bounded_array.zig +++ b/lib/std/bounded_array.zig @@ -280,7 +280,7 @@ pub fn BoundedArrayAligned( /// Same as `appendSlice` except it returns the number of bytes written, which is always the same /// as `m.len`. The purpose of this function existing is to match `std.io.Writer` API. - fn appendWritev(self: *Self, iov: []std.os.iovec_const) error{Overflow}!usize { + fn appendWritev(self: *Self, iov: []const std.os.iovec_const) error{Overflow}!usize { var written: usize = 0; for (iov) |v| { const m = v.iov_base[0..v.iov_len]; diff --git a/lib/std/compress.zig b/lib/std/compress.zig index 0ac8211f6907..242eb8aff9b9 100644 --- a/lib/std/compress.zig +++ b/lib/std/compress.zig @@ -21,7 +21,7 @@ pub fn HashedReader( pub const Error = ReaderType.Error; pub const Reader = std.io.Reader(*@This(), Error, readv); - pub fn readv(self: *@This(), iov: []std.os.iovec) Error!usize { + pub fn readv(self: *@This(), iov: []const std.os.iovec) Error!usize { const n_read = try self.child_reader.readv(iov); var hashed_amt: usize = 0; for (iov) |v| { @@ -55,9 +55,9 @@ pub fn HashedWriter( hasher: HasherType, pub const Error = WriterType.Error; - pub const Writer = std.io.Writer(*@This(), Error, write); + pub const Writer = std.io.Writer(*@This(), Error, writev); - pub fn write(self: *@This(), iov: []std.os.iovec_const) Error!usize { + pub fn writev(self: *@This(), iov: []const std.os.iovec_const) Error!usize { const n_written = try self.child_writer.writev(iov); var hashed_amt: usize = 0; for (iov) |v| { diff --git a/lib/std/compress/flate/deflate.zig b/lib/std/compress/flate/deflate.zig index 6f1c744b3fd8..094dc847fc4c 100644 --- a/lib/std/compress/flate/deflate.zig +++ b/lib/std/compress/flate/deflate.zig @@ -359,7 +359,7 @@ fn Deflate(comptime container: Container, comptime WriterType: type, comptime Bl /// Write `input` of uncompressed data. /// See compress. - pub fn writev(self: *Self, iov: []std.os.iovec_const) !usize { + pub fn writev(self: *Self, iov: []const std.os.iovec_const) !usize { var written: usize = 0; for (iov) |v| { const input = v.iov_base[0..v.iov_len]; diff --git a/lib/std/compress/flate/inflate.zig b/lib/std/compress/flate/inflate.zig index 93f3a30cea0d..15752a4c9599 100644 --- a/lib/std/compress/flate/inflate.zig +++ b/lib/std/compress/flate/inflate.zig @@ -345,7 +345,7 @@ pub fn Inflate(comptime container: Container, comptime LookaheadType: type, comp /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { var read: usize = 0; for (iov) |v| { const out = try self.get(v.iov_len); diff --git a/lib/std/compress/lzma.zig b/lib/std/compress/lzma.zig index 2bd31b636ebe..34eddb99f4ff 100644 --- a/lib/std/compress/lzma.zig +++ b/lib/std/compress/lzma.zig @@ -63,7 +63,7 @@ pub fn Decompress(comptime ReaderType: type) type { self.* = undefined; } - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const writer = self.to_read.writer(self.allocator); var n_read: usize = 0; for (iov) |v| { diff --git a/lib/std/compress/xz.zig b/lib/std/compress/xz.zig index 6458a7e7ca2a..a0b1c7c33993 100644 --- a/lib/std/compress/xz.zig +++ b/lib/std/compress/xz.zig @@ -71,7 +71,7 @@ pub fn Decompress(comptime ReaderType: type) type { return .{ .context = self }; } - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; if (buffer.len == 0) diff --git a/lib/std/compress/zstandard.zig b/lib/std/compress/zstandard.zig index b1b283e5afd4..0733250bc22d 100644 --- a/lib/std/compress/zstandard.zig +++ b/lib/std/compress/zstandard.zig @@ -105,7 +105,7 @@ pub fn Decompressor(comptime ReaderType: type) type { return .{ .context = self }; } - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; if (buffer.len == 0) return 0; diff --git a/lib/std/compress/zstandard/readers.zig b/lib/std/compress/zstandard/readers.zig index 27590cef9300..2235b46bf2de 100644 --- a/lib/std/compress/zstandard/readers.zig +++ b/lib/std/compress/zstandard/readers.zig @@ -17,7 +17,7 @@ pub const ReversedByteReader = struct { return .{ .context = self }; } - fn readvFn(ctx: *ReversedByteReader, iov: []std.os.iovec) !usize { + fn readvFn(ctx: *ReversedByteReader, iov: []const std.os.iovec) !usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; std.debug.assert(buffer.len > 0); diff --git a/lib/std/crypto/sha2.zig b/lib/std/crypto/sha2.zig index 0287de90ed6c..c5f151771c4f 100644 --- a/lib/std/crypto/sha2.zig +++ b/lib/std/crypto/sha2.zig @@ -394,7 +394,7 @@ fn Sha2x32(comptime params: Sha2Params32) type { pub const Error = error{}; pub const Writer = std.io.Writer(*Self, Error, writev); - fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { var written: usize = 0; for (iov) |v| { self.update(v.iov_base[0..v.iov_len]); diff --git a/lib/std/crypto/siphash.zig b/lib/std/crypto/siphash.zig index 5d1ac4f87469..1272a4f98b63 100644 --- a/lib/std/crypto/siphash.zig +++ b/lib/std/crypto/siphash.zig @@ -240,11 +240,15 @@ fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize) } pub const Error = error{}; - pub const Writer = std.io.Writer(*Self, Error, write); + pub const Writer = std.io.Writer(*Self, Error, writev); - fn write(self: *Self, bytes: []const u8) Error!usize { - self.update(bytes); - return bytes.len; + fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + self.update(v.iov_base[0..v.iov_len]); + written += v.iov_len; + } + return written; } pub fn writer(self: *Self) Writer { diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 5570dbf38f02..1d4291ba6d6a 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -535,21 +535,16 @@ pub const supported_signature_schemes = [_]SignatureScheme{ }; /// Key exchange formats +/// +/// https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml pub const NamedGroup = enum(u16) { + // Use reserved value for invalid. invalid = 0x0000, // Elliptic Curve Groups (ECDHE) secp256r1 = 0x0017, secp384r1 = 0x0018, secp521r1 = 0x0019, x25519 = 0x001D, - x448 = 0x001E, - - // Finite Field Groups (DHE) - ffdhe2048 = 0x0100, - ffdhe3072 = 0x0101, - ffdhe4096 = 0x0102, - ffdhe6144 = 0x0103, - ffdhe8192 = 0x0104, // Hybrid post-quantum key agreements. Still in draft. x25519_kyber768d00 = 0x6399, @@ -561,69 +556,19 @@ pub fn NamedGroupT(comptime named_group: NamedGroup) type { .secp256r1 => crypto.sign.ecdsa.EcdsaP256Sha256, .secp384r1 => crypto.sign.ecdsa.EcdsaP384Sha384, .x25519 => crypto.dh.X25519, - .x25519_kyber768d00 => X25519Kyber768Draft, else => |t| @compileError("unsupported named group " ++ @tagName(t)), }; } -// Hybrid share, see https://www.ietf.org/archive/id/draft-ietf-tls-hybrid-design-05.html -pub const X25519Kyber768Draft = struct { - pub const X25519 = NamedGroupT(.x25519); - pub const Kyber768 = crypto.kem.kyber_d00.Kyber768; - pub const KeyPair = struct { - x25519: X25519.KeyPair, - kyber768d00: Kyber768.KeyPair, - - pub const seed_length = X25519.KeyPair.seed_length + Kyber768.KeyPair.seed_length; - - pub fn create(seed: ?[seed_length]u8) !@This() { - var seed_: [seed_length]u8 = seed orelse undefined; - if (seed == null) { - crypto.random.bytes(&seed_); - } - return .{ - .x25519 = try X25519.KeyPair.create(seed_[0..X25519.KeyPair.seed_length].*), - .kyber768d00 = try Kyber768.KeyPair.create(seed_[X25519.KeyPair.seed_length..].*), - }; - } - }; - pub const PublicKey = struct { - x25519: X25519.PublicKey, - kyber768d00: Kyber768.PublicKey, - - pub const bytes_length = X25519.public_length + Kyber768.PublicKey.bytes_length; - pub const ciphertext_length = X25519.public_length + Kyber768.ciphertext_length; - - pub fn toBytes(self: @This()) [bytes_length]u8 { - return self.x25519 ++ self.kyber768d00.toBytes(); - } - - pub fn ciphertext(self: @This()) [ciphertext_length]u8 { - return self.x25519 ++ self.kyber768d00.encaps(null).ciphertext; - } - }; -}; pub const KeyPair = union(NamedGroup) { invalid: void, secp256r1: NamedGroupT(.secp256r1).KeyPair, secp384r1: NamedGroupT(.secp384r1).KeyPair, secp521r1: void, x25519: NamedGroupT(.x25519).KeyPair, - x448: void, - - ffdhe2048: void, - ffdhe3072: void, - ffdhe4096: void, - ffdhe6144: void, - ffdhe8192: void, - - x25519_kyber768d00: NamedGroupT(.x25519_kyber768d00).KeyPair, + x25519_kyber768d00: void, pub fn toKeyShare(self: @This()) KeyShare { return switch (self) { - .x25519_kyber768d00 => |k| .{ .x25519_kyber768d00 = X25519Kyber768Draft.PublicKey{ - .x25519 = k.x25519.public_key, - .kyber768d00 = k.kyber768d00.public_key, - } }, .secp256r1 => |k| .{ .secp256r1 = k.public_key }, .secp384r1 => |k| .{ .secp384r1 = k.public_key }, .x25519 => |k| .{ .x25519 = k.public_key }, @@ -638,15 +583,7 @@ pub const KeyShare = union(NamedGroup) { secp384r1: NamedGroupT(.secp384r1).PublicKey, secp521r1: void, x25519: NamedGroupT(.x25519).PublicKey, - x448: void, - - ffdhe2048: void, - ffdhe3072: void, - ffdhe4096: void, - ffdhe6144: void, - ffdhe8192: void, - - x25519_kyber768d00: NamedGroupT(.x25519_kyber768d00).PublicKey, + x25519_kyber768d00: void, const Self = @This(); @@ -657,18 +594,6 @@ pub const KeyShare = union(NamedGroup) { const group = try stream.read(NamedGroup); const len = try stream.read(u16); switch (group) { - .x25519_kyber768d00 => { - const T = X25519Kyber768Draft.Kyber768.PublicKey; - var res = Self{ .x25519_kyber768d00 = undefined }; - - try reader.readNoEof(&res.x25519_kyber768d00.x25519); - - var buf: [T.bytes_length]u8 = undefined; - try reader.readNoEof(&buf); - res.x25519_kyber768d00.kyber768d00 = T.fromBytes(&buf) catch return Error.TlsDecryptError; - - return res; - }, inline .secp256r1, .secp384r1 => |k| { const T = NamedGroupT(k).PublicKey; var buf: [T.uncompressed_sec1_encoded_length]u8 = undefined; @@ -692,7 +617,6 @@ pub const KeyShare = union(NamedGroup) { var res: usize = 0; res += try stream.write(NamedGroup, self); const public = switch (self) { - .x25519_kyber768d00 => |k| if (stream.is_client) &k.toBytes() else &k.ciphertext(), .secp256r1 => |k| &k.toUncompressedSec1(), .secp384r1 => |k| &k.toUncompressedSec1(), .x25519 => |k| &k, @@ -704,7 +628,6 @@ pub const KeyShare = union(NamedGroup) { }; /// In descending order of preference pub const supported_groups = [_]NamedGroup{ - .x25519_kyber768d00, .secp256r1, .secp384r1, .x25519, @@ -1337,13 +1260,13 @@ const TestStream = struct { self.buffer.deinit(allocator); } - pub fn readv(self: *Self, iov: []std.os.iovec) ReadError!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { const first = iov[0]; try self.buffer.readFirst(first.iov_base[0..first.iov_len], first.iov_len); return first.iov_len; } - pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { var written: usize = 0; for (iov) |v| { try self.buffer.writeSlice(v.iov_base[0..v.iov_len]); @@ -1406,10 +1329,11 @@ test "tls client and server handshake, data, and close_notify" { }, .options = .{ .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, + .key_shares = &[_]NamedGroup{.x25519}, .certificate = .{ .entries = &[_]Certificate.Entry{ .{ .data = server_cert }, } }, - .certificate_key = server_rsa, + .certificate_key = .{ .rsa = server_rsa }, }, }; @@ -1417,29 +1341,27 @@ test "tls client and server handshake, data, and close_notify" { const session_id: [32]u8 = ("session_id012345" ** 2).*; const client_random: [32]u8 = ("client_random012" ** 2).*; const server_random: [32]u8 = ("server_random012" ** 2).*; - const client_x25519_seed: [32]u8 = ("client_seed01234" ** 2).*; - const server_x25519_seed: [32]u8 = ("server_seed01234" ** 2).*; + const client_key_seed: [32]u8 = ("client_seed01234" ** 2).*; + const server_keygen_seed: [48]u8 = ("server_seed01234" ** 3).*; const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; const key_pairs = try Client.KeyPairs.initAdvanced( client_random, session_id, - client_x25519_seed ++ client_x25519_seed, - client_x25519_seed, - client_x25519_seed ++ [_]u8{0} ** (48 - 32), - client_x25519_seed, + client_key_seed, + client_key_seed ++ [_]u8{0} ** (48 - 32), + client_key_seed, ); var client_command = Client.Command{ .send_hello = key_pairs }; client_command = try client.next(client_command); try std.testing.expect(client_command == .recv_hello); - var server_command = Server.Command{ .recv_hello = {} }; + var server_command = Server.Command{ .recv_hello = .{ + .server_random = server_random, + .keygen_seed = server_keygen_seed, + } }; server_command = try server.next(server_command); // recv_hello try std.testing.expect(server_command == .send_hello); - server_command.send_hello.server_random = server_random; - server_command.send_hello.server_pair = .{ - .x25519 = crypto.dh.X25519.KeyPair.create(server_x25519_seed) catch unreachable, - }; server_command = try server.next(server_command); // send_hello try std.testing.expect(server_command == .send_change_cipher_spec); @@ -1458,8 +1380,6 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - const client_iv = [_]u8{ 0x77, 0x02, 0x2F, 0x09, 0xB2, 0x93, 0x5A, 0x5E, 0x3F, 0x2B, 0xB0, 0x32 }; - try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } server_command = try server.next(server_command); // send_change_cipher_spec @@ -1499,8 +1419,6 @@ test "tls client and server handshake, data, and close_notify" { try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - const client_iv = [_]u8{ 0x54, 0xF3, 0x34, 0x20, 0xA8, 0x50, 0xF5, 0x3A, 0x22, 0x9A, 0xBB, 0x1B }; - try std.testing.expectEqualSlices(u8, &client_iv, &c.client_iv); } try client.any().writer().writeAll("ping"); diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3dc730ecc083..0a15f1212a57 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -110,11 +110,8 @@ pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { .{ .signature_algorithms = &tls.supported_signature_schemes }, .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, .{ .key_share = &[_]tls.KeyShare{ - .{ .x25519_kyber768d00 = .{ - .x25519 = key_pairs.x25519.public_key, - .kyber768d00 = key_pairs.kyber768d00.public_key, - } }, .{ .secp256r1 = key_pairs.secp256r1.public_key }, + .{ .secp384r1 = key_pairs.secp384r1.public_key }, .{ .x25519 = key_pairs.x25519.public_key }, } }, }, @@ -165,23 +162,6 @@ pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { const named_group = try stream.read(tls.NamedGroup); const key_size = try stream.read(u16); switch (named_group) { - .x25519_kyber768d00 => { - const T = tls.NamedGroupT(.x25519_kyber768d00); - const x25519_len = T.X25519.public_length; - const expected_len = x25519_len + T.Kyber768.ciphertext_length; - if (key_size != expected_len) return stream.writeError(.illegal_parameter); - var server_ks: [expected_len]u8 = undefined; - try r.readNoEof(&server_ks); - - const mult = T.X25519.scalarmult( - key_pairs.x25519.secret_key, - server_ks[0..x25519_len].*, - ) catch return stream.writeError(.decrypt_error); - const decaps = key_pairs.kyber768d00.secret_key.decaps( - server_ks[x25519_len..expected_len], - ) catch return stream.writeError(.decrypt_error); - shared_key = &(mult ++ decaps); - }, .x25519 => { const T = tls.NamedGroupT(.x25519); const expected_len = T.public_length; @@ -434,7 +414,7 @@ pub const ReadError = anyerror; pub const WriteError = anyerror; /// Reads next application_data message. -pub fn readv(self: *Self, buffers: []std.os.iovec) ReadError!usize { +pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { var stream = &self.stream; if (stream.eof()) return 0; @@ -484,7 +464,7 @@ pub fn readv(self: *Self, buffers: []std.os.iovec) ReadError!usize { return try stream.readv(buffers); } -pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { +pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { if (self.stream.eof()) return 0; const res = try self.stream.writev(iov); @@ -528,7 +508,6 @@ pub const Options = struct { pub const KeyPairs = struct { hello_rand: [hello_rand_length]u8, session_id: [session_id_length]u8, - kyber768d00: Kyber768, secp256r1: Secp256r1, secp384r1: Secp384r1, x25519: X25519, @@ -538,13 +517,11 @@ pub const KeyPairs = struct { const X25519 = tls.NamedGroupT(.x25519).KeyPair; const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; const Secp384r1 = tls.NamedGroupT(.secp384r1).KeyPair; - const Kyber768 = tls.NamedGroupT(.x25519_kyber768d00).Kyber768.KeyPair; pub fn init() @This() { var random_buffer: [ hello_rand_length + session_id_length + - Kyber768.seed_length + Secp256r1.seed_length + Secp384r1.seed_length + X25519.seed_length @@ -555,17 +532,15 @@ pub const KeyPairs = struct { const split1 = hello_rand_length; const split2 = split1 + session_id_length; - const split3 = split2 + Kyber768.seed_length; - const split4 = split3 + Secp256r1.seed_length; - const split5 = split4 + Secp384r1.seed_length; + const split3 = split2 + Secp256r1.seed_length; + const split4 = split3 + Secp384r1.seed_length; return initAdvanced( random_buffer[0..split1].*, random_buffer[split1..split2].*, random_buffer[split2..split3].*, random_buffer[split3..split4].*, - random_buffer[split4..split5].*, - random_buffer[split5..].*, + random_buffer[split4..].*, ) catch continue; } } @@ -573,13 +548,11 @@ pub const KeyPairs = struct { pub fn initAdvanced( hello_rand: [hello_rand_length]u8, session_id: [session_id_length]u8, - kyber_768_seed: [Kyber768.seed_length]u8, secp256r1_seed: [Secp256r1.seed_length]u8, secp384r1_seed: [Secp384r1.seed_length]u8, x25519_seed: [X25519.seed_length]u8, ) !@This() { return .{ - .kyber768d00 = Kyber768.create(kyber_768_seed) catch {}, .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. }, diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index 8eba2bd730ae..759f9db7c9d3 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -39,11 +39,11 @@ pub fn init(stream: std.io.AnyStream, options: Options) !Self { if (expected != options.certificate_key) return error.CertificateKeyMismatch; // TODO: verify private key corresponds to public key - const cmd_init = Command{ .recv_hello = {} }; - var command = cmd_init; + var command = initial_command(); while (command != .none) { command = res.next(command) catch |err| switch (err) { - error.ConnectionResetByPeer => cmd_init, + // Prevent replay attacks in later handshake stages. + error.ConnectionResetByPeer => initial_command(), else => return err, }; } @@ -51,13 +51,21 @@ pub fn init(stream: std.io.AnyStream, options: Options) !Self { return res; } +inline fn initial_command() Command { + var res = Command{ .recv_hello = undefined }; + crypto.random.bytes(&res.recv_hello.server_random); + crypto.random.bytes(&res.recv_hello.keygen_seed); + + return res; +} + /// Executes handshake command and returns next one. pub fn next(self: *Self, command: Command) !Command { var stream = &self.stream; switch (command) { - .recv_hello => { - const client_hello = try self.recv_hello(); + .recv_hello => |random| { + const client_hello = try self.recv_hello(random); return .{ .send_hello = client_hello }; }, @@ -105,7 +113,7 @@ pub fn next(self: *Self, command: Command) !Command { } } -pub fn recv_hello(self: *Self) !ClientHello { +pub fn recv_hello(self: *Self, random: Command.Random) !ClientHello { var stream = &self.stream; var reader = stream.any().reader(); @@ -166,12 +174,14 @@ pub fn recv_hello(self: *Self) !ClientHello { if (ks == s and key_share == null) key_share = ks; } } + if (key_share == null) return stream.writeError(.decode_error); }, .ec_point_formats => { var format_iter = try stream.iterator(u8, tls.EcPointFormat); while (try format_iter.next()) |f| { if (f == .uncompressed) ec_point_format = .uncompressed; } + if (ec_point_format == null) return stream.writeError(.decode_error); }, .signature_algorithms => { const acceptable = switch (self.options.certificate_key) { @@ -189,6 +199,7 @@ pub fn recv_hello(self: *Self) !ClientHello { if (algo == a and sig_scheme == null) sig_scheme = algo; } } + if (sig_scheme == null) return stream.writeError(.decode_error); }, else => { try reader.skipBytes(ext.len, .{}); @@ -201,16 +212,13 @@ pub fn recv_hello(self: *Self) !ClientHello { if (ec_point_format == null) return stream.writeError(.missing_extension); if (sig_scheme == null) return stream.writeError(.missing_extension); - var server_random: [32]u8 = undefined; - crypto.random.bytes(&server_random); - - const key_pair: tls.KeyPair = switch (key_share.?) { + const key_pair = switch (key_share.?) { inline .secp256r1, .secp384r1, .x25519, - .x25519_kyber768d00, => |_, tag| brk: { - const pair = tls.NamedGroupT(tag).KeyPair.create(null) catch unreachable; + const T = tls.NamedGroupT(tag).KeyPair; + const pair = T.create(random.keygen_seed[0..T.seed_length].*) catch unreachable; break :brk @unionInit(tls.KeyPair, @tagName(tag), pair); }, else => return stream.writeError(.decode_error), @@ -223,7 +231,7 @@ pub fn recv_hello(self: *Self) !ClientHello { .cipher_suite = cipher_suite, .key_share = key_share.?, .sig_scheme = sig_scheme.?, - .server_random = server_random, + .server_random = random.server_random, .server_pair = key_pair, }; } @@ -246,19 +254,6 @@ pub fn send_hello(self: *Self, client_hello: ClientHello) !void { try stream.flush(); const shared_key = switch (client_hello.key_share) { - .x25519_kyber768d00 => |ks| brk: { - const T = tls.NamedGroupT(.x25519_kyber768d00); - const pair: tls.X25519Kyber768Draft.KeyPair = key_pair.x25519_kyber768d00; - const shared_point = T.X25519.scalarmult( - ks.x25519, - pair.x25519.secret_key, - ) catch return stream.writeError(.decrypt_error); - // pair.kyber768d00.secret_key - // ks.kyber768d00 - const encaps = ks.kyber768d00.encaps(null).ciphertext; - - break :brk &(shared_point ++ encaps); - }, .x25519 => |ks| brk: { const shared_point = tls.NamedGroupT(.x25519).scalarmult( key_pair.x25519.secret_key, @@ -266,11 +261,10 @@ pub fn send_hello(self: *Self, client_hello: ClientHello) !void { ) catch return stream.writeError(.decrypt_error); break :brk &shared_point; }, - .secp256r1 => |ks| brk: { - const mul = ks.p.mulPublic( - key_pair.secp256r1.secret_key.bytes, - .big, - ) catch return stream.writeError(.decrypt_error); + inline .secp256r1, .secp384r1 => |ks, tag| brk: { + const key = @field(key_pair, @tagName(tag)); + const mul = ks.p.mulPublic(key.secret_key.bytes, .big) catch + return stream.writeError(.decrypt_error); break :brk &mul.affineCoordinates().x.toBytes(.big); }, else => return stream.writeError(.illegal_parameter), @@ -416,7 +410,7 @@ pub const ReadError = anyerror; pub const WriteError = anyerror; /// Reads next application_data message. -pub fn readv(self: *Self, buffers: []std.os.iovec) ReadError!usize { +pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { var stream = &self.stream; if (stream.eof()) return 0; @@ -432,7 +426,7 @@ pub fn readv(self: *Self, buffers: []std.os.iovec) ReadError!usize { return try self.stream.readv(buffers); } -pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { +pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { if (self.stream.eof()) return 0; const res = try self.stream.writev(iov); @@ -471,7 +465,7 @@ pub const Options = struct { /// A command to send or receive a single message. Allows deterministically /// testing `advance` on a single thread. pub const Command = union(enum) { - recv_hello: void, + recv_hello: Random, send_hello: ClientHello, send_change_cipher_spec: tls.SignatureScheme, send_encrypted_extensions: tls.SignatureScheme, @@ -481,6 +475,11 @@ pub const Command = union(enum) { recv_finished: void, none: void, + pub const Random = struct { + server_random: [32]u8, + keygen_seed: [tls.NamedGroupT(.secp384r1).KeyPair.seed_length]u8, + }; + pub const CertificateVerify = struct { scheme: tls.SignatureScheme, salt: [tls.MultiHash.max_digest_len]u8, @@ -495,6 +494,8 @@ pub const ClientHello = struct { key_share: tls.KeyShare, sig_scheme: tls.SignatureScheme, server_random: [32]u8, - /// active member MUST match `key_share` + /// Everything needed to generate a shared secret and send ciphertext to the client + /// so it can do the same. + /// Active member MUST match `key_share`. server_pair: tls.KeyPair, }; diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig index f006d5b3a40b..ad46d253977c 100644 --- a/lib/std/crypto/tls/Stream.zig +++ b/lib/std/crypto/tls/Stream.zig @@ -169,7 +169,7 @@ pub fn close(self: *Self) void { } /// Write bytes to `stream`, potentially flushing once `self.buffer` is full. -pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { +pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { const first = iov[0]; const bytes = first.iov_base[0..first.iov_len]; if (self.nocommit > 0) return bytes.len; @@ -244,7 +244,7 @@ pub fn arrayLength( /// Reads bytes from `view`, potentially reading more fragments from underlying `stream`. /// /// A return value of 0 indicates EOF. -pub fn readv(self: *Self, iov: []std.os.iovec) ReadError!usize { +pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { // > Any data received after a closure alert has been received MUST be ignored. if (self.eof()) return 0; diff --git a/lib/std/fifo.zig b/lib/std/fifo.zig index e02833a8cc63..6fe588bcfeea 100644 --- a/lib/std/fifo.zig +++ b/lib/std/fifo.zig @@ -232,7 +232,7 @@ pub fn LinearFifo( /// Same as `read` except it returns an error union /// The purpose of this function existing is to match `std.io.Reader` API. - fn readvFn(self: *Self, iov: []std.os.iovec) error{}!usize { + fn readvFn(self: *Self, iov: []const std.os.iovec) error{}!usize { var n_read: usize = 0; for (iov) |v| { const n = self.read(v.iov_base[0..v.iov_len]); @@ -328,7 +328,7 @@ pub fn LinearFifo( /// Same as `write` except it returns the number of bytes written, which is always the same /// as `bytes.len`. The purpose of this function existing is to match `std.io.Writer` API. - fn appendWritev(self: *Self, iov: []std.os.iovec_const) error{OutOfMemory}!usize { + fn appendWritev(self: *Self, iov: []const std.os.iovec_const) error{OutOfMemory}!usize { var written: usize = 0; for (iov) |v| { try self.write(v.iov_base[0..v.iov_len]); diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index a239ddefbfd7..58e957e00163 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -192,7 +192,7 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { /// Underlying socket - socket: net.Stream, + socket: net.Socket, /// TLS client. tls: tls.Client, @@ -225,7 +225,7 @@ pub const Connection = struct { pub inline fn stream(conn: *Connection) std.io.AnyStream { return switch (conn.protocol) { - .plain => conn.socket.any().any(), + .plain => conn.socket.stream().any(), .tls => conn.tls.any().any(), }; } @@ -293,7 +293,7 @@ pub const Connection = struct { return nread; } - pub fn readv(conn: *Connection, iov: []std.os.iovec) ReadError!usize { + pub fn readv(conn: *Connection, iov: []const std.os.iovec) ReadError!usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; return try conn.read(buffer); @@ -333,7 +333,7 @@ pub const Connection = struct { return buffer.len; } - pub fn writev(conn: *Connection, iov: []std.os.iovec_const) WriteError!usize { + pub fn writev(conn: *Connection, iov: []const std.os.iovec_const) WriteError!usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; return try conn.write(buffer); @@ -903,7 +903,7 @@ pub const Request = struct { return .{ .context = req }; } - fn transferReadv(req: *Request, iov: []std.os.iovec) TransferReadError!usize { + fn transferReadv(req: *Request, iov: []const std.os.iovec) TransferReadError!usize { if (req.response.parser.done) return 0; const first = iov[0]; const buf = first.iov_base[0..first.iov_len]; @@ -1102,7 +1102,7 @@ pub const Request = struct { } /// Reads data from the response body. Must be called after `wait`. - pub fn readv(req: *Request, iov: []std.os.iovec) ReadError!usize { + pub fn readv(req: *Request, iov: []const std.os.iovec) ReadError!usize { const out_index = switch (req.response.compression) { .deflate => |*deflate| deflate.readv(iov) catch return error.DecompressionFailure, .gzip => |*gzip| gzip.readv(iov) catch return error.DecompressionFailure, @@ -1132,7 +1132,7 @@ pub const Request = struct { /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. /// Must be called after `send` and before `finish`. - pub fn writev(req: *Request, iov: []std.os.iovec_const) WriteError!usize { + pub fn writev(req: *Request, iov: []const std.os.iovec_const) WriteError!usize { var iov_len: usize = 0; for (iov) |v| iov_len += v.iov_len; @@ -1142,7 +1142,7 @@ pub const Request = struct { if (iov_len > 0) { try w.print("{x}\r\n", .{iov_len}); - try w.writevAll(iov); + for (iov) |v| try w.writeAll(v.iov_base[0..v.iov_len]); try w.writeAll("\r\n"); } @@ -1344,7 +1344,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.free(conn.data.host); if (protocol == .tls) { - conn.data.tls = tls.Client.init(conn.data.socket.any().any(), .{ + conn.data.tls = tls.Client.init(conn.data.socket.stream().any(), .{ .ca_bundle = client.tls_options.ca_bundle, .cipher_suites = client.tls_options.cipher_suites, .host = host, @@ -1381,7 +1381,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti errdefer stream.close(); conn.data = .{ - .stream = stream.any(), + .stream = stream.stream(), .protocol = .plain, .host = try client.allocator.dupe(u8, path), diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index a540274ad7ba..0deb5a6d3370 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -615,7 +615,7 @@ pub const Request = struct { }; return .{ - .stream = request.server.connection.stream(), + .connection = request.server.connection, .send_buffer = options.send_buffer, .send_buffer_start = 0, .send_buffer_end = h.items.len, @@ -630,12 +630,12 @@ pub const Request = struct { }; } - pub const ReadError = anyerror || net.Stream.ReadError || error{ + pub const ReadError = anyerror || net.Socket.ReadError || error{ HttpChunkInvalid, HttpHeadersOversize, }; - fn read_cl(context: *const anyopaque, iov: []std.os.iovec) ReadError!usize { + fn read_cl(context: *const anyopaque, iov: []const std.os.iovec) ReadError!usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; const request: *Request = @constCast(@alignCast(@ptrCast(context))); @@ -665,7 +665,7 @@ pub const Request = struct { return s.read_buffer[head_end..s.read_buffer_len]; } - fn readv_chunked(context: *const anyopaque, iov: []std.os.iovec) ReadError!usize { + fn readv_chunked(context: *const anyopaque, iov: []const std.os.iovec) ReadError!usize { const first = iov[0]; const buffer = first.iov_base[0..first.iov_len]; const request: *Request = @constCast(@alignCast(@ptrCast(context))); @@ -836,7 +836,7 @@ pub const Request = struct { }; pub const Response = struct { - stream: std.io.AnyStream, + connection: net.Server.Connection, send_buffer: []u8, /// Index of the first byte in `send_buffer`. /// This is 0 unless a short write happens in `write`. @@ -860,14 +860,12 @@ pub const Response = struct { chunked, }; - pub const WriteError = net.Stream.WriteError; - /// When using content-length, asserts that the amount of data sent matches /// the value sent in the header, then calls `flush`. /// Otherwise, transfer-encoding: chunked is being used, and it writes the /// end-of-stream message, then flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { + pub fn end(r: *Response) !void { switch (r.transfer_encoding) { .content_length => |len| { assert(len == 0); // Trips when end() called before all bytes written. @@ -892,13 +890,13 @@ pub const Response = struct { /// flushes the stream to the system. /// Respects the value of `elide_body` to omit all data after the headers. /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { + pub fn endChunked(r: *Response, options: EndChunkedOptions) !void { assert(r.transfer_encoding == .chunked); try flush_chunked(r, options.trailers); r.* = undefined; } - fn write_cl(context: *const anyopaque, iov: []std.os.iovec_const) WriteError!usize { + fn write_cl(context: *const anyopaque, iov: []const std.os.iovec_const) !usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); const first = iov[0]; @@ -927,7 +925,7 @@ pub const Response = struct { .iov_len = bytes.len, }, }; - const n = try r.stream.writev(&iovecs); + const n = try r.connection.stream().writev(&iovecs); if (n >= send_buffer_len) { // It was enough to reset the buffer. @@ -951,7 +949,7 @@ pub const Response = struct { return bytes.len; } - fn write_chunked(context: *const anyopaque, iov: []std.os.iovec_const) WriteError!usize { + fn write_chunked(context: *const anyopaque, iov: []const std.os.iovec_const) !usize { const r: *Response = @constCast(@alignCast(@ptrCast(context))); assert(r.transfer_encoding == .chunked); @@ -991,7 +989,7 @@ pub const Response = struct { }; // TODO make this writev instead of writevAll, which involves // complicating the logic of this function. - try r.stream.writer().writevAll(&iovecs); + try r.connection.stream().writer().writevAll(&iovecs); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1008,20 +1006,20 @@ pub const Response = struct { /// Sends all buffered data to the client. /// This is redundant after calling `end`. /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { + pub fn flush(r: *Response) !void { switch (r.transfer_encoding) { .none, .content_length => return flush_cl(r), .chunked => return flush_chunked(r, null), } } - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writer().writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); + fn flush_cl(r: *Response) !void { + try r.connection.stream().writer().writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); r.send_buffer_start = 0; r.send_buffer_end = 0; } - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { + fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) !void { const max_trailers = 25; if (end_trailers) |trailers| assert(trailers.len <= max_trailers); assert(r.transfer_encoding == .chunked); @@ -1029,7 +1027,7 @@ pub const Response = struct { const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; if (r.elide_body) { - try r.stream.writer().writeAll(http_headers); + try r.connection.stream().writer().writeAll(http_headers); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; @@ -1110,7 +1108,7 @@ pub const Response = struct { iovecs_len += 1; } - try r.stream.writer().writevAll(iovecs[0..iovecs_len]); + try r.connection.stream().writer().writevAll(iovecs[0..iovecs_len]); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index a54f8bdb2077..88cb343aa912 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -133,8 +133,8 @@ test "HTTP server handles a chunked transfer coding request" { "\r\n"; const gpa = std.testing.allocator; - const tcp_stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); - const stream = tcp_stream.any(); + const socket = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = socket.stream(); defer stream.close(); try stream.writer().writeAll(request_bytes); @@ -270,10 +270,11 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { const request_bytes = "GET /foo HTTP/1.1\r\n\r\n"; const gpa = std.testing.allocator; - const tcp_stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); - const stream = tcp_stream.any(); + const socket = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = socket.stream(); defer stream.close(); try stream.writer().writeAll(request_bytes); + std.debug.print("requested\n", .{}); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); @@ -334,8 +335,8 @@ test "receiving arbitrary http headers from the client" { "aoeu: asdf \r\n" ++ "\r\n"; const gpa = std.testing.allocator; - const tcp_stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); - const stream = tcp_stream.any(); + const socket = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port()); + const stream = socket.stream(); defer stream.close(); try stream.writer().writeAll(request_bytes); diff --git a/lib/std/io.zig b/lib/std/io.zig index b77c385d651f..cd86a3c5ee69 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -80,17 +80,15 @@ pub fn GenericReader( /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - comptime readvFn: fn (context: Context, iov: []iovec) ReadError!usize, + comptime readvFn: fn (context: Context, iov: []const iovec) ReadError!usize, ) type { return struct { context: Context, pub const Error = ReadError; - pub const NoEofError = ReadError || error{ - EndOfStream, - }; + pub const NoEofError = ReadError || error{ EndOfStream }; - pub inline fn readv(self: Self, iov: []iovec) Error!usize { + pub inline fn readv(self: Self, iov: []const iovec) Error!usize { return readvFn(self.context, iov); } @@ -304,7 +302,7 @@ pub fn GenericReader( const Self = @This(); - fn typeErasedReadvFn(context: *const anyopaque, iov: []iovec) anyerror!usize { + fn typeErasedReadvFn(context: *const anyopaque, iov: []const iovec) anyerror!usize { const ptr: *const Context = @alignCast(@ptrCast(context)); return readvFn(ptr.*, iov); } @@ -314,7 +312,7 @@ pub fn GenericReader( pub fn GenericWriter( comptime Context: type, comptime WriteError: type, - comptime writevFn: fn (context: Context, iov: []iovec_const) WriteError!usize, + comptime writevFn: fn (context: Context, iov: []const iovec_const) WriteError!usize, ) type { return struct { context: Context, @@ -322,7 +320,7 @@ pub fn GenericWriter( const Self = @This(); pub const Error = WriteError; - pub inline fn writev(self: Self, iov: []iovec_const) Error!usize { + pub inline fn writev(self: Self, iov: []const iovec_const) Error!usize { return writevFn(self.context, iov); } @@ -369,7 +367,7 @@ pub fn GenericWriter( }; } - fn typeErasedWritevFn(context: *const anyopaque, iov: []iovec_const) anyerror!usize { + fn typeErasedWritevFn(context: *const anyopaque, iov: []const iovec_const) anyerror!usize { const ptr: *const Context = @alignCast(@ptrCast(context)); return writevFn(ptr.*, iov); } @@ -382,9 +380,9 @@ pub fn GenericStream( /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. - comptime readvFn: fn (context: Context, iov: []iovec) ReadError!usize, + comptime readvFn: fn (context: Context, iov: []const iovec) ReadError!usize, comptime WriteError: type, - comptime writevFn: fn (context: Context, iov: []iovec_const) WriteError!usize, + comptime writevFn: fn (context: Context, iov: []const iovec_const) WriteError!usize, comptime closeFn: fn (context: Context) void, ) type { return struct { @@ -485,7 +483,7 @@ pub const tty = @import("io/tty.zig"); pub const null_writer = @as(NullWriter, .{ .context = {} }); const NullWriter = Writer(void, error{}, dummyWritev); -fn dummyWritev(context: void, iov: []std.os.iovec_const) error{}!usize { +fn dummyWritev(context: void, iov: []const std.os.iovec_const) error{}!usize { _ = context; var written: usize = 0; for (iov) |v| written += v.iov_len; diff --git a/lib/std/io/Reader.zig b/lib/std/io/Reader.zig index f06a7be4a256..92fd16758094 100644 --- a/lib/std/io/Reader.zig +++ b/lib/std/io/Reader.zig @@ -8,14 +8,14 @@ const native_endian = @import("builtin").target.cpu.arch.endian(); const iovec = std.os.iovec; context: *const anyopaque, -readvFn: *const fn (context: *const anyopaque, iov: []iovec) anyerror!usize, +readvFn: *const fn (context: *const anyopaque, iov: []const iovec) anyerror!usize, pub const Error = anyerror; /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. -pub fn readv(self: Self, iov: []iovec) anyerror!usize { +pub fn readv(self: Self, iov: []const iovec) anyerror!usize { return self.readvFn(self.context, iov); } diff --git a/lib/std/io/Stream.zig b/lib/std/io/Stream.zig index a8658fd7d7dd..cafcda5bb64d 100644 --- a/lib/std/io/Stream.zig +++ b/lib/std/io/Stream.zig @@ -6,21 +6,21 @@ const iovec = os.iovec; const iovec_const = os.iovec_const; context: *const anyopaque, -readvFn: *const fn (context: *const anyopaque, iov: []iovec) anyerror!usize, -writevFn: *const fn (context: *const anyopaque, iov: []iovec_const) anyerror!usize, +readvFn: *const fn (context: *const anyopaque, iov: []const iovec) anyerror!usize, +writevFn: *const fn (context: *const anyopaque, iov: []const iovec_const) anyerror!usize, closeFn: *const fn (context: *const anyopaque) void, const Self = @This(); pub const Error = anyerror; -pub fn writev(self: Self, iov: []iovec_const) anyerror!usize { +pub fn writev(self: Self, iov: []const iovec_const) anyerror!usize { return self.writevFn(self.context, iov); } /// Returns the number of bytes read. It may be less than buffer.len. /// If the number of bytes read is 0, it means end of stream. /// End of stream is not an error condition. -pub fn readv(self: Self, iov: []iovec) anyerror!usize { +pub fn readv(self: Self, iov: []const iovec) anyerror!usize { return self.readvFn(self.context, iov); } diff --git a/lib/std/io/Writer.zig b/lib/std/io/Writer.zig index 383120c4a16b..a8015b59ddfa 100644 --- a/lib/std/io/Writer.zig +++ b/lib/std/io/Writer.zig @@ -4,12 +4,12 @@ const mem = std.mem; const iovec_const = std.os.iovec_const; context: *const anyopaque, -writevFn: *const fn (context: *const anyopaque, iov: []iovec_const) anyerror!usize, +writevFn: *const fn (context: *const anyopaque, iov: []const iovec_const) anyerror!usize, const Self = @This(); pub const Error = anyerror; -pub fn writev(self: Self, iov: []iovec_const) anyerror!usize { +pub fn writev(self: Self, iov: []const iovec_const) anyerror!usize { return self.writevFn(self.context, iov); } diff --git a/lib/std/io/buffered_reader.zig b/lib/std/io/buffered_reader.zig index a8e5193f7b78..1e0b78b7597b 100644 --- a/lib/std/io/buffered_reader.zig +++ b/lib/std/io/buffered_reader.zig @@ -16,7 +16,7 @@ pub fn BufferedReader(comptime buffer_size: usize, comptime ReaderType: type) ty const Self = @This(); - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const first = iov[0]; const dest = first.iov_base[0..first.iov_len]; var dest_index: usize = 0; @@ -71,7 +71,7 @@ test "OneByte" { }; } - fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const first = iov[0]; const dest = first.iov_base[0..first.iov_len]; if (self.str.len <= self.curr or dest.len == 0) diff --git a/lib/std/io/buffered_tee.zig b/lib/std/io/buffered_tee.zig index bc817ddb69b5..7fbb64f34575 100644 --- a/lib/std/io/buffered_tee.zig +++ b/lib/std/io/buffered_tee.zig @@ -40,7 +40,7 @@ pub fn BufferedTee( const Self = @This(); - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const first = iov[0]; const dest = first.iov_base[0..first.iov_len]; var dest_index: usize = 0; @@ -164,7 +164,7 @@ test "OneByte" { }; } - fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const first = iov[0]; const dest = first.iov_base[0..first.iov_len]; if (self.str.len <= self.curr or dest.len == 0) diff --git a/lib/std/io/buffered_writer.zig b/lib/std/io/buffered_writer.zig index b3518d69f553..6daaac8427e0 100644 --- a/lib/std/io/buffered_writer.zig +++ b/lib/std/io/buffered_writer.zig @@ -23,7 +23,7 @@ pub fn BufferedWriter(comptime buffer_size: usize, comptime WriterType: type) ty return .{ .context = self }; } - pub fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { var written: usize = 0; for (iov) |v| { const bytes = v.iov_base[0..v.iov_len]; diff --git a/lib/std/io/counting_reader.zig b/lib/std/io/counting_reader.zig index 72cd10f1ea3e..30a99bac9dcf 100644 --- a/lib/std/io/counting_reader.zig +++ b/lib/std/io/counting_reader.zig @@ -13,7 +13,7 @@ pub fn CountingReader(comptime ReaderType: anytype) type { const Self = @This(); - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { const amt = try self.child_reader.readv(iov); self.bytes_read += amt; return amt; diff --git a/lib/std/io/counting_writer.zig b/lib/std/io/counting_writer.zig index 989fbcbab614..a663423c3348 100644 --- a/lib/std/io/counting_writer.zig +++ b/lib/std/io/counting_writer.zig @@ -13,7 +13,7 @@ pub fn CountingWriter(comptime WriterType: type) type { const Self = @This(); - pub fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { const amt = try self.child_stream.writev(iov); self.bytes_written += amt; return amt; diff --git a/lib/std/io/fixed_buffer_stream.zig b/lib/std/io/fixed_buffer_stream.zig index e455e591c454..6a48ba32df98 100644 --- a/lib/std/io/fixed_buffer_stream.zig +++ b/lib/std/io/fixed_buffer_stream.zig @@ -44,7 +44,7 @@ pub fn FixedBufferStream(comptime Buffer: type) type { return .{ .context = self }; } - pub fn readv(self: *Self, iov: []std.os.iovec) ReadError!usize { + pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { var read: usize = 0; for (iov) |v| { const size = @min(v.iov_len, self.buffer.len - self.pos); @@ -62,7 +62,7 @@ pub fn FixedBufferStream(comptime Buffer: type) type { /// buffer is full. Returns `error.NoSpaceLeft` when no bytes would be written. /// Note: `error.NoSpaceLeft` matches the corresponding error from /// `std.fs.File.WriteError`. - pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { var written: usize = 0; for (iov) |v| { if (v.iov_len == 0) continue; diff --git a/lib/std/io/limited_reader.zig b/lib/std/io/limited_reader.zig index b20018e6e717..7b14557284bf 100644 --- a/lib/std/io/limited_reader.zig +++ b/lib/std/io/limited_reader.zig @@ -13,12 +13,17 @@ pub fn LimitedReader(comptime ReaderType: type) type { const Self = @This(); - pub fn readv(self: *Self, iov: []std.os.iovec) Error!usize { - for (iov) |*v| { - v.iov_len = @min(self.bytes_left, v.iov_len); - self.bytes_left -= v.iov_len; - } - return try self.inner_reader.readv(iov); + pub fn read(self: *Self, dest: []u8) Error!usize { + const max_read = @min(self.bytes_left, dest.len); + const n = try self.inner_reader.read(dest[0..max_read]); + self.bytes_left -= n; + return n; + } + + pub fn readv(self: *Self, iov: []const std.os.iovec) Error!usize { + const first = iov[0]; + const buf = first.iov_base[0..first.iov_len]; + return try self.read(buf); } pub fn reader(self: *Self) Reader { diff --git a/lib/std/io/multi_writer.zig b/lib/std/io/multi_writer.zig index 9cd4600e634c..3ae7c98931c2 100644 --- a/lib/std/io/multi_writer.zig +++ b/lib/std/io/multi_writer.zig @@ -15,16 +15,16 @@ pub fn MultiWriter(comptime Writers: type) type { streams: Writers, pub const Error = ErrSet; - pub const Writer = io.Writer(*Self, Error, write); + pub const Writer = io.Writer(*Self, Error, writev); pub fn writer(self: *Self) Writer { return .{ .context = self }; } - pub fn write(self: *Self, bytes: []const u8) Error!usize { - inline for (self.streams) |stream| - try stream.writeAll(bytes); - return bytes.len; + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + inline for (self.streams) |stream| written = try stream.writev(iov); + return written; } }; } diff --git a/lib/std/io/stream_source.zig b/lib/std/io/stream_source.zig index f4fcc5ad31c2..5234908f0a7c 100644 --- a/lib/std/io/stream_source.zig +++ b/lib/std/io/stream_source.zig @@ -38,7 +38,7 @@ pub const StreamSource = union(enum) { getEndPos, ); - pub fn readv(self: *StreamSource, iov: []std.os.iovec) ReadError!usize { + pub fn readv(self: *StreamSource, iov: []const std.os.iovec) ReadError!usize { switch (self.*) { .buffer => |*x| return x.readv(iov), .const_buffer => |*x| return x.readv(iov), @@ -46,7 +46,7 @@ pub const StreamSource = union(enum) { } } - pub fn writev(self: *StreamSource, iov: []std.os.iovec_const) WriteError!usize { + pub fn writev(self: *StreamSource, iov: []const std.os.iovec_const) WriteError!usize { switch (self.*) { .buffer => |*x| return x.writev(iov), .const_buffer => return error.AccessDenied, diff --git a/lib/std/json/stringify_test.zig b/lib/std/json/stringify_test.zig index 437ca2ddbca9..a00f6c114669 100644 --- a/lib/std/json/stringify_test.zig +++ b/lib/std/json/stringify_test.zig @@ -314,7 +314,7 @@ fn testStringify(expected: []const u8, value: anytype, options: StringifyOptions return .{ .context = self }; } - fn writev(self: *Self, iov: []std.os.iovec_const) Error!usize { + fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { const first = iov[0]; const bytes = first.iov_base[0..first.iov_len]; if (self.expected_remaining.len < bytes.len) { diff --git a/lib/std/net.zig b/lib/std/net.zig index c209c5ac4096..23f832ba062e 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -709,26 +709,26 @@ pub const Ip6Address = extern struct { } }; -pub fn connectUnixSocket(path: []const u8) !Stream { +pub fn connectUnixSocket(path: []const u8) !Socket { const opt_non_block = 0; const sockfd = try os.socket( os.AF.UNIX, os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block, 0, ); - errdefer Stream.close(.{ .handle = sockfd }); + errdefer Socket.close(.{ .handle = sockfd }); var addr = try std.net.Address.initUnix(path); try os.connect(sockfd, &addr.any, addr.getOsSockLen()); - return Stream{ .handle = sockfd }; + return Socket{ .handle = sockfd }; } fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 { if (builtin.target.os.tag == .linux) { var ifr: os.ifreq = undefined; const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0); - defer Stream.close(.{ .handle = sockfd }); + defer Socket.close(.{ .handle = sockfd }); @memcpy(ifr.ifrn.name[0..name.len], name); ifr.ifrn.name[name.len] = 0; @@ -773,7 +773,7 @@ pub const AddressList = struct { pub const TcpConnectToHostError = GetAddressListError || TcpConnectToAddressError; /// All memory allocated with `allocator` will be freed before this function returns. -pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Stream { +pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) TcpConnectToHostError!Socket { const list = try getAddressList(allocator, name, port); defer list.deinit(); @@ -792,16 +792,16 @@ pub fn tcpConnectToHost(allocator: mem.Allocator, name: []const u8, port: u16) T pub const TcpConnectToAddressError = std.os.SocketError || std.os.ConnectError; -pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream { +pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Socket { const nonblock = 0; const sock_flags = os.SOCK.STREAM | nonblock | (if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC); const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP); - errdefer Stream.close(.{ .handle = sockfd }); + errdefer Socket.close(.{ .handle = sockfd }); try os.connect(sockfd, &address.any, address.getOsSockLen()); - return Stream{ .handle = sockfd }; + return Socket{ .handle = sockfd }; } const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || std.os.SetSockOptError || error{ @@ -1127,7 +1127,7 @@ fn linuxLookupName( var prefixlen: i32 = 0; const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC; if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: { - defer Stream.close(.{ .handle = fd }); + defer Socket.close(.{ .handle = fd }); os.connect(fd, da, dalen) catch break :syscalls; key |= DAS_USABLE; os.getsockname(fd, sa, &salen) catch break :syscalls; @@ -1612,7 +1612,7 @@ fn resMSendRc( }, else => |e| return e, }; - defer Stream.close(.{ .handle = fd }); + defer Socket.close(.{ .handle = fd }); // Past this point, there are no errors. Each individual query will // yield either no reply (indicated by zero length) or an answer @@ -1787,12 +1787,12 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) } } -pub const Stream = struct { +pub const Socket = struct { /// Underlying platform-defined type which may or may not be /// interchangeable with a file system file descriptor. handle: posix.socket_t, - pub fn close(s: Stream) void { + pub fn close(s: Socket) void { switch (builtin.os.tag) { .windows => std.os.windows.closesocket(s.handle) catch unreachable, else => posix.close(s.handle), @@ -1801,13 +1801,13 @@ pub const Stream = struct { pub const ReadError = os.ReadError; pub const WriteError = os.WriteError; - pub const GenericStream = io.GenericStream(Stream, ReadError, readv, WriteError, writev, close); + pub const GenericStream = io.GenericStream(Socket, ReadError, readv, WriteError, writev, close); - pub fn any(self: Stream) GenericStream { + pub fn stream(self: Socket) GenericStream { return .{ .context = self }; } - pub fn readv(s: Stream, iovecs: []const os.iovec) ReadError!usize { + pub fn readv(s: Socket, iovecs: []const os.iovec) ReadError!usize { if (builtin.os.tag == .windows) { // TODO improve this to use ReadFileScatter if (iovecs.len == 0) return @as(usize, 0); @@ -1820,26 +1820,26 @@ pub const Stream = struct { /// See https://github.com/ziglang/zig/issues/7699 /// See equivalent function: `std.fs.File.writev`. - pub fn writev(self: Stream, iovecs: []const os.iovec_const) WriteError!usize { + pub fn writev(self: Socket, iovecs: []const os.iovec_const) WriteError!usize { return os.writev(self.handle, iovecs); } }; pub const Server = struct { listen_address: Address, - stream: Stream, + stream: Socket, pub const Connection = struct { address: Address, protocol: Protocol, - socket: Stream, + socket: Socket, tls: tls.Server, pub const Protocol = enum { plain, tls }; pub inline fn stream(conn: *Connection) std.io.AnyStream { return switch (conn.protocol) { - .plain => conn.socket.any().any(), + .plain => conn.socket.stream().any(), .tls => conn.tls.any().any(), }; } @@ -1857,10 +1857,10 @@ pub const Server = struct { var accepted_addr: Address = undefined; var addr_len: posix.socklen_t = @sizeOf(Address); const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC); - const socket = Stream{ .handle = fd }; + const socket = Socket{ .handle = fd }; const protocol: Connection.Protocol = if (tls_options == null) .plain else .tls; const _tls: tls.Server = if (tls_options) |options| - try tls.Server.init(socket.any().any(), options) + try tls.Server.init(socket.stream().any(), options) else undefined; return .{ @@ -1875,6 +1875,6 @@ pub const Server = struct { test { _ = @import("net/test.zig"); _ = Server; - _ = Stream; + _ = Socket; _ = Address; } diff --git a/lib/std/tar.zig b/lib/std/tar.zig index 2782fc1303aa..70bdfaad1cb3 100644 --- a/lib/std/tar.zig +++ b/lib/std/tar.zig @@ -295,13 +295,13 @@ pub fn Iterator(comptime ReaderType: type) type { unread_bytes: *u64, parent_reader: ReaderType, - pub const Reader = std.io.Reader(File, ReaderType.Error, File.readv); + pub const Reader = std.io.Reader(File, ReaderType.Error, readv); pub fn reader(self: File) Reader { return .{ .context = self }; } - pub fn readv(self: File, iov: []std.os.iovec) ReaderType.Error!usize { + pub fn readv(self: File, iov: []const std.os.iovec) ReaderType.Error!usize { var n_read: usize = 0; for (iov) |v| { const dest = v.iov_base[0..v.iov_len]; diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index b8f9f98d4cef..98b2b2fc2d3e 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -3355,7 +3355,7 @@ fn AutoIndentingStream(comptime UnderlyingWriter: type) type { return .{ .context = self }; } - pub fn writev(self: *Self, iov: []std.os.iovec_const) WriteError!usize { + pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { var n_written: usize = 0; for (iov) |v| { if (v.iov_len == 0) return n_written; diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 93a6868603c6..70ba9164a924 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -807,16 +807,16 @@ const Resource = union(enum) { fn reader(resource: *Resource) std.io.AnyReader { return .{ .context = resource, - .readFn = read, + .readvFn = readv, }; } - fn read(context: *const anyopaque, buffer: []u8) anyerror!usize { + fn readv(context: *const anyopaque, iov: []const std.os.iovec) anyerror!usize { const resource: *Resource = @constCast(@ptrCast(@alignCast(context))); switch (resource.*) { - .file => |*f| return f.read(buffer), - .http_request => |*r| return r.read(buffer), - .git => |*g| return g.fetch_stream.read(buffer), + .file => |*f| return f.readv(iov), + .http_request => |*r| return r.readv(iov), + .git => |*g| return g.fetch_stream.reader().readv(iov), .dir => unreachable, } } @@ -1105,14 +1105,14 @@ fn unpackResource( .tar => try unpackTarball(f, tmp_directory.handle, resource.reader()), .@"tar.gz" => { const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var br = std.io.bufferedReaderSize(4096, reader); var dcp = std.compress.gzip.decompressor(br.reader()); try unpackTarball(f, tmp_directory.handle, dcp.reader()); }, .@"tar.xz" => { const gpa = f.arena.child_allocator; const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var br = std.io.bufferedReaderSize(4096, reader); var dcp = std.compress.xz.decompress(gpa, br.reader()) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to decompress tarball: {s}", @@ -1126,7 +1126,7 @@ fn unpackResource( const window_size = std.compress.zstd.DecompressorOptions.default_window_buffer_len; const window_buffer = try f.arena.allocator().create([window_size]u8); const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); + var br = std.io.bufferedReaderSize(4096, reader); var dcp = std.compress.zstd.decompressor(br.reader(), .{ .window_buffer = window_buffer, }); diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index 36652bd88c55..7e97a35c6c65 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -10,6 +10,8 @@ const testing = std.testing; const Allocator = mem.Allocator; const Sha1 = std.crypto.hash.Sha1; const assert = std.debug.assert; +const hashedWriter = std.compress.hashedWriter; +const hashedReader = std.compress.hashedReader; pub const oid_length = Sha1.digest_length; pub const fmt_oid_length = 2 * oid_length; @@ -667,7 +669,7 @@ pub const Session = struct { errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; try request.send(.{}); - try request.writeAll(body.items); + try request.writer().writeAll(body.items); try request.finish(); try request.wait(); @@ -772,7 +774,7 @@ pub const Session = struct { errdefer request.deinit(); request.transfer_encoding = .{ .content_length = body.items.len }; try request.send(.{}); - try request.writeAll(body.items); + try request.writer().writeAll(body.items); try request.finish(); try request.wait(); @@ -819,7 +821,7 @@ pub const Session = struct { ProtocolError, UnexpectedPacket, }; - pub const Reader = std.io.Reader(*FetchStream, ReadError, read); + pub const Reader = std.io.Reader(*FetchStream, ReadError, readv); const StreamCode = enum(u8) { pack_data = 1, @@ -857,6 +859,12 @@ pub const Session = struct { return size; } }; + + pub fn readv(stream: *FetchStream, iov: []const std.os.iovec) !usize { + const first = iov[0]; + const buf = first.iov_base[0..first.iov_len]; + return try stream.read(buf); + } }; const PackHeader = struct { @@ -1113,7 +1121,7 @@ fn indexPackFirstPass( ) ![Sha1.digest_length]u8 { var pack_buffered_reader = std.io.bufferedReader(pack.reader()); var pack_counting_reader = std.io.countingReader(pack_buffered_reader.reader()); - var pack_hashed_reader = std.compress.hashedReader(pack_counting_reader.reader(), Sha1.init(.{})); + var pack_hashed_reader = hashedReader(pack_counting_reader.reader(), Sha1.init(.{})); const pack_reader = pack_hashed_reader.reader(); const pack_header = try PackHeader.read(pack_reader); @@ -1121,7 +1129,7 @@ fn indexPackFirstPass( var current_entry: u32 = 0; while (current_entry < pack_header.total_objects) : (current_entry += 1) { const entry_offset = pack_counting_reader.bytes_read; - var entry_crc32_reader = std.compress.hashedReader(pack_reader, std.hash.Crc32.init()); + var entry_crc32_reader = hashedReader(pack_reader, std.hash.Crc32.init()); const entry_header = try EntryHeader.read(entry_crc32_reader.reader()); switch (entry_header) { .commit, .tree, .blob, .tag => |object| { @@ -1325,36 +1333,6 @@ fn expandDelta(base_object: anytype, delta_reader: anytype, writer: anytype) !vo } } -fn HashedWriter( - comptime WriterType: anytype, - comptime HasherType: anytype, -) type { - return struct { - child_writer: WriterType, - hasher: HasherType, - - const Error = WriterType.Error; - const Writer = std.io.Writer(*@This(), Error, write); - - fn write(hashed_writer: *@This(), buf: []const u8) Error!usize { - const amt = try hashed_writer.child_writer.write(buf); - hashed_writer.hasher.update(buf); - return amt; - } - - fn writer(hashed_writer: *@This()) Writer { - return .{ .context = hashed_writer }; - } - }; -} - -fn hashedWriter( - writer: anytype, - hasher: anytype, -) HashedWriter(@TypeOf(writer), @TypeOf(hasher)) { - return .{ .child_writer = writer, .hasher = hasher }; -} - test "packfile indexing and checkout" { // To verify the contents of this packfile without using the code in this // file: diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 10795529f976..66a37a46e03c 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -7500,12 +7500,12 @@ const ArrayListWriter = ErrorOnlyGenericWriter(std.ArrayList(u8).Writer.Error); fn arrayListWriter(list: *std.ArrayList(u8)) ArrayListWriter { return .{ .context = .{ .context = list, - .writeFn = struct { - fn write(context: *const anyopaque, bytes: []const u8) anyerror!usize { + .writevFn = struct { + fn writev(context: *const anyopaque, iov: []const std.os.iovec_const) anyerror!usize { const l: *std.ArrayList(u8) = @alignCast(@constCast(@ptrCast(context))); - return l.writer().write(bytes); + return l.writer().writev(iov); } - }.write, + }.writev, } }; } @@ -7524,25 +7524,28 @@ fn IndentWriter(comptime UnderlyingWriter: type) type { pub fn writer(self: *Self) Writer { return .{ .context = .{ .context = self, - .writeFn = writeAny, + .writevFn = writevAny, } }; } - pub fn write(self: *Self, bytes: []const u8) Error!usize { - if (bytes.len == 0) return @as(usize, 0); - - const current_indent = self.indent_count * Self.indent_delta; - if (self.current_line_empty and current_indent > 0) { - try self.underlying_writer.writeByteNTimes(' ', current_indent); + pub fn writev(self: *Self, iov: []const std.os.iovec_const) Error!usize { + var written: usize = 0; + for (iov) |v| { + const bytes = v.iov_base[0..v.iov_len]; + const current_indent = self.indent_count * Self.indent_delta; + if (self.current_line_empty and current_indent > 0) { + try self.underlying_writer.writeByteNTimes(' ', current_indent); + } + self.current_line_empty = false; + written += try self.writeNoIndent(bytes); } - self.current_line_empty = false; - return self.writeNoIndent(bytes); + return written; } - fn writeAny(context: *const anyopaque, bytes: []const u8) anyerror!usize { + fn writevAny(context: *const anyopaque, iov: []const std.os.iovec_const) anyerror!usize { const self: *Self = @alignCast(@constCast(@ptrCast(context))); - return self.write(bytes); + return self.writev(iov); } pub fn insertNewline(self: *Self) Error!void { @@ -7576,10 +7579,10 @@ fn IndentWriter(comptime UnderlyingWriter: type) type { /// maintaining ease of error handling. fn ErrorOnlyGenericWriter(comptime Error: type) type { return std.io.GenericWriter(std.io.AnyWriter, Error, struct { - fn write(context: std.io.AnyWriter, bytes: []const u8) Error!usize { - return @errorCast(context.write(bytes)); + fn writev(context: std.io.AnyWriter, iov: []const std.os.iovec_const) Error!usize { + return @errorCast(context.writev(iov)); } - }.write); + }.writev); } fn toCIntBits(zig_bits: u32) ?u32 { diff --git a/src/main.zig b/src/main.zig index db76f7605ca6..04a4efe1e528 100644 --- a/src/main.zig +++ b/src/main.zig @@ -3373,13 +3373,13 @@ fn buildOutputType( }); defer server.deinit(); - const conn = try server.accept(); - defer conn.stream.close(); + var conn = try server.accept(null); + defer conn.stream().close(); try serve( comp, - .{ .handle = conn.stream.handle }, - .{ .handle = conn.stream.handle }, + .{ .handle = conn.socket.handle }, + .{ .handle = conn.socket.handle }, test_exec_args.items, self_exe_path, arg_mode, From 4a5dc020528a7a1f286d1890c78cfe84564820de Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Tue, 19 Mar 2024 16:03:09 -0400 Subject: [PATCH 14/17] fix http client wait --- lib/std/http/Client.zig | 6 +++++- lib/std/http/test.zig | 2 +- lib/std/net/test.zig | 4 ++-- lib/std/os/test.zig | 2 +- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 58e957e00163..91a21940c029 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1011,7 +1011,11 @@ pub const Request = struct { // skip the body of the redirect response, this will at least // leave the connection in a known good state. req.response.skip = true; - assert(try req.transferReadv(&.{}) == 0); // we're skipping, no buffer is necessary + var buf: [0]u8 = undefined; + // we're skipping, no buffer is necessary + const iovecs = [_]std.os.iovec{ .{ .iov_base = &buf, .iov_len = 0 } }; + const n_read = try req.transferReadv(&iovecs); + assert(n_read == 0); if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 88cb343aa912..69022a5e9e86 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -835,7 +835,7 @@ test "general client/server API coverage" { // connection has been kept alive try expect(client.http_proxy != null or client.connection_pool.free_len == 1); - { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** + { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); defer gpa.free(location); const uri = try std.Uri.parse(location); diff --git a/lib/std/net/test.zig b/lib/std/net/test.zig index ccbde4e93627..1f98adc38de3 100644 --- a/lib/std/net/test.zig +++ b/lib/std/net/test.zig @@ -189,7 +189,7 @@ test "listen on a port, send bytes, receive bytes" { const socket = try net.tcpConnectToAddress(server_address); defer socket.close(); - _ = try socket.any().writer().writeAll("Hello world!"); + _ = try socket.stream().writer().writeAll("Hello world!"); } }; @@ -280,7 +280,7 @@ test "listen on a unix socket, send bytes, receive bytes" { const socket = try net.connectUnixSocket(path); defer socket.close(); - _ = try socket.any().writer().writeAll("Hello world!"); + _ = try socket.stream().writer().writeAll("Hello world!"); } }; diff --git a/lib/std/os/test.zig b/lib/std/os/test.zig index d9f05d94ef93..b93a5d6d9439 100644 --- a/lib/std/os/test.zig +++ b/lib/std/os/test.zig @@ -819,7 +819,7 @@ test "shutdown socket" { error.SocketNotConnected => {}, else => |e| return e, }; - std.net.Stream.close(.{ .handle = sock }); + std.net.Socket.close(.{ .handle = sock }); } test "sigaction" { From f3f8a5208d844d17bb0adcdf482af5679d348569 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Tue, 19 Mar 2024 17:09:54 -0400 Subject: [PATCH 15/17] add untested ssl key log file --- lib/std/crypto/tls.zig | 94 ++++++++++++++++++++++++++--------- lib/std/crypto/tls/Client.zig | 60 +++++++++++++--------- lib/std/crypto/tls/Server.zig | 13 ++++- lib/std/crypto/tls/Stream.zig | 2 +- lib/std/http/Client.zig | 4 +- lib/std/http/Server.zig | 2 +- lib/std/http/test.zig | 2 +- lib/std/io.zig | 2 +- 8 files changed, 124 insertions(+), 55 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 1d4291ba6d6a..79846bd43079 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -673,7 +673,13 @@ pub const HandshakeCipher = union(CipherSuite) { const Self = @This(); - pub fn init(suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8) Error!Self { + pub fn init( + suite: CipherSuite, + shared_key: []const u8, + hello_hash: []const u8, + logger: std.io.AnyWriter, + client_random: []const u8, + ) Error!Self { switch (suite) { inline .aes_128_gcm_sha256, .aes_256_gcm_sha384, @@ -683,11 +689,10 @@ pub const HandshakeCipher = union(CipherSuite) { => |tag| { var res = @unionInit(Self, @tagName(tag), .{ .handshake_secret = undefined, - .master_secret = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, .client_key = undefined, .server_key = undefined, + .client_finished_key = undefined, + .server_finished_key = undefined, .client_iv = undefined, .server_iv = undefined, }); @@ -700,10 +705,13 @@ pub const HandshakeCipher = union(CipherSuite) { const derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); p.handshake_secret = P.Hkdf.extract(&derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", hello_hash, P.Hash.digest_length); + + // Not being able to log our secrets shouldn't prevent the handshake from continuing. + writeKeyLogEntry(logger, "CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_random, &client_secret) catch {}; + writeKeyLogEntry(logger, "SERVER_HANDSHAKE_TRAFFIC_SECRET", client_random, &server_secret) catch {}; + p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); @@ -733,7 +741,12 @@ pub const ApplicationCipher = union(CipherSuite) { const Self = @This(); - pub fn init(handshake_cipher: HandshakeCipher, handshake_hash: []const u8) Self { + pub fn init( + handshake_cipher: HandshakeCipher, + handshake_hash: []const u8, + logger: std.io.AnyWriter, + client_random: []const u8, + ) Self { switch (handshake_cipher) { inline .aes_128_gcm_sha256, .aes_256_gcm_sha384, @@ -759,6 +772,11 @@ pub const ApplicationCipher = union(CipherSuite) { const master_secret = P.Hkdf.extract(&derived_secret, &zeroes); p.client_secret = hkdfExpandLabel(P.Hkdf, master_secret, "c ap traffic", handshake_hash, P.Hash.digest_length); p.server_secret = hkdfExpandLabel(P.Hkdf, master_secret, "s ap traffic", handshake_hash, P.Hash.digest_length); + + // Not being able to log our secrets shouldn't prevent the handshake from continuing. + writeKeyLogEntry(logger, "CLIENT_TRAFFIC_SECRET_0", client_random, &p.client_secret) catch {}; + writeKeyLogEntry(logger, "SERVER_TRAFFIC_SECRET_0", client_random, &p.server_secret) catch {}; + p.client_key = hkdfExpandLabel(P.Hkdf, p.client_secret, "key", "", P.AEAD.key_length); p.server_key = hkdfExpandLabel(P.Hkdf, p.server_secret, "key", "", P.AEAD.key_length); p.client_iv = hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); @@ -1032,14 +1050,20 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + // Later used in ApplicationCipher.init handshake_secret: [Hkdf.prk_length]u8, - master_secret: [Hkdf.prk_length]u8, - client_finished_key: [Hmac.key_length]u8, - server_finished_key: [Hmac.key_length]u8, + // For encrypting/decrypting handshake messages client_key: [AEAD.key_length]u8, server_key: [AEAD.key_length]u8, + // For generating handshake finished messages + client_finished_key: [Hmac.key_length]u8, + server_finished_key: [Hmac.key_length]u8, + // Used as a nonce for encrypting/decrypting handshake messages + // iv = initialization vector client_iv: [AEAD.nonce_length]u8, server_iv: [AEAD.nonce_length]u8, + + // m0aR s3cUr1tY! read_seq: usize = 0, write_seq: usize = 0, @@ -1089,12 +1113,18 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + // Used to derive new keys and iv's in key_update messages client_secret: [Hash.digest_length]u8, server_secret: [Hash.digest_length]u8, + // For encrypting/decrypting application data messages client_key: [AEAD.key_length]u8, server_key: [AEAD.key_length]u8, + // Used as a nonce for encrypting/decrypting application data messages + // iv = initialization vector client_iv: [AEAD.nonce_length]u8, server_iv: [AEAD.nonce_length]u8, + + // m0aR s3cUr1tY! read_seq: usize = 0, write_seq: usize = 0, @@ -1307,14 +1337,29 @@ test "tls client and server handshake, data, and close_notify" { defer inner_stream.deinit(allocator); const stream = inner_stream.stream(); + // Use these seeded values for reproducible handshake and application ciphertext. + const session_id: [32]u8 = ("session_id012345" ** 2).*; + const client_random: [32]u8 = ("client_random012" ** 2).*; + const server_random: [32]u8 = ("server_random012" ** 2).*; + const client_key_seed: [32]u8 = ("client_seed01234" ** 2).*; + const server_keygen_seed: [48]u8 = ("server_seed01234" ** 3).*; + const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; + + const stdout = std.io.getStdOut(); var client_transcript: MultiHash = .{}; var client = Client{ + .random = client_random, .stream = Stream{ .stream = stream.any(), .is_client = true, .transcript_hash = &client_transcript, }, - .options = .{ .host = "localhost", .ca_bundle = null, .allocator = allocator }, + .options = .{ + .host = "localhost", + .ca_bundle = null, + .allocator = allocator, + .key_log = stdout.writer().any(), + }, }; const server_cert = @embedFile("./testdata/cert.der"); @@ -1337,16 +1382,7 @@ test "tls client and server handshake, data, and close_notify" { }, }; - // Use these seeded values for reproducible handshake and application ciphertext. - const session_id: [32]u8 = ("session_id012345" ** 2).*; - const client_random: [32]u8 = ("client_random012" ** 2).*; - const server_random: [32]u8 = ("server_random012" ** 2).*; - const client_key_seed: [32]u8 = ("client_seed01234" ** 2).*; - const server_keygen_seed: [48]u8 = ("server_seed01234" ** 3).*; - const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; - const key_pairs = try Client.KeyPairs.initAdvanced( - client_random, session_id, client_key_seed, client_key_seed ++ [_]u8{0} ** (48 - 32), @@ -1372,8 +1408,6 @@ test "tls client and server handshake, data, and close_notify" { const s = server.stream.cipher.handshake.aes_256_gcm_sha384; const c = client.stream.cipher.handshake.aes_256_gcm_sha384; - try std.testing.expectEqualSlices(u8, &s.handshake_secret, &c.handshake_secret); - try std.testing.expectEqualSlices(u8, &s.master_secret, &c.master_secret); try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); @@ -1413,8 +1447,6 @@ test "tls client and server handshake, data, and close_notify" { const s = server.stream.cipher.application.aes_256_gcm_sha384; const c = client.stream.cipher.application.aes_256_gcm_sha384; - try std.testing.expectEqualSlices(u8, &s.client_secret, &c.client_secret); - try std.testing.expectEqualSlices(u8, &s.server_secret, &c.server_secret); try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); @@ -1443,3 +1475,17 @@ pub fn debugPrint(name: []const u8, slice: anytype) void { } std.debug.print("\n", .{}); } + +pub fn writeKeyLogEntry( + writer: std.io.AnyWriter, + label: []const u8, + client_random: []const u8, + secret: []const u8, +) !void { + try writer.writeAll(label); + try writer.writeByte(' '); + for (client_random) |b| writer.print("{x:0>2}", .{b}) catch {}; + try writer.writeByte(' '); + for (secret) |b| writer.print("{x:0>2}", .{b}) catch {}; + try writer.writeByte('\n'); +} diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 0a15f1212a57..35c641b32857 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -6,6 +6,13 @@ const assert = std.debug.assert; const Certificate = crypto.Certificate; stream: tls.Stream, +/// The value sent in our `ClientHello` message. +/// +/// Used as a session identifier by `options.key_log`. +/// Since the server may renegotiate (without a new random) +/// after the initial handshake in a `key_update` message, +/// save it here instead of in `Command`. +random: [32]u8 = undefined, options: Options, const Self = @This(); @@ -18,7 +25,10 @@ pub fn init(stream: std.io.AnyStream, options: Options) !Self { .is_client = true, .transcript_hash = &transcript_hash, }; - var res = Self{ .stream = stream_, .options = options }; + var random: [32]u8 = undefined; + crypto.random.bytes(&random); + + var res = Self{ .stream = stream_, .random = random, .options = options }; var command = Command{ .send_hello = KeyPairs.init() }; while (command != .none) command = try res.next(command); @@ -100,7 +110,7 @@ pub fn next(self: *Self, command: Command) !Command { pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { const hello = tls.ClientHello{ - .random = key_pairs.hello_rand, + .random = self.random, .session_id = &key_pairs.session_id, .cipher_suites = self.options.cipher_suites, .extensions = &.{ @@ -213,6 +223,8 @@ pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { cipher_suite, shared_key.?, hello_hash, + self.options.key_log, + &self.random, ) catch return stream.writeError(.illegal_parameter); stream.cipher = .{ .handshake = handshake_cipher }; } @@ -404,7 +416,12 @@ pub fn send_finished(self: *Self) !void { _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); try stream.flush(); - const application_cipher = tls.ApplicationCipher.init(stream.cipher.handshake, handshake_hash); + const application_cipher = tls.ApplicationCipher.init( + stream.cipher.handshake, + handshake_hash, + self.options.key_log, + &self.random, + ); stream.cipher = .{ .application = application_cipher }; stream.content_type = .application_data; stream.transcript_hash = null; @@ -432,10 +449,9 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { switch (stream.cipher.application) { inline else => |*p| { const P = @TypeOf(p.*); - const server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = tls.hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = tls.hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + p.server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); + p.server_key = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "key", "", P.AEAD.key_length); + p.server_iv = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "iv", "", P.AEAD.nonce_length); p.read_seq = 0; }, } @@ -444,10 +460,9 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { switch (stream.cipher.application) { inline else => |*p| { const P = @TypeOf(p.*); - const client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = tls.hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = tls.hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + p.client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); + p.client_key = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "key", "", P.AEAD.key_length); + p.client_iv = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); p.write_seq = 0; }, } @@ -502,17 +517,19 @@ pub const Options = struct { /// Certificate verify messages may be up to 2^16-1 bytes long. /// This is the allocator used for them. allocator: std.mem.Allocator, + /// Writer to log shared secrets for traffic decryption. + /// + /// See https://www.ietf.org/archive/id/draft-thomson-tls-keylogfile-01.html + key_log: std.io.AnyWriter = std.io.null_writer.any(), }; /// One of these potential key pairs will be selected during the handshake. pub const KeyPairs = struct { - hello_rand: [hello_rand_length]u8, session_id: [session_id_length]u8, secp256r1: Secp256r1, secp384r1: Secp384r1, x25519: X25519, - const hello_rand_length = 32; const session_id_length = 32; const X25519 = tls.NamedGroupT(.x25519).KeyPair; const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; @@ -520,8 +537,7 @@ pub const KeyPairs = struct { pub fn init() @This() { var random_buffer: [ - hello_rand_length + - session_id_length + + session_id_length + Secp256r1.seed_length + Secp384r1.seed_length + X25519.seed_length @@ -530,29 +546,27 @@ pub const KeyPairs = struct { while (true) { crypto.random.bytes(&random_buffer); - const split1 = hello_rand_length; - const split2 = split1 + session_id_length; - const split3 = split2 + Secp256r1.seed_length; - const split4 = split3 + Secp384r1.seed_length; + const split1 = session_id_length; + const split2 = split1 + Secp256r1.seed_length; + const split3 = split2 + Secp384r1.seed_length; return initAdvanced( random_buffer[0..split1].*, random_buffer[split1..split2].*, random_buffer[split2..split3].*, - random_buffer[split3..split4].*, - random_buffer[split4..].*, + random_buffer[split3..].*, ) catch continue; } } pub fn initAdvanced( - hello_rand: [hello_rand_length]u8, session_id: [session_id_length]u8, secp256r1_seed: [Secp256r1.seed_length]u8, secp384r1_seed: [Secp384r1.seed_length]u8, x25519_seed: [X25519.seed_length]u8, ) !@This() { return .{ + .session_id = session_id, .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. }, @@ -562,8 +576,6 @@ pub const KeyPairs = struct { .x25519 = X25519.create(x25519_seed) catch |err| switch (err) { error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. }, - .hello_rand = hello_rand, - .session_id = session_id, }; } }; diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index 759f9db7c9d3..537a961a3df0 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -275,6 +275,8 @@ pub fn send_hello(self: *Self, client_hello: ClientHello) !void { client_hello.cipher_suite, shared_key, hello_hash, + self.options.key_log, + &client_hello.random, ) catch return stream.writeError(.illegal_parameter); stream.cipher = .{ .handshake = handshake_cipher }; @@ -384,7 +386,12 @@ pub fn recv_finished(self: *Self) !void { const handshake_hash = stream.transcript_hash.?.peek(); - const application_cipher = tls.ApplicationCipher.init(stream.cipher.handshake, handshake_hash); + const application_cipher = tls.ApplicationCipher.init( + stream.cipher.handshake, + handshake_hash, + self.options.key_log, + "idk", + ); const expected = switch (stream.cipher.handshake) { inline else => |p| brk: { @@ -453,6 +460,10 @@ pub const Options = struct { certificate: tls.Certificate, /// Key to use in `send_certificate_verify`. Must match `certificate.parse().pub_key_algo`. certificate_key: CertificateKey, + /// Writer to log shared secrets for traffic decryption. + /// + /// See https://www.ietf.org/archive/id/draft-thomson-tls-keylogfile-01.html + key_log: std.io.AnyWriter = std.io.null_writer.any(), pub const CertificateKey = union(enum) { rsa: crypto.Certificate.rsa.SecretKey, diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig index ad46d253977c..de30a714d4ba 100644 --- a/lib/std/crypto/tls/Stream.zig +++ b/lib/std/crypto/tls/Stream.zig @@ -1,6 +1,6 @@ //! Abstraction over TLS record layer (RFC 8446 S5). //! -//! After writing must call `flush` before reading or contents will not be written. +//! After writing must `flush` before reading. //! //! Handles: //! * Fragmentation diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 91a21940c029..971c8d1ec742 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1012,8 +1012,8 @@ pub const Request = struct { // leave the connection in a known good state. req.response.skip = true; var buf: [0]u8 = undefined; - // we're skipping, no buffer is necessary - const iovecs = [_]std.os.iovec{ .{ .iov_base = &buf, .iov_len = 0 } }; + // we're skipping, no buffer is necessary + const iovecs = [_]std.os.iovec{.{ .iov_base = &buf, .iov_len = 0 }}; const n_read = try req.transferReadv(&iovecs); assert(n_read == 0); diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 0deb5a6d3370..8fb150e3adf0 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -989,7 +989,7 @@ pub const Response = struct { }; // TODO make this writev instead of writevAll, which involves // complicating the logic of this function. - try r.connection.stream().writer().writevAll(&iovecs); + try r.connection.stream().writer().writevAll(&iovecs); r.send_buffer_start = 0; r.send_buffer_end = 0; r.chunk_len = 0; diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 69022a5e9e86..917c08c63e51 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -274,7 +274,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { const stream = socket.stream(); defer stream.close(); try stream.writer().writeAll(request_bytes); - std.debug.print("requested\n", .{}); + std.debug.print("requested\n", .{}); const response = try stream.reader().readAllAlloc(gpa, 8192); defer gpa.free(response); diff --git a/lib/std/io.zig b/lib/std/io.zig index cd86a3c5ee69..9225f40f3ba3 100644 --- a/lib/std/io.zig +++ b/lib/std/io.zig @@ -86,7 +86,7 @@ pub fn GenericReader( context: Context, pub const Error = ReadError; - pub const NoEofError = ReadError || error{ EndOfStream }; + pub const NoEofError = ReadError || error{EndOfStream}; pub inline fn readv(self: Self, iov: []const iovec) Error!usize { return readvFn(self.context, iov); From eff31f2794801445f9731f1ff7bb3d7506445341 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Tue, 19 Mar 2024 19:02:40 -0400 Subject: [PATCH 16/17] remove unused HandshakeCipherT member --- lib/std/crypto/tls.zig | 136 ++++++++++++++++------------------ lib/std/crypto/tls/Client.zig | 11 +++ lib/std/crypto/tls/Stream.zig | 3 +- lib/std/http/Client.zig | 16 +++- 4 files changed, 90 insertions(+), 76 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 79846bd43079..868cee575577 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -687,49 +687,13 @@ pub const HandshakeCipher = union(CipherSuite) { .aegis_256_sha512, .aegis_128l_sha256, => |tag| { - var res = @unionInit(Self, @tagName(tag), .{ - .handshake_secret = undefined, - .client_key = undefined, - .server_key = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, - .client_iv = undefined, - .server_iv = undefined, - }); - const P = std.meta.TagPayloadByName(Self, @tagName(tag)); - const p = &@field(res, @tagName(tag)); - - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = emptyHash(P.Hash); - - const derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - p.handshake_secret = P.Hkdf.extract(&derived_secret, shared_key); - const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", hello_hash, P.Hash.digest_length); - - // Not being able to log our secrets shouldn't prevent the handshake from continuing. - writeKeyLogEntry(logger, "CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_random, &client_secret) catch {}; - writeKeyLogEntry(logger, "SERVER_HANDSHAKE_TRAFFIC_SECRET", client_random, &server_secret) catch {}; - - p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - - return res; + const T = std.meta.TagPayloadByName(Self, @tagName(tag)); + const cipher = T.init(shared_key, hello_hash, logger, client_random); + return @unionInit(Self, @tagName(tag), cipher); }, _ => return Error.TlsIllegalParameter, } } - - pub fn print(self: Self) void { - switch (self) { - inline else => |v| v.print(), - } - } }; pub const ApplicationCipher = union(CipherSuite) { @@ -754,35 +718,9 @@ pub const ApplicationCipher = union(CipherSuite) { .aegis_256_sha512, .aegis_128l_sha256, => |c, tag| { - var res = @unionInit(Self, @tagName(tag), .{ - .client_secret = undefined, - .server_secret = undefined, - .client_key = undefined, - .server_key = undefined, - .client_iv = undefined, - .server_iv = undefined, - }); - const P = std.meta.TagPayloadByName(Self, @tagName(tag)); - const p = &@field(res, @tagName(tag)); - - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const empty_hash = emptyHash(P.Hash); - - const derived_secret = hkdfExpandLabel(P.Hkdf, c.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - const master_secret = P.Hkdf.extract(&derived_secret, &zeroes); - p.client_secret = hkdfExpandLabel(P.Hkdf, master_secret, "c ap traffic", handshake_hash, P.Hash.digest_length); - p.server_secret = hkdfExpandLabel(P.Hkdf, master_secret, "s ap traffic", handshake_hash, P.Hash.digest_length); - - // Not being able to log our secrets shouldn't prevent the handshake from continuing. - writeKeyLogEntry(logger, "CLIENT_TRAFFIC_SECRET_0", client_random, &p.client_secret) catch {}; - writeKeyLogEntry(logger, "SERVER_TRAFFIC_SECRET_0", client_random, &p.server_secret) catch {}; - - p.client_key = hkdfExpandLabel(P.Hkdf, p.client_secret, "key", "", P.AEAD.key_length); - p.server_key = hkdfExpandLabel(P.Hkdf, p.server_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, p.server_secret, "iv", "", P.AEAD.nonce_length); - - return res; + const T = std.meta.TagPayloadByName(Self, @tagName(tag)); + const cipher = T.init(c.handshake_secret, handshake_hash, logger, client_random); + return @unionInit(Self, @tagName(tag), cipher); }, } } @@ -1069,6 +1007,36 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { const Self = @This(); + pub fn init( + shared_key: []const u8, + hello_hash: []const u8, + logger: std.io.AnyWriter, + client_random: []const u8, + ) Self { + const zeroes = [1]u8{0} ** Hash.digest_length; + const early = Hkdf.extract(&[1]u8{0}, &zeroes); + const empty = emptyHash(Hash); + + const derived = hkdfExpandLabel(Hkdf, early, "derived", &empty, Hash.digest_length); + const handshake = Hkdf.extract(&derived, shared_key); + const client = hkdfExpandLabel(Hkdf, handshake, "c hs traffic", hello_hash, Hash.digest_length); + const server = hkdfExpandLabel(Hkdf, handshake, "s hs traffic", hello_hash, Hash.digest_length); + + // Not being able to log our secrets shouldn't prevent the handshake from continuing. + writeKeyLogEntry(logger, "CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_random, &client) catch {}; + writeKeyLogEntry(logger, "SERVER_HANDSHAKE_TRAFFIC_SECRET", client_random, &server) catch {}; + + return .{ + .handshake_secret = handshake, + .client_finished_key = hkdfExpandLabel(Hkdf, client, "finished", "", Hmac.key_length), + .server_finished_key = hkdfExpandLabel(Hkdf, server, "finished", "", Hmac.key_length), + .client_key = hkdfExpandLabel(Hkdf, client, "key", "", AEAD.key_length), + .server_key = hkdfExpandLabel(Hkdf, server, "key", "", AEAD.key_length), + .client_iv = hkdfExpandLabel(Hkdf, client, "iv", "", AEAD.nonce_length), + .server_iv = hkdfExpandLabel(Hkdf, server, "iv", "", AEAD.nonce_length), + }; + } + pub fn encrypt( self: *Self, data: []const u8, @@ -1099,10 +1067,6 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { AEAD.decrypt(out, data, tag, additional, nonce, key) catch return Error.TlsBadRecordMac; self.read_seq += 1; } - - pub fn print(self: Self) void { - inline for (std.meta.fields(Self)) |f| debugPrint(f.name, @field(self, f.name)); - } }; } @@ -1130,6 +1094,34 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { const Self = @This(); + pub fn init( + handshake_secret: [Hkdf.prk_length]u8, + handshake_hash: []const u8, + logger: std.io.AnyWriter, + client_random: []const u8, + ) Self { + const zeroes = [1]u8{0} ** Hash.digest_length; + const empty_hash = emptyHash(Hash); + + const derived = hkdfExpandLabel(Hkdf, handshake_secret, "derived", &empty_hash, Hash.digest_length); + const master = Hkdf.extract(&derived, &zeroes); + const client = hkdfExpandLabel(Hkdf, master, "c ap traffic", handshake_hash, Hash.digest_length); + const server = hkdfExpandLabel(Hkdf, master, "s ap traffic", handshake_hash, Hash.digest_length); + + // Not being able to log our secrets shouldn't prevent the handshake from continuing. + writeKeyLogEntry(logger, "CLIENT_TRAFFIC_SECRET_0", client_random, &client) catch {}; + writeKeyLogEntry(logger, "SERVER_TRAFFIC_SECRET_0", client_random, &server) catch {}; + + return .{ + .client_secret = client, + .server_secret = server, + .client_key = hkdfExpandLabel(Hkdf, client, "key", "", AEAD.key_length), + .server_key = hkdfExpandLabel(Hkdf, server, "key", "", AEAD.key_length), + .client_iv = hkdfExpandLabel(Hkdf, client, "iv", "", AEAD.nonce_length), + .server_iv = hkdfExpandLabel(Hkdf, server, "iv", "", AEAD.nonce_length), + }; + } + pub fn encrypt( self: *Self, data: []const u8, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 35c641b32857..c017d0508ca3 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -14,6 +14,9 @@ stream: tls.Stream, /// save it here instead of in `Command`. random: [32]u8 = undefined, options: Options, +// For logging after `key_update` messages +server_update_n: usize = 0, +client_update_n: usize = 0, const Self = @This(); @@ -453,6 +456,10 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { p.server_key = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "key", "", P.AEAD.key_length); p.server_iv = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "iv", "", P.AEAD.nonce_length); p.read_seq = 0; + + self.server_update_n += 1; + self.options.key_log.print("SERVER_TRAFFIC_SECRET_{d}", .{self.server_update_n}) catch {}; + tls.writeKeyLogEntry(self.options.key_log, "", &self.random, &p.server_secret) catch {}; }, } const update = try stream.read(tls.KeyUpdate); @@ -464,6 +471,10 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { p.client_key = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "key", "", P.AEAD.key_length); p.client_iv = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); p.write_seq = 0; + + self.client_update_n += 1; + self.options.key_log.print("CLIENT_TRAFFIC_SECRET_{d}", .{self.client_update_n}) catch {}; + tls.writeKeyLogEntry(self.options.key_log, "", &self.random, &p.client_secret) catch {}; }, } } diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig index de30a714d4ba..f513fd414c14 100644 --- a/lib/std/crypto/tls/Stream.zig +++ b/lib/std/crypto/tls/Stream.zig @@ -157,7 +157,8 @@ pub fn writeError(self: *Self, err: Alert.Description) tls.Error { self.flush() catch {}; self.close(); - return err.toError(); + @panic("ohnooo"); + // return err.toError(); } pub fn close(self: *Self) void { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 971c8d1ec742..95f3302be420 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1308,7 +1308,17 @@ pub const basic_authorization = struct { } }; -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +pub const ConnectTcpError = Allocator.Error || tls.Client.ReadError || tls.Client.WriteError || error{ + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, +}; /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. /// @@ -1348,7 +1358,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec errdefer client.allocator.free(conn.data.host); if (protocol == .tls) { - conn.data.tls = tls.Client.init(conn.data.socket.stream().any(), .{ + conn.data.tls = try tls.Client.init(conn.data.socket.stream().any(), .{ .ca_bundle = client.tls_options.ca_bundle, .cipher_suites = client.tls_options.cipher_suites, .host = host, @@ -1356,7 +1366,7 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec // the content length which is used to detect truncation attacks. .allow_truncation_attacks = true, .allocator = client.allocator, - }) catch return error.TlsInitializationFailed; + }); } client.connection_pool.addUsed(conn); From 551d7142d7ae683b05bd7c84934e444924a0e8f9 Mon Sep 17 00:00:00 2001 From: clickingbuttons Date: Wed, 20 Mar 2024 18:59:26 -0400 Subject: [PATCH 17/17] handshake struct types for key logging tls.Client and tls.Server --- lib/std/crypto/Certificate.zig | 4 +- lib/std/crypto/tls.zig | 267 ++++---- lib/std/crypto/tls/Client.zig | 1103 +++++++++++++++++--------------- lib/std/crypto/tls/Server.zig | 935 ++++++++++++++------------- lib/std/crypto/tls/Stream.zig | 49 +- 5 files changed, 1227 insertions(+), 1131 deletions(-) diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 80aa355fce55..bf2c74da1723 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -1260,7 +1260,7 @@ pub const rsa = struct { return try fromBytes(modulus, exponent); } - fn encrypt(self: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + pub fn encrypt(self: PublicKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { const m = Fe.fromBytes(self.n, &msg, .big) catch return error.MessageTooLong; const e = self.n.powPublic(m, self.e) catch unreachable; var res: [modulus_len]u8 = undefined; @@ -1308,7 +1308,7 @@ pub const rsa = struct { return try fromBytes(modulus, pub_exponent, priv_exponent); } - fn decrypt(self: SecretKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { + pub fn decrypt(self: SecretKey, comptime modulus_len: usize, msg: [modulus_len]u8) ![modulus_len]u8 { const m = Fe.fromBytes(self.public.n, &msg, .big) catch return error.MessageTooLong; const e = self.public.n.pow(m, self.d) catch unreachable; var res: [modulus_len]u8 = undefined; diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 868cee575577..ca14eaf97db8 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -9,9 +9,11 @@ const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; +const testing = std.testing; const native_endian = builtin.cpu.arch.endian(); pub const ServerOptions = Server.Options; pub const ClientOptions = Client.Options; +const Allocator = std.mem.Allocator; pub const Version = enum(u16) { tls_1_0 = 0x0301, @@ -167,7 +169,7 @@ pub const KeyUpdate = enum(u8) { /// A DER encoded certificate chain with the first entry being for this domain. pub const Certificate = struct { context: []const u8 = "", - entries: []const Entry, + entries: []const Entry = &.{}, pub const max_context_len = 255; @@ -590,7 +592,7 @@ pub const KeyShare = union(NamedGroup) { pub fn read(stream: *Stream) !Self { std.debug.assert(!stream.is_client); - var reader = stream.any().reader(); + var reader = stream.stream().reader(); const group = try stream.read(NamedGroup); const len = try stream.read(u16); switch (group) { @@ -677,8 +679,7 @@ pub const HandshakeCipher = union(CipherSuite) { suite: CipherSuite, shared_key: []const u8, hello_hash: []const u8, - logger: std.io.AnyWriter, - client_random: []const u8, + logger: KeyLogger, ) Error!Self { switch (suite) { inline .aes_128_gcm_sha256, @@ -688,7 +689,7 @@ pub const HandshakeCipher = union(CipherSuite) { .aegis_128l_sha256, => |tag| { const T = std.meta.TagPayloadByName(Self, @tagName(tag)); - const cipher = T.init(shared_key, hello_hash, logger, client_random); + const cipher = T.init(shared_key, hello_hash, logger); return @unionInit(Self, @tagName(tag), cipher); }, _ => return Error.TlsIllegalParameter, @@ -708,8 +709,7 @@ pub const ApplicationCipher = union(CipherSuite) { pub fn init( handshake_cipher: HandshakeCipher, handshake_hash: []const u8, - logger: std.io.AnyWriter, - client_random: []const u8, + logger: KeyLogger, ) Self { switch (handshake_cipher) { inline .aes_128_gcm_sha256, @@ -719,7 +719,7 @@ pub const ApplicationCipher = union(CipherSuite) { .aegis_128l_sha256, => |c, tag| { const T = std.meta.TagPayloadByName(Self, @tagName(tag)); - const cipher = T.init(c.handshake_secret, handshake_hash, logger, client_random); + const cipher = T.init(c.handshake_secret, handshake_hash, logger); return @unionInit(Self, @tagName(tag), cipher); }, } @@ -1010,8 +1010,7 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { pub fn init( shared_key: []const u8, hello_hash: []const u8, - logger: std.io.AnyWriter, - client_random: []const u8, + logger: KeyLogger, ) Self { const zeroes = [1]u8{0} ** Hash.digest_length; const early = Hkdf.extract(&[1]u8{0}, &zeroes); @@ -1023,8 +1022,8 @@ fn HandshakeCipherT(comptime suite: CipherSuite) type { const server = hkdfExpandLabel(Hkdf, handshake, "s hs traffic", hello_hash, Hash.digest_length); // Not being able to log our secrets shouldn't prevent the handshake from continuing. - writeKeyLogEntry(logger, "CLIENT_HANDSHAKE_TRAFFIC_SECRET", client_random, &client) catch {}; - writeKeyLogEntry(logger, "SERVER_HANDSHAKE_TRAFFIC_SECRET", client_random, &server) catch {}; + logger.writeLine("CLIENT_HANDSHAKE_TRAFFIC_SECRET", &client) catch {}; + logger.writeLine("SERVER_HANDSHAKE_TRAFFIC_SECRET", &server) catch {}; return .{ .handshake_secret = handshake, @@ -1097,8 +1096,7 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { pub fn init( handshake_secret: [Hkdf.prk_length]u8, handshake_hash: []const u8, - logger: std.io.AnyWriter, - client_random: []const u8, + logger: KeyLogger, ) Self { const zeroes = [1]u8{0} ** Hash.digest_length; const empty_hash = emptyHash(Hash); @@ -1109,8 +1107,8 @@ fn ApplicationCipherT(comptime suite: CipherSuite) type { const server = hkdfExpandLabel(Hkdf, master, "s ap traffic", handshake_hash, Hash.digest_length); // Not being able to log our secrets shouldn't prevent the handshake from continuing. - writeKeyLogEntry(logger, "CLIENT_TRAFFIC_SECRET_0", client_random, &client) catch {}; - writeKeyLogEntry(logger, "SERVER_TRAFFIC_SECRET_0", client_random, &server) catch {}; + logger.writeLine("CLIENT_TRAFFIC_SECRET_0", &client) catch {}; + logger.writeLine("SERVER_TRAFFIC_SECRET_0", &server) catch {}; return .{ .client_secret = client, @@ -1322,48 +1320,44 @@ const TestStream = struct { } }; -test "tls client and server handshake, data, and close_notify" { - const allocator = std.testing.allocator; - - var inner_stream = try TestStream.init(allocator); - defer inner_stream.deinit(allocator); - const stream = inner_stream.stream(); - - // Use these seeded values for reproducible handshake and application ciphertext. - const session_id: [32]u8 = ("session_id012345" ** 2).*; +fn seededClientHandshake(allocator: Allocator, stream: std.io.AnyStream) !Client.Handshake { const client_random: [32]u8 = ("client_random012" ** 2).*; - const server_random: [32]u8 = ("server_random012" ** 2).*; - const client_key_seed: [32]u8 = ("client_seed01234" ** 2).*; - const server_keygen_seed: [48]u8 = ("server_seed01234" ** 3).*; - const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; + const session_id: [32]u8 = ("session_id012345" ** 2).*; + const client_key_seed: [16]u8 = "client_seed01234".*; + const key_pairs = try Client.Handshake.KeyPairs.initAdvanced( + client_key_seed ** 2, + client_key_seed ** 3, + client_key_seed ** 2, + ); - const stdout = std.io.getStdOut(); - var client_transcript: MultiHash = .{}; - var client = Client{ - .random = client_random, - .stream = Stream{ - .stream = stream.any(), - .is_client = true, - .transcript_hash = &client_transcript, - }, + return Client.Handshake{ + .tls_stream = .{ .inner_stream = stream, .is_client = true }, .options = .{ .host = "localhost", .ca_bundle = null, .allocator = allocator, - .key_log = stdout.writer().any(), }, + .client_random = client_random, + .session_id = session_id, + .key_pairs = key_pairs, }; +} + +fn seededServerHandshake(stream: std.io.AnyStream) !Server.Handshake { + const server_random: [32]u8 = ("server_random012" ** 2).*; + const server_keygen_seed: [48]u8 = ("server_seed01234" ** 3).*; + const server_sig_salt: [MultiHash.max_digest_len]u8 = ("server_sig_salt0" ** 4).*; const server_cert = @embedFile("./testdata/cert.der"); const server_key = @embedFile("./testdata/key.der"); const server_rsa = try crypto.Certificate.rsa.SecretKey.fromDer(server_key); - var server_transcript: MultiHash = .{}; - var server = Server{ - .stream = Stream{ - .stream = stream.any(), - .is_client = false, - .transcript_hash = &server_transcript, - }, + + // For debugging + const stdout = std.io.getStdOut(); + const key_log = stdout.writer().any(); + + return Server.Handshake{ + .tls_stream = .{ .inner_stream = stream, .is_client = false }, .options = .{ .cipher_suites = &[_]CipherSuite{.aes_256_gcm_sha384}, .key_shares = &[_]NamedGroup{.x25519}, @@ -1371,91 +1365,97 @@ test "tls client and server handshake, data, and close_notify" { .{ .data = server_cert }, } }, .certificate_key = .{ .rsa = server_rsa }, + .key_log = key_log, }, + .server_random = server_random, + .keygen_seed = server_keygen_seed, + .certificate_verify_salt = server_sig_salt, }; +} - const key_pairs = try Client.KeyPairs.initAdvanced( - session_id, - client_key_seed, - client_key_seed ++ [_]u8{0} ** (48 - 32), - client_key_seed, - ); - var client_command = Client.Command{ .send_hello = key_pairs }; - client_command = try client.next(client_command); - try std.testing.expect(client_command == .recv_hello); +test "tls client and server handshake, data, and close_notify" { + const allocator = testing.allocator; - var server_command = Server.Command{ .recv_hello = .{ - .server_random = server_random, - .keygen_seed = server_keygen_seed, - } }; - server_command = try server.next(server_command); // recv_hello - try std.testing.expect(server_command == .send_hello); + var inner_stream = try TestStream.init(allocator); + defer inner_stream.deinit(allocator); + const stream = inner_stream.stream(); + + var c_hs = try seededClientHandshake(allocator, stream.any()); + var s_hs = try seededServerHandshake(stream.any()); + + try c_hs.next(); + try testing.expectEqual(.recv_hello, c_hs.command); + + try s_hs.next(); // recv_hello + try testing.expectEqual(.send_hello, s_hs.command); - server_command = try server.next(server_command); // send_hello - try std.testing.expect(server_command == .send_change_cipher_spec); + try s_hs.next(); // send_hello + try testing.expectEqual(.send_change_cipher_spec, s_hs.command); - client_command = try client.next(client_command); // recv_hello - try std.testing.expect(client_command == .recv_encrypted_extensions); + try c_hs.next(); // recv_hello + try testing.expectEqual(.recv_encrypted_extensions, c_hs.command); { - const s = server.stream.cipher.handshake.aes_256_gcm_sha384; - const c = client.stream.cipher.handshake.aes_256_gcm_sha384; - - try std.testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); - try std.testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); - try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); - try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); - try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); - try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + const s = s_hs.tls_stream.cipher.handshake.aes_256_gcm_sha384; + const c = c_hs.tls_stream.cipher.handshake.aes_256_gcm_sha384; + + try testing.expectEqualSlices(u8, &s.server_finished_key, &c.server_finished_key); + try testing.expectEqualSlices(u8, &s.client_finished_key, &c.client_finished_key); + try testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + try testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); } - server_command = try server.next(server_command); // send_change_cipher_spec - try std.testing.expect(server_command == .send_encrypted_extensions); - server_command = try server.next(server_command); // send_encrypted_extensions - try std.testing.expect(server_command == .send_certificate); - server_command = try server.next(server_command); // send_certificate - try std.testing.expect(server_command == .send_certificate_verify); - server_command.send_certificate_verify.salt = server_sig_salt; - server_command = try server.next(server_command); // send_certificate_verify - try std.testing.expect(server_command == .send_finished); - server_command = try server.next(server_command); // send_finished - try std.testing.expect(server_command == .recv_finished); - - client_command = try client.next(client_command); // recv_encrypted_extensions - try std.testing.expect(client_command == .recv_certificate_or_finished); - client_command = try client.next(client_command); // recv_certificate_or_finished (certificate) - try std.testing.expect(client_command == .recv_certificate_verify); - client_command = try client.next(client_command); // recv_certificate_verify - try std.testing.expect(client_command == .recv_finished); - client_command = try client.next(client_command); // recv_finished - try std.testing.expect(client_command == .send_change_cipher_spec); - client_command = try client.next(client_command); // send_change_cipher_spec - try std.testing.expect(client_command == .send_finished); - client_command = try client.next(client_command); // send_finished - try std.testing.expect(client_command == .none); - - server_command = try server.next(server_command); // recv_finished - try std.testing.expect(server_command == .none); + try s_hs.next(); // send_change_cipher_spec + try testing.expectEqual(.send_encrypted_extensions, s_hs.command); + try s_hs.next(); // send_encrypted_extensions + try testing.expectEqual(.send_certificate, s_hs.command); + try s_hs.next(); // send_certificate + try testing.expectEqual(.send_certificate_verify, s_hs.command); + try s_hs.next(); // send_certificate_verify + try testing.expectEqual(.send_finished, s_hs.command); + try s_hs.next(); // send_finished + try testing.expectEqual(.recv_finished, s_hs.command); + + try c_hs.next(); // recv_encrypted_extensions + try testing.expectEqual(.recv_certificate_or_finished, c_hs.command); + try c_hs.next(); // recv_certificate_or_finished (certificate) + try testing.expectEqual(.recv_certificate_verify, c_hs.command); + try c_hs.next(); // recv_certificate_verify + try testing.expectEqual(.recv_finished, c_hs.command); + try c_hs.next(); // recv_finished + try testing.expectEqual(.send_change_cipher_spec, c_hs.command); + try c_hs.next(); // send_change_cipher_spec + try testing.expectEqual(.send_finished, c_hs.command); + try c_hs.next(); // send_finished + try testing.expectEqual(.none, c_hs.command); + + try s_hs.next(); // recv_finished + try testing.expectEqual(.none, s_hs.command); { - const s = server.stream.cipher.application.aes_256_gcm_sha384; - const c = client.stream.cipher.application.aes_256_gcm_sha384; + const s = s_hs.tls_stream.cipher.application.aes_256_gcm_sha384; + const c = c_hs.tls_stream.cipher.application.aes_256_gcm_sha384; - try std.testing.expectEqualSlices(u8, &s.client_key, &c.client_key); - try std.testing.expectEqualSlices(u8, &s.server_key, &c.server_key); - try std.testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); - try std.testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); + try testing.expectEqualSlices(u8, &s.client_key, &c.client_key); + try testing.expectEqualSlices(u8, &s.server_key, &c.server_key); + try testing.expectEqualSlices(u8, &s.client_iv, &c.client_iv); + try testing.expectEqualSlices(u8, &s.server_iv, &c.server_iv); } - try client.any().writer().writeAll("ping"); + var client = try c_hs.handshake(); + var server = try s_hs.handshake(); + + try client.stream().writer().writeAll("ping"); var recv_ping: [4]u8 = undefined; - _ = try server.stream.any().reader().readAll(&recv_ping); - try std.testing.expectEqualStrings("ping", &recv_ping); + _ = try server.tls_stream.stream().reader().readAll(&recv_ping); + try testing.expectEqualStrings("ping", &recv_ping); - server.stream.close(); - try std.testing.expect(server.stream.closed); + server.tls_stream.close(); + try testing.expect(server.tls_stream.closed); - _ = try client.stream.readPlaintext(); - try std.testing.expect(client.stream.closed); + _ = try client.stream().reader().discard(); + try testing.expect(client.tls_stream.closed); } pub fn debugPrint(name: []const u8, slice: anytype) void { @@ -1468,16 +1468,29 @@ pub fn debugPrint(name: []const u8, slice: anytype) void { std.debug.print("\n", .{}); } -pub fn writeKeyLogEntry( - writer: std.io.AnyWriter, - label: []const u8, - client_random: []const u8, - secret: []const u8, -) !void { - try writer.writeAll(label); - try writer.writeByte(' '); - for (client_random) |b| writer.print("{x:0>2}", .{b}) catch {}; - try writer.writeByte(' '); - for (secret) |b| writer.print("{x:0>2}", .{b}) catch {}; - try writer.writeByte('\n'); -} +/// https://www.ietf.org/archive/id/draft-thomson-tls-keylogfile-01.html +pub const KeyLogger = struct { + /// Copy of `options.key_log`. Needed for `key_update` messages. + writer: std.io.AnyWriter = std.io.null_writer.any(), + /// The value received in our `ClientHello` message. + /// Used as a session identifier in messages sent to `key_log`. + client_random: [32]u8 = undefined, + // For logging after `key_update` messages. + server_update_n: u32 = 0, + client_update_n: u32 = 0, + + pub fn writeLine( + self: @This(), + label: []const u8, + secret: []const u8, + ) !void { + var w = self.writer; + + try w.writeAll(label); + try w.writeByte(' '); + for (self.client_random) |b| w.print("{x:0>2}", .{b}) catch {}; + try w.writeByte(' '); + for (secret) |b| w.print("{x:0>2}", .{b}) catch {}; + try w.writeByte('\n'); + } +}; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c017d0508ca3..3efb37e7faf5 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -4,452 +4,63 @@ const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = crypto.Certificate; +const Allocator = std.mem.Allocator; -stream: tls.Stream, -/// The value sent in our `ClientHello` message. -/// -/// Used as a session identifier by `options.key_log`. -/// Since the server may renegotiate (without a new random) -/// after the initial handshake in a `key_update` message, -/// save it here instead of in `Command`. -random: [32]u8 = undefined, -options: Options, -// For logging after `key_update` messages -server_update_n: usize = 0, -client_update_n: usize = 0, - -const Self = @This(); - -/// Initiates a TLS handshake and establishes a TLSv1.3 session -pub fn init(stream: std.io.AnyStream, options: Options) !Self { - var transcript_hash: tls.MultiHash = .{}; - const stream_ = tls.Stream{ - .stream = stream, - .is_client = true, - .transcript_hash = &transcript_hash, - }; - var random: [32]u8 = undefined; - crypto.random.bytes(&random); - - var res = Self{ .stream = stream_, .random = random, .options = options }; - - var command = Command{ .send_hello = KeyPairs.init() }; - while (command != .none) command = try res.next(command); - - return res; -} - -/// Executes handshake command and returns next one. -pub fn next(self: *Self, command: Command) !Command { - var stream = &self.stream; - switch (command) { - .send_hello => |key_pairs| { - try self.send_hello(key_pairs); - - return .{ .recv_hello = key_pairs }; - }, - .recv_hello => |key_pairs| { - try stream.expectInnerPlaintext(.handshake, .server_hello); - try self.recv_hello(key_pairs); - - return .{ .recv_encrypted_extensions = {} }; - }, - .recv_encrypted_extensions => { - try stream.expectInnerPlaintext(.handshake, .encrypted_extensions); - try self.recv_encrypted_extensions(); - - return .{ .recv_certificate_or_finished = {} }; - }, - .recv_certificate_or_finished => { - const digest = stream.transcript_hash.?.peek(); - const inner_plaintext = try stream.readInnerPlaintext(); - if (inner_plaintext.type != .handshake) return stream.writeError(.unexpected_message); - switch (inner_plaintext.handshake_type) { - .certificate => { - const parsed = try self.recv_certificate(); - - return .{ .recv_certificate_verify = parsed }; - }, - .finished => { - if (self.options.ca_bundle != null) - return self.stream.writeError(.certificate_required); - - try self.recv_finished(digest); - - return .{ .send_finished = {} }; - }, - else => return self.stream.writeError(.unexpected_message), - } - }, - .recv_certificate_verify => |parsed| { - defer self.options.allocator.free(parsed.certificate.buffer); - - const digest = stream.transcript_hash.?.peek(); - try stream.expectInnerPlaintext(.handshake, .certificate_verify); - try self.recv_certificate_verify(digest, parsed); - - return .{ .recv_finished = {} }; - }, - .recv_finished => { - const digest = stream.transcript_hash.?.peek(); - try stream.expectInnerPlaintext(.handshake, .finished); - try self.recv_finished(digest); - - return .{ .send_change_cipher_spec = {} }; - }, - .send_change_cipher_spec => { - try stream.changeCipherSpec(); - - return .{ .send_finished = {} }; - }, - .send_finished => { - try self.send_finished(); - - return .{ .none = {} }; - }, - .none => return .{ .none = {} }, - } -} - -pub fn send_hello(self: *Self, key_pairs: KeyPairs) !void { - const hello = tls.ClientHello{ - .random = self.random, - .session_id = &key_pairs.session_id, - .cipher_suites = self.options.cipher_suites, - .extensions = &.{ - .{ .server_name = &[_]tls.ServerName{.{ .host_name = self.options.host }} }, - .{ .ec_point_formats = &[_]tls.EcPointFormat{.uncompressed} }, - .{ .supported_groups = &tls.supported_groups }, - .{ .signature_algorithms = &tls.supported_signature_schemes }, - .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, - .{ .key_share = &[_]tls.KeyShare{ - .{ .secp256r1 = key_pairs.secp256r1.public_key }, - .{ .secp384r1 = key_pairs.secp384r1.public_key }, - .{ .x25519 = key_pairs.x25519.public_key }, - } }, - }, - }; - - _ = try self.stream.write(tls.Handshake, .{ .client_hello = hello }); - try self.stream.flush(); -} - -pub fn recv_hello(self: *Self, key_pairs: KeyPairs) !void { - var stream = &self.stream; - var r = stream.any().reader(); - - // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. - _ = try stream.read(tls.Version); - var random: [32]u8 = undefined; - try r.readNoEof(&random); - if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { - // We already offered all our supported options and we aren't changing them. - return stream.writeError(.unexpected_message); - } - - var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; - const session_id_len = try stream.read(u8); - if (session_id_len > tls.ClientHello.session_id_max_len) - return stream.writeError(.illegal_parameter); - const session_id: []u8 = session_id_buf[0..session_id_len]; - try r.readNoEof(session_id); - if (!mem.eql(u8, session_id, &key_pairs.session_id)) - return stream.writeError(.illegal_parameter); - - const cipher_suite = try stream.read(tls.CipherSuite); - const compression_method = try stream.read(u8); - if (compression_method != 0) return stream.writeError(.illegal_parameter); - - var supported_version: ?tls.Version = null; - var shared_key: ?[]const u8 = null; - - var iter = try stream.extensions(); - while (try iter.next()) |ext| { - switch (ext.type) { - .supported_versions => { - if (supported_version != null) return stream.writeError(.illegal_parameter); - supported_version = try stream.read(tls.Version); - }, - .key_share => { - if (shared_key != null) return stream.writeError(.illegal_parameter); - const named_group = try stream.read(tls.NamedGroup); - const key_size = try stream.read(u16); - switch (named_group) { - .x25519 => { - const T = tls.NamedGroupT(.x25519); - const expected_len = T.public_length; - if (key_size != expected_len) return stream.writeError(.illegal_parameter); - var server_ks: [expected_len]u8 = undefined; - try r.readNoEof(&server_ks); - - const mult = crypto.dh.X25519.scalarmult( - key_pairs.x25519.secret_key, - server_ks[0..expected_len].*, - ) catch return stream.writeError(.illegal_parameter); - shared_key = &mult; - }, - inline .secp256r1, .secp384r1 => |t| { - const T = tls.NamedGroupT(t); - const expected_len = T.PublicKey.uncompressed_sec1_encoded_length; - if (key_size != expected_len) return stream.writeError(.illegal_parameter); - - var server_ks: [expected_len]u8 = undefined; - try r.readNoEof(&server_ks); - - const pk = T.PublicKey.fromSec1(&server_ks) catch - return stream.writeError(.illegal_parameter); - const key_pair = @field(key_pairs, @tagName(t)); - const mult = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch - return stream.writeError(.illegal_parameter); - shared_key = &mult.affineCoordinates().x.toBytes(.big); - }, - // Server sent us back unknown key. That's weird because we only request known ones, - // but we can keep iterating for another. - else => { - try r.skipBytes(key_size, .{}); - }, - } - }, - else => { - try r.skipBytes(ext.len, .{}); - }, - } - } - - if (supported_version != tls.Version.tls_1_3) return stream.writeError(.protocol_version); - if (shared_key == null) return stream.writeError(.missing_extension); - - stream.transcript_hash.?.setActive(cipher_suite); - const hello_hash = stream.transcript_hash.?.peek(); - - const handshake_cipher = tls.HandshakeCipher.init( - cipher_suite, - shared_key.?, - hello_hash, - self.options.key_log, - &self.random, - ) catch return stream.writeError(.illegal_parameter); - stream.cipher = .{ .handshake = handshake_cipher }; -} - -pub fn recv_encrypted_extensions(self: *Self) !void { - var stream = &self.stream; - var r = stream.any().reader(); - - var iter = try stream.extensions(); - while (try iter.next()) |ext| { - try r.skipBytes(ext.len, .{}); - } -} - -/// Verifies trust chain if `options.ca_bundle` is specified. -/// -/// Caller owns allocated Certificate.Parsed.certificate. -pub fn recv_certificate(self: *Self) !Certificate.Parsed { - var stream = &self.stream; - var r = stream.any().reader(); - const allocator = self.options.allocator; - const ca_bundle = self.options.ca_bundle; - const verify = ca_bundle != null; - - var context: [tls.Certificate.max_context_len]u8 = undefined; - const context_len = try stream.read(u8); - if (context_len > tls.Certificate.max_context_len) return stream.writeError(.decode_error); - try r.readNoEof(context[0..context_len]); - - var first: ?crypto.Certificate.Parsed = null; - errdefer if (first) |f| allocator.free(f.certificate.buffer); - var prev: Certificate.Parsed = undefined; - var verified = false; - const now_sec = std.time.timestamp(); - - var certs_iter = try stream.iterator(u24, u24); - while (try certs_iter.next()) |cert_len| { - const is_first = first == null; - - if (verified) { - try r.skipBytes(cert_len, .{}); - } else { - if (cert_len > tls.Certificate.Entry.max_data_len) - return stream.writeError(.decode_error); - const buf = allocator.alloc(u8, cert_len) catch - return stream.writeError(.internal_error); - defer if (!is_first) allocator.free(buf); - errdefer allocator.free(buf); - try r.readNoEof(buf); - - const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; - const cur = cert.parse() catch return stream.writeError(.bad_certificate); - if (first == null) { - if (verify) try cur.verifyHostName(self.options.host); - first = cur; - } else { - if (verify) try prev.verify(cur, now_sec); - } - - if (ca_bundle) |b| { - if (b.verify(cur, now_sec)) |_| { - verified = true; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - error.CertificateExpired => return stream.writeError(.certificate_expired), - else => return stream.writeError(.bad_certificate), - } - } - - prev = cur; - } - - var ext_iter = try stream.extensions(); - while (try ext_iter.next()) |ext| try r.skipBytes(ext.len, .{}); - } - if (verify and !verified) return stream.writeError(.bad_certificate); - - return if (first) |f| f else stream.writeError(.bad_certificate); -} +tls_stream: tls.Stream, +key_logger: tls.KeyLogger, -pub fn recv_certificate_verify(self: *Self, digest: []const u8, cert: Certificate.Parsed) !void { - var stream = &self.stream; - var r = stream.any().reader(); - const allocator = self.options.allocator; - - const sig_content = tls.sigContent(digest); - - const scheme = try stream.read(tls.SignatureScheme); - const len = try stream.read(u16); - if (len > tls.CertificateVerify.max_signature_length) - return stream.writeError(.decode_error); - const sig_bytes = allocator.alloc(u8, len) catch - return stream.writeError(.internal_error); - defer allocator.free(sig_bytes); - try r.readNoEof(sig_bytes); - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (cert.pub_key_algo != .X9_62_id_ecPublicKey) - return stream.writeError(.bad_certificate); - const Ecdsa = comptime_scheme.Ecdsa(); - const sig = Ecdsa.Signature.fromDer(sig_bytes) catch - return stream.writeError(.decode_error); - const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch - return stream.writeError(.decode_error); - sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (cert.pub_key_algo != .rsaEncryption) - return stream.writeError(.bad_certificate); - - const Hash = comptime_scheme.Hash(); - const rsa = Certificate.rsa; - const key = rsa.PublicKey.fromDer(cert.pubKey()) catch - return stream.writeError(.bad_certificate); - switch (key.n.bits() / 8) { - inline 128, 256, 512 => |modulus_len| { - const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); - rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch - return stream.writeError(.decode_error); - }, - else => { - return stream.writeError(.bad_certificate); - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (cert.pub_key_algo != .curveEd25519) - return stream.writeError(.bad_certificate); - const Eddsa = comptime_scheme.Eddsa(); - if (sig_content.len != Eddsa.Signature.encoded_length) - return stream.writeError(.decode_error); - const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); - if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) - return stream.writeError(.decode_error); - const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch - return stream.writeError(.bad_certificate); - sig.verify(sig_content, key) catch return stream.writeError(.bad_certificate); - }, - else => { - return stream.writeError(.bad_certificate); - }, - } -} - -pub fn recv_finished(self: *Self, digest: []const u8) !void { - var stream = &self.stream; - var r = stream.any().reader(); - const cipher = stream.cipher.handshake; - - switch (cipher) { - inline else => |p| { - const P = @TypeOf(p); - const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); +pub const Options = struct { + /// Certificate messages may be up to 2^24-1 bytes long. + /// Certificate verify messages may be up to 2^16-1 bytes long. + /// This is the allocator to use for them. + allocator: Allocator, + /// Trusted certificate authority bundle used to authenticate server certificates. + /// When null, server certificate and certificate_verify messages will be skipped. + ca_bundle: ?Certificate.Bundle, + /// Used to verify cerficate chain and for Server Name Indication. + host: []const u8, + /// List of cipher suites to advertise in order of descending preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + /// Minimum version to support. + min_version: tls.Version = .tls_1_3, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, + /// Writer to log shared secrets for traffic decryption in SSLKEYLOGFILE format. + key_log: std.io.AnyWriter = std.io.null_writer.any(), +}; - var actual: [expected.len]u8 = undefined; - try r.readNoEof(&actual); - if (!mem.eql(u8, expected, &actual)) return stream.writeError(.decode_error); - }, - } -} +const Client = @This(); -pub fn send_finished(self: *Self) !void { - var stream = &self.stream; - - const handshake_hash = stream.transcript_hash.?.peek(); - - const verify_data = switch (stream.cipher.handshake) { - inline .aes_128_gcm_sha256, - .aes_256_gcm_sha384, - .chacha20_poly1305_sha256, - .aegis_256_sha512, - .aegis_128l_sha256, - => |v| brk: { - const T = @TypeOf(v); - const secret = v.client_finished_key; - const transcript_hash = stream.transcript_hash.?.peek(); - - break :brk &tls.hmac(T.Hmac, transcript_hash, secret); - }, - else => return stream.writeError(.decrypt_error), - }; - stream.content_type = .handshake; - _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); - try stream.flush(); - - const application_cipher = tls.ApplicationCipher.init( - stream.cipher.handshake, - handshake_hash, - self.options.key_log, - &self.random, - ); - stream.cipher = .{ .application = application_cipher }; - stream.content_type = .application_data; - stream.transcript_hash = null; +/// Executes a TLSv1.3 handshake on `any_stream`. +pub fn init(any_stream: std.io.AnyStream, options: Options) !Client { + const hs = Handshake.init(any_stream, options); + return try hs.handshake(); } pub const ReadError = anyerror; pub const WriteError = anyerror; /// Reads next application_data message. -pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { - var stream = &self.stream; - - if (stream.eof()) return 0; +pub fn readv(self: *Client, buffers: []const std.os.iovec) ReadError!usize { + var s = &self.tls_stream; - while (stream.view.len == 0) { - const inner_plaintext = try stream.readInnerPlaintext(); + while (s.view.len == 0 and !s.eof()) { + const inner_plaintext = try s.readInnerPlaintext(); switch (inner_plaintext.type) { .handshake => { switch (inner_plaintext.handshake_type) { // A multithreaded client could use these. .new_session_ticket => { - try stream.any().reader().skipBytes(inner_plaintext.len, .{}); + try s.stream().reader().skipBytes(inner_plaintext.len, .{}); }, .key_update => { - switch (stream.cipher.application) { + switch (s.cipher.application) { inline else => |*p| { const P = @TypeOf(p.*); p.server_secret = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); @@ -457,14 +68,15 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { p.server_iv = tls.hkdfExpandLabel(P.Hkdf, p.server_secret, "iv", "", P.AEAD.nonce_length); p.read_seq = 0; - self.server_update_n += 1; - self.options.key_log.print("SERVER_TRAFFIC_SECRET_{d}", .{self.server_update_n}) catch {}; - tls.writeKeyLogEntry(self.options.key_log, "", &self.random, &p.server_secret) catch {}; + var logger = &self.key_logger; + logger.server_update_n += 1; + logger.writer.print("SERVER_TRAFFIC_SECRET_{d}", .{logger.server_update_n}) catch {}; + logger.writeLine("", &p.server_secret) catch {}; }, } - const update = try stream.read(tls.KeyUpdate); + const update = try s.read(tls.KeyUpdate); if (update == .update_requested) { - switch (stream.cipher.application) { + switch (s.cipher.application) { inline else => |*p| { const P = @TypeOf(p.*); p.client_secret = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); @@ -472,135 +84,564 @@ pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { p.client_iv = tls.hkdfExpandLabel(P.Hkdf, p.client_secret, "iv", "", P.AEAD.nonce_length); p.write_seq = 0; - self.client_update_n += 1; - self.options.key_log.print("CLIENT_TRAFFIC_SECRET_{d}", .{self.client_update_n}) catch {}; - tls.writeKeyLogEntry(self.options.key_log, "", &self.random, &p.client_secret) catch {}; + var logger = &self.key_logger; + logger.client_update_n += 1; + logger.writer.print("CLIENT_TRAFFIC_SECRET_{d}", .{logger.client_update_n}) catch {}; + logger.writeLine("", &p.client_secret) catch {}; }, } } }, - else => return stream.writeError(.unexpected_message), + else => return s.writeError(.unexpected_message), } }, .alert => {}, .application_data => {}, - else => return stream.writeError(.unexpected_message), + else => return s.writeError(.unexpected_message), } } - return try stream.readv(buffers); + return try s.readv(buffers); } -pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { - if (self.stream.eof()) return 0; +/// Writes application_data message and flushes stream. +pub fn writev(self: *Client, iov: []const std.os.iovec_const) WriteError!usize { + if (self.tls_stream.eof()) return 0; - const res = try self.stream.writev(iov); - try self.stream.flush(); + const res = try self.tls_stream.writev(iov); + try self.tls_stream.flush(); return res; } -pub fn close(self: *Self) void { - self.stream.close(); +pub fn close(self: *Client) void { + self.tls_stream.close(); } -pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); +pub const GenericStream = std.io.GenericStream(*Client, ReadError, readv, WriteError, writev, close); -pub fn any(self: *Self) GenericStream { +pub fn stream(self: *Client) GenericStream { return .{ .context = self }; } -pub const Options = struct { - /// Trusted certificate authority bundle used to authenticate server certificates. - /// When null, server certificate and certificate_verify messages will be skipped. - ca_bundle: ?Certificate.Bundle, - /// Used to verify cerficate chain and for Server Name Indication. - host: []const u8, - /// List of cipher suites to advertise in order of descending preference. - cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, - /// By default, reaching the end-of-stream when reading from the server will - /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify - /// message has been received. By setting this flag to `true`, instead, the - /// end-of-stream will be forwarded to the application layer above TLS. - /// This makes the application vulnerable to truncation attacks unless the - /// application layer itself verifies that the amount of data received equals - /// the amount of data expected, such as HTTP with the Content-Length header. - allow_truncation_attacks: bool = false, - /// Certificate messages may be up to 2^24-1 bytes long. - /// Certificate verify messages may be up to 2^16-1 bytes long. - /// This is the allocator used for them. - allocator: std.mem.Allocator, - /// Writer to log shared secrets for traffic decryption. +pub const Handshake = struct { + tls_stream: tls.Stream, + /// Running hash of handshake messages for cryptographic functions + transcript_hash: tls.MultiHash = .{}, + options: Options, + + client_random: [32]u8, + /// Used in TLSv1.2. Always set for middlebox compatibility. + session_id: [32]u8, + /// One of these potential key pairs will be selected in `recv_hello` + /// to establish a shared secret for encryption. + key_pairs: KeyPairs, + + /// Next command to execute + command: Command = .send_hello, + + /// Server certificate to later verify. + /// `cert.certificate.buffer` is allocated + cert: Certificate.Parsed = undefined, + + pub const KeyPairs = struct { + secp256r1: Secp256r1, + secp384r1: Secp384r1, + x25519: X25519, + + const X25519 = tls.NamedGroupT(.x25519).KeyPair; + const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; + const Secp384r1 = tls.NamedGroupT(.secp384r1).KeyPair; + + pub fn init() @This() { + var random_buffer: [ + Secp256r1.seed_length + + Secp384r1.seed_length + + X25519.seed_length + ]u8 = undefined; + + while (true) { + crypto.random.bytes(&random_buffer); + + const split1 = Secp256r1.seed_length; + const split2 = split1 + Secp384r1.seed_length; + + return initAdvanced( + random_buffer[0..split1].*, + random_buffer[split1..split2].*, + random_buffer[split2..].*, + ) catch continue; + } + } + + pub fn initAdvanced( + secp256r1_seed: [Secp256r1.seed_length]u8, + secp384r1_seed: [Secp384r1.seed_length]u8, + x25519_seed: [X25519.seed_length]u8, + ) !@This() { + return .{ + .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + .secp384r1 = Secp384r1.create(secp384r1_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + .x25519 = X25519.create(x25519_seed) catch |err| switch (err) { + error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + }, + }; + } + }; + + /// A command to send or receive a single message. + pub const Command = enum { + send_hello, + recv_hello, + recv_encrypted_extensions, + recv_certificate_or_finished, + recv_certificate_verify, + recv_finished, + send_change_cipher_spec, + send_finished, + none, + }; + + /// Initializes members. Does NOT send any messages to `any_stream`. + pub fn init(any_stream: std.io.AnyStream, options: Options) Handshake { + const tls_stream = tls.Stream{ .stream = any_stream, .is_client = true }; + var res = Handshake{ + .tls_stream = tls_stream, + .options = options, + .random = undefined, + .session_id = undefined, + .key_pairs = undefined, + }; + res.init_random(); + + return res; + } + + inline fn init_random(self: *Handshake) void { + self.key_pairs = KeyPairs.init(); + crypto.random.bytes(&self.client_random); + crypto.random.bytes(&self.session_id); + } + + /// Establishes a TLS connection on `tls_stream` and returns a Client. + pub fn handshake(self: *Handshake) !Client { + while (self.command != .none) self.next() catch |err| switch (err) { + error.ConnectionResetByPeer => { + // Prevent reply attacks + self.command = .send_hello; + self.init_random(); + }, + else => return err, + }; + + return Client{ + .tls_stream = self.tls_stream, + .key_logger = .{ + .writer = self.options.key_log, + .client_random = self.client_random, + }, + }; + } + + /// Sends or receives exactly ONE handshake message on `tls_stream`. + /// Sets `self.command` to next expected message. + pub fn next(self: *Handshake) !void { + var s = &self.tls_stream; + s.transcript_hash = &self.transcript_hash; + + self.command = switch (self.command) { + .send_hello => brk: { + try self.send_hello(); + + break :brk .recv_hello; + }, + .recv_hello => brk: { + try s.expectInnerPlaintext(.handshake, .server_hello); + try self.recv_hello(); + + break :brk .recv_encrypted_extensions; + }, + .recv_encrypted_extensions => brk: { + try s.expectInnerPlaintext(.handshake, .encrypted_extensions); + try self.recv_encrypted_extensions(); + + break :brk .recv_certificate_or_finished; + }, + .recv_certificate_or_finished => brk: { + const digest = s.transcript_hash.?.peek(); + const inner_plaintext = try s.readInnerPlaintext(); + if (inner_plaintext.type != .handshake) return s.writeError(.unexpected_message); + switch (inner_plaintext.handshake_type) { + .certificate => { + self.cert = try self.recv_certificate(); + + break :brk .recv_certificate_verify; + }, + .finished => { + if (self.options.ca_bundle != null) + return self.tls_stream.writeError(.certificate_required); + + try self.recv_finished(digest); + + break :brk .send_finished; + }, + else => return self.tls_stream.writeError(.unexpected_message), + } + }, + .recv_certificate_verify => brk: { + defer self.options.allocator.free(self.cert.certificate.buffer); + + const digest = s.transcript_hash.?.peek(); + try s.expectInnerPlaintext(.handshake, .certificate_verify); + try self.recv_certificate_verify(digest); + + break :brk .recv_finished; + }, + .recv_finished => brk: { + const digest = s.transcript_hash.?.peek(); + try s.expectInnerPlaintext(.handshake, .finished); + try self.recv_finished(digest); + + break :brk .send_change_cipher_spec; + }, + .send_change_cipher_spec => brk: { + try s.changeCipherSpec(); + + break :brk .send_finished; + }, + .send_finished => brk: { + try self.send_finished(); + + break :brk .none; + }, + .none => .none, + }; + } + + pub fn send_hello(self: *Handshake) !void { + const hello = tls.ClientHello{ + .random = self.client_random, + .session_id = &self.session_id, + .cipher_suites = self.options.cipher_suites, + .extensions = &.{ + .{ .server_name = &[_]tls.ServerName{.{ .host_name = self.options.host }} }, + .{ .ec_point_formats = &[_]tls.EcPointFormat{.uncompressed} }, + .{ .supported_groups = &tls.supported_groups }, + .{ .signature_algorithms = &tls.supported_signature_schemes }, + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{ + .{ .secp256r1 = self.key_pairs.secp256r1.public_key }, + .{ .secp384r1 = self.key_pairs.secp384r1.public_key }, + .{ .x25519 = self.key_pairs.x25519.public_key }, + } }, + }, + }; + + _ = try self.tls_stream.write(tls.Handshake, .{ .client_hello = hello }); + try self.tls_stream.flush(); + } + + pub fn recv_hello(self: *Handshake) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + + // > The value of TLSPlaintext.legacy_record_version MUST be ignored by all implementations. + _ = try s.read(tls.Version); + var random: [32]u8 = undefined; + try r.readNoEof(&random); + if (mem.eql(u8, &random, &tls.ServerHello.hello_retry_request)) { + // We already offered all our supported options and we aren't changing them. + return s.writeError(.unexpected_message); + } + + var session_id_buf: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try s.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return s.writeError(.illegal_parameter); + const session_id: []u8 = session_id_buf[0..session_id_len]; + try r.readNoEof(session_id); + if (!mem.eql(u8, session_id, &self.session_id)) + return s.writeError(.illegal_parameter); + + const cipher_suite = try s.read(tls.CipherSuite); + const compression_method = try s.read(u8); + if (compression_method != 0) return s.writeError(.illegal_parameter); + + var supported_version: ?tls.Version = null; + var shared_key: ?[]const u8 = null; + + var iter = try s.extensions(); + while (try iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (supported_version != null) return s.writeError(.illegal_parameter); + supported_version = try s.read(tls.Version); + }, + .key_share => { + if (shared_key != null) return s.writeError(.illegal_parameter); + const named_group = try s.read(tls.NamedGroup); + const key_size = try s.read(u16); + switch (named_group) { + .x25519 => { + const T = tls.NamedGroupT(.x25519); + const expected_len = T.public_length; + if (key_size != expected_len) return s.writeError(.illegal_parameter); + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const mult = crypto.dh.X25519.scalarmult( + self.key_pairs.x25519.secret_key, + server_ks[0..expected_len].*, + ) catch return s.writeError(.illegal_parameter); + shared_key = &mult; + }, + inline .secp256r1, .secp384r1 => |t| { + const T = tls.NamedGroupT(t); + const expected_len = T.PublicKey.uncompressed_sec1_encoded_length; + if (key_size != expected_len) return s.writeError(.illegal_parameter); + + var server_ks: [expected_len]u8 = undefined; + try r.readNoEof(&server_ks); + + const pk = T.PublicKey.fromSec1(&server_ks) catch + return s.writeError(.illegal_parameter); + const key_pair = @field(self.key_pairs, @tagName(t)); + const mult = pk.p.mulPublic(key_pair.secret_key.bytes, .big) catch + return s.writeError(.illegal_parameter); + shared_key = &mult.affineCoordinates().x.toBytes(.big); + }, + // Server sent us back unknown key. That's weird because we only request known ones, + // but we can keep iterating for another. + else => { + try r.skipBytes(key_size, .{}); + }, + } + }, + else => { + try r.skipBytes(ext.len, .{}); + }, + } + } + + if (supported_version != tls.Version.tls_1_3) return s.writeError(.protocol_version); + if (shared_key == null) return s.writeError(.missing_extension); + + s.transcript_hash.?.setActive(cipher_suite); + const hello_hash = s.transcript_hash.?.peek(); + + const handshake_cipher = tls.HandshakeCipher.init( + cipher_suite, + shared_key.?, + hello_hash, + self.logger(), + ) catch return s.writeError(.illegal_parameter); + s.cipher = .{ .handshake = handshake_cipher }; + } + + pub fn recv_encrypted_extensions(self: *Handshake) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + + var iter = try s.extensions(); + while (try iter.next()) |ext| { + try r.skipBytes(ext.len, .{}); + } + } + + /// Verifies trust chain if `options.ca_bundle` is specified. /// - /// See https://www.ietf.org/archive/id/draft-thomson-tls-keylogfile-01.html - key_log: std.io.AnyWriter = std.io.null_writer.any(), -}; + /// Caller owns allocated Certificate.Parsed.certificate. + pub fn recv_certificate(self: *Handshake) !Certificate.Parsed { + var s = &self.tls_stream; + var r = s.stream().reader(); + const allocator = self.options.allocator; + const ca_bundle = self.options.ca_bundle; + const verify = ca_bundle != null; + + var context: [tls.Certificate.max_context_len]u8 = undefined; + const context_len = try s.read(u8); + if (context_len > tls.Certificate.max_context_len) return s.writeError(.decode_error); + try r.readNoEof(context[0..context_len]); + + var first: ?crypto.Certificate.Parsed = null; + errdefer if (first) |f| allocator.free(f.certificate.buffer); + var prev: Certificate.Parsed = undefined; + var verified = false; + const now_sec = std.time.timestamp(); + + var certs_iter = try s.iterator(u24, u24); + while (try certs_iter.next()) |cert_len| { + const is_first = first == null; + + if (verified) { + try r.skipBytes(cert_len, .{}); + } else { + if (cert_len > tls.Certificate.Entry.max_data_len) + return s.writeError(.decode_error); + const buf = allocator.alloc(u8, cert_len) catch + return s.writeError(.internal_error); + defer if (!is_first) allocator.free(buf); + errdefer allocator.free(buf); + try r.readNoEof(buf); + + const cert = crypto.Certificate{ .buffer = buf, .index = 0 }; + const cur = cert.parse() catch return s.writeError(.bad_certificate); + if (first == null) { + if (verify) try cur.verifyHostName(self.options.host); + first = cur; + } else { + if (verify) try prev.verify(cur, now_sec); + } + + if (ca_bundle) |b| { + if (b.verify(cur, now_sec)) |_| { + verified = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + error.CertificateExpired => return s.writeError(.certificate_expired), + else => return s.writeError(.bad_certificate), + } + } -/// One of these potential key pairs will be selected during the handshake. -pub const KeyPairs = struct { - session_id: [session_id_length]u8, - secp256r1: Secp256r1, - secp384r1: Secp384r1, - x25519: X25519, - - const session_id_length = 32; - const X25519 = tls.NamedGroupT(.x25519).KeyPair; - const Secp256r1 = tls.NamedGroupT(.secp256r1).KeyPair; - const Secp384r1 = tls.NamedGroupT(.secp384r1).KeyPair; - - pub fn init() @This() { - var random_buffer: [ - session_id_length + - Secp256r1.seed_length + - Secp384r1.seed_length + - X25519.seed_length - ]u8 = undefined; - - while (true) { - crypto.random.bytes(&random_buffer); - - const split1 = session_id_length; - const split2 = split1 + Secp256r1.seed_length; - const split3 = split2 + Secp384r1.seed_length; - - return initAdvanced( - random_buffer[0..split1].*, - random_buffer[split1..split2].*, - random_buffer[split2..split3].*, - random_buffer[split3..].*, - ) catch continue; + prev = cur; + } + + var ext_iter = try s.extensions(); + while (try ext_iter.next()) |ext| try r.skipBytes(ext.len, .{}); } + if (verify and !verified) return s.writeError(.bad_certificate); + + return if (first) |f| f else s.writeError(.bad_certificate); } - pub fn initAdvanced( - session_id: [session_id_length]u8, - secp256r1_seed: [Secp256r1.seed_length]u8, - secp384r1_seed: [Secp384r1.seed_length]u8, - x25519_seed: [X25519.seed_length]u8, - ) !@This() { - return .{ - .session_id = session_id, - .secp256r1 = Secp256r1.create(secp256r1_seed) catch |err| switch (err) { - error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + pub fn recv_certificate_verify(self: *Handshake, digest: []const u8) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + const allocator = self.options.allocator; + const cert = self.cert; + + const sig_content = tls.sigContent(digest); + + const scheme = try s.read(tls.SignatureScheme); + const len = try s.read(u16); + if (len > tls.CertificateVerify.max_signature_length) + return s.writeError(.decode_error); + const sig_bytes = allocator.alloc(u8, len) catch + return s.writeError(.internal_error); + defer allocator.free(sig_bytes); + try r.readNoEof(sig_bytes); + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (cert.pub_key_algo != .X9_62_id_ecPublicKey) + return s.writeError(.bad_certificate); + const Ecdsa = comptime_scheme.Ecdsa(); + const sig = Ecdsa.Signature.fromDer(sig_bytes) catch + return s.writeError(.decode_error); + const key = Ecdsa.PublicKey.fromSec1(cert.pubKey()) catch + return s.writeError(.decode_error); + sig.verify(sig_content, key) catch return s.writeError(.bad_certificate); }, - .secp384r1 = Secp384r1.create(secp384r1_seed) catch |err| switch (err) { - error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (cert.pub_key_algo != .rsaEncryption) + return s.writeError(.bad_certificate); + + const Hash = comptime_scheme.Hash(); + const rsa = Certificate.rsa; + const key = rsa.PublicKey.fromDer(cert.pubKey()) catch + return s.writeError(.bad_certificate); + switch (key.n.bits() / 8) { + inline 128, 256, 512 => |modulus_len| { + const sig = rsa.PSSSignature.fromBytes(modulus_len, sig_bytes); + rsa.PSSSignature.verify(modulus_len, sig, sig_content, key, Hash) catch + return s.writeError(.decode_error); + }, + else => { + return s.writeError(.bad_certificate); + }, + } }, - .x25519 = X25519.create(x25519_seed) catch |err| switch (err) { - error.IdentityElement => return error.InsufficientEntropy, // Private key is all zeroes. + inline .ed25519 => |comptime_scheme| { + if (cert.pub_key_algo != .curveEd25519) + return s.writeError(.bad_certificate); + const Eddsa = comptime_scheme.Eddsa(); + if (sig_content.len != Eddsa.Signature.encoded_length) + return s.writeError(.decode_error); + const sig = Eddsa.Signature.fromBytes(sig_bytes[0..Eddsa.Signature.encoded_length].*); + if (cert.pubKey().len != Eddsa.PublicKey.encoded_length) + return s.writeError(.decode_error); + const key = Eddsa.PublicKey.fromBytes(cert.pubKey()[0..Eddsa.PublicKey.encoded_length].*) catch + return s.writeError(.bad_certificate); + sig.verify(sig_content, key) catch return s.writeError(.bad_certificate); }, + else => { + return s.writeError(.bad_certificate); + }, + } + } + + pub fn recv_finished(self: *Handshake, digest: []const u8) !void { + var s = &self.tls_stream; + var r = s.stream().reader(); + const cipher = s.cipher.handshake; + + switch (cipher) { + inline else => |p| { + const P = @TypeOf(p); + const expected = &tls.hmac(P.Hmac, digest, p.server_finished_key); + + var actual: [expected.len]u8 = undefined; + try r.readNoEof(&actual); + if (!mem.eql(u8, expected, &actual)) return s.writeError(.decode_error); + }, + } + } + + pub fn send_finished(self: *Handshake) !void { + var s = &self.tls_stream; + + const handshake_hash = s.transcript_hash.?.peek(); + + const verify_data = switch (s.cipher.handshake) { + inline .aes_128_gcm_sha256, + .aes_256_gcm_sha384, + .chacha20_poly1305_sha256, + .aegis_256_sha512, + .aegis_128l_sha256, + => |v| brk: { + const T = @TypeOf(v); + const secret = v.client_finished_key; + const transcript_hash = s.transcript_hash.?.peek(); + + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); + }, + else => return s.writeError(.decrypt_error), }; + s.content_type = .handshake; + _ = try s.write(tls.Handshake, .{ .finished = verify_data }); + try s.flush(); + + const application_cipher = tls.ApplicationCipher.init( + s.cipher.handshake, + handshake_hash, + self.logger(), + ); + s.cipher = .{ .application = application_cipher }; + s.content_type = .application_data; + s.transcript_hash = null; } -}; -/// A command to send or receive a single message. Allows deterministically -/// testing `advance` on a single thread. -pub const Command = union(enum) { - send_hello: KeyPairs, - recv_hello: KeyPairs, - recv_encrypted_extensions: void, - recv_certificate_or_finished: void, - recv_certificate_verify: Certificate.Parsed, - recv_finished: void, - send_change_cipher_spec: void, - send_finished: void, - none: void, + fn logger(self: *Handshake) tls.KeyLogger { + return tls.KeyLogger{ + .client_random = self.client_random, + .writer = self.options.key_log, + }; + } }; diff --git a/lib/std/crypto/tls/Server.zig b/lib/std/crypto/tls/Server.zig index 537a961a3df0..4935e368071b 100644 --- a/lib/std/crypto/tls/Server.zig +++ b/lib/std/crypto/tls/Server.zig @@ -8,505 +8,552 @@ const assert = std.debug.assert; const Certificate = std.crypto.Certificate; const Allocator = std.mem.Allocator; -stream: tls.Stream, -options: Options, +tls_stream: tls.Stream, +key_logger: tls.KeyLogger, -const Self = @This(); +pub const Options = struct { + /// List of potential cipher suites in descending order of preference. + cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, + /// Types of shared keys to accept from client. + key_shares: []const tls.NamedGroup = &tls.supported_groups, + /// Certificate(s) to send in `send_certificate` messages. + /// The first entry will be used for verification. + certificate: tls.Certificate = .{}, + /// Secret key corresponding to `certificate.entries[0]` to use for certificate verification. + certificate_key: CertificateKey = .none, + /// Writer to log shared secrets for traffic decryption in SSLKEYLOGFILE format. + key_log: std.io.AnyWriter = std.io.null_writer.any(), -/// Initiates a TLS handshake and establishes a TLSv1.3 session -pub fn init(stream: std.io.AnyStream, options: Options) !Self { - var transcript_hash: tls.MultiHash = .{}; - const stream_ = tls.Stream{ - .stream = stream, - .is_client = false, - .transcript_hash = &transcript_hash, - }; - var res = Self{ .stream = stream_, .options = options }; - - // Verify that the certificate key matches the certificate. - const cert_buf = Certificate{ .buffer = options.certificate.entries[0].data, .index = 0 }; - // TODO: don't reparse cert in send_certificate_verify - const cert = try cert_buf.parse(); - const expected: std.meta.Tag(Options.CertificateKey) = switch (cert.pub_key_algo) { - .rsaEncryption => .rsa, - .X9_62_id_ecPublicKey => |curve| switch (curve) { - .X9_62_prime256v1 => .ecdsa256, - .secp384r1 => .ecdsa384, - else => return error.UnsupportedCertificateSignature, - }, - .curveEd25519 => .ed25519, + pub const CertificateKey = union(enum) { + none: void, + rsa: crypto.Certificate.rsa.SecretKey, + ecdsa256: tls.NamedGroupT(.secp256r1).SecretKey, + ecdsa384: tls.NamedGroupT(.secp384r1).SecretKey, + ed25519: crypto.sign.Ed25519.SecretKey, }; - if (expected != options.certificate_key) return error.CertificateKeyMismatch; - // TODO: verify private key corresponds to public key - - var command = initial_command(); - while (command != .none) { - command = res.next(command) catch |err| switch (err) { - // Prevent replay attacks in later handshake stages. - error.ConnectionResetByPeer => initial_command(), - else => return err, - }; - } +}; - return res; +const Server = @This(); + +/// Initiates a TLS handshake and establishes a TLSv1.3 session +pub fn init(any_stream: std.io.AnyStream, options: Options) !Server { + const hs = Handshake.init(any_stream, options); + return try hs.hanshake(); } -inline fn initial_command() Command { - var res = Command{ .recv_hello = undefined }; - crypto.random.bytes(&res.recv_hello.server_random); - crypto.random.bytes(&res.recv_hello.keygen_seed); +pub const ReadError = anyerror; +pub const WriteError = anyerror; + +/// Reads next application_data message. +pub fn readv(self: *Server, buffers: []const std.os.iovec) ReadError!usize { + var s = &self.tls_stream; + + if (s.eof()) return 0; + + while (s.view.len == 0) { + const inner_plaintext = try s.readInnerPlaintext(); + switch (inner_plaintext.type) { + .application_data => {}, + .alert => {}, + else => return s.writeError(.unexpected_message), + } + } + return try self.tls_stream.readv(buffers); +} +pub fn writev(self: *Server, iov: []const std.os.iovec_const) WriteError!usize { + if (self.tls_stream.eof()) return 0; + + const res = try self.tls_stream.writev(iov); + try self.tls_stream.flush(); return res; } -/// Executes handshake command and returns next one. -pub fn next(self: *Self, command: Command) !Command { - var stream = &self.stream; - - switch (command) { - .recv_hello => |random| { - const client_hello = try self.recv_hello(random); - - return .{ .send_hello = client_hello }; - }, - .send_hello => |client_hello| { - try self.send_hello(client_hello); - - const scheme = client_hello.sig_scheme; - // > if the client sends a non-empty session ID, - // > the server MUST send the change_cipher_spec - if (client_hello.session_id_len > 0) return .{ .send_change_cipher_spec = scheme }; - - return .{ .send_encrypted_extensions = scheme }; - }, - .send_change_cipher_spec => |scheme| { - try stream.changeCipherSpec(); - - return .{ .send_encrypted_extensions = scheme }; - }, - .send_encrypted_extensions => |scheme| { - try self.send_encrypted_extensions(); - - return .{ .send_certificate = scheme }; - }, - .send_certificate => |scheme| { - try self.send_certificate(); - - var cert_verify = Command.CertificateVerify{ .scheme = scheme, .salt = undefined }; - crypto.random.bytes(&cert_verify.salt); - - return .{ .send_certificate_verify = cert_verify }; - }, - .send_certificate_verify => |cert_verify| { - try self.send_certificate_verify(cert_verify); - return .{ .send_finished = {} }; - }, - .send_finished => { - try self.send_finished(); - return .{ .recv_finished = {} }; - }, - .recv_finished => { - try self.recv_finished(); - return .{ .none = {} }; - }, - .none => return .{ .none = {} }, - } +pub fn close(self: *Server) void { + self.tls_stream.close(); } -pub fn recv_hello(self: *Self, random: Command.Random) !ClientHello { - var stream = &self.stream; - var reader = stream.any().reader(); +pub const GenericStream = std.io.GenericStream(*Server, ReadError, readv, WriteError, writev, close); - try stream.expectInnerPlaintext(.handshake, .client_hello); +pub fn stream(self: *Server) GenericStream { + return .{ .context = self }; +} - _ = try stream.read(tls.Version); - var client_random: [32]u8 = undefined; - try reader.readNoEof(&client_random); +pub const Handshake = struct { + tls_stream: tls.Stream, + /// Running hash of handshake messages for cryptographic functions + transcript_hash: tls.MultiHash = .{}, + options: Options, + cert: ?Certificate.Parsed = null, - var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; - const session_id_len = try stream.read(u8); - if (session_id_len > tls.ClientHello.session_id_max_len) - return stream.writeError(.illegal_parameter); - try reader.readNoEof(session_id[0..session_id_len]); + server_random: [32]u8, + keygen_seed: [tls.NamedGroupT(.secp384r1).KeyPair.seed_length]u8, + certificate_verify_salt: [tls.MultiHash.max_digest_len]u8, + + /// Next command to execute + command: Command = .recv_hello, + + /// Defined after `recv_hello` + client_hello: ClientHello = undefined, + /// Used to establish a shared secret. Defined after `recv_hello` + key_pair: tls.KeyPair = undefined, + + pub const ClientHello = struct { + random: [32]u8, + session_id_len: u8, + session_id: [32]u8, + cipher_suite: tls.CipherSuite, + key_share: tls.KeyShare, + sig_scheme: tls.SignatureScheme, + }; - const cipher_suite: tls.CipherSuite = brk: { - var cipher_suite_iter = try stream.iterator(u16, tls.CipherSuite); - var res: ?tls.CipherSuite = null; - while (try cipher_suite_iter.next()) |suite| { - for (self.options.cipher_suites) |s| { - if (s == suite and res == null) res = s; - } - } - if (res == null) return stream.writeError(.illegal_parameter); - break :brk res.?; + /// A command to send or receive a single message. + pub const Command = enum { + recv_hello, + send_hello, + send_change_cipher_spec, + send_encrypted_extensions, + send_certificate, + send_certificate_verify, + send_finished, + recv_finished, + none, }; - stream.transcript_hash.?.setActive(cipher_suite); - { - var compression_methods: [2]u8 = undefined; - try reader.readNoEof(&compression_methods); - if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) - return stream.writeError(.illegal_parameter); + /// Initializes members. Does NOT send any messages to `any_stream`. + pub fn init(any_stream: std.io.AnyStream, options: Options) Handshake { + const tls_stream = tls.Stream{ .stream = any_stream, .is_client = false }; + var res = Handshake{ + .tls_stream = tls_stream, + .options = options, + .random = undefined, + .session_id = undefined, + .key_pairs = undefined, + }; + res.init_random(); + + const certificates = options.certificate.entries; + // Verify that the certificate key matches the root certificate. + // This allows failing fast now instead of every client failing signature verification later. + if (certificates.len > 0) { + const cert_buf = Certificate{ .buffer = options.certificate.entries[0].data, .index = 0 }; + res.cert = try cert_buf.parse(); + const expected: std.meta.Tag(Options.CertificateKey) = switch (res.cert.pub_key_algo) { + .rsaEncryption => .rsa, + .X9_62_id_ecPublicKey => |curve| switch (curve) { + .X9_62_prime256v1 => .ecdsa256, + .secp384r1 => .ecdsa384, + else => return error.UnsupportedCertificateSignature, + }, + .curveEd25519 => .ed25519, + }; + if (expected != options.certificate_key) return error.CertificateKeyMismatch; + + // TODO: test private key matches cert public key + // const test_msg = "hello"; + // switch (options.certificate_key) { + // .rsa => |key| { + // switch (key.n.bits() / 8) { + // inline 128, 256, 512 => |modulus_len| { + // const enc = key.public.encrypt(modulus_len, test_msg); + // const dec = key.decrypt(modulus_len, enc) catch return error.CertificateKeyMismatch; + // if (!std.mem.eql(u8, test_msg, dec)) return error.CertificateKeyMismatch; + // }, + // else => return error.CertificateKeyMismatch, + // } + // }, + // inline .ecdsa256, .ecdsa384 => |comptime_scheme| { + // }, + // .ed25519: crypto.sign.Ed25519.SecretKey, + // } + } + + return res; } - var tls_version: ?tls.Version = null; - var key_share: ?tls.KeyShare = null; - var ec_point_format: ?tls.EcPointFormat = null; - var sig_scheme: ?tls.SignatureScheme = null; - - var extension_iter = try stream.extensions(); - while (try extension_iter.next()) |ext| { - switch (ext.type) { - .supported_versions => { - if (tls_version != null) return stream.writeError(.illegal_parameter); - var versions_iter = try stream.iterator(u8, tls.Version); - while (try versions_iter.next()) |v| { - if (v == .tls_1_3) tls_version = v; - } + inline fn init_random(self: *Handshake) void { + crypto.random.bytes(&self.server_random); + crypto.random.bytes(&self.keygen_seed); + crypto.random.bytes(&self.certificate_verify_salt); + } + + /// Executes handshake command and returns next one. + pub fn next(self: *Handshake) !void { + var s = &self.tls_stream; + s.transcript_hash = &self.transcript_hash; + + self.command = switch (self.command) { + .recv_hello => brk: { + self.client_hello = try self.recv_hello(); + + break: brk .send_hello; }, - // TODO: use supported_groups instead - .key_share => { - if (key_share != null) return stream.writeError(.illegal_parameter); - - var key_share_iter = try stream.iterator(u16, tls.KeyShare); - while (try key_share_iter.next()) |ks| { - for (self.options.key_shares) |s| { - if (ks == s and key_share == null) key_share = ks; - } - } - if (key_share == null) return stream.writeError(.decode_error); + .send_hello => brk: { + try self.send_hello(); + + // > if the client sends a non-empty session ID, + // > the server MUST send the change_cipher_spec + if (self.client_hello.session_id_len > 0) break :brk .send_change_cipher_spec; + + break :brk .send_encrypted_extensions; }, - .ec_point_formats => { - var format_iter = try stream.iterator(u8, tls.EcPointFormat); - while (try format_iter.next()) |f| { - if (f == .uncompressed) ec_point_format = .uncompressed; - } - if (ec_point_format == null) return stream.writeError(.decode_error); + .send_change_cipher_spec => brk: { + try s.changeCipherSpec(); + + break :brk .send_encrypted_extensions; }, - .signature_algorithms => { - const acceptable = switch (self.options.certificate_key) { - .rsa => &[_]tls.SignatureScheme{ - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha256, - }, - .ecdsa256 => &[_]tls.SignatureScheme{.ecdsa_secp256r1_sha256}, - .ecdsa384 => &[_]tls.SignatureScheme{.ecdsa_secp384r1_sha384}, - .ed25519 => &[_]tls.SignatureScheme{.ed25519}, - }; - var algos_iter = try stream.iterator(u16, tls.SignatureScheme); - while (try algos_iter.next()) |algo| { - for (acceptable) |a| { - if (algo == a and sig_scheme == null) sig_scheme = algo; - } - } - if (sig_scheme == null) return stream.writeError(.decode_error); + .send_encrypted_extensions => brk: { + try self.send_encrypted_extensions(); + + break :brk .send_certificate; }, - else => { - try reader.skipBytes(ext.len, .{}); + .send_certificate => brk: { + try self.send_certificate(); + + break :brk .send_certificate_verify; }, - } + .send_certificate_verify => brk: { + try self.send_certificate_verify(); + break :brk .send_finished; + }, + .send_finished => brk: { + try self.send_finished(); + break :brk .recv_finished; + }, + .recv_finished => brk: { + try self.recv_finished(); + break :brk .none; + }, + .none => .none, + }; } - if (tls_version != .tls_1_3) return stream.writeError(.protocol_version); - if (key_share == null) return stream.writeError(.missing_extension); - if (ec_point_format == null) return stream.writeError(.missing_extension); - if (sig_scheme == null) return stream.writeError(.missing_extension); - - const key_pair = switch (key_share.?) { - inline .secp256r1, - .secp384r1, - .x25519, - => |_, tag| brk: { - const T = tls.NamedGroupT(tag).KeyPair; - const pair = T.create(random.keygen_seed[0..T.seed_length].*) catch unreachable; - break :brk @unionInit(tls.KeyPair, @tagName(tag), pair); - }, - else => return stream.writeError(.decode_error), - }; + pub fn recv_hello(self: *Handshake) !ClientHello { + var s = &self.tls_stream; + var reader = s.stream().reader(); - return .{ - .random = client_random, - .session_id_len = session_id_len, - .session_id = session_id, - .cipher_suite = cipher_suite, - .key_share = key_share.?, - .sig_scheme = sig_scheme.?, - .server_random = random.server_random, - .server_pair = key_pair, - }; -} + try s.expectInnerPlaintext(.handshake, .client_hello); -pub fn send_hello(self: *Self, client_hello: ClientHello) !void { - var stream = &self.stream; - const key_pair = client_hello.server_pair; - - const hello = tls.ServerHello{ - .random = client_hello.server_random, - .session_id = &client_hello.session_id, - .cipher_suite = client_hello.cipher_suite, - .extensions = &.{ - .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, - .{ .key_share = &[_]tls.KeyShare{key_pair.toKeyShare()} }, - }, - }; - stream.version = .tls_1_2; - _ = try stream.write(tls.Handshake, .{ .server_hello = hello }); - try stream.flush(); - - const shared_key = switch (client_hello.key_share) { - .x25519 => |ks| brk: { - const shared_point = tls.NamedGroupT(.x25519).scalarmult( - key_pair.x25519.secret_key, - ks, - ) catch return stream.writeError(.decrypt_error); - break :brk &shared_point; - }, - inline .secp256r1, .secp384r1 => |ks, tag| brk: { - const key = @field(key_pair, @tagName(tag)); - const mul = ks.p.mulPublic(key.secret_key.bytes, .big) catch - return stream.writeError(.decrypt_error); - break :brk &mul.affineCoordinates().x.toBytes(.big); - }, - else => return stream.writeError(.illegal_parameter), - }; + _ = try s.read(tls.Version); + var client_random: [32]u8 = undefined; + try reader.readNoEof(&client_random); - const hello_hash = stream.transcript_hash.?.peek(); - const handshake_cipher = tls.HandshakeCipher.init( - client_hello.cipher_suite, - shared_key, - hello_hash, - self.options.key_log, - &client_hello.random, - ) catch - return stream.writeError(.illegal_parameter); - stream.cipher = .{ .handshake = handshake_cipher }; -} + var session_id: [tls.ClientHello.session_id_max_len]u8 = undefined; + const session_id_len = try s.read(u8); + if (session_id_len > tls.ClientHello.session_id_max_len) + return s.writeError(.illegal_parameter); + try reader.readNoEof(session_id[0..session_id_len]); -pub fn send_encrypted_extensions(self: *Self) !void { - var stream = &self.stream; - _ = try stream.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); - try stream.flush(); -} - -pub fn send_certificate(self: *Self) !void { - var stream = &self.stream; - _ = try self.stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); - try stream.flush(); -} - -pub fn send_certificate_verify(self: *Self, verify: Command.CertificateVerify) !void { - var stream = &self.stream; - - const digest = stream.transcript_hash.?.peek(); - const sig_content = tls.sigContent(digest); + const cipher_suite: tls.CipherSuite = brk: { + var cipher_suite_iter = try s.iterator(u16, tls.CipherSuite); + var res: ?tls.CipherSuite = null; + while (try cipher_suite_iter.next()) |suite| { + for (self.options.cipher_suites) |cs| { + if (cs == suite and res == null) res = cs; + } + } + if (res == null) return s.writeError(.illegal_parameter); + break :brk res.?; + }; + s.transcript_hash.?.setActive(cipher_suite); - const signature: []const u8 = switch (verify.scheme) { - inline .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384 => |comptime_scheme| brk: { - const Ecdsa = comptime_scheme.Ecdsa(); - const key = switch (comptime_scheme) { - .ecdsa_secp256r1_sha256 => self.options.certificate_key.ecdsa256, - .ecdsa_secp384r1_sha384 => self.options.certificate_key.ecdsa384, - else => unreachable, - }; + { + var compression_methods: [2]u8 = undefined; + try reader.readNoEof(&compression_methods); + if (!std.mem.eql(u8, &compression_methods, &[_]u8{ 1, 0 })) + return s.writeError(.illegal_parameter); + } - var signer = Ecdsa.Signer.init(key, verify.salt[0..Ecdsa.noise_length].*); - signer.update(sig_content); - const sig = signer.finalize() catch return stream.writeError(.internal_error); - break :brk &sig.toBytes(); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| brk: { - const Hash = comptime_scheme.Hash(); - const key = self.options.certificate_key.rsa; - - switch (key.public.n.bits() / 8) { - inline 128, 256, 512 => |modulus_length| { - const sig = Certificate.rsa.PSSSignature.sign( - modulus_length, - sig_content, - Hash, - key, - verify.salt[0..Hash.digest_length].*, - ) catch return stream.writeError(.bad_certificate); - break :brk &sig; + var tls_version: ?tls.Version = null; + var key_share: ?tls.KeyShare = null; + var ec_point_format: ?tls.EcPointFormat = null; + var sig_scheme: ?tls.SignatureScheme = null; + + var extension_iter = try s.extensions(); + while (try extension_iter.next()) |ext| { + switch (ext.type) { + .supported_versions => { + if (tls_version != null) return s.writeError(.illegal_parameter); + var versions_iter = try s.iterator(u8, tls.Version); + while (try versions_iter.next()) |v| { + if (v == .tls_1_3) tls_version = v; + } + }, + // TODO: use supported_groups instead + .key_share => { + if (key_share != null) return s.writeError(.illegal_parameter); + + var key_share_iter = try s.iterator(u16, tls.KeyShare); + while (try key_share_iter.next()) |ks| { + for (self.options.key_shares) |k| { + if (ks == k and key_share == null) key_share = ks; + } + } + if (key_share == null) return s.writeError(.decode_error); + }, + .ec_point_formats => { + var format_iter = try s.iterator(u8, tls.EcPointFormat); + while (try format_iter.next()) |f| { + if (f == .uncompressed) ec_point_format = .uncompressed; + } + if (ec_point_format == null) return s.writeError(.decode_error); + }, + .signature_algorithms => { + const acceptable = switch (self.options.certificate_key) { + .none => &[_]tls.SignatureScheme{}, // should be all of them + .rsa => &[_]tls.SignatureScheme{ + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha256, + }, + .ecdsa256 => &[_]tls.SignatureScheme{.ecdsa_secp256r1_sha256}, + .ecdsa384 => &[_]tls.SignatureScheme{.ecdsa_secp384r1_sha384}, + .ed25519 => &[_]tls.SignatureScheme{.ed25519}, + }; + var algos_iter = try s.iterator(u16, tls.SignatureScheme); + while (try algos_iter.next()) |algo| { + if (self.options.certificate_key == .none) sig_scheme = algo; + for (acceptable) |a| { + if (algo == a and sig_scheme == null) sig_scheme = algo; + } + } + if (sig_scheme == null) return s.writeError(.decode_error); + }, + else => { + try reader.skipBytes(ext.len, .{}); }, - else => return stream.writeError(.bad_certificate), } - }, - .ed25519 => brk: { - const Ed25519 = crypto.sign.Ed25519; - const key = self.options.certificate_key.ed25519; - - const pub_key = brk2: { - const cert_buf = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; - const cert = try cert_buf.parse(); - const expected_len = Ed25519.PublicKey.encoded_length; - if (cert.pubKey().len != expected_len) return stream.writeError(.bad_certificate); - break :brk2 Ed25519.PublicKey.fromBytes(cert.pubKey()[0..expected_len].*) catch - return stream.writeError(.bad_certificate); - }; - const nonce: Ed25519.CompressedScalar = verify.salt[0..Ed25519.noise_length].*; - - const key_pair = Ed25519.KeyPair{ .public_key = pub_key, .secret_key = key }; - const sig = key_pair.sign(sig_content, nonce) catch return stream.writeError(.internal_error); - break :brk &sig.toBytes(); - }, - else => { - return stream.writeError(.bad_certificate); - }, - }; + } - _ = try self.stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ - .algorithm = verify.scheme, - .signature = signature, - } }); - try stream.flush(); -} + if (tls_version != .tls_1_3) return s.writeError(.protocol_version); + if (key_share == null) return s.writeError(.missing_extension); + if (ec_point_format == null) return s.writeError(.missing_extension); + if (sig_scheme == null) return s.writeError(.missing_extension); + + self.key_pair = switch (key_share.?) { + inline .secp256r1, + .secp384r1, + .x25519, + => |_, tag| brk: { + const T = tls.NamedGroupT(tag).KeyPair; + const pair = T.create(self.keygen_seed[0..T.seed_length].*) catch unreachable; + break :brk @unionInit(tls.KeyPair, @tagName(tag), pair); + }, + else => return s.writeError(.decode_error), + }; -pub fn send_finished(self: *Self) !void { - var stream = &self.stream; - const verify_data = switch (stream.cipher.handshake) { - inline else => |v| brk: { - const T = @TypeOf(v); - const secret = v.server_finished_key; - const transcript_hash = stream.transcript_hash.?.peek(); + return .{ + .random = client_random, + .session_id_len = session_id_len, + .session_id = session_id, + .cipher_suite = cipher_suite, + .key_share = key_share.?, + .sig_scheme = sig_scheme.?, + }; + } - break :brk &tls.hmac(T.Hmac, transcript_hash, secret); - }, - }; - _ = try stream.write(tls.Handshake, .{ .finished = verify_data }); - try stream.flush(); -} + pub fn send_hello(self: *Handshake) !void { + var s = &self.tls_stream; + const key_pair = self.key_pair; + const client_hello = self.client_hello; + + const hello = tls.ServerHello{ + .random = self.server_random, + .session_id = &client_hello.session_id, + .cipher_suite = client_hello.cipher_suite, + .extensions = &.{ + .{ .supported_versions = &[_]tls.Version{.tls_1_3} }, + .{ .key_share = &[_]tls.KeyShare{key_pair.toKeyShare()} }, + }, + }; + s.version = .tls_1_2; + _ = try s.write(tls.Handshake, .{ .server_hello = hello }); + try s.flush(); + + const shared_key = switch (client_hello.key_share) { + .x25519 => |ks| brk: { + const shared_point = tls.NamedGroupT(.x25519).scalarmult( + key_pair.x25519.secret_key, + ks, + ) catch return s.writeError(.decrypt_error); + break :brk &shared_point; + }, + inline .secp256r1, .secp384r1 => |ks, tag| brk: { + const key = @field(key_pair, @tagName(tag)); + const mul = ks.p.mulPublic(key.secret_key.bytes, .big) catch + return s.writeError(.decrypt_error); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + else => return s.writeError(.illegal_parameter), + }; -pub fn recv_finished(self: *Self) !void { - var stream = &self.stream; - var reader = stream.any().reader(); - - const handshake_hash = stream.transcript_hash.?.peek(); - - const application_cipher = tls.ApplicationCipher.init( - stream.cipher.handshake, - handshake_hash, - self.options.key_log, - "idk", - ); - - const expected = switch (stream.cipher.handshake) { - inline else => |p| brk: { - const P = @TypeOf(p); - const digest = stream.transcript_hash.?.peek(); - break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); - }, - }; + const hello_hash = s.transcript_hash.?.peek(); + const handshake_cipher = tls.HandshakeCipher.init( + client_hello.cipher_suite, + shared_key, + hello_hash, + self.logger(), + ) catch + return s.writeError(.illegal_parameter); + s.cipher = .{ .handshake = handshake_cipher }; + } - try stream.expectInnerPlaintext(.handshake, .finished); - const actual = stream.view; - try reader.skipBytes(stream.view.len, .{}); + pub fn send_encrypted_extensions(self: *Handshake) !void { + var s = &self.tls_stream; + _ = try s.write(tls.Handshake, .{ .encrypted_extensions = &.{} }); + try s.flush(); + } - if (!mem.eql(u8, expected, actual)) return stream.writeError(.decode_error); + pub fn send_certificate(self: *Handshake) !void { + var s = &self.tls_stream; + _ = try self.tls_stream.write(tls.Handshake, .{ .certificate = self.options.certificate }); + try s.flush(); + } - stream.content_type = .application_data; - stream.handshake_type = null; - stream.cipher = .{ .application = application_cipher }; - stream.transcript_hash = null; -} + pub fn send_certificate_verify(self: *Handshake) !void { + var s = &self.tls_stream; + const salt = self.certificate_verify_salt; + const scheme = self.client_hello.sig_scheme; + + const digest = s.transcript_hash.?.peek(); + const sig_content = tls.sigContent(digest); + + const signature: []const u8 = switch (scheme) { + inline .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384 => |comptime_scheme| brk: { + const Ecdsa = comptime_scheme.Ecdsa(); + const key = switch (comptime_scheme) { + .ecdsa_secp256r1_sha256 => self.options.certificate_key.ecdsa256, + .ecdsa_secp384r1_sha384 => self.options.certificate_key.ecdsa384, + else => unreachable, + }; -pub const ReadError = anyerror; -pub const WriteError = anyerror; + var signer = Ecdsa.Signer.init(key, salt[0..Ecdsa.noise_length].*); + signer.update(sig_content); + const sig = signer.finalize() catch return s.writeError(.internal_error); + break :brk &sig.toBytes(); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| brk: { + const Hash = comptime_scheme.Hash(); + const key = self.options.certificate_key.rsa; + + switch (key.public.n.bits() / 8) { + inline 128, 256, 512 => |modulus_length| { + const sig = Certificate.rsa.PSSSignature.sign( + modulus_length, + sig_content, + Hash, + key, + salt[0..Hash.digest_length].*, + ) catch return s.writeError(.bad_certificate); + break :brk &sig; + }, + else => return s.writeError(.bad_certificate), + } + }, + .ed25519 => brk: { + const Ed25519 = crypto.sign.Ed25519; + const key = self.options.certificate_key.ed25519; + + const pub_key = brk2: { + const cert_buf = Certificate{ .buffer = self.options.certificate.entries[0].data, .index = 0 }; + const cert = try cert_buf.parse(); + const expected_len = Ed25519.PublicKey.encoded_length; + if (cert.pubKey().len != expected_len) return s.writeError(.bad_certificate); + break :brk2 Ed25519.PublicKey.fromBytes(cert.pubKey()[0..expected_len].*) catch + return s.writeError(.bad_certificate); + }; + const nonce: Ed25519.CompressedScalar = salt[0..Ed25519.noise_length].*; -/// Reads next application_data message. -pub fn readv(self: *Self, buffers: []const std.os.iovec) ReadError!usize { - var stream = &self.stream; + const key_pair = Ed25519.KeyPair{ .public_key = pub_key, .secret_key = key }; + const sig = key_pair.sign(sig_content, nonce) catch return s.writeError(.internal_error); + break :brk &sig.toBytes(); + }, + else => { + return s.writeError(.bad_certificate); + }, + }; - if (stream.eof()) return 0; + _ = try self.tls_stream.write(tls.Handshake, .{ .certificate_verify = tls.CertificateVerify{ + .algorithm = scheme, + .signature = signature, + } }); + try s.flush(); + } - while (stream.view.len == 0) { - const inner_plaintext = try stream.readInnerPlaintext(); - switch (inner_plaintext.type) { - .application_data => {}, - .alert => {}, - else => return stream.writeError(.unexpected_message), - } + pub fn send_finished(self: *Handshake) !void { + var s = &self.tls_stream; + const verify_data = switch (s.cipher.handshake) { + inline else => |v| brk: { + const T = @TypeOf(v); + const secret = v.server_finished_key; + const transcript_hash = s.transcript_hash.?.peek(); + + break :brk &tls.hmac(T.Hmac, transcript_hash, secret); + }, + }; + _ = try s.write(tls.Handshake, .{ .finished = verify_data }); + try s.flush(); } - return try self.stream.readv(buffers); -} -pub fn writev(self: *Self, iov: []const std.os.iovec_const) WriteError!usize { - if (self.stream.eof()) return 0; + pub fn recv_finished(self: *Handshake) !void { + var s = &self.tls_stream; + var reader = s.stream().reader(); - const res = try self.stream.writev(iov); - try self.stream.flush(); - return res; -} + const handshake_hash = s.transcript_hash.?.peek(); -pub fn close(self: *Self) void { - self.stream.close(); -} + const application_cipher = tls.ApplicationCipher.init( + s.cipher.handshake, + handshake_hash, + self.logger(), + ); -pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); + const expected = switch (s.cipher.handshake) { + inline else => |p| brk: { + const P = @TypeOf(p); + const digest = s.transcript_hash.?.peek(); + break :brk &tls.hmac(P.Hmac, digest, p.client_finished_key); + }, + }; -pub fn any(self: *Self) GenericStream { - return .{ .context = self }; -} + try s.expectInnerPlaintext(.handshake, .finished); + const actual = s.view; + try reader.skipBytes(s.view.len, .{}); -pub const Options = struct { - /// List of potential cipher suites in descending order of preference. - cipher_suites: []const tls.CipherSuite = &tls.default_cipher_suites, - /// Types of shared keys to accept from client. - key_shares: []const tls.NamedGroup = &tls.supported_groups, - /// Certificate(s) to send in `send_certificate` messages. - certificate: tls.Certificate, - /// Key to use in `send_certificate_verify`. Must match `certificate.parse().pub_key_algo`. - certificate_key: CertificateKey, - /// Writer to log shared secrets for traffic decryption. - /// - /// See https://www.ietf.org/archive/id/draft-thomson-tls-keylogfile-01.html - key_log: std.io.AnyWriter = std.io.null_writer.any(), + if (!mem.eql(u8, expected, actual)) return s.writeError(.decode_error); - pub const CertificateKey = union(enum) { - rsa: crypto.Certificate.rsa.SecretKey, - ecdsa256: tls.NamedGroupT(.secp256r1).SecretKey, - ecdsa384: tls.NamedGroupT(.secp384r1).SecretKey, - ed25519: crypto.sign.Ed25519.SecretKey, - }; -}; + s.content_type = .application_data; + s.handshake_type = null; + s.cipher = .{ .application = application_cipher }; + s.transcript_hash = null; + } -/// A command to send or receive a single message. Allows deterministically -/// testing `advance` on a single thread. -pub const Command = union(enum) { - recv_hello: Random, - send_hello: ClientHello, - send_change_cipher_spec: tls.SignatureScheme, - send_encrypted_extensions: tls.SignatureScheme, - send_certificate: tls.SignatureScheme, - send_certificate_verify: CertificateVerify, - send_finished: void, - recv_finished: void, - none: void, - - pub const Random = struct { - server_random: [32]u8, - keygen_seed: [tls.NamedGroupT(.secp384r1).KeyPair.seed_length]u8, - }; + /// Establishes a TLS connection on `tls_stream` and returns a Client. + pub fn handshake(self: *Handshake) !Server { + while (self.command != .none) self.next() catch |err| switch (err) { + error.ConnectionResetByPeer => { + // Prevent reply attacks + self.command = .send_hello; + self.init_random(); + }, + else => return err, + }; - pub const CertificateVerify = struct { - scheme: tls.SignatureScheme, - salt: [tls.MultiHash.max_digest_len]u8, - }; -}; + return Server{ + .tls_stream = self.tls_stream, + .key_logger = .{ + .writer = self.options.key_log, + .client_random = self.client_hello.random, + }, + }; + } -pub const ClientHello = struct { - random: [32]u8, - session_id_len: u8, - session_id: [32]u8, - cipher_suite: tls.CipherSuite, - key_share: tls.KeyShare, - sig_scheme: tls.SignatureScheme, - server_random: [32]u8, - /// Everything needed to generate a shared secret and send ciphertext to the client - /// so it can do the same. - /// Active member MUST match `key_share`. - server_pair: tls.KeyPair, + fn logger(self: *Handshake) tls.KeyLogger { + return tls.KeyLogger{ + .client_random = self.client_hello.random, + .writer = self.options.key_log, + }; + } }; + diff --git a/lib/std/crypto/tls/Stream.zig b/lib/std/crypto/tls/Stream.zig index f513fd414c14..140e1802cfac 100644 --- a/lib/std/crypto/tls/Stream.zig +++ b/lib/std/crypto/tls/Stream.zig @@ -11,7 +11,7 @@ const std = @import("../../std.zig"); const tls = std.crypto.tls; -stream: std.io.AnyStream, +inner_stream: std.io.AnyStream, /// Used for both reading and writing. /// Stores plaintext or briefly ciphertext, but not Plaintext headers. buffer: [fragment_size]u8 = undefined, @@ -23,11 +23,10 @@ view: []const u8 = "", content_type: ContentType = .handshake, /// When sending this is the flushed version. version: Version = .tls_1_0, -/// When receiving a handshake message will be expected with this type. +/// When receiving a handshake message this is its expected type. handshake_type: ?HandshakeType = .client_hello, -/// Used to decrypt .application_data messages. -/// Used to encrypt messages that aren't alert or change_cipher_spec. +/// Used to encrypt and decrypt messages. cipher: Cipher = .none, /// True when we send or receive a close_notify alert. @@ -35,15 +34,15 @@ closed: bool = false, /// True if we're being used as a client. This changes: /// * Certain shared struct formats (like Extension) -/// * Which ciphers are used for encoding/decoding handshake and application messages. +/// * Which cipher members are used for encryption/decryption is_client: bool, -/// When > 0 won't actually do anything with writes. Used to discover prefix lengths. -nocommit: usize = 0, +/// When > 0 will discard writes. Used to discover prefix lengths. +nocommit: u32 = 0, -/// Client and server implementations can set this. While set sent or received handshake messages -/// will update the hash. -transcript_hash: ?*MultiHash, +/// Client and server implementations can set this to cause +/// sent and received handshake messages to update the hash. +transcript_hash: ?*MultiHash = null, const Self = @This(); const ContentType = tls.ContentType; @@ -64,7 +63,6 @@ const Cipher = union(enum) { handshake: HandshakeCipher, }; -// Useful mostly as reference or until std.io.Any* types don't type erase errors. pub const ReadError = anyerror || tls.Error || error{EndOfStream}; pub const WriteError = anyerror || error{TlsEncodeError}; @@ -125,9 +123,9 @@ pub fn flush(self: *Self) WriteError!void { } // TODO: contiguous buffer management - try self.stream.writer().writeAll(&header); - try self.stream.writer().writeAll(self.view); - try self.stream.writer().writeAll(aead); + try self.inner_stream.writer().writeAll(&header); + try self.inner_stream.writer().writeAll(self.view); + try self.inner_stream.writer().writeAll(aead); self.view = self.buffer[0..0]; } @@ -143,8 +141,8 @@ pub fn changeCipherSpec(self: *Self) !void { const msg = [_]u8{1}; const header: [Plaintext.size]u8 = Encoder.encode(Plaintext, plaintext); // TODO: contiguous buffer management - try self.stream.writer().writeAll(&header); - try self.stream.writer().writeAll(&msg); + try self.inner_stream.writer().writeAll(&header); + try self.inner_stream.writer().writeAll(&msg); } /// Write an alert to stream and call `close_notify` after. Returns Zig error. @@ -205,7 +203,7 @@ pub fn writeArray(self: *Self, comptime PrefixT: type, comptime T: type, values: /// Returns number of bytes written. Convienent for encoding struct types in tls.zig . pub fn writeAll(self: *Self, bytes: []const u8) !usize { - try self.any().writer().writeAll(bytes); + try self.stream().writer().writeAll(bytes); return bytes.len; } @@ -213,7 +211,7 @@ pub fn write(self: *Self, comptime T: type, value: T) !usize { switch (@typeInfo(T)) { .Int, .Enum => { const encoded = Encoder.encode(T, value); - try self.any().writer().writeAll(&encoded); + try self.stream().writer().writeAll(&encoded); return encoded.len; }, .Struct, .Union => { @@ -255,7 +253,7 @@ pub fn readv(self: *Self, iov: []const std.os.iovec) ReadError!usize { for (iov) |b| { var bytes_read_buffer: usize = 0; - while (bytes_read_buffer != b.iov_len) { + while (bytes_read_buffer != b.iov_len and !self.eof()) { const to_read = @min(b.iov_len, self.view.len); if (to_read == 0) return bytes_read; @@ -279,14 +277,14 @@ pub fn readPlaintext(self: *Self) !Plaintext { var n_read: usize = 0; while (true) { - n_read = try self.stream.reader().readAll(&plaintext_bytes); + n_read = try self.inner_stream.reader().readAll(&plaintext_bytes); if (n_read != plaintext_bytes.len) return self.writeError(.decode_error); var res = Plaintext.init(plaintext_bytes); if (res.len > Plaintext.max_length) return self.writeError(.record_overflow); self.view = self.buffer[0..res.len]; - n_read = try self.stream.reader().readAll(@constCast(self.view)); + n_read = try self.inner_stream.reader().readAll(@constCast(self.view)); if (n_read != res.len) return self.writeError(.decode_error); const encryption_method = self.encryptionMethod(res.type); @@ -388,10 +386,7 @@ pub fn expectInnerPlaintext( expected_handshake: ?HandshakeType, ) !void { const inner_plaintext = try self.readInnerPlaintext(); - if (expected_content != inner_plaintext.type) { - std.debug.print("expected {} got {}\n", .{ expected_content, inner_plaintext }); - return self.writeError(.unexpected_message); - } + if (expected_content != inner_plaintext.type) return self.writeError(.unexpected_message); if (expected_handshake) |expected| { if (expected != inner_plaintext.handshake_type) return self.writeError(.decode_error); } @@ -400,7 +395,7 @@ pub fn expectInnerPlaintext( pub fn read(self: *Self, comptime T: type) !T { comptime std.debug.assert(@sizeOf(T) < fragment_size); switch (@typeInfo(T)) { - .Int => return self.any().reader().readInt(T, .big) catch |err| switch (err) { + .Int => return self.stream().reader().readInt(T, .big) catch |err| switch (err) { error.EndOfStream => return self.writeError(.decode_error), else => |e| return e, }, @@ -478,7 +473,7 @@ pub fn eof(self: Self) bool { pub const GenericStream = std.io.GenericStream(*Self, ReadError, readv, WriteError, writev, close); -pub fn any(self: *Self) GenericStream { +pub fn stream(self: *Self) GenericStream { return .{ .context = self }; }